mirror of
https://gitee.com/IrisVega/frp.git
synced 2024-11-01 22:31:29 +08:00
refactor: refine pkg net utils (#2720)
* refactor: refine pkg net utils * fix: x Co-authored-by: blizard863 <760076784@qq.com>
This commit is contained in:
parent
0fb6aeef58
commit
ea568e8a4f
@ -234,8 +234,11 @@ func (ctl *Control) connectServer() (conn net.Conn, err error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
address := net.JoinHostPort(ctl.clientCfg.ServerAddr, strconv.Itoa(ctl.clientCfg.ServerPort))
|
conn, err = frpNet.DialWithOptions(net.JoinHostPort(ctl.clientCfg.ServerAddr, strconv.Itoa(ctl.clientCfg.ServerPort)),
|
||||||
conn, err = frpNet.ConnectServerByProxyWithTLS(ctl.clientCfg.HTTPProxy, ctl.clientCfg.Protocol, address, tlsConfig, ctl.clientCfg.DisableCustomTLSFirstByte)
|
frpNet.WithProxyURL(ctl.clientCfg.HTTPProxy),
|
||||||
|
frpNet.WithProtocol(ctl.clientCfg.Protocol),
|
||||||
|
frpNet.WithTLSConfig(tlsConfig),
|
||||||
|
frpNet.WithDisableCustomTLSHeadByte(ctl.clientCfg.DisableCustomTLSFirstByte))
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
xl.Warn("start new connection to server error: %v", err)
|
xl.Warn("start new connection to server error: %v", err)
|
||||||
|
@ -790,7 +790,7 @@ func HandleTCPWorkConnection(ctx context.Context, localInfo *config.LocalSvrConf
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
localConn, err := frpNet.ConnectServer("tcp", fmt.Sprintf("%s:%d", localInfo.LocalIP, localInfo.LocalPort))
|
localConn, err := frpNet.DialWithOptions(net.JoinHostPort(localInfo.LocalIP, strconv.Itoa(localInfo.LocalPort)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
workConn.Close()
|
workConn.Close()
|
||||||
xl.Error("connect to local service [%s:%d] error: %v", localInfo.LocalIP, localInfo.LocalPort, err)
|
xl.Error("connect to local service [%s:%d] error: %v", localInfo.LocalIP, localInfo.LocalPort, err)
|
||||||
|
@ -228,8 +228,12 @@ func (svr *Service) login() (conn net.Conn, session *fmux.Session, err error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
address := net.JoinHostPort(svr.cfg.ServerAddr, strconv.Itoa(svr.cfg.ServerPort))
|
conn, err = frpNet.DialWithOptions(net.JoinHostPort(svr.cfg.ServerAddr, strconv.Itoa(svr.cfg.ServerPort)),
|
||||||
conn, err = frpNet.ConnectServerByProxyWithTLS(svr.cfg.HTTPProxy, svr.cfg.Protocol, address, tlsConfig, svr.cfg.DisableCustomTLSFirstByte)
|
frpNet.WithProxyURL(svr.cfg.HTTPProxy),
|
||||||
|
frpNet.WithProtocol(svr.cfg.Protocol),
|
||||||
|
frpNet.WithTLSConfig(tlsConfig),
|
||||||
|
frpNet.WithDisableCustomTLSHeadByte(svr.cfg.DisableCustomTLSFirstByte))
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -16,15 +16,16 @@ package net
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"net/url"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/fatedier/frp/pkg/util/xlog"
|
"github.com/fatedier/frp/pkg/util/xlog"
|
||||||
|
"golang.org/x/net/websocket"
|
||||||
|
|
||||||
gnet "github.com/fatedier/golib/net"
|
gnet "github.com/fatedier/golib/net"
|
||||||
kcp "github.com/fatedier/kcp-go"
|
kcp "github.com/fatedier/kcp-go"
|
||||||
@ -194,6 +195,15 @@ func ConnectServer(protocol string, addr string) (c net.Conn, err error) {
|
|||||||
case "tcp":
|
case "tcp":
|
||||||
return net.Dial("tcp", addr)
|
return net.Dial("tcp", addr)
|
||||||
case "kcp":
|
case "kcp":
|
||||||
|
return DialKCPServer(addr)
|
||||||
|
case "websocket":
|
||||||
|
return DialWebsocketServer(addr)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unsupport protocol: %s", protocol)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func DialKCPServer(addr string) (c net.Conn, err error) {
|
||||||
kcpConn, errRet := kcp.DialWithOptions(addr, nil, 10, 3)
|
kcpConn, errRet := kcp.DialWithOptions(addr, nil, 10, 3)
|
||||||
if errRet != nil {
|
if errRet != nil {
|
||||||
err = errRet
|
err = errRet
|
||||||
@ -209,35 +219,37 @@ func ConnectServer(protocol string, addr string) (c net.Conn, err error) {
|
|||||||
kcpConn.SetWriteBuffer(4194304)
|
kcpConn.SetWriteBuffer(4194304)
|
||||||
c = kcpConn
|
c = kcpConn
|
||||||
return
|
return
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("unsupport protocol: %s", protocol)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func ConnectServerByProxy(proxyURL string, protocol string, addr string) (c net.Conn, err error) {
|
func ConnectServerByProxy(proxyURL string, protocol string, addr string) (c net.Conn, err error) {
|
||||||
switch protocol {
|
switch protocol {
|
||||||
case "tcp":
|
case "tcp":
|
||||||
return gnet.DialTcpByProxy(proxyURL, addr)
|
return gnet.DialTcpByProxy(proxyURL, addr)
|
||||||
case "kcp":
|
|
||||||
// http proxy is not supported for kcp
|
|
||||||
return ConnectServer(protocol, addr)
|
|
||||||
case "websocket":
|
|
||||||
return ConnectWebsocketServer(addr)
|
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unsupport protocol: %s", protocol)
|
return nil, fmt.Errorf("unsupport protocol: %s when connecting by proxy", protocol)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func ConnectServerByProxyWithTLS(proxyURL string, protocol string, addr string, tlsConfig *tls.Config, disableCustomTLSHeadByte bool) (c net.Conn, err error) {
|
// addr: domain:port
|
||||||
c, err = ConnectServerByProxy(proxyURL, protocol, addr)
|
func DialWebsocketServer(addr string) (net.Conn, error) {
|
||||||
|
addr = "ws://" + addr + FrpWebsocketPath
|
||||||
|
uri, err := url.Parse(addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if tlsConfig == nil {
|
origin := "http://" + uri.Host
|
||||||
return
|
cfg, err := websocket.NewConfig(addr, origin)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
cfg.Dialer = &net.Dialer{
|
||||||
|
Timeout: 10 * time.Second,
|
||||||
}
|
}
|
||||||
|
|
||||||
c = WrapTLSClientConn(c, tlsConfig, disableCustomTLSHeadByte)
|
conn, err := websocket.DialConfig(cfg)
|
||||||
return
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
89
pkg/util/net/dial.go
Normal file
89
pkg/util/net/dial.go
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
package net
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
type dialOptions struct {
|
||||||
|
proxyURL string
|
||||||
|
protocol string
|
||||||
|
tlsConfig *tls.Config
|
||||||
|
disableCustomTLSHeadByte bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type DialOption interface {
|
||||||
|
apply(*dialOptions)
|
||||||
|
}
|
||||||
|
|
||||||
|
type EmptyDialOption struct{}
|
||||||
|
|
||||||
|
func (EmptyDialOption) apply(*dialOptions) {}
|
||||||
|
|
||||||
|
type funcDialOption struct {
|
||||||
|
f func(*dialOptions)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fdo *funcDialOption) apply(do *dialOptions) {
|
||||||
|
fdo.f(do)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFuncDialOption(f func(*dialOptions)) *funcDialOption {
|
||||||
|
return &funcDialOption{
|
||||||
|
f: f,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func DefaultDialOptions() dialOptions {
|
||||||
|
return dialOptions{
|
||||||
|
protocol: "tcp",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithProxyURL(proxyURL string) DialOption {
|
||||||
|
return newFuncDialOption(func(do *dialOptions) {
|
||||||
|
do.proxyURL = proxyURL
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithTLSConfig(tlsConfig *tls.Config) DialOption {
|
||||||
|
return newFuncDialOption(func(do *dialOptions) {
|
||||||
|
do.tlsConfig = tlsConfig
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithDisableCustomTLSHeadByte(disableCustomTLSHeadByte bool) DialOption {
|
||||||
|
return newFuncDialOption(func(do *dialOptions) {
|
||||||
|
do.disableCustomTLSHeadByte = disableCustomTLSHeadByte
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithProtocol(protocol string) DialOption {
|
||||||
|
return newFuncDialOption(func(do *dialOptions) {
|
||||||
|
do.protocol = protocol
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func DialWithOptions(addr string, opts ...DialOption) (c net.Conn, err error) {
|
||||||
|
op := DefaultDialOptions()
|
||||||
|
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt.apply(&op)
|
||||||
|
}
|
||||||
|
|
||||||
|
if op.proxyURL == "" {
|
||||||
|
c, err = ConnectServer(op.protocol, addr)
|
||||||
|
} else {
|
||||||
|
c, err = ConnectServerByProxy(op.proxyURL, op.protocol, addr)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if op.tlsConfig == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c = WrapTLSClientConn(c, op.tlsConfig, op.disableCustomTLSHeadByte)
|
||||||
|
return
|
||||||
|
}
|
@ -5,8 +5,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"golang.org/x/net/websocket"
|
"golang.org/x/net/websocket"
|
||||||
)
|
)
|
||||||
@ -77,27 +75,3 @@ func (p *WebsocketListener) Close() error {
|
|||||||
func (p *WebsocketListener) Addr() net.Addr {
|
func (p *WebsocketListener) Addr() net.Addr {
|
||||||
return p.ln.Addr()
|
return p.ln.Addr()
|
||||||
}
|
}
|
||||||
|
|
||||||
// addr: domain:port
|
|
||||||
func ConnectWebsocketServer(addr string) (net.Conn, error) {
|
|
||||||
addr = "ws://" + addr + FrpWebsocketPath
|
|
||||||
uri, err := url.Parse(addr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
origin := "http://" + uri.Host
|
|
||||||
cfg, err := websocket.NewConfig(addr, origin)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
cfg.Dialer = &net.Dialer{
|
|
||||||
Timeout: 10 * time.Second,
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err := websocket.DialConfig(cfg)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return conn, nil
|
|
||||||
}
|
|
||||||
|
Loading…
Reference in New Issue
Block a user