refactor: refine pkg net utils (#2720)

* refactor: refine pkg net utils

* fix: x

Co-authored-by: blizard863 <760076784@qq.com>
This commit is contained in:
Blizard 2021-12-28 21:14:57 +08:00 committed by GitHub
parent 0fb6aeef58
commit ea568e8a4f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 142 additions and 60 deletions

View File

@ -234,8 +234,11 @@ func (ctl *Control) connectServer() (conn net.Conn, err error) {
} }
} }
address := net.JoinHostPort(ctl.clientCfg.ServerAddr, strconv.Itoa(ctl.clientCfg.ServerPort)) conn, err = frpNet.DialWithOptions(net.JoinHostPort(ctl.clientCfg.ServerAddr, strconv.Itoa(ctl.clientCfg.ServerPort)),
conn, err = frpNet.ConnectServerByProxyWithTLS(ctl.clientCfg.HTTPProxy, ctl.clientCfg.Protocol, address, tlsConfig, ctl.clientCfg.DisableCustomTLSFirstByte) frpNet.WithProxyURL(ctl.clientCfg.HTTPProxy),
frpNet.WithProtocol(ctl.clientCfg.Protocol),
frpNet.WithTLSConfig(tlsConfig),
frpNet.WithDisableCustomTLSHeadByte(ctl.clientCfg.DisableCustomTLSFirstByte))
if err != nil { if err != nil {
xl.Warn("start new connection to server error: %v", err) xl.Warn("start new connection to server error: %v", err)

View File

@ -790,7 +790,7 @@ func HandleTCPWorkConnection(ctx context.Context, localInfo *config.LocalSvrConf
return return
} }
localConn, err := frpNet.ConnectServer("tcp", fmt.Sprintf("%s:%d", localInfo.LocalIP, localInfo.LocalPort)) localConn, err := frpNet.DialWithOptions(net.JoinHostPort(localInfo.LocalIP, strconv.Itoa(localInfo.LocalPort)))
if err != nil { if err != nil {
workConn.Close() workConn.Close()
xl.Error("connect to local service [%s:%d] error: %v", localInfo.LocalIP, localInfo.LocalPort, err) xl.Error("connect to local service [%s:%d] error: %v", localInfo.LocalIP, localInfo.LocalPort, err)

View File

@ -228,8 +228,12 @@ func (svr *Service) login() (conn net.Conn, session *fmux.Session, err error) {
} }
} }
address := net.JoinHostPort(svr.cfg.ServerAddr, strconv.Itoa(svr.cfg.ServerPort)) conn, err = frpNet.DialWithOptions(net.JoinHostPort(svr.cfg.ServerAddr, strconv.Itoa(svr.cfg.ServerPort)),
conn, err = frpNet.ConnectServerByProxyWithTLS(svr.cfg.HTTPProxy, svr.cfg.Protocol, address, tlsConfig, svr.cfg.DisableCustomTLSFirstByte) frpNet.WithProxyURL(svr.cfg.HTTPProxy),
frpNet.WithProtocol(svr.cfg.Protocol),
frpNet.WithTLSConfig(tlsConfig),
frpNet.WithDisableCustomTLSHeadByte(svr.cfg.DisableCustomTLSFirstByte))
if err != nil { if err != nil {
return return
} }

View File

@ -16,15 +16,16 @@ package net
import ( import (
"context" "context"
"crypto/tls"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net" "net"
"net/url"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/fatedier/frp/pkg/util/xlog" "github.com/fatedier/frp/pkg/util/xlog"
"golang.org/x/net/websocket"
gnet "github.com/fatedier/golib/net" gnet "github.com/fatedier/golib/net"
kcp "github.com/fatedier/kcp-go" kcp "github.com/fatedier/kcp-go"
@ -194,50 +195,61 @@ func ConnectServer(protocol string, addr string) (c net.Conn, err error) {
case "tcp": case "tcp":
return net.Dial("tcp", addr) return net.Dial("tcp", addr)
case "kcp": case "kcp":
kcpConn, errRet := kcp.DialWithOptions(addr, nil, 10, 3) return DialKCPServer(addr)
if errRet != nil { case "websocket":
err = errRet return DialWebsocketServer(addr)
return
}
kcpConn.SetStreamMode(true)
kcpConn.SetWriteDelay(true)
kcpConn.SetNoDelay(1, 20, 2, 1)
kcpConn.SetWindowSize(128, 512)
kcpConn.SetMtu(1350)
kcpConn.SetACKNoDelay(false)
kcpConn.SetReadBuffer(4194304)
kcpConn.SetWriteBuffer(4194304)
c = kcpConn
return
default: default:
return nil, fmt.Errorf("unsupport protocol: %s", protocol) return nil, fmt.Errorf("unsupport protocol: %s", protocol)
} }
} }
func DialKCPServer(addr string) (c net.Conn, err error) {
kcpConn, errRet := kcp.DialWithOptions(addr, nil, 10, 3)
if errRet != nil {
err = errRet
return
}
kcpConn.SetStreamMode(true)
kcpConn.SetWriteDelay(true)
kcpConn.SetNoDelay(1, 20, 2, 1)
kcpConn.SetWindowSize(128, 512)
kcpConn.SetMtu(1350)
kcpConn.SetACKNoDelay(false)
kcpConn.SetReadBuffer(4194304)
kcpConn.SetWriteBuffer(4194304)
c = kcpConn
return
}
func ConnectServerByProxy(proxyURL string, protocol string, addr string) (c net.Conn, err error) { func ConnectServerByProxy(proxyURL string, protocol string, addr string) (c net.Conn, err error) {
switch protocol { switch protocol {
case "tcp": case "tcp":
return gnet.DialTcpByProxy(proxyURL, addr) return gnet.DialTcpByProxy(proxyURL, addr)
case "kcp":
// http proxy is not supported for kcp
return ConnectServer(protocol, addr)
case "websocket":
return ConnectWebsocketServer(addr)
default: default:
return nil, fmt.Errorf("unsupport protocol: %s", protocol) return nil, fmt.Errorf("unsupport protocol: %s when connecting by proxy", protocol)
} }
} }
func ConnectServerByProxyWithTLS(proxyURL string, protocol string, addr string, tlsConfig *tls.Config, disableCustomTLSHeadByte bool) (c net.Conn, err error) { // addr: domain:port
c, err = ConnectServerByProxy(proxyURL, protocol, addr) func DialWebsocketServer(addr string) (net.Conn, error) {
addr = "ws://" + addr + FrpWebsocketPath
uri, err := url.Parse(addr)
if err != nil { if err != nil {
return return nil, err
} }
if tlsConfig == nil { origin := "http://" + uri.Host
return cfg, err := websocket.NewConfig(addr, origin)
if err != nil {
return nil, err
}
cfg.Dialer = &net.Dialer{
Timeout: 10 * time.Second,
} }
c = WrapTLSClientConn(c, tlsConfig, disableCustomTLSHeadByte) conn, err := websocket.DialConfig(cfg)
return if err != nil {
return nil, err
}
return conn, nil
} }

89
pkg/util/net/dial.go Normal file
View File

@ -0,0 +1,89 @@
package net
import (
"crypto/tls"
"net"
)
type dialOptions struct {
proxyURL string
protocol string
tlsConfig *tls.Config
disableCustomTLSHeadByte bool
}
type DialOption interface {
apply(*dialOptions)
}
type EmptyDialOption struct{}
func (EmptyDialOption) apply(*dialOptions) {}
type funcDialOption struct {
f func(*dialOptions)
}
func (fdo *funcDialOption) apply(do *dialOptions) {
fdo.f(do)
}
func newFuncDialOption(f func(*dialOptions)) *funcDialOption {
return &funcDialOption{
f: f,
}
}
func DefaultDialOptions() dialOptions {
return dialOptions{
protocol: "tcp",
}
}
func WithProxyURL(proxyURL string) DialOption {
return newFuncDialOption(func(do *dialOptions) {
do.proxyURL = proxyURL
})
}
func WithTLSConfig(tlsConfig *tls.Config) DialOption {
return newFuncDialOption(func(do *dialOptions) {
do.tlsConfig = tlsConfig
})
}
func WithDisableCustomTLSHeadByte(disableCustomTLSHeadByte bool) DialOption {
return newFuncDialOption(func(do *dialOptions) {
do.disableCustomTLSHeadByte = disableCustomTLSHeadByte
})
}
func WithProtocol(protocol string) DialOption {
return newFuncDialOption(func(do *dialOptions) {
do.protocol = protocol
})
}
func DialWithOptions(addr string, opts ...DialOption) (c net.Conn, err error) {
op := DefaultDialOptions()
for _, opt := range opts {
opt.apply(&op)
}
if op.proxyURL == "" {
c, err = ConnectServer(op.protocol, addr)
} else {
c, err = ConnectServerByProxy(op.proxyURL, op.protocol, addr)
}
if err != nil {
return nil, err
}
if op.tlsConfig == nil {
return
}
c = WrapTLSClientConn(c, op.tlsConfig, op.disableCustomTLSHeadByte)
return
}

View File

@ -5,8 +5,6 @@ import (
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
"net/url"
"time"
"golang.org/x/net/websocket" "golang.org/x/net/websocket"
) )
@ -77,27 +75,3 @@ func (p *WebsocketListener) Close() error {
func (p *WebsocketListener) Addr() net.Addr { func (p *WebsocketListener) Addr() net.Addr {
return p.ln.Addr() return p.ln.Addr()
} }
// addr: domain:port
func ConnectWebsocketServer(addr string) (net.Conn, error) {
addr = "ws://" + addr + FrpWebsocketPath
uri, err := url.Parse(addr)
if err != nil {
return nil, err
}
origin := "http://" + uri.Host
cfg, err := websocket.NewConfig(addr, origin)
if err != nil {
return nil, err
}
cfg.Dialer = &net.Dialer{
Timeout: 10 * time.Second,
}
conn, err := websocket.DialConfig(cfg)
if err != nil {
return nil, err
}
return conn, nil
}