frp/pkg/util/net/dial.go

51 lines
1.1 KiB
Go
Raw Normal View History

package net
import (
2022-01-20 20:03:07 +08:00
"context"
"net"
2022-01-20 20:03:07 +08:00
"net/url"
2022-01-20 20:03:07 +08:00
libdial "github.com/fatedier/golib/net/dial"
"golang.org/x/net/websocket"
)
2022-01-20 20:03:07 +08:00
func DialHookCustomTLSHeadByte(enableTLS bool, disableCustomTLSHeadByte bool) libdial.AfterHookFunc {
return func(ctx context.Context, c net.Conn, addr string) (context.Context, net.Conn, error) {
if enableTLS && !disableCustomTLSHeadByte {
_, err := c.Write([]byte{byte(FRPTLSHeadByte)})
if err != nil {
return nil, nil, err
}
}
return ctx, c, nil
}
}
func DialHookWebsocket(protocol string, host string) libdial.AfterHookFunc {
2022-01-20 20:03:07 +08:00
return func(ctx context.Context, c net.Conn, addr string) (context.Context, net.Conn, error) {
if protocol != "wss" {
protocol = "ws"
}
if host == "" {
host = addr
}
addr = protocol + "://" + host + FrpWebsocketPath
2022-01-20 20:03:07 +08:00
uri, err := url.Parse(addr)
if err != nil {
return nil, nil, err
}
origin := "http://" + uri.Host
cfg, err := websocket.NewConfig(addr, origin)
if err != nil {
return nil, nil, err
}
conn, err := websocket.NewClient(cfg, c)
if err != nil {
return nil, nil, err
}
return ctx, conn, nil
}
}