optimize some code (#3801)

This commit is contained in:
fatedier 2023-11-27 15:47:49 +08:00 committed by GitHub
parent d5b41f1e14
commit 69ae2b0b69
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
52 changed files with 880 additions and 600 deletions

View File

@ -1,6 +1,9 @@
### Features ### Features
* New command line parameter `--strict_config` is added to enable strict configuration validation mode. It will throw an error for non-existent fields instead of ignoring them. * New command line parameter `--strict_config` is added to enable strict configuration validation mode. It will throw an error for non-existent fields instead of ignoring them. In future versions, we may set the default value of this parameter to true.
* Support `SSH reverse tunneling`. With this feature, you can expose your local service without running frpc, only using SSH. The SSH reverse tunnel agent has many functional limitations compared to the frpc agent. The currently supported proxy types are tcp, http, https, tcpmux, and stcp.
* The frpc tcpmux command line parameters have been updated to support configuring `http_user` and `http_pwd`.
* The frpc stcp/sudp/xtcp command line parameters have been updated to support configuring `allow_users`.
### Fixes ### Fixes

View File

@ -1,85 +0,0 @@
// Copyright 2017 fatedier, fatedier@gmail.com
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package client
import (
"net"
"net/http"
"net/http/pprof"
"time"
"github.com/gorilla/mux"
"github.com/fatedier/frp/assets"
utilnet "github.com/fatedier/frp/pkg/util/net"
)
var (
httpServerReadTimeout = 60 * time.Second
httpServerWriteTimeout = 60 * time.Second
)
func (svr *Service) RunAdminServer(address string) (err error) {
// url router
router := mux.NewRouter()
router.HandleFunc("/healthz", svr.healthz)
// debug
if svr.cfg.WebServer.PprofEnable {
router.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
router.HandleFunc("/debug/pprof/profile", pprof.Profile)
router.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
router.HandleFunc("/debug/pprof/trace", pprof.Trace)
router.PathPrefix("/debug/pprof/").HandlerFunc(pprof.Index)
}
subRouter := router.NewRoute().Subrouter()
user, passwd := svr.cfg.WebServer.User, svr.cfg.WebServer.Password
subRouter.Use(utilnet.NewHTTPAuthMiddleware(user, passwd).SetAuthFailDelay(200 * time.Millisecond).Middleware)
// api, see admin_api.go
subRouter.HandleFunc("/api/reload", svr.apiReload).Methods("GET")
subRouter.HandleFunc("/api/stop", svr.apiStop).Methods("POST")
subRouter.HandleFunc("/api/status", svr.apiStatus).Methods("GET")
subRouter.HandleFunc("/api/config", svr.apiGetConfig).Methods("GET")
subRouter.HandleFunc("/api/config", svr.apiPutConfig).Methods("PUT")
// view
subRouter.Handle("/favicon.ico", http.FileServer(assets.FileSystem)).Methods("GET")
subRouter.PathPrefix("/static/").Handler(utilnet.MakeHTTPGzipHandler(http.StripPrefix("/static/", http.FileServer(assets.FileSystem)))).Methods("GET")
subRouter.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "/static/", http.StatusMovedPermanently)
})
server := &http.Server{
Addr: address,
Handler: router,
ReadTimeout: httpServerReadTimeout,
WriteTimeout: httpServerWriteTimeout,
}
if address == "" {
address = ":http"
}
ln, err := net.Listen("tcp", address)
if err != nil {
return err
}
go func() {
_ = server.Serve(ln)
}()
return
}

View File

@ -31,7 +31,9 @@ import (
"github.com/fatedier/frp/client/proxy" "github.com/fatedier/frp/client/proxy"
"github.com/fatedier/frp/pkg/config" "github.com/fatedier/frp/pkg/config"
"github.com/fatedier/frp/pkg/config/v1/validation" "github.com/fatedier/frp/pkg/config/v1/validation"
httppkg "github.com/fatedier/frp/pkg/util/http"
"github.com/fatedier/frp/pkg/util/log" "github.com/fatedier/frp/pkg/util/log"
netpkg "github.com/fatedier/frp/pkg/util/net"
) )
type GeneralResponse struct { type GeneralResponse struct {
@ -39,6 +41,29 @@ type GeneralResponse struct {
Msg string Msg string
} }
func (svr *Service) registerRouteHandlers(helper *httppkg.RouterRegisterHelper) {
helper.Router.HandleFunc("/healthz", svr.healthz)
subRouter := helper.Router.NewRoute().Subrouter()
subRouter.Use(helper.AuthMiddleware.Middleware)
// api, see admin_api.go
subRouter.HandleFunc("/api/reload", svr.apiReload).Methods("GET")
subRouter.HandleFunc("/api/stop", svr.apiStop).Methods("POST")
subRouter.HandleFunc("/api/status", svr.apiStatus).Methods("GET")
subRouter.HandleFunc("/api/config", svr.apiGetConfig).Methods("GET")
subRouter.HandleFunc("/api/config", svr.apiPutConfig).Methods("PUT")
// view
subRouter.Handle("/favicon.ico", http.FileServer(helper.AssetsFS)).Methods("GET")
subRouter.PathPrefix("/static/").Handler(
netpkg.MakeHTTPGzipHandler(http.StripPrefix("/static/", http.FileServer(helper.AssetsFS))),
).Methods("GET")
subRouter.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "/static/", http.StatusMovedPermanently)
})
}
// /healthz // /healthz
func (svr *Service) healthz(w http.ResponseWriter, _ *http.Request) { func (svr *Service) healthz(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(200) w.WriteHeader(200)
@ -62,21 +87,21 @@ func (svr *Service) apiReload(w http.ResponseWriter, r *http.Request) {
} }
}() }()
cliCfg, pxyCfgs, visitorCfgs, _, err := config.LoadClientConfig(svr.cfgFile, strictConfigMode) cliCfg, proxyCfgs, visitorCfgs, _, err := config.LoadClientConfig(svr.configFilePath, strictConfigMode)
if err != nil { if err != nil {
res.Code = 400 res.Code = 400
res.Msg = err.Error() res.Msg = err.Error()
log.Warn("reload frpc proxy config error: %s", res.Msg) log.Warn("reload frpc proxy config error: %s", res.Msg)
return return
} }
if _, err := validation.ValidateAllClientConfig(cliCfg, pxyCfgs, visitorCfgs); err != nil { if _, err := validation.ValidateAllClientConfig(cliCfg, proxyCfgs, visitorCfgs); err != nil {
res.Code = 400 res.Code = 400
res.Msg = err.Error() res.Msg = err.Error()
log.Warn("reload frpc proxy config error: %s", res.Msg) log.Warn("reload frpc proxy config error: %s", res.Msg)
return return
} }
if err := svr.ReloadConf(pxyCfgs, visitorCfgs); err != nil { if err := svr.UpdateAllConfigurer(proxyCfgs, visitorCfgs); err != nil {
res.Code = 500 res.Code = 500
res.Msg = err.Error() res.Msg = err.Error()
log.Warn("reload frpc proxy config error: %s", res.Msg) log.Warn("reload frpc proxy config error: %s", res.Msg)
@ -158,7 +183,7 @@ func (svr *Service) apiStatus(w http.ResponseWriter, _ *http.Request) {
ps := ctl.pm.GetAllProxyStatus() ps := ctl.pm.GetAllProxyStatus()
for _, status := range ps { for _, status := range ps {
res[status.Type] = append(res[status.Type], NewProxyStatusResp(status, svr.cfg.ServerAddr)) res[status.Type] = append(res[status.Type], NewProxyStatusResp(status, svr.common.ServerAddr))
} }
for _, arrs := range res { for _, arrs := range res {
@ -184,14 +209,14 @@ func (svr *Service) apiGetConfig(w http.ResponseWriter, _ *http.Request) {
} }
}() }()
if svr.cfgFile == "" { if svr.configFilePath == "" {
res.Code = 400 res.Code = 400
res.Msg = "frpc has no config file path" res.Msg = "frpc has no config file path"
log.Warn("%s", res.Msg) log.Warn("%s", res.Msg)
return return
} }
content, err := os.ReadFile(svr.cfgFile) content, err := os.ReadFile(svr.configFilePath)
if err != nil { if err != nil {
res.Code = 400 res.Code = 400
res.Msg = err.Error() res.Msg = err.Error()
@ -230,7 +255,7 @@ func (svr *Service) apiPutConfig(w http.ResponseWriter, r *http.Request) {
return return
} }
if err := os.WriteFile(svr.cfgFile, body, 0o644); err != nil { if err := os.WriteFile(svr.configFilePath, body, 0o644); err != nil {
res.Code = 500 res.Code = 500
res.Msg = fmt.Sprintf("write content to frpc config file error: %v", err) res.Msg = fmt.Sprintf("write content to frpc config file error: %v", err)
log.Warn("%s", res.Msg) log.Warn("%s", res.Msg)

View File

@ -21,6 +21,7 @@ import (
"net" "net"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
libdial "github.com/fatedier/golib/net/dial" libdial "github.com/fatedier/golib/net/dial"
@ -30,7 +31,7 @@ import (
v1 "github.com/fatedier/frp/pkg/config/v1" v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/transport" "github.com/fatedier/frp/pkg/transport"
utilnet "github.com/fatedier/frp/pkg/util/net" netpkg "github.com/fatedier/frp/pkg/util/net"
"github.com/fatedier/frp/pkg/util/xlog" "github.com/fatedier/frp/pkg/util/xlog"
) )
@ -48,6 +49,7 @@ type defaultConnectorImpl struct {
muxSession *fmux.Session muxSession *fmux.Session
quicConn quic.Connection quicConn quic.Connection
closeOnce sync.Once
} }
func NewConnector(ctx context.Context, cfg *v1.ClientCommonConfig) Connector { func NewConnector(ctx context.Context, cfg *v1.ClientCommonConfig) Connector {
@ -130,7 +132,7 @@ func (c *defaultConnectorImpl) Connect() (net.Conn, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return utilnet.QuicStreamToNetConn(stream, c.quicConn), nil return netpkg.QuicStreamToNetConn(stream, c.quicConn), nil
} else if c.muxSession != nil { } else if c.muxSession != nil {
stream, err := c.muxSession.OpenStream() stream, err := c.muxSession.OpenStream()
if err != nil { if err != nil {
@ -177,19 +179,19 @@ func (c *defaultConnectorImpl) realConnect() (net.Conn, error) {
switch protocol { switch protocol {
case "websocket": case "websocket":
protocol = "tcp" protocol = "tcp"
dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{Hook: utilnet.DialHookWebsocket(protocol, "")})) dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{Hook: netpkg.DialHookWebsocket(protocol, "")}))
dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{ dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{
Hook: utilnet.DialHookCustomTLSHeadByte(tlsConfig != nil, lo.FromPtr(c.cfg.Transport.TLS.DisableCustomTLSFirstByte)), Hook: netpkg.DialHookCustomTLSHeadByte(tlsConfig != nil, lo.FromPtr(c.cfg.Transport.TLS.DisableCustomTLSFirstByte)),
})) }))
dialOptions = append(dialOptions, libdial.WithTLSConfig(tlsConfig)) dialOptions = append(dialOptions, libdial.WithTLSConfig(tlsConfig))
case "wss": case "wss":
protocol = "tcp" protocol = "tcp"
dialOptions = append(dialOptions, libdial.WithTLSConfigAndPriority(100, tlsConfig)) dialOptions = append(dialOptions, libdial.WithTLSConfigAndPriority(100, tlsConfig))
// Make sure that if it is wss, the websocket hook is executed after the tls hook. // Make sure that if it is wss, the websocket hook is executed after the tls hook.
dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{Hook: utilnet.DialHookWebsocket(protocol, tlsConfig.ServerName), Priority: 110})) dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{Hook: netpkg.DialHookWebsocket(protocol, tlsConfig.ServerName), Priority: 110}))
default: default:
dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{ dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{
Hook: utilnet.DialHookCustomTLSHeadByte(tlsConfig != nil, lo.FromPtr(c.cfg.Transport.TLS.DisableCustomTLSFirstByte)), Hook: netpkg.DialHookCustomTLSHeadByte(tlsConfig != nil, lo.FromPtr(c.cfg.Transport.TLS.DisableCustomTLSFirstByte)),
})) }))
dialOptions = append(dialOptions, libdial.WithTLSConfig(tlsConfig)) dialOptions = append(dialOptions, libdial.WithTLSConfig(tlsConfig))
} }
@ -213,11 +215,13 @@ func (c *defaultConnectorImpl) realConnect() (net.Conn, error) {
} }
func (c *defaultConnectorImpl) Close() error { func (c *defaultConnectorImpl) Close() error {
if c.quicConn != nil { c.closeOnce.Do(func() {
_ = c.quicConn.CloseWithError(0, "") if c.quicConn != nil {
} _ = c.quicConn.CloseWithError(0, "")
if c.muxSession != nil { }
_ = c.muxSession.Close() if c.muxSession != nil {
} _ = c.muxSession.Close()
}
})
return nil return nil
} }

View File

@ -28,39 +28,42 @@ import (
v1 "github.com/fatedier/frp/pkg/config/v1" v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/msg" "github.com/fatedier/frp/pkg/msg"
"github.com/fatedier/frp/pkg/transport" "github.com/fatedier/frp/pkg/transport"
utilnet "github.com/fatedier/frp/pkg/util/net" netpkg "github.com/fatedier/frp/pkg/util/net"
"github.com/fatedier/frp/pkg/util/wait" "github.com/fatedier/frp/pkg/util/wait"
"github.com/fatedier/frp/pkg/util/xlog" "github.com/fatedier/frp/pkg/util/xlog"
) )
type SessionContext struct {
// The client common configuration.
Common *v1.ClientCommonConfig
// Unique ID obtained from frps.
// It should be attached to the login message when reconnecting.
RunID string
// Underlying control connection. Once conn is closed, the msgDispatcher and the entire Control will exit.
Conn net.Conn
// Indicates whether the connection is encrypted.
ConnEncrypted bool
// Sets authentication based on selected method
AuthSetter auth.Setter
// Connector is used to create new connections, which could be real TCP connections or virtual streams.
Connector Connector
}
type Control struct { type Control struct {
// service context // service context
ctx context.Context ctx context.Context
xl *xlog.Logger xl *xlog.Logger
// The client configuration // session context
clientCfg *v1.ClientCommonConfig sessionCtx *SessionContext
// sets authentication based on selected method
authSetter auth.Setter
// Unique ID obtained from frps.
// It should be attached to the login message when reconnecting.
runID string
// manage all proxies // manage all proxies
pxyCfgs []v1.ProxyConfigurer pm *proxy.Manager
pm *proxy.Manager
// manage all visitors // manage all visitors
vm *visitor.Manager vm *visitor.Manager
// control connection. Once conn is closed, the msgDispatcher and the entire Control will exit.
conn net.Conn
// use connector to create new connections, which could be real TCP connections or virtual streams.
connector Connector
doneCh chan struct{} doneCh chan struct{}
// of time.Time, last time got the Pong message // of time.Time, last time got the Pong message
@ -76,50 +79,41 @@ type Control struct {
msgDispatcher *msg.Dispatcher msgDispatcher *msg.Dispatcher
} }
func NewControl( func NewControl(ctx context.Context, sessionCtx *SessionContext) (*Control, error) {
ctx context.Context, runID string, conn net.Conn, connector Connector,
clientCfg *v1.ClientCommonConfig,
pxyCfgs []v1.ProxyConfigurer,
visitorCfgs []v1.VisitorConfigurer,
authSetter auth.Setter,
) (*Control, error) {
// new xlog instance // new xlog instance
ctl := &Control{ ctl := &Control{
ctx: ctx, ctx: ctx,
xl: xlog.FromContextSafe(ctx), xl: xlog.FromContextSafe(ctx),
clientCfg: clientCfg, sessionCtx: sessionCtx,
authSetter: authSetter,
runID: runID,
pxyCfgs: pxyCfgs,
conn: conn,
connector: connector,
doneCh: make(chan struct{}), doneCh: make(chan struct{}),
} }
ctl.lastPong.Store(time.Now()) ctl.lastPong.Store(time.Now())
cryptoRW, err := utilnet.NewCryptoReadWriter(conn, []byte(clientCfg.Auth.Token)) if sessionCtx.ConnEncrypted {
if err != nil { cryptoRW, err := netpkg.NewCryptoReadWriter(sessionCtx.Conn, []byte(sessionCtx.Common.Auth.Token))
return nil, err if err != nil {
return nil, err
}
ctl.msgDispatcher = msg.NewDispatcher(cryptoRW)
} else {
ctl.msgDispatcher = msg.NewDispatcher(sessionCtx.Conn)
} }
ctl.msgDispatcher = msg.NewDispatcher(cryptoRW)
ctl.registerMsgHandlers() ctl.registerMsgHandlers()
ctl.msgTransporter = transport.NewMessageTransporter(ctl.msgDispatcher.SendChannel()) ctl.msgTransporter = transport.NewMessageTransporter(ctl.msgDispatcher.SendChannel())
ctl.pm = proxy.NewManager(ctl.ctx, clientCfg, ctl.msgTransporter) ctl.pm = proxy.NewManager(ctl.ctx, sessionCtx.Common, ctl.msgTransporter)
ctl.vm = visitor.NewManager(ctl.ctx, ctl.runID, ctl.clientCfg, ctl.connectServer, ctl.msgTransporter) ctl.vm = visitor.NewManager(ctl.ctx, sessionCtx.RunID, sessionCtx.Common, ctl.connectServer, ctl.msgTransporter)
ctl.vm.Reload(visitorCfgs)
return ctl, nil return ctl, nil
} }
func (ctl *Control) Run() { func (ctl *Control) Run(proxyCfgs []v1.ProxyConfigurer, visitorCfgs []v1.VisitorConfigurer) {
go ctl.worker() go ctl.worker()
// start all proxies // start all proxies
ctl.pm.Reload(ctl.pxyCfgs) ctl.pm.UpdateAll(proxyCfgs)
// start all visitors // start all visitors
go ctl.vm.Run() ctl.vm.UpdateAll(visitorCfgs)
} }
func (ctl *Control) SetInWorkConnCallback(cb func(*v1.ProxyBaseConfig, net.Conn, *msg.StartWorkConn) bool) { func (ctl *Control) SetInWorkConnCallback(cb func(*v1.ProxyBaseConfig, net.Conn, *msg.StartWorkConn) bool) {
@ -135,9 +129,9 @@ func (ctl *Control) handleReqWorkConn(_ msg.Message) {
} }
m := &msg.NewWorkConn{ m := &msg.NewWorkConn{
RunID: ctl.runID, RunID: ctl.sessionCtx.RunID,
} }
if err = ctl.authSetter.SetNewWorkConn(m); err != nil { if err = ctl.sessionCtx.AuthSetter.SetNewWorkConn(m); err != nil {
xl.Warn("error during NewWorkConn authentication: %v", err) xl.Warn("error during NewWorkConn authentication: %v", err)
return return
} }
@ -193,13 +187,19 @@ func (ctl *Control) handlePong(m msg.Message) {
if inMsg.Error != "" { if inMsg.Error != "" {
xl.Error("Pong message contains error: %s", inMsg.Error) xl.Error("Pong message contains error: %s", inMsg.Error)
ctl.conn.Close() ctl.closeSession()
return return
} }
ctl.lastPong.Store(time.Now()) ctl.lastPong.Store(time.Now())
xl.Debug("receive heartbeat from server") xl.Debug("receive heartbeat from server")
} }
// closeSession closes the control connection.
func (ctl *Control) closeSession() {
ctl.sessionCtx.Conn.Close()
ctl.sessionCtx.Connector.Close()
}
func (ctl *Control) Close() error { func (ctl *Control) Close() error {
return ctl.GracefulClose(0) return ctl.GracefulClose(0)
} }
@ -210,8 +210,7 @@ func (ctl *Control) GracefulClose(d time.Duration) error {
time.Sleep(d) time.Sleep(d)
ctl.conn.Close() ctl.closeSession()
ctl.connector.Close()
return nil return nil
} }
@ -221,8 +220,8 @@ func (ctl *Control) Done() <-chan struct{} {
} }
// connectServer return a new connection to frps // connectServer return a new connection to frps
func (ctl *Control) connectServer() (conn net.Conn, err error) { func (ctl *Control) connectServer() (net.Conn, error) {
return ctl.connector.Connect() return ctl.sessionCtx.Connector.Connect()
} }
func (ctl *Control) registerMsgHandlers() { func (ctl *Control) registerMsgHandlers() {
@ -238,12 +237,12 @@ func (ctl *Control) heartbeatWorker() {
// TODO(fatedier): Change default value of HeartbeatInterval to -1 if tcpmux is enabled. // TODO(fatedier): Change default value of HeartbeatInterval to -1 if tcpmux is enabled.
// Users can still enable heartbeat feature by setting HeartbeatInterval to a positive value. // Users can still enable heartbeat feature by setting HeartbeatInterval to a positive value.
if ctl.clientCfg.Transport.HeartbeatInterval > 0 { if ctl.sessionCtx.Common.Transport.HeartbeatInterval > 0 {
// send heartbeat to server // send heartbeat to server
sendHeartBeat := func() error { sendHeartBeat := func() error {
xl.Debug("send heartbeat to server") xl.Debug("send heartbeat to server")
pingMsg := &msg.Ping{} pingMsg := &msg.Ping{}
if err := ctl.authSetter.SetPing(pingMsg); err != nil { if err := ctl.sessionCtx.AuthSetter.SetPing(pingMsg); err != nil {
xl.Warn("error during ping authentication: %v, skip sending ping message", err) xl.Warn("error during ping authentication: %v, skip sending ping message", err)
return err return err
} }
@ -253,24 +252,24 @@ func (ctl *Control) heartbeatWorker() {
go wait.BackoffUntil(sendHeartBeat, go wait.BackoffUntil(sendHeartBeat,
wait.NewFastBackoffManager(wait.FastBackoffOptions{ wait.NewFastBackoffManager(wait.FastBackoffOptions{
Duration: time.Duration(ctl.clientCfg.Transport.HeartbeatInterval) * time.Second, Duration: time.Duration(ctl.sessionCtx.Common.Transport.HeartbeatInterval) * time.Second,
InitDurationIfFail: time.Second, InitDurationIfFail: time.Second,
Factor: 2.0, Factor: 2.0,
Jitter: 0.1, Jitter: 0.1,
MaxDuration: time.Duration(ctl.clientCfg.Transport.HeartbeatInterval) * time.Second, MaxDuration: time.Duration(ctl.sessionCtx.Common.Transport.HeartbeatInterval) * time.Second,
}), }),
true, ctl.doneCh, true, ctl.doneCh,
) )
} }
// Check heartbeat timeout only if TCPMux is not enabled and users don't disable heartbeat feature. // Check heartbeat timeout only if TCPMux is not enabled and users don't disable heartbeat feature.
if ctl.clientCfg.Transport.HeartbeatInterval > 0 && ctl.clientCfg.Transport.HeartbeatTimeout > 0 && if ctl.sessionCtx.Common.Transport.HeartbeatInterval > 0 && ctl.sessionCtx.Common.Transport.HeartbeatTimeout > 0 &&
!lo.FromPtr(ctl.clientCfg.Transport.TCPMux) { !lo.FromPtr(ctl.sessionCtx.Common.Transport.TCPMux) {
go wait.Until(func() { go wait.Until(func() {
if time.Since(ctl.lastPong.Load().(time.Time)) > time.Duration(ctl.clientCfg.Transport.HeartbeatTimeout)*time.Second { if time.Since(ctl.lastPong.Load().(time.Time)) > time.Duration(ctl.sessionCtx.Common.Transport.HeartbeatTimeout)*time.Second {
xl.Warn("heartbeat timeout") xl.Warn("heartbeat timeout")
ctl.conn.Close() ctl.closeSession()
return return
} }
}, time.Second, ctl.doneCh) }, time.Second, ctl.doneCh)
@ -282,17 +281,15 @@ func (ctl *Control) worker() {
go ctl.msgDispatcher.Run() go ctl.msgDispatcher.Run()
<-ctl.msgDispatcher.Done() <-ctl.msgDispatcher.Done()
ctl.conn.Close() ctl.closeSession()
ctl.pm.Close() ctl.pm.Close()
ctl.vm.Close() ctl.vm.Close()
ctl.connector.Close()
close(ctl.doneCh) close(ctl.doneCh)
} }
func (ctl *Control) ReloadConf(pxyCfgs []v1.ProxyConfigurer, visitorCfgs []v1.VisitorConfigurer) error { func (ctl *Control) UpdateAllConfigurer(proxyCfgs []v1.ProxyConfigurer, visitorCfgs []v1.VisitorConfigurer) error {
ctl.vm.Reload(visitorCfgs) ctl.vm.UpdateAll(visitorCfgs)
ctl.pm.Reload(pxyCfgs) ctl.pm.UpdateAll(proxyCfgs)
return nil return nil
} }

View File

@ -120,9 +120,18 @@ func (pm *Manager) GetAllProxyStatus() []*WorkingStatus {
return ps return ps
} }
func (pm *Manager) Reload(pxyCfgs []v1.ProxyConfigurer) { func (pm *Manager) GetProxyStatus(name string) (*WorkingStatus, bool) {
pm.mu.RLock()
defer pm.mu.RUnlock()
if pxy, ok := pm.proxies[name]; ok {
return pxy.GetStatus(), true
}
return nil, false
}
func (pm *Manager) UpdateAll(proxyCfgs []v1.ProxyConfigurer) {
xl := xlog.FromContextSafe(pm.ctx) xl := xlog.FromContextSafe(pm.ctx)
pxyCfgsMap := lo.KeyBy(pxyCfgs, func(c v1.ProxyConfigurer) string { proxyCfgsMap := lo.KeyBy(proxyCfgs, func(c v1.ProxyConfigurer) string {
return c.GetBaseConfig().Name return c.GetBaseConfig().Name
}) })
pm.mu.Lock() pm.mu.Lock()
@ -131,7 +140,7 @@ func (pm *Manager) Reload(pxyCfgs []v1.ProxyConfigurer) {
delPxyNames := make([]string, 0) delPxyNames := make([]string, 0)
for name, pxy := range pm.proxies { for name, pxy := range pm.proxies {
del := false del := false
cfg, ok := pxyCfgsMap[name] cfg, ok := proxyCfgsMap[name]
if !ok || !reflect.DeepEqual(pxy.Cfg, cfg) { if !ok || !reflect.DeepEqual(pxy.Cfg, cfg) {
del = true del = true
} }
@ -147,7 +156,7 @@ func (pm *Manager) Reload(pxyCfgs []v1.ProxyConfigurer) {
} }
addPxyNames := make([]string, 0) addPxyNames := make([]string, 0)
for _, cfg := range pxyCfgs { for _, cfg := range proxyCfgs {
name := cfg.GetBaseConfig().Name name := cfg.GetBaseConfig().Name
if _, ok := pm.proxies[name]; !ok { if _, ok := pm.proxies[name]; !ok {
pxy := NewWrapper(pm.ctx, cfg, pm.clientCfg, pm.HandleEvent, pm.msgTransporter) pxy := NewWrapper(pm.ctx, cfg, pm.clientCfg, pm.HandleEvent, pm.msgTransporter)

View File

@ -31,7 +31,7 @@ import (
"github.com/fatedier/frp/pkg/msg" "github.com/fatedier/frp/pkg/msg"
"github.com/fatedier/frp/pkg/proto/udp" "github.com/fatedier/frp/pkg/proto/udp"
"github.com/fatedier/frp/pkg/util/limit" "github.com/fatedier/frp/pkg/util/limit"
utilnet "github.com/fatedier/frp/pkg/util/net" netpkg "github.com/fatedier/frp/pkg/util/net"
) )
func init() { func init() {
@ -101,7 +101,7 @@ func (pxy *SUDPProxy) InWorkConn(conn net.Conn, _ *msg.StartWorkConn) {
if pxy.cfg.Transport.UseCompression { if pxy.cfg.Transport.UseCompression {
rwc = libio.WithCompression(rwc) rwc = libio.WithCompression(rwc)
} }
conn = utilnet.WrapReadWriteCloserToConn(rwc, conn) conn = netpkg.WrapReadWriteCloserToConn(rwc, conn)
workConn := conn workConn := conn
readCh := make(chan *msg.UDPPacket, 1024) readCh := make(chan *msg.UDPPacket, 1024)

View File

@ -30,7 +30,7 @@ import (
"github.com/fatedier/frp/pkg/msg" "github.com/fatedier/frp/pkg/msg"
"github.com/fatedier/frp/pkg/proto/udp" "github.com/fatedier/frp/pkg/proto/udp"
"github.com/fatedier/frp/pkg/util/limit" "github.com/fatedier/frp/pkg/util/limit"
utilnet "github.com/fatedier/frp/pkg/util/net" netpkg "github.com/fatedier/frp/pkg/util/net"
) )
func init() { func init() {
@ -112,7 +112,7 @@ func (pxy *UDPProxy) InWorkConn(conn net.Conn, _ *msg.StartWorkConn) {
if pxy.cfg.Transport.UseCompression { if pxy.cfg.Transport.UseCompression {
rwc = libio.WithCompression(rwc) rwc = libio.WithCompression(rwc)
} }
conn = utilnet.WrapReadWriteCloserToConn(rwc, conn) conn = netpkg.WrapReadWriteCloserToConn(rwc, conn)
pxy.mu.Lock() pxy.mu.Lock()
pxy.workConn = conn pxy.workConn = conn

View File

@ -29,7 +29,7 @@ import (
"github.com/fatedier/frp/pkg/msg" "github.com/fatedier/frp/pkg/msg"
"github.com/fatedier/frp/pkg/nathole" "github.com/fatedier/frp/pkg/nathole"
"github.com/fatedier/frp/pkg/transport" "github.com/fatedier/frp/pkg/transport"
utilnet "github.com/fatedier/frp/pkg/util/net" netpkg "github.com/fatedier/frp/pkg/util/net"
) )
func init() { func init() {
@ -133,7 +133,7 @@ func (pxy *XTCPProxy) listenByKCP(listenConn *net.UDPConn, raddr *net.UDPAddr, s
} }
defer lConn.Close() defer lConn.Close()
remote, err := utilnet.NewKCPConnFromUDP(lConn, true, raddr.String()) remote, err := netpkg.NewKCPConnFromUDP(lConn, true, raddr.String())
if err != nil { if err != nil {
xl.Warn("create kcp connection from udp connection error: %v", err) xl.Warn("create kcp connection from udp connection error: %v", err)
return return
@ -194,6 +194,6 @@ func (pxy *XTCPProxy) listenByQUIC(listenConn *net.UDPConn, _ *net.UDPAddr, star
_ = c.CloseWithError(0, "") _ = c.CloseWithError(0, "")
return return
} }
go pxy.HandleTCPWorkConnection(utilnet.QuicStreamToNetConn(stream, c), startWorkConnMsg, []byte(pxy.cfg.Secretkey)) go pxy.HandleTCPWorkConnection(netpkg.QuicStreamToNetConn(stream, c), startWorkConnMsg, []byte(pxy.cfg.Secretkey))
} }
} }

View File

@ -20,18 +20,19 @@ import (
"fmt" "fmt"
"net" "net"
"runtime" "runtime"
"strconv"
"sync" "sync"
"time" "time"
"github.com/fatedier/golib/crypto" "github.com/fatedier/golib/crypto"
"github.com/samber/lo" "github.com/samber/lo"
"github.com/fatedier/frp/assets" "github.com/fatedier/frp/client/proxy"
"github.com/fatedier/frp/pkg/auth" "github.com/fatedier/frp/pkg/auth"
v1 "github.com/fatedier/frp/pkg/config/v1" v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/msg" "github.com/fatedier/frp/pkg/msg"
httppkg "github.com/fatedier/frp/pkg/util/http"
"github.com/fatedier/frp/pkg/util/log" "github.com/fatedier/frp/pkg/util/log"
netpkg "github.com/fatedier/frp/pkg/util/net"
"github.com/fatedier/frp/pkg/util/version" "github.com/fatedier/frp/pkg/util/version"
"github.com/fatedier/frp/pkg/util/wait" "github.com/fatedier/frp/pkg/util/wait"
"github.com/fatedier/frp/pkg/util/xlog" "github.com/fatedier/frp/pkg/util/xlog"
@ -41,66 +42,106 @@ func init() {
crypto.DefaultSalt = "frp" crypto.DefaultSalt = "frp"
} }
// Service is a client service. // ServiceOptions contains options for creating a new client service.
type Service struct { type ServiceOptions struct {
// uniq id got from frps, attach it in loginMsg Common *v1.ClientCommonConfig
runID string ProxyCfgs []v1.ProxyConfigurer
VisitorCfgs []v1.VisitorConfigurer
// manager control connection with server // ConfigFilePath is the path to the configuration file used to initialize.
ctl *Control // If it is empty, it means that the configuration file is not used for initialization.
// It may be initialized using command line parameters or called directly.
ConfigFilePath string
// ClientSpec is the client specification that control the client behavior.
ClientSpec *msg.ClientSpec
// ConnectorCreator is a function that creates a new connector to make connections to the server.
// The Connector shields the underlying connection details, whether it is through TCP or QUIC connection,
// and regardless of whether multiplexing is used.
//
// If it is not set, the default frpc connector will be used.
// By using a custom Connector, it can be used to implement a VirtualClient, which connects to frps
// through a pipe instead of a real physical connection.
ConnectorCreator func(context.Context, *v1.ClientCommonConfig) Connector
// HandleWorkConnCb is a callback function that is called when a new work connection is created.
//
// If it is not set, the default frpc implementation will be used.
HandleWorkConnCb func(*v1.ProxyBaseConfig, net.Conn, *msg.StartWorkConn) bool
}
// setServiceOptionsDefault sets the default values for ServiceOptions.
func setServiceOptionsDefault(options *ServiceOptions) {
if options.Common != nil {
options.Common.Complete()
}
if options.ConnectorCreator == nil {
options.ConnectorCreator = NewConnector
}
}
// Service is the client service that connects to frps and provides proxy services.
type Service struct {
ctlMu sync.RWMutex ctlMu sync.RWMutex
// manager control connection with server
ctl *Control
// Uniq id got from frps, it will be attached to loginMsg.
runID string
// Sets authentication based on selected method // Sets authentication based on selected method
authSetter auth.Setter authSetter auth.Setter
cfg *v1.ClientCommonConfig // web server for admin UI and apis
pxyCfgs []v1.ProxyConfigurer webServer *httppkg.Server
visitorCfgs []v1.VisitorConfigurer
cfgMu sync.RWMutex cfgMu sync.RWMutex
common *v1.ClientCommonConfig
proxyCfgs []v1.ProxyConfigurer
visitorCfgs []v1.VisitorConfigurer
clientSpec *msg.ClientSpec
// The configuration file used to initialize this client, or an empty // The configuration file used to initialize this client, or an empty
// string if no configuration file was used. // string if no configuration file was used.
cfgFile string configFilePath string
// service context // service context
ctx context.Context ctx context.Context
// call cancel to stop service // call cancel to stop service
cancel context.CancelFunc cancel context.CancelFunc
gracefulDuration time.Duration gracefulShutdownDuration time.Duration
connectorCreator func(context.Context, *v1.ClientCommonConfig) Connector connectorCreator func(context.Context, *v1.ClientCommonConfig) Connector
inWorkConnCallback func(*v1.ProxyBaseConfig, net.Conn, *msg.StartWorkConn) bool handleWorkConnCb func(*v1.ProxyBaseConfig, net.Conn, *msg.StartWorkConn) bool
} }
func NewService( func NewService(options ServiceOptions) (*Service, error) {
cfg *v1.ClientCommonConfig, setServiceOptionsDefault(&options)
pxyCfgs []v1.ProxyConfigurer,
visitorCfgs []v1.VisitorConfigurer, var webServer *httppkg.Server
cfgFile string, if options.Common.WebServer.Port > 0 {
) *Service { ws, err := httppkg.NewServer(options.Common.WebServer)
return &Service{ if err != nil {
authSetter: auth.NewAuthSetter(cfg.Auth), return nil, err
cfg: cfg, }
cfgFile: cfgFile, webServer = ws
pxyCfgs: pxyCfgs,
visitorCfgs: visitorCfgs,
ctx: context.Background(),
connectorCreator: NewConnector,
} }
} s := &Service{
ctx: context.Background(),
func (svr *Service) SetConnectorCreator(h func(context.Context, *v1.ClientCommonConfig) Connector) { authSetter: auth.NewAuthSetter(options.Common.Auth),
svr.connectorCreator = h webServer: webServer,
} common: options.Common,
configFilePath: options.ConfigFilePath,
func (svr *Service) SetInWorkConnCallback(cb func(*v1.ProxyBaseConfig, net.Conn, *msg.StartWorkConn) bool) { proxyCfgs: options.ProxyCfgs,
svr.inWorkConnCallback = cb visitorCfgs: options.VisitorCfgs,
} clientSpec: options.ClientSpec,
connectorCreator: options.ConnectorCreator,
func (svr *Service) GetController() *Control { handleWorkConnCb: options.HandleWorkConnCb,
svr.ctlMu.RLock() }
defer svr.ctlMu.RUnlock() if webServer != nil {
return svr.ctl webServer.RouteRegister(s.registerRouteHandlers)
}
return s, nil
} }
func (svr *Service) Run(ctx context.Context) error { func (svr *Service) Run(ctx context.Context) error {
@ -109,38 +150,25 @@ func (svr *Service) Run(ctx context.Context) error {
svr.cancel = cancel svr.cancel = cancel
// set custom DNSServer // set custom DNSServer
if svr.cfg.DNSServer != "" { if svr.common.DNSServer != "" {
dnsAddr := svr.cfg.DNSServer netpkg.SetDefaultDNSAddress(svr.common.DNSServer)
if _, _, err := net.SplitHostPort(dnsAddr); err != nil {
dnsAddr = net.JoinHostPort(dnsAddr, "53")
}
// Change default dns server for frpc
net.DefaultResolver = &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
return net.Dial("udp", dnsAddr)
},
}
} }
// login to frps // first login to frps
svr.loopLoginUntilSuccess(10*time.Second, lo.FromPtr(svr.cfg.LoginFailExit)) svr.loopLoginUntilSuccess(10*time.Second, lo.FromPtr(svr.common.LoginFailExit))
if svr.ctl == nil { if svr.ctl == nil {
return fmt.Errorf("the process exited because the first login to the server failed, and the loginFailExit feature is enabled") return fmt.Errorf("the process exited because the first login to the server failed, and the loginFailExit feature is enabled")
} }
go svr.keepControllerWorking() go svr.keepControllerWorking()
if svr.cfg.WebServer.Port != 0 { if svr.webServer != nil {
// Init admin server assets go func() {
assets.Load(svr.cfg.WebServer.AssetsDir) log.Info("admin server listen on %s", svr.webServer.Address())
if err := svr.webServer.Run(); err != nil {
address := net.JoinHostPort(svr.cfg.WebServer.Addr, strconv.Itoa(svr.cfg.WebServer.Port)) log.Warn("admin server exit with error: %v", err)
err := svr.RunAdminServer(address) }
if err != nil { }()
log.Warn("run admin server error: %v", err)
}
log.Info("admin server listen on %s:%d", svr.cfg.WebServer.Addr, svr.cfg.WebServer.Port)
} }
<-svr.ctx.Done() <-svr.ctx.Done()
svr.stop() svr.stop()
@ -158,8 +186,12 @@ func (svr *Service) keepControllerWorking() {
// loopLoginUntilSuccess is another layer of loop that will continuously attempt to // loopLoginUntilSuccess is another layer of loop that will continuously attempt to
// login to the server until successful. // login to the server until successful.
svr.loopLoginUntilSuccess(20*time.Second, false) svr.loopLoginUntilSuccess(20*time.Second, false)
<-svr.ctl.Done() if svr.ctl != nil {
return errors.New("control is closed and try another loop") <-svr.ctl.Done()
return errors.New("control is closed and try another loop")
}
// If the control is nil, it means that the login failed and the service is also closed.
return nil
}, wait.NewFastBackoffManager( }, wait.NewFastBackoffManager(
wait.FastBackoffOptions{ wait.FastBackoffOptions{
Duration: time.Second, Duration: time.Second,
@ -179,7 +211,7 @@ func (svr *Service) keepControllerWorking() {
// session: if it's not nil, using tcp mux // session: if it's not nil, using tcp mux
func (svr *Service) login() (conn net.Conn, connector Connector, err error) { func (svr *Service) login() (conn net.Conn, connector Connector, err error) {
xl := xlog.FromContextSafe(svr.ctx) xl := xlog.FromContextSafe(svr.ctx)
connector = svr.connectorCreator(svr.ctx, svr.cfg) connector = svr.connectorCreator(svr.ctx, svr.common)
if err = connector.Open(); err != nil { if err = connector.Open(); err != nil {
return nil, nil, err return nil, nil, err
} }
@ -198,12 +230,15 @@ func (svr *Service) login() (conn net.Conn, connector Connector, err error) {
loginMsg := &msg.Login{ loginMsg := &msg.Login{
Arch: runtime.GOARCH, Arch: runtime.GOARCH,
Os: runtime.GOOS, Os: runtime.GOOS,
PoolCount: svr.cfg.Transport.PoolCount, PoolCount: svr.common.Transport.PoolCount,
User: svr.cfg.User, User: svr.common.User,
Version: version.Full(), Version: version.Full(),
Timestamp: time.Now().Unix(), Timestamp: time.Now().Unix(),
RunID: svr.runID, RunID: svr.runID,
Metas: svr.cfg.Metadatas, Metas: svr.common.Metadatas,
}
if svr.clientSpec != nil {
loginMsg.ClientSpec = *svr.clientSpec
} }
// Add auth // Add auth
@ -250,16 +285,31 @@ func (svr *Service) loopLoginUntilSuccess(maxInterval time.Duration, firstLoginE
return err return err
} }
ctl, err := NewControl(svr.ctx, svr.runID, conn, connector, svr.cfgMu.RLock()
svr.cfg, svr.pxyCfgs, svr.visitorCfgs, svr.authSetter) proxyCfgs := svr.proxyCfgs
visitorCfgs := svr.visitorCfgs
svr.cfgMu.RUnlock()
connEncrypted := true
if svr.clientSpec != nil && svr.clientSpec.Type == "ssh-tunnel" {
connEncrypted = false
}
sessionCtx := &SessionContext{
Common: svr.common,
RunID: svr.runID,
Conn: conn,
ConnEncrypted: connEncrypted,
AuthSetter: svr.authSetter,
Connector: connector,
}
ctl, err := NewControl(svr.ctx, sessionCtx)
if err != nil { if err != nil {
conn.Close() conn.Close()
xl.Error("NewControl error: %v", err) xl.Error("NewControl error: %v", err)
return err return err
} }
ctl.SetInWorkConnCallback(svr.inWorkConnCallback) ctl.SetInWorkConnCallback(svr.handleWorkConnCb)
ctl.Run() ctl.Run(proxyCfgs, visitorCfgs)
// close and replace previous control // close and replace previous control
svr.ctlMu.Lock() svr.ctlMu.Lock()
if svr.ctl != nil { if svr.ctl != nil {
@ -284,9 +334,9 @@ func (svr *Service) loopLoginUntilSuccess(maxInterval time.Duration, firstLoginE
wait.MergeAndCloseOnAnyStopChannel(svr.ctx.Done(), successCh)) wait.MergeAndCloseOnAnyStopChannel(svr.ctx.Done(), successCh))
} }
func (svr *Service) ReloadConf(pxyCfgs []v1.ProxyConfigurer, visitorCfgs []v1.VisitorConfigurer) error { func (svr *Service) UpdateAllConfigurer(proxyCfgs []v1.ProxyConfigurer, visitorCfgs []v1.VisitorConfigurer) error {
svr.cfgMu.Lock() svr.cfgMu.Lock()
svr.pxyCfgs = pxyCfgs svr.proxyCfgs = proxyCfgs
svr.visitorCfgs = visitorCfgs svr.visitorCfgs = visitorCfgs
svr.cfgMu.Unlock() svr.cfgMu.Unlock()
@ -295,7 +345,7 @@ func (svr *Service) ReloadConf(pxyCfgs []v1.ProxyConfigurer, visitorCfgs []v1.Vi
svr.ctlMu.RUnlock() svr.ctlMu.RUnlock()
if ctl != nil { if ctl != nil {
return svr.ctl.ReloadConf(pxyCfgs, visitorCfgs) return svr.ctl.UpdateAllConfigurer(proxyCfgs, visitorCfgs)
} }
return nil return nil
} }
@ -305,7 +355,7 @@ func (svr *Service) Close() {
} }
func (svr *Service) GracefulClose(d time.Duration) { func (svr *Service) GracefulClose(d time.Duration) {
svr.gracefulDuration = d svr.gracefulShutdownDuration = d
svr.cancel() svr.cancel()
} }
@ -313,7 +363,23 @@ func (svr *Service) stop() {
svr.ctlMu.Lock() svr.ctlMu.Lock()
defer svr.ctlMu.Unlock() defer svr.ctlMu.Unlock()
if svr.ctl != nil { if svr.ctl != nil {
svr.ctl.GracefulClose(svr.gracefulDuration) svr.ctl.GracefulClose(svr.gracefulShutdownDuration)
svr.ctl = nil svr.ctl = nil
} }
} }
// TODO(fatedier): Use StatusExporter to provide query interfaces instead of directly using methods from the Service.
func (svr *Service) GetProxyStatus(name string) (*proxy.WorkingStatus, error) {
svr.ctlMu.RLock()
ctl := svr.ctl
svr.ctlMu.RUnlock()
if ctl == nil {
return nil, fmt.Errorf("control is not running")
}
ws, ok := ctl.pm.GetProxyStatus(name)
if !ok {
return nil, fmt.Errorf("proxy [%s] is not found", name)
}
return ws, nil
}

View File

@ -28,7 +28,7 @@ import (
v1 "github.com/fatedier/frp/pkg/config/v1" v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/msg" "github.com/fatedier/frp/pkg/msg"
"github.com/fatedier/frp/pkg/proto/udp" "github.com/fatedier/frp/pkg/proto/udp"
utilnet "github.com/fatedier/frp/pkg/util/net" netpkg "github.com/fatedier/frp/pkg/util/net"
"github.com/fatedier/frp/pkg/util/util" "github.com/fatedier/frp/pkg/util/util"
"github.com/fatedier/frp/pkg/util/xlog" "github.com/fatedier/frp/pkg/util/xlog"
) )
@ -242,7 +242,7 @@ func (sv *SUDPVisitor) getNewVisitorConn() (net.Conn, error) {
if sv.cfg.Transport.UseCompression { if sv.cfg.Transport.UseCompression {
remote = libio.WithCompression(remote) remote = libio.WithCompression(remote)
} }
return utilnet.WrapReadWriteCloserToConn(remote, visitorConn), nil return netpkg.WrapReadWriteCloserToConn(remote, visitorConn), nil
} }
func (sv *SUDPVisitor) Close() { func (sv *SUDPVisitor) Close() {

View File

@ -21,7 +21,7 @@ import (
v1 "github.com/fatedier/frp/pkg/config/v1" v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/transport" "github.com/fatedier/frp/pkg/transport"
utilnet "github.com/fatedier/frp/pkg/util/net" netpkg "github.com/fatedier/frp/pkg/util/net"
"github.com/fatedier/frp/pkg/util/xlog" "github.com/fatedier/frp/pkg/util/xlog"
) )
@ -56,7 +56,7 @@ func NewVisitor(
clientCfg: clientCfg, clientCfg: clientCfg,
helper: helper, helper: helper,
ctx: xlog.NewContext(ctx, xl), ctx: xlog.NewContext(ctx, xl),
internalLn: utilnet.NewInternalListener(), internalLn: netpkg.NewInternalListener(),
} }
switch cfg := cfg.(type) { switch cfg := cfg.(type) {
case *v1.STCPVisitorConfig: case *v1.STCPVisitorConfig:
@ -84,7 +84,7 @@ type BaseVisitor struct {
clientCfg *v1.ClientCommonConfig clientCfg *v1.ClientCommonConfig
helper Helper helper Helper
l net.Listener l net.Listener
internalLn *utilnet.InternalListener internalLn *netpkg.InternalListener
mu sync.RWMutex mu sync.RWMutex
ctx context.Context ctx context.Context

View File

@ -35,7 +35,8 @@ type Manager struct {
visitors map[string]Visitor visitors map[string]Visitor
helper Helper helper Helper
checkInterval time.Duration checkInterval time.Duration
keepVisitorsRunningOnce sync.Once
mu sync.RWMutex mu sync.RWMutex
ctx context.Context ctx context.Context
@ -67,7 +68,9 @@ func NewManager(
return m return m
} }
func (vm *Manager) Run() { // keepVisitorsRunning checks all visitors' status periodically, if some visitor is not running, start it.
// It will only start after Reload is called and a new visitor is added.
func (vm *Manager) keepVisitorsRunning() {
xl := xlog.FromContextSafe(vm.ctx) xl := xlog.FromContextSafe(vm.ctx)
ticker := time.NewTicker(vm.checkInterval) ticker := time.NewTicker(vm.checkInterval)
@ -76,7 +79,7 @@ func (vm *Manager) Run() {
for { for {
select { select {
case <-vm.stopCh: case <-vm.stopCh:
xl.Info("gracefully shutdown visitor manager") xl.Trace("gracefully shutdown visitor manager")
return return
case <-ticker.C: case <-ticker.C:
vm.mu.Lock() vm.mu.Lock()
@ -120,7 +123,14 @@ func (vm *Manager) startVisitor(cfg v1.VisitorConfigurer) (err error) {
return return
} }
func (vm *Manager) Reload(cfgs []v1.VisitorConfigurer) { func (vm *Manager) UpdateAll(cfgs []v1.VisitorConfigurer) {
if len(cfgs) > 0 {
// Only start keepVisitorsRunning goroutine once and only when there is at least one visitor.
vm.keepVisitorsRunningOnce.Do(func() {
go vm.keepVisitorsRunning()
})
}
xl := xlog.FromContextSafe(vm.ctx) xl := xlog.FromContextSafe(vm.ctx)
cfgsMap := lo.KeyBy(cfgs, func(c v1.VisitorConfigurer) string { cfgsMap := lo.KeyBy(cfgs, func(c v1.VisitorConfigurer) string {
return c.GetBaseConfig().Name return c.GetBaseConfig().Name

View File

@ -33,7 +33,7 @@ import (
"github.com/fatedier/frp/pkg/msg" "github.com/fatedier/frp/pkg/msg"
"github.com/fatedier/frp/pkg/nathole" "github.com/fatedier/frp/pkg/nathole"
"github.com/fatedier/frp/pkg/transport" "github.com/fatedier/frp/pkg/transport"
utilnet "github.com/fatedier/frp/pkg/util/net" netpkg "github.com/fatedier/frp/pkg/util/net"
"github.com/fatedier/frp/pkg/util/util" "github.com/fatedier/frp/pkg/util/util"
"github.com/fatedier/frp/pkg/util/xlog" "github.com/fatedier/frp/pkg/util/xlog"
) )
@ -349,7 +349,7 @@ func (ks *KCPTunnelSession) Init(listenConn *net.UDPConn, raddr *net.UDPAddr) er
if err != nil { if err != nil {
return fmt.Errorf("dial udp error: %v", err) return fmt.Errorf("dial udp error: %v", err)
} }
remote, err := utilnet.NewKCPConnFromUDP(lConn, true, raddr.String()) remote, err := netpkg.NewKCPConnFromUDP(lConn, true, raddr.String())
if err != nil { if err != nil {
return fmt.Errorf("create kcp connection from udp connection error: %v", err) return fmt.Errorf("create kcp connection from udp connection error: %v", err)
} }
@ -440,7 +440,7 @@ func (qs *QUICTunnelSession) OpenConn(ctx context.Context) (net.Conn, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return utilnet.QuicStreamToNetConn(stream, session), nil return netpkg.QuicStreamToNetConn(stream, session), nil
} }
func (qs *QUICTunnelSession) Close() { func (qs *QUICTunnelSession) Close() {

View File

@ -110,7 +110,7 @@ func handleTermSignal(svr *client.Service) {
} }
func runClient(cfgFilePath string) error { func runClient(cfgFilePath string) error {
cfg, pxyCfgs, visitorCfgs, isLegacyFormat, err := config.LoadClientConfig(cfgFilePath, strictConfigMode) cfg, proxyCfgs, visitorCfgs, isLegacyFormat, err := config.LoadClientConfig(cfgFilePath, strictConfigMode)
if err != nil { if err != nil {
return err return err
} }
@ -119,19 +119,19 @@ func runClient(cfgFilePath string) error {
"please use yaml/json/toml format instead!\n") "please use yaml/json/toml format instead!\n")
} }
warning, err := validation.ValidateAllClientConfig(cfg, pxyCfgs, visitorCfgs) warning, err := validation.ValidateAllClientConfig(cfg, proxyCfgs, visitorCfgs)
if warning != nil { if warning != nil {
fmt.Printf("WARNING: %v\n", warning) fmt.Printf("WARNING: %v\n", warning)
} }
if err != nil { if err != nil {
return err return err
} }
return startService(cfg, pxyCfgs, visitorCfgs, cfgFilePath) return startService(cfg, proxyCfgs, visitorCfgs, cfgFilePath)
} }
func startService( func startService(
cfg *v1.ClientCommonConfig, cfg *v1.ClientCommonConfig,
pxyCfgs []v1.ProxyConfigurer, proxyCfgs []v1.ProxyConfigurer,
visitorCfgs []v1.VisitorConfigurer, visitorCfgs []v1.VisitorConfigurer,
cfgFile string, cfgFile string,
) error { ) error {
@ -141,7 +141,15 @@ func startService(
log.Info("start frpc service for config file [%s]", cfgFile) log.Info("start frpc service for config file [%s]", cfgFile)
defer log.Info("frpc service for config file [%s] stopped", cfgFile) defer log.Info("frpc service for config file [%s] stopped", cfgFile)
} }
svr := client.NewService(cfg, pxyCfgs, visitorCfgs, cfgFile) svr, err := client.NewService(client.ServiceOptions{
Common: cfg,
ProxyCfgs: proxyCfgs,
VisitorCfgs: visitorCfgs,
ConfigFilePath: cfgFile,
})
if err != nil {
return err
}
shouldGracefulClose := cfg.Transport.Protocol == "kcp" || cfg.Transport.Protocol == "quic" shouldGracefulClose := cfg.Transport.Protocol == "kcp" || cfg.Transport.Protocol == "quic"
// Capture the exit signal if we use kcp or quic. // Capture the exit signal if we use kcp or quic.

View File

@ -37,12 +37,12 @@ var verifyCmd = &cobra.Command{
return nil return nil
} }
cliCfg, pxyCfgs, visitorCfgs, _, err := config.LoadClientConfig(cfgFile, strictConfigMode) cliCfg, proxyCfgs, visitorCfgs, _, err := config.LoadClientConfig(cfgFile, strictConfigMode)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
os.Exit(1) os.Exit(1)
} }
warning, err := validation.ValidateAllClientConfig(cliCfg, pxyCfgs, visitorCfgs) warning, err := validation.ValidateAllClientConfig(cliCfg, proxyCfgs, visitorCfgs)
if warning != nil { if warning != nil {
fmt.Printf("WARNING: %v\n", warning) fmt.Printf("WARNING: %v\n", warning)
} }

31
pkg/auth/pass.go Normal file
View File

@ -0,0 +1,31 @@
// Copyright 2023 The frp Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package auth
import (
"github.com/fatedier/frp/pkg/msg"
)
var AlwaysPassVerifier = &alwaysPass{}
var _ Verifier = &alwaysPass{}
type alwaysPass struct{}
func (*alwaysPass) VerifyLogin(*msg.Login) error { return nil }
func (*alwaysPass) VerifyPing(*msg.Ping) error { return nil }
func (*alwaysPass) VerifyNewWorkConn(*msg.NewWorkConn) error { return nil }

View File

@ -59,12 +59,17 @@ func RegisterProxyFlags(cmd *cobra.Command, c v1.ProxyConfigurer) {
case *v1.TCPMuxProxyConfig: case *v1.TCPMuxProxyConfig:
registerProxyDomainConfigFlags(cmd, &cc.DomainConfig) registerProxyDomainConfigFlags(cmd, &cc.DomainConfig)
cmd.Flags().StringVarP(&cc.Multiplexer, "mux", "", "", "multiplexer") cmd.Flags().StringVarP(&cc.Multiplexer, "mux", "", "", "multiplexer")
cmd.Flags().StringVarP(&cc.HTTPUser, "http_user", "", "", "http auth user")
cmd.Flags().StringVarP(&cc.HTTPPassword, "http_pwd", "", "", "http auth password")
case *v1.STCPProxyConfig: case *v1.STCPProxyConfig:
cmd.Flags().StringVarP(&cc.Secretkey, "sk", "", "", "secret key") cmd.Flags().StringVarP(&cc.Secretkey, "sk", "", "", "secret key")
cmd.Flags().StringSliceVarP(&cc.AllowUsers, "allow_users", "", []string{}, "allow visitor users")
case *v1.SUDPProxyConfig: case *v1.SUDPProxyConfig:
cmd.Flags().StringVarP(&cc.Secretkey, "sk", "", "", "secret key") cmd.Flags().StringVarP(&cc.Secretkey, "sk", "", "", "secret key")
cmd.Flags().StringSliceVarP(&cc.AllowUsers, "allow_users", "", []string{}, "allow visitor users")
case *v1.XTCPProxyConfig: case *v1.XTCPProxyConfig:
cmd.Flags().StringVarP(&cc.Secretkey, "sk", "", "", "secret key") cmd.Flags().StringVarP(&cc.Secretkey, "sk", "", "", "secret key")
cmd.Flags().StringSliceVarP(&cc.AllowUsers, "allow_users", "", []string{}, "allow visitor users")
} }
} }

View File

@ -23,7 +23,7 @@ import (
func ParseClientConfig(filePath string) ( func ParseClientConfig(filePath string) (
cfg ClientCommonConf, cfg ClientCommonConf,
pxyCfgs map[string]ProxyConf, proxyCfgs map[string]ProxyConf,
visitorCfgs map[string]VisitorConf, visitorCfgs map[string]VisitorConf,
err error, err error,
) { ) {
@ -56,7 +56,7 @@ func ParseClientConfig(filePath string) (
configBuffer.Write(buf) configBuffer.Write(buf)
// Parse all proxy and visitor configs. // Parse all proxy and visitor configs.
pxyCfgs, visitorCfgs, err = LoadAllProxyConfsFromIni(cfg.User, configBuffer.Bytes(), cfg.Start) proxyCfgs, visitorCfgs, err = LoadAllProxyConfsFromIni(cfg.User, configBuffer.Bytes(), cfg.Start)
if err != nil { if err != nil {
return return
} }

View File

@ -110,6 +110,7 @@ func LoadConfigureFromFile(path string, c any, strict bool) error {
// LoadConfigure loads configuration from bytes and unmarshal into c. // LoadConfigure loads configuration from bytes and unmarshal into c.
// Now it supports json, yaml and toml format. // Now it supports json, yaml and toml format.
// TODO(fatedier): strict is not valide for ProxyConfigurer/VisitorConfigurer/ClientPluginOptions.
func LoadConfigure(b []byte, c any, strict bool) error { func LoadConfigure(b []byte, c any, strict bool) error {
var tomlObj interface{} var tomlObj interface{}
// Try to unmarshal as TOML first; swallow errors from that (assume it's not valid TOML). // Try to unmarshal as TOML first; swallow errors from that (assume it's not valid TOML).
@ -188,19 +189,19 @@ func LoadClientConfig(path string, strict bool) (
) { ) {
var ( var (
cliCfg *v1.ClientCommonConfig cliCfg *v1.ClientCommonConfig
pxyCfgs = make([]v1.ProxyConfigurer, 0) proxyCfgs = make([]v1.ProxyConfigurer, 0)
visitorCfgs = make([]v1.VisitorConfigurer, 0) visitorCfgs = make([]v1.VisitorConfigurer, 0)
isLegacyFormat bool isLegacyFormat bool
) )
if DetectLegacyINIFormatFromFile(path) { if DetectLegacyINIFormatFromFile(path) {
legacyCommon, legacyPxyCfgs, legacyVisitorCfgs, err := legacy.ParseClientConfig(path) legacyCommon, legacyProxyCfgs, legacyVisitorCfgs, err := legacy.ParseClientConfig(path)
if err != nil { if err != nil {
return nil, nil, nil, true, err return nil, nil, nil, true, err
} }
cliCfg = legacy.Convert_ClientCommonConf_To_v1(&legacyCommon) cliCfg = legacy.Convert_ClientCommonConf_To_v1(&legacyCommon)
for _, c := range legacyPxyCfgs { for _, c := range legacyProxyCfgs {
pxyCfgs = append(pxyCfgs, legacy.Convert_ProxyConf_To_v1(c)) proxyCfgs = append(proxyCfgs, legacy.Convert_ProxyConf_To_v1(c))
} }
for _, c := range legacyVisitorCfgs { for _, c := range legacyVisitorCfgs {
visitorCfgs = append(visitorCfgs, legacy.Convert_VisitorConf_To_v1(c)) visitorCfgs = append(visitorCfgs, legacy.Convert_VisitorConf_To_v1(c))
@ -213,7 +214,7 @@ func LoadClientConfig(path string, strict bool) (
} }
cliCfg = &allCfg.ClientCommonConfig cliCfg = &allCfg.ClientCommonConfig
for _, c := range allCfg.Proxies { for _, c := range allCfg.Proxies {
pxyCfgs = append(pxyCfgs, c.ProxyConfigurer) proxyCfgs = append(proxyCfgs, c.ProxyConfigurer)
} }
for _, c := range allCfg.Visitors { for _, c := range allCfg.Visitors {
visitorCfgs = append(visitorCfgs, c.VisitorConfigurer) visitorCfgs = append(visitorCfgs, c.VisitorConfigurer)
@ -223,18 +224,18 @@ func LoadClientConfig(path string, strict bool) (
// Load additional config from includes. // Load additional config from includes.
// legacy ini format already handle this in ParseClientConfig. // legacy ini format already handle this in ParseClientConfig.
if len(cliCfg.IncludeConfigFiles) > 0 && !isLegacyFormat { if len(cliCfg.IncludeConfigFiles) > 0 && !isLegacyFormat {
extPxyCfgs, extVisitorCfgs, err := LoadAdditionalClientConfigs(cliCfg.IncludeConfigFiles, isLegacyFormat, strict) extProxyCfgs, extVisitorCfgs, err := LoadAdditionalClientConfigs(cliCfg.IncludeConfigFiles, isLegacyFormat, strict)
if err != nil { if err != nil {
return nil, nil, nil, isLegacyFormat, err return nil, nil, nil, isLegacyFormat, err
} }
pxyCfgs = append(pxyCfgs, extPxyCfgs...) proxyCfgs = append(proxyCfgs, extProxyCfgs...)
visitorCfgs = append(visitorCfgs, extVisitorCfgs...) visitorCfgs = append(visitorCfgs, extVisitorCfgs...)
} }
// Filter by start // Filter by start
if len(cliCfg.Start) > 0 { if len(cliCfg.Start) > 0 {
startSet := sets.New(cliCfg.Start...) startSet := sets.New(cliCfg.Start...)
pxyCfgs = lo.Filter(pxyCfgs, func(c v1.ProxyConfigurer, _ int) bool { proxyCfgs = lo.Filter(proxyCfgs, func(c v1.ProxyConfigurer, _ int) bool {
return startSet.Has(c.GetBaseConfig().Name) return startSet.Has(c.GetBaseConfig().Name)
}) })
visitorCfgs = lo.Filter(visitorCfgs, func(c v1.VisitorConfigurer, _ int) bool { visitorCfgs = lo.Filter(visitorCfgs, func(c v1.VisitorConfigurer, _ int) bool {
@ -245,17 +246,17 @@ func LoadClientConfig(path string, strict bool) (
if cliCfg != nil { if cliCfg != nil {
cliCfg.Complete() cliCfg.Complete()
} }
for _, c := range pxyCfgs { for _, c := range proxyCfgs {
c.Complete(cliCfg.User) c.Complete(cliCfg.User)
} }
for _, c := range visitorCfgs { for _, c := range visitorCfgs {
c.Complete(cliCfg) c.Complete(cliCfg)
} }
return cliCfg, pxyCfgs, visitorCfgs, isLegacyFormat, nil return cliCfg, proxyCfgs, visitorCfgs, isLegacyFormat, nil
} }
func LoadAdditionalClientConfigs(paths []string, isLegacyFormat bool, strict bool) ([]v1.ProxyConfigurer, []v1.VisitorConfigurer, error) { func LoadAdditionalClientConfigs(paths []string, isLegacyFormat bool, strict bool) ([]v1.ProxyConfigurer, []v1.VisitorConfigurer, error) {
pxyCfgs := make([]v1.ProxyConfigurer, 0) proxyCfgs := make([]v1.ProxyConfigurer, 0)
visitorCfgs := make([]v1.VisitorConfigurer, 0) visitorCfgs := make([]v1.VisitorConfigurer, 0)
for _, path := range paths { for _, path := range paths {
absDir, err := filepath.Abs(filepath.Dir(path)) absDir, err := filepath.Abs(filepath.Dir(path))
@ -281,7 +282,7 @@ func LoadAdditionalClientConfigs(paths []string, isLegacyFormat bool, strict boo
return nil, nil, fmt.Errorf("load additional config from %s error: %v", absFile, err) return nil, nil, fmt.Errorf("load additional config from %s error: %v", absFile, err)
} }
for _, c := range cfg.Proxies { for _, c := range cfg.Proxies {
pxyCfgs = append(pxyCfgs, c.ProxyConfigurer) proxyCfgs = append(proxyCfgs, c.ProxyConfigurer)
} }
for _, c := range cfg.Visitors { for _, c := range cfg.Visitors {
visitorCfgs = append(visitorCfgs, c.VisitorConfigurer) visitorCfgs = append(visitorCfgs, c.VisitorConfigurer)
@ -289,5 +290,5 @@ func LoadAdditionalClientConfigs(paths []string, isLegacyFormat bool, strict boo
} }
} }
} }
return pxyCfgs, visitorCfgs, nil return proxyCfgs, visitorCfgs, nil
} }

View File

@ -224,7 +224,9 @@ func NewProxyConfigurerByType(proxyType ProxyType) ProxyConfigurer {
if !ok { if !ok {
return nil return nil
} }
return reflect.New(v).Interface().(ProxyConfigurer) pc := reflect.New(v).Interface().(ProxyConfigurer)
pc.GetBaseConfig().Type = string(proxyType)
return pc
} }
var _ ProxyConfigurer = &TCPProxyConfig{} var _ ProxyConfigurer = &TCPProxyConfig{}

View File

@ -80,7 +80,7 @@ func ValidateClientCommonConfig(c *v1.ClientCommonConfig) (Warning, error) {
return warnings, errs return warnings, errs
} }
func ValidateAllClientConfig(c *v1.ClientCommonConfig, pxyCfgs []v1.ProxyConfigurer, visitorCfgs []v1.VisitorConfigurer) (Warning, error) { func ValidateAllClientConfig(c *v1.ClientCommonConfig, proxyCfgs []v1.ProxyConfigurer, visitorCfgs []v1.VisitorConfigurer) (Warning, error) {
var warnings Warning var warnings Warning
if c != nil { if c != nil {
warning, err := ValidateClientCommonConfig(c) warning, err := ValidateClientCommonConfig(c)
@ -90,7 +90,7 @@ func ValidateAllClientConfig(c *v1.ClientCommonConfig, pxyCfgs []v1.ProxyConfigu
} }
} }
for _, c := range pxyCfgs { for _, c := range proxyCfgs {
if err := ValidateProxyConfigurerForClient(c); err != nil { if err := ValidateProxyConfigurerForClient(c); err != nil {
return warnings, fmt.Errorf("proxy %s: %v", c.GetBaseConfig().Name, err) return warnings, fmt.Errorf("proxy %s: %v", c.GetBaseConfig().Name, err)
} }

View File

@ -63,6 +63,15 @@ var msgTypeMap = map[byte]interface{}{
var TypeNameNatHoleResp = reflect.TypeOf(&NatHoleResp{}).Elem().Name() var TypeNameNatHoleResp = reflect.TypeOf(&NatHoleResp{}).Elem().Name()
type ClientSpec struct {
// Due to the support of VirtualClient, frps needs to know the client type in order to
// differentiate the processing logic.
// Optional values: ssh-tunnel
Type string `json:"type,omitempty"`
// If the value is true, the client will not require authentication.
AlwaysAuthPass bool `json:"always_auth_pass,omitempty"`
}
// When frpc start, client send this message to login to server. // When frpc start, client send this message to login to server.
type Login struct { type Login struct {
Version string `json:"version,omitempty"` Version string `json:"version,omitempty"`
@ -75,6 +84,9 @@ type Login struct {
RunID string `json:"run_id,omitempty"` RunID string `json:"run_id,omitempty"`
Metas map[string]string `json:"metas,omitempty"` Metas map[string]string `json:"metas,omitempty"`
// Currently only effective for VirtualClient.
ClientSpec ClientSpec `json:"client_spec,omitempty"`
// Some global configures. // Some global configures.
PoolCount int `json:"pool_count,omitempty"` PoolCount int `json:"pool_count,omitempty"`
} }

View File

@ -24,7 +24,7 @@ import (
"net/http/httputil" "net/http/httputil"
v1 "github.com/fatedier/frp/pkg/config/v1" v1 "github.com/fatedier/frp/pkg/config/v1"
utilnet "github.com/fatedier/frp/pkg/util/net" netpkg "github.com/fatedier/frp/pkg/util/net"
) )
func init() { func init() {
@ -79,7 +79,7 @@ func NewHTTP2HTTPSPlugin(options v1.ClientPluginOptions) (Plugin, error) {
} }
func (p *HTTP2HTTPSPlugin) Handle(conn io.ReadWriteCloser, realConn net.Conn, _ *ExtraInfo) { func (p *HTTP2HTTPSPlugin) Handle(conn io.ReadWriteCloser, realConn net.Conn, _ *ExtraInfo) {
wrapConn := utilnet.WrapReadWriteCloserToConn(conn, realConn) wrapConn := netpkg.WrapReadWriteCloserToConn(conn, realConn)
_ = p.l.PutConn(wrapConn) _ = p.l.PutConn(wrapConn)
} }

View File

@ -29,7 +29,7 @@ import (
libnet "github.com/fatedier/golib/net" libnet "github.com/fatedier/golib/net"
v1 "github.com/fatedier/frp/pkg/config/v1" v1 "github.com/fatedier/frp/pkg/config/v1"
utilnet "github.com/fatedier/frp/pkg/util/net" netpkg "github.com/fatedier/frp/pkg/util/net"
"github.com/fatedier/frp/pkg/util/util" "github.com/fatedier/frp/pkg/util/util"
) )
@ -68,7 +68,7 @@ func (hp *HTTPProxy) Name() string {
} }
func (hp *HTTPProxy) Handle(conn io.ReadWriteCloser, realConn net.Conn, _ *ExtraInfo) { func (hp *HTTPProxy) Handle(conn io.ReadWriteCloser, realConn net.Conn, _ *ExtraInfo) {
wrapConn := utilnet.WrapReadWriteCloserToConn(conn, realConn) wrapConn := netpkg.WrapReadWriteCloserToConn(conn, realConn)
sc, rd := libnet.NewSharedConn(wrapConn) sc, rd := libnet.NewSharedConn(wrapConn)
firstBytes := make([]byte, 7) firstBytes := make([]byte, 7)

View File

@ -26,7 +26,7 @@ import (
v1 "github.com/fatedier/frp/pkg/config/v1" v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/transport" "github.com/fatedier/frp/pkg/transport"
utilnet "github.com/fatedier/frp/pkg/util/net" netpkg "github.com/fatedier/frp/pkg/util/net"
) )
func init() { func init() {
@ -98,7 +98,7 @@ func (p *HTTPS2HTTPPlugin) genTLSConfig() (*tls.Config, error) {
} }
func (p *HTTPS2HTTPPlugin) Handle(conn io.ReadWriteCloser, realConn net.Conn, _ *ExtraInfo) { func (p *HTTPS2HTTPPlugin) Handle(conn io.ReadWriteCloser, realConn net.Conn, _ *ExtraInfo) {
wrapConn := utilnet.WrapReadWriteCloserToConn(conn, realConn) wrapConn := netpkg.WrapReadWriteCloserToConn(conn, realConn)
_ = p.l.PutConn(wrapConn) _ = p.l.PutConn(wrapConn)
} }

View File

@ -26,7 +26,7 @@ import (
v1 "github.com/fatedier/frp/pkg/config/v1" v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/transport" "github.com/fatedier/frp/pkg/transport"
utilnet "github.com/fatedier/frp/pkg/util/net" netpkg "github.com/fatedier/frp/pkg/util/net"
) )
func init() { func init() {
@ -104,7 +104,7 @@ func (p *HTTPS2HTTPSPlugin) genTLSConfig() (*tls.Config, error) {
} }
func (p *HTTPS2HTTPSPlugin) Handle(conn io.ReadWriteCloser, realConn net.Conn, _ *ExtraInfo) { func (p *HTTPS2HTTPSPlugin) Handle(conn io.ReadWriteCloser, realConn net.Conn, _ *ExtraInfo) {
wrapConn := utilnet.WrapReadWriteCloserToConn(conn, realConn) wrapConn := netpkg.WrapReadWriteCloserToConn(conn, realConn)
_ = p.l.PutConn(wrapConn) _ = p.l.PutConn(wrapConn)
} }

View File

@ -24,7 +24,7 @@ import (
gosocks5 "github.com/armon/go-socks5" gosocks5 "github.com/armon/go-socks5"
v1 "github.com/fatedier/frp/pkg/config/v1" v1 "github.com/fatedier/frp/pkg/config/v1"
utilnet "github.com/fatedier/frp/pkg/util/net" netpkg "github.com/fatedier/frp/pkg/util/net"
) )
func init() { func init() {
@ -52,7 +52,7 @@ func NewSocks5Plugin(options v1.ClientPluginOptions) (p Plugin, err error) {
func (sp *Socks5Plugin) Handle(conn io.ReadWriteCloser, realConn net.Conn, _ *ExtraInfo) { func (sp *Socks5Plugin) Handle(conn io.ReadWriteCloser, realConn net.Conn, _ *ExtraInfo) {
defer conn.Close() defer conn.Close()
wrapConn := utilnet.WrapReadWriteCloserToConn(conn, realConn) wrapConn := netpkg.WrapReadWriteCloserToConn(conn, realConn)
_ = sp.Server.ServeConn(wrapConn) _ = sp.Server.ServeConn(wrapConn)
} }

View File

@ -25,7 +25,7 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
v1 "github.com/fatedier/frp/pkg/config/v1" v1 "github.com/fatedier/frp/pkg/config/v1"
utilnet "github.com/fatedier/frp/pkg/util/net" netpkg "github.com/fatedier/frp/pkg/util/net"
) )
func init() { func init() {
@ -57,8 +57,8 @@ func NewStaticFilePlugin(options v1.ClientPluginOptions) (Plugin, error) {
} }
router := mux.NewRouter() router := mux.NewRouter()
router.Use(utilnet.NewHTTPAuthMiddleware(opts.HTTPUser, opts.HTTPPassword).SetAuthFailDelay(200 * time.Millisecond).Middleware) router.Use(netpkg.NewHTTPAuthMiddleware(opts.HTTPUser, opts.HTTPPassword).SetAuthFailDelay(200 * time.Millisecond).Middleware)
router.PathPrefix(prefix).Handler(utilnet.MakeHTTPGzipHandler(http.StripPrefix(prefix, http.FileServer(http.Dir(opts.LocalPath))))).Methods("GET") router.PathPrefix(prefix).Handler(netpkg.MakeHTTPGzipHandler(http.StripPrefix(prefix, http.FileServer(http.Dir(opts.LocalPath))))).Methods("GET")
sp.s = &http.Server{ sp.s = &http.Server{
Handler: router, Handler: router,
} }
@ -69,7 +69,7 @@ func NewStaticFilePlugin(options v1.ClientPluginOptions) (Plugin, error) {
} }
func (sp *StaticFilePlugin) Handle(conn io.ReadWriteCloser, realConn net.Conn, _ *ExtraInfo) { func (sp *StaticFilePlugin) Handle(conn io.ReadWriteCloser, realConn net.Conn, _ *ExtraInfo) {
wrapConn := utilnet.WrapReadWriteCloserToConn(conn, realConn) wrapConn := netpkg.WrapReadWriteCloserToConn(conn, realConn)
_ = sp.l.PutConn(wrapConn) _ = sp.l.PutConn(wrapConn)
} }

View File

@ -11,7 +11,7 @@ import (
"strings" "strings"
"github.com/fatedier/frp/client" "github.com/fatedier/frp/client"
"github.com/fatedier/frp/pkg/util/util" httppkg "github.com/fatedier/frp/pkg/util/http"
) )
type Client struct { type Client struct {
@ -115,7 +115,7 @@ func (c *Client) UpdateConfig(content string) error {
func (c *Client) setAuthHeader(req *http.Request) { func (c *Client) setAuthHeader(req *http.Request) {
if c.authUser != "" || c.authPwd != "" { if c.authUser != "" || c.authPwd != "" {
req.Header.Set("Authorization", util.BasicAuth(c.authUser, c.authPwd)) req.Header.Set("Authorization", httppkg.BasicAuth(c.authUser, c.authPwd))
} }
} }

View File

@ -26,21 +26,21 @@ import (
v1 "github.com/fatedier/frp/pkg/config/v1" v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/transport" "github.com/fatedier/frp/pkg/transport"
"github.com/fatedier/frp/pkg/util/log" "github.com/fatedier/frp/pkg/util/log"
utilnet "github.com/fatedier/frp/pkg/util/net" netpkg "github.com/fatedier/frp/pkg/util/net"
) )
type Gateway struct { type Gateway struct {
bindPort int bindPort int
ln net.Listener ln net.Listener
serverPeerListener *utilnet.InternalListener peerServerListener *netpkg.InternalListener
sshConfig *ssh.ServerConfig sshConfig *ssh.ServerConfig
} }
func NewGateway( func NewGateway(
cfg v1.SSHTunnelGateway, bindAddr string, cfg v1.SSHTunnelGateway, bindAddr string,
serverPeerListener *utilnet.InternalListener, peerServerListener *netpkg.InternalListener,
) (*Gateway, error) { ) (*Gateway, error) {
sshConfig := &ssh.ServerConfig{} sshConfig := &ssh.ServerConfig{}
@ -71,15 +71,8 @@ func NewGateway(
} }
sshConfig.AddHostKey(privateKey) sshConfig.AddHostKey(privateKey)
sshConfig.NoClientAuth = cfg.AuthorizedKeysFile == ""
sshConfig.PublicKeyCallback = func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { sshConfig.PublicKeyCallback = func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
if cfg.AuthorizedKeysFile == "" {
return &ssh.Permissions{
Extensions: map[string]string{
"user": "",
},
}, nil
}
authorizedKeysMap, err := loadAuthorizedKeysFromFile(cfg.AuthorizedKeysFile) authorizedKeysMap, err := loadAuthorizedKeysFromFile(cfg.AuthorizedKeysFile)
if err != nil { if err != nil {
return nil, fmt.Errorf("internal error") return nil, fmt.Errorf("internal error")
@ -103,7 +96,7 @@ func NewGateway(
return &Gateway{ return &Gateway{
bindPort: cfg.BindPort, bindPort: cfg.BindPort,
ln: ln, ln: ln,
serverPeerListener: serverPeerListener, peerServerListener: peerServerListener,
sshConfig: sshConfig, sshConfig: sshConfig,
}, nil }, nil
} }
@ -121,7 +114,7 @@ func (g *Gateway) Run() {
func (g *Gateway) handleConn(conn net.Conn) { func (g *Gateway) handleConn(conn net.Conn) {
defer conn.Close() defer conn.Close()
ts, err := NewTunnelServer(conn, g.sshConfig, g.serverPeerListener) ts, err := NewTunnelServer(conn, g.sshConfig, g.peerServerListener)
if err != nil { if err != nil {
return return
} }

View File

@ -17,9 +17,11 @@ package ssh
import ( import (
"context" "context"
"encoding/binary" "encoding/binary"
"errors"
"fmt" "fmt"
"net" "net"
"strings" "strings"
"sync"
"time" "time"
libio "github.com/fatedier/golib/io" libio "github.com/fatedier/golib/io"
@ -27,10 +29,12 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
"github.com/fatedier/frp/client/proxy"
"github.com/fatedier/frp/pkg/config" "github.com/fatedier/frp/pkg/config"
v1 "github.com/fatedier/frp/pkg/config/v1" v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/msg" "github.com/fatedier/frp/pkg/msg"
utilnet "github.com/fatedier/frp/pkg/util/net" "github.com/fatedier/frp/pkg/util/log"
netpkg "github.com/fatedier/frp/pkg/util/net"
"github.com/fatedier/frp/pkg/util/util" "github.com/fatedier/frp/pkg/util/util"
"github.com/fatedier/frp/pkg/util/xlog" "github.com/fatedier/frp/pkg/util/xlog"
"github.com/fatedier/frp/pkg/virtual" "github.com/fatedier/frp/pkg/virtual"
@ -64,15 +68,16 @@ type TunnelServer struct {
sc *ssh.ServerConfig sc *ssh.ServerConfig
vc *virtual.Client vc *virtual.Client
serverPeerListener *utilnet.InternalListener peerServerListener *netpkg.InternalListener
doneCh chan struct{} doneCh chan struct{}
closeDoneChOnce sync.Once
} }
func NewTunnelServer(conn net.Conn, sc *ssh.ServerConfig, serverPeerListener *utilnet.InternalListener) (*TunnelServer, error) { func NewTunnelServer(conn net.Conn, sc *ssh.ServerConfig, peerServerListener *netpkg.InternalListener) (*TunnelServer, error) {
s := &TunnelServer{ s := &TunnelServer{
underlyingConn: conn, underlyingConn: conn,
sc: sc, sc: sc,
serverPeerListener: serverPeerListener, peerServerListener: peerServerListener,
doneCh: make(chan struct{}), doneCh: make(chan struct{}),
} }
return s, nil return s, nil
@ -94,19 +99,35 @@ func (s *TunnelServer) Run() error {
if err != nil { if err != nil {
return err return err
} }
clientCfg.User = util.EmptyOr(sshConn.Permissions.Extensions["user"], clientCfg.User) clientCfg.Complete()
if sshConn.Permissions != nil {
clientCfg.User = util.EmptyOr(sshConn.Permissions.Extensions["user"], clientCfg.User)
}
pc.Complete(clientCfg.User) pc.Complete(clientCfg.User)
s.vc = virtual.NewClient(clientCfg) vc, err := virtual.NewClient(virtual.ClientOptions{
// join workConn and ssh channel Common: clientCfg,
s.vc.SetInWorkConnCallback(func(base *v1.ProxyBaseConfig, workConn net.Conn, m *msg.StartWorkConn) bool { Spec: &msg.ClientSpec{
c, err := s.openConn(addr) Type: "ssh-tunnel",
if err != nil { // If ssh does not require authentication, then the virtual client needs to authenticate through a token.
// Otherwise, once ssh authentication is passed, the virtual client does not need to authenticate again.
AlwaysAuthPass: !s.sc.NoClientAuth,
},
HandleWorkConnCb: func(base *v1.ProxyBaseConfig, workConn net.Conn, m *msg.StartWorkConn) bool {
// join workConn and ssh channel
c, err := s.openConn(addr)
if err != nil {
return false
}
libio.Join(c, workConn)
return false return false
} },
libio.Join(c, workConn)
return false
}) })
if err != nil {
return err
}
s.vc = vc
// transfer connection from virtual client to server peer listener // transfer connection from virtual client to server peer listener
go func() { go func() {
l := s.vc.PeerListener() l := s.vc.PeerListener()
@ -115,21 +136,35 @@ func (s *TunnelServer) Run() error {
if err != nil { if err != nil {
return return
} }
_ = s.serverPeerListener.PutConn(conn) _ = s.peerServerListener.PutConn(conn)
} }
}() }()
xl := xlog.New().AddPrefix(xlog.LogPrefix{Name: "sshVirtualClient", Value: "sshVirtualClient", Priority: 100}) xl := xlog.New().AddPrefix(xlog.LogPrefix{Name: "sshVirtualClient", Value: "sshVirtualClient", Priority: 100})
ctx := xlog.NewContext(context.Background(), xl) ctx := xlog.NewContext(context.Background(), xl)
go func() { go func() {
_ = s.vc.Run(ctx) _ = s.vc.Run(ctx)
// If vc.Run returns, it means that the virtual client has been closed, and the ssh tunnel connection should be closed.
// One scenario is that the virtual client exits due to login failure.
s.closeDoneChOnce.Do(func() {
_ = sshConn.Close()
close(s.doneCh)
})
}() }()
s.vc.UpdateProxyConfigurer([]v1.ProxyConfigurer{pc}) s.vc.UpdateProxyConfigurer([]v1.ProxyConfigurer{pc})
_ = sshConn.Wait() if err := s.waitProxyStatusReady(pc.GetBaseConfig().Name, time.Second); err != nil {
_ = sshConn.Close() log.Warn("wait proxy status ready error: %v", err)
} else {
_ = sshConn.Wait()
}
s.vc.Close() s.vc.Close()
close(s.doneCh) log.Trace("ssh tunnel connection from %v closed", sshConn.RemoteAddr())
s.closeDoneChOnce.Do(func() {
_ = sshConn.Close()
close(s.doneCh)
})
return nil return nil
} }
@ -217,6 +252,14 @@ func (s *TunnelServer) parseClientAndProxyConfigurer(_ *tcpipForward, extraPaylo
if err := cmd.ParseFlags(args); err != nil { if err := cmd.ParseFlags(args); err != nil {
return nil, nil, fmt.Errorf("parse flags from ssh client error: %v", err) return nil, nil, fmt.Errorf("parse flags from ssh client error: %v", err)
} }
// if name is not set, generate a random one
if pc.GetBaseConfig().Name == "" {
id, err := util.RandIDWithLen(8)
if err != nil {
return nil, nil, fmt.Errorf("generate random id error: %v", err)
}
pc.GetBaseConfig().Name = fmt.Sprintf("sshtunnel-%s-%s", proxyType, id)
}
return &clientCfg, pc, nil return &clientCfg, pc, nil
} }
@ -274,6 +317,34 @@ func (s *TunnelServer) openConn(addr *tcpipForward) (net.Conn, error) {
} }
go ssh.DiscardRequests(reqs) go ssh.DiscardRequests(reqs)
conn := utilnet.WrapReadWriteCloserToConn(channel, s.underlyingConn) conn := netpkg.WrapReadWriteCloserToConn(channel, s.underlyingConn)
return conn, nil return conn, nil
} }
func (s *TunnelServer) waitProxyStatusReady(name string, timeout time.Duration) error {
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
timer := time.NewTimer(timeout)
defer timer.Stop()
for {
select {
case <-ticker.C:
ps, err := s.vc.Service().GetProxyStatus(name)
if err != nil {
continue
}
switch ps.Phase {
case proxy.ProxyPhaseRunning:
return nil
case proxy.ProxyPhaseStartErr, proxy.ProxyPhaseClosed:
return errors.New(ps.Err)
}
case <-timer.C:
return fmt.Errorf("wait proxy status ready timeout")
case <-s.doneCh:
return fmt.Errorf("ssh tunnel server closed")
}
}
}

View File

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package util package http
import ( import (
"encoding/base64" "encoding/base64"

128
pkg/util/http/server.go Normal file
View File

@ -0,0 +1,128 @@
// Copyright 2023 The frp Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package http
import (
"crypto/tls"
"net"
"net/http"
"net/http/pprof"
"strconv"
"time"
"github.com/gorilla/mux"
"github.com/fatedier/frp/assets"
v1 "github.com/fatedier/frp/pkg/config/v1"
netpkg "github.com/fatedier/frp/pkg/util/net"
)
var (
defaultReadTimeout = 60 * time.Second
defaultWriteTimeout = 60 * time.Second
)
type Server struct {
addr string
ln net.Listener
tlsCfg *tls.Config
router *mux.Router
hs *http.Server
authMiddleware mux.MiddlewareFunc
}
func NewServer(cfg v1.WebServerConfig) (*Server, error) {
if cfg.AssetsDir != "" {
assets.Load(cfg.AssetsDir)
}
addr := net.JoinHostPort(cfg.Addr, strconv.Itoa(cfg.Port))
if addr == ":" {
addr = ":http"
}
ln, err := net.Listen("tcp", addr)
if err != nil {
return nil, err
}
router := mux.NewRouter()
hs := &http.Server{
Addr: addr,
Handler: router,
ReadTimeout: defaultReadTimeout,
WriteTimeout: defaultWriteTimeout,
}
s := &Server{
addr: addr,
ln: ln,
hs: hs,
router: router,
}
if cfg.PprofEnable {
s.registerPprofHandlers()
}
if cfg.TLS != nil {
cert, err := tls.LoadX509KeyPair(cfg.TLS.CertFile, cfg.TLS.KeyFile)
if err != nil {
return nil, err
}
s.tlsCfg = &tls.Config{
Certificates: []tls.Certificate{cert},
}
}
s.authMiddleware = netpkg.NewHTTPAuthMiddleware(cfg.User, cfg.Password).SetAuthFailDelay(200 * time.Millisecond).Middleware
return s, nil
}
func (s *Server) Address() string {
return s.addr
}
func (s *Server) Run() error {
ln := s.ln
if s.tlsCfg != nil {
ln = tls.NewListener(ln, s.tlsCfg)
}
return s.hs.Serve(ln)
}
func (s *Server) Close() error {
return s.hs.Close()
}
type RouterRegisterHelper struct {
Router *mux.Router
AssetsFS http.FileSystem
AuthMiddleware mux.MiddlewareFunc
}
func (s *Server) RouteRegister(register func(helper *RouterRegisterHelper)) {
register(&RouterRegisterHelper{
Router: s.router,
AssetsFS: assets.FileSystem,
AuthMiddleware: s.authMiddleware,
})
}
func (s *Server) registerPprofHandlers() {
s.router.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
s.router.HandleFunc("/debug/pprof/profile", pprof.Profile)
s.router.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
s.router.HandleFunc("/debug/pprof/trace", pprof.Trace)
s.router.PathPrefix("/debug/pprof/").HandlerFunc(pprof.Index)
}

33
pkg/util/net/dns.go Normal file
View File

@ -0,0 +1,33 @@
// Copyright 2023 The frp Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package net
import (
"context"
"net"
)
func SetDefaultDNSAddress(dnsAddress string) {
if _, _, err := net.SplitHostPort(dnsAddress); err != nil {
dnsAddress = net.JoinHostPort(dnsAddress, "53")
}
// Change default dns server
net.DefaultResolver = &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
return net.Dial("udp", dnsAddress)
},
}
}

View File

@ -52,7 +52,10 @@ func (l *InternalListener) PutConn(conn net.Conn) error {
conn.Close() conn.Close()
} }
}) })
return err if err != nil {
return fmt.Errorf("put conn error: listener is closed")
}
return nil
} }
func (l *InternalListener) Close() error { func (l *InternalListener) Close() error {

View File

@ -24,7 +24,7 @@ import (
libnet "github.com/fatedier/golib/net" libnet "github.com/fatedier/golib/net"
"github.com/fatedier/frp/pkg/util/util" httppkg "github.com/fatedier/frp/pkg/util/http"
"github.com/fatedier/frp/pkg/util/vhost" "github.com/fatedier/frp/pkg/util/vhost"
) )
@ -59,10 +59,10 @@ func (muxer *HTTPConnectTCPMuxer) readHTTPConnectRequest(rd io.Reader) (host, ht
return return
} }
host, _ = util.CanonicalHost(req.Host) host, _ = httppkg.CanonicalHost(req.Host)
proxyAuth := req.Header.Get("Proxy-Authorization") proxyAuth := req.Header.Get("Proxy-Authorization")
if proxyAuth != "" { if proxyAuth != "" {
httpUser, httpPwd, _ = util.ParseBasicAuth(proxyAuth) httpUser, httpPwd, _ = httppkg.ParseBasicAuth(proxyAuth)
} }
return return
} }
@ -71,7 +71,7 @@ func (muxer *HTTPConnectTCPMuxer) sendConnectResponse(c net.Conn, _ map[string]s
if muxer.passthrough { if muxer.passthrough {
return nil return nil
} }
res := util.OkResponse() res := httppkg.OkResponse()
if res.Body != nil { if res.Body != nil {
defer res.Body.Close() defer res.Body.Close()
} }
@ -85,7 +85,7 @@ func (muxer *HTTPConnectTCPMuxer) auth(c net.Conn, username, password string, re
return true, nil return true, nil
} }
resp := util.ProxyUnauthorizedResponse() resp := httppkg.ProxyUnauthorizedResponse()
if resp.Body != nil { if resp.Body != nil {
defer resp.Body.Close() defer resp.Body.Close()
} }

View File

@ -31,8 +31,8 @@ import (
libio "github.com/fatedier/golib/io" libio "github.com/fatedier/golib/io"
"github.com/fatedier/golib/pool" "github.com/fatedier/golib/pool"
frpLog "github.com/fatedier/frp/pkg/util/log" httppkg "github.com/fatedier/frp/pkg/util/http"
"github.com/fatedier/frp/pkg/util/util" logpkg "github.com/fatedier/frp/pkg/util/log"
) )
var ErrNoRouteFound = errors.New("no route found") var ErrNoRouteFound = errors.New("no route found")
@ -61,7 +61,7 @@ func NewHTTPReverseProxy(option HTTPReverseProxyOptions, vhostRouter *Routers) *
Director: func(req *http.Request) { Director: func(req *http.Request) {
req.URL.Scheme = "http" req.URL.Scheme = "http"
reqRouteInfo := req.Context().Value(RouteInfoKey).(*RequestRouteInfo) reqRouteInfo := req.Context().Value(RouteInfoKey).(*RequestRouteInfo)
oldHost, _ := util.CanonicalHost(reqRouteInfo.Host) oldHost, _ := httppkg.CanonicalHost(reqRouteInfo.Host)
rc := rp.GetRouteConfig(oldHost, reqRouteInfo.URL, reqRouteInfo.HTTPUser) rc := rp.GetRouteConfig(oldHost, reqRouteInfo.URL, reqRouteInfo.HTTPUser)
if rc != nil { if rc != nil {
@ -74,7 +74,7 @@ func NewHTTPReverseProxy(option HTTPReverseProxyOptions, vhostRouter *Routers) *
// ignore error here, it will use CreateConnFn instead later // ignore error here, it will use CreateConnFn instead later
endpoint, _ = rc.ChooseEndpointFn() endpoint, _ = rc.ChooseEndpointFn()
reqRouteInfo.Endpoint = endpoint reqRouteInfo.Endpoint = endpoint
frpLog.Trace("choose endpoint name [%s] for http request host [%s] path [%s] httpuser [%s]", logpkg.Trace("choose endpoint name [%s] for http request host [%s] path [%s] httpuser [%s]",
endpoint, oldHost, reqRouteInfo.URL, reqRouteInfo.HTTPUser) endpoint, oldHost, reqRouteInfo.URL, reqRouteInfo.HTTPUser)
} }
// Set {domain}.{location}.{routeByHTTPUser}.{endpoint} as URL host here to let http transport reuse connections. // Set {domain}.{location}.{routeByHTTPUser}.{endpoint} as URL host here to let http transport reuse connections.
@ -116,7 +116,7 @@ func NewHTTPReverseProxy(option HTTPReverseProxyOptions, vhostRouter *Routers) *
BufferPool: newWrapPool(), BufferPool: newWrapPool(),
ErrorLog: log.New(newWrapLogger(), "", 0), ErrorLog: log.New(newWrapLogger(), "", 0),
ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) {
frpLog.Warn("do http proxy request [host: %s] error: %v", req.Host, err) logpkg.Warn("do http proxy request [host: %s] error: %v", req.Host, err)
rw.WriteHeader(http.StatusNotFound) rw.WriteHeader(http.StatusNotFound)
_, _ = rw.Write(getNotFoundPageContent()) _, _ = rw.Write(getNotFoundPageContent())
}, },
@ -143,7 +143,7 @@ func (rp *HTTPReverseProxy) UnRegister(routeCfg RouteConfig) {
func (rp *HTTPReverseProxy) GetRouteConfig(domain, location, routeByHTTPUser string) *RouteConfig { func (rp *HTTPReverseProxy) GetRouteConfig(domain, location, routeByHTTPUser string) *RouteConfig {
vr, ok := rp.getVhost(domain, location, routeByHTTPUser) vr, ok := rp.getVhost(domain, location, routeByHTTPUser)
if ok { if ok {
frpLog.Debug("get new HTTP request host [%s] path [%s] httpuser [%s]", domain, location, routeByHTTPUser) logpkg.Debug("get new HTTP request host [%s] path [%s] httpuser [%s]", domain, location, routeByHTTPUser)
return vr.payload.(*RouteConfig) return vr.payload.(*RouteConfig)
} }
return nil return nil
@ -159,7 +159,7 @@ func (rp *HTTPReverseProxy) GetHeaders(domain, location, routeByHTTPUser string)
// CreateConnection create a new connection by route config // CreateConnection create a new connection by route config
func (rp *HTTPReverseProxy) CreateConnection(reqRouteInfo *RequestRouteInfo, byEndpoint bool) (net.Conn, error) { func (rp *HTTPReverseProxy) CreateConnection(reqRouteInfo *RequestRouteInfo, byEndpoint bool) (net.Conn, error) {
host, _ := util.CanonicalHost(reqRouteInfo.Host) host, _ := httppkg.CanonicalHost(reqRouteInfo.Host)
vr, ok := rp.getVhost(host, reqRouteInfo.URL, reqRouteInfo.HTTPUser) vr, ok := rp.getVhost(host, reqRouteInfo.URL, reqRouteInfo.HTTPUser)
if ok { if ok {
if byEndpoint { if byEndpoint {
@ -303,7 +303,7 @@ func (rp *HTTPReverseProxy) injectRequestInfoToCtx(req *http.Request) *http.Requ
} }
func (rp *HTTPReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { func (rp *HTTPReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
domain, _ := util.CanonicalHost(req.Host) domain, _ := httppkg.CanonicalHost(req.Host)
location := req.URL.Path location := req.URL.Path
user, passwd, _ := req.BasicAuth() user, passwd, _ := req.BasicAuth()
if !rp.CheckAuth(domain, location, user, user, passwd) { if !rp.CheckAuth(domain, location, user, user, passwd) {
@ -333,6 +333,6 @@ type wrapLogger struct{}
func newWrapLogger() *wrapLogger { return &wrapLogger{} } func newWrapLogger() *wrapLogger { return &wrapLogger{} }
func (l *wrapLogger) Write(p []byte) (n int, err error) { func (l *wrapLogger) Write(p []byte) (n int, err error) {
frpLog.Warn("%s", string(bytes.TrimRight(p, "\n"))) logpkg.Warn("%s", string(bytes.TrimRight(p, "\n")))
return len(p), nil return len(p), nil
} }

View File

@ -20,7 +20,7 @@ import (
"net/http" "net/http"
"os" "os"
frpLog "github.com/fatedier/frp/pkg/util/log" logpkg "github.com/fatedier/frp/pkg/util/log"
"github.com/fatedier/frp/pkg/util/version" "github.com/fatedier/frp/pkg/util/version"
) )
@ -58,7 +58,7 @@ func getNotFoundPageContent() []byte {
if NotFoundPagePath != "" { if NotFoundPagePath != "" {
buf, err = os.ReadFile(NotFoundPagePath) buf, err = os.ReadFile(NotFoundPagePath)
if err != nil { if err != nil {
frpLog.Warn("read custom 404 page error: %v", err) logpkg.Warn("read custom 404 page error: %v", err)
buf = []byte(NotFound) buf = []byte(NotFound)
} }
} else { } else {

View File

@ -22,7 +22,7 @@ import (
"github.com/fatedier/golib/errors" "github.com/fatedier/golib/errors"
"github.com/fatedier/frp/pkg/util/log" "github.com/fatedier/frp/pkg/util/log"
utilnet "github.com/fatedier/frp/pkg/util/net" netpkg "github.com/fatedier/frp/pkg/util/net"
"github.com/fatedier/frp/pkg/util/xlog" "github.com/fatedier/frp/pkg/util/xlog"
) )
@ -284,7 +284,7 @@ func (l *Listener) Accept() (net.Conn, error) {
xl.Debug("rewrite host to [%s] success", l.rewriteHost) xl.Debug("rewrite host to [%s] success", l.rewriteHost)
conn = sConn conn = sConn
} }
return utilnet.NewContextConn(l.ctx, conn), nil return netpkg.NewContextConn(l.ctx, conn), nil
} }
func (l *Listener) Close() error { func (l *Listener) Close() error {

View File

@ -21,55 +21,70 @@ import (
"github.com/fatedier/frp/client" "github.com/fatedier/frp/client"
v1 "github.com/fatedier/frp/pkg/config/v1" v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/msg" "github.com/fatedier/frp/pkg/msg"
utilnet "github.com/fatedier/frp/pkg/util/net" netpkg "github.com/fatedier/frp/pkg/util/net"
) )
type ClientOptions struct {
Common *v1.ClientCommonConfig
Spec *msg.ClientSpec
HandleWorkConnCb func(*v1.ProxyBaseConfig, net.Conn, *msg.StartWorkConn) bool
}
type Client struct { type Client struct {
l *utilnet.InternalListener l *netpkg.InternalListener
svr *client.Service svr *client.Service
} }
func NewClient(cfg *v1.ClientCommonConfig) *Client { func NewClient(options ClientOptions) (*Client, error) {
cfg.Complete() if options.Common != nil {
options.Common.Complete()
}
ln := utilnet.NewInternalListener() ln := netpkg.NewInternalListener()
svr := client.NewService(cfg, nil, nil, "")
svr.SetConnectorCreator(func(context.Context, *v1.ClientCommonConfig) client.Connector {
return &pipeConnector{
peerListener: ln,
}
})
serviceOptions := client.ServiceOptions{
Common: options.Common,
ClientSpec: options.Spec,
ConnectorCreator: func(context.Context, *v1.ClientCommonConfig) client.Connector {
return &pipeConnector{
peerListener: ln,
}
},
HandleWorkConnCb: options.HandleWorkConnCb,
}
svr, err := client.NewService(serviceOptions)
if err != nil {
return nil, err
}
return &Client{ return &Client{
l: ln, l: ln,
svr: svr, svr: svr,
} }, nil
} }
func (c *Client) PeerListener() net.Listener { func (c *Client) PeerListener() net.Listener {
return c.l return c.l
} }
func (c *Client) SetInWorkConnCallback(cb func(*v1.ProxyBaseConfig, net.Conn, *msg.StartWorkConn) bool) {
c.svr.SetInWorkConnCallback(cb)
}
func (c *Client) UpdateProxyConfigurer(proxyCfgs []v1.ProxyConfigurer) { func (c *Client) UpdateProxyConfigurer(proxyCfgs []v1.ProxyConfigurer) {
_ = c.svr.ReloadConf(proxyCfgs, nil) _ = c.svr.UpdateAllConfigurer(proxyCfgs, nil)
} }
func (c *Client) Run(ctx context.Context) error { func (c *Client) Run(ctx context.Context) error {
return c.svr.Run(ctx) return c.svr.Run(ctx)
} }
func (c *Client) Service() *client.Service {
return c.svr
}
func (c *Client) Close() { func (c *Client) Close() {
c.l.Close()
c.svr.Close() c.svr.Close()
c.l.Close()
} }
type pipeConnector struct { type pipeConnector struct {
peerListener *utilnet.InternalListener peerListener *netpkg.InternalListener
} }
func (pc *pipeConnector) Open() error { func (pc *pipeConnector) Open() error {

View File

@ -32,7 +32,7 @@ import (
"github.com/fatedier/frp/pkg/msg" "github.com/fatedier/frp/pkg/msg"
plugin "github.com/fatedier/frp/pkg/plugin/server" plugin "github.com/fatedier/frp/pkg/plugin/server"
"github.com/fatedier/frp/pkg/transport" "github.com/fatedier/frp/pkg/transport"
utilnet "github.com/fatedier/frp/pkg/util/net" netpkg "github.com/fatedier/frp/pkg/util/net"
"github.com/fatedier/frp/pkg/util/util" "github.com/fatedier/frp/pkg/util/util"
"github.com/fatedier/frp/pkg/util/version" "github.com/fatedier/frp/pkg/util/version"
"github.com/fatedier/frp/pkg/util/wait" "github.com/fatedier/frp/pkg/util/wait"
@ -150,6 +150,7 @@ type Control struct {
doneCh chan struct{} doneCh chan struct{}
} }
// TODO(fatedier): Referencing the implementation of frpc, encapsulate the input parameters as SessionContext.
func NewControl( func NewControl(
ctx context.Context, ctx context.Context,
rc *controller.ResourceController, rc *controller.ResourceController,
@ -157,6 +158,7 @@ func NewControl(
pluginManager *plugin.Manager, pluginManager *plugin.Manager,
authVerifier auth.Verifier, authVerifier auth.Verifier,
ctlConn net.Conn, ctlConn net.Conn,
ctlConnEncrypted bool,
loginMsg *msg.Login, loginMsg *msg.Login,
serverCfg *v1.ServerConfig, serverCfg *v1.ServerConfig,
) (*Control, error) { ) (*Control, error) {
@ -183,11 +185,15 @@ func NewControl(
} }
ctl.lastPing.Store(time.Now()) ctl.lastPing.Store(time.Now())
cryptoRW, err := utilnet.NewCryptoReadWriter(ctl.conn, []byte(ctl.serverCfg.Auth.Token)) if ctlConnEncrypted {
if err != nil { cryptoRW, err := netpkg.NewCryptoReadWriter(ctl.conn, []byte(ctl.serverCfg.Auth.Token))
return nil, err if err != nil {
return nil, err
}
ctl.msgDispatcher = msg.NewDispatcher(cryptoRW)
} else {
ctl.msgDispatcher = msg.NewDispatcher(ctl.conn)
} }
ctl.msgDispatcher = msg.NewDispatcher(cryptoRW)
ctl.registerMsgHandlers() ctl.registerMsgHandlers()
ctl.msgTransporter = transport.NewMessageTransporter(ctl.msgDispatcher.SendChannel()) ctl.msgTransporter = transport.NewMessageTransporter(ctl.msgDispatcher.SendChannel())
return ctl, nil return ctl, nil
@ -300,6 +306,7 @@ func (ctl *Control) heartbeatWorker() {
go wait.Until(func() { go wait.Until(func() {
if time.Since(ctl.lastPing.Load().(time.Time)) > time.Duration(ctl.serverCfg.Transport.HeartbeatTimeout)*time.Second { if time.Since(ctl.lastPing.Load().(time.Time)) > time.Duration(ctl.serverCfg.Transport.HeartbeatTimeout)*time.Second {
xl.Warn("heartbeat timeout") xl.Warn("heartbeat timeout")
ctl.conn.Close()
return return
} }
}, time.Second, ctl.doneCh) }, time.Second, ctl.doneCh)
@ -555,6 +562,5 @@ func (ctl *Control) CloseProxy(closeMsg *msg.CloseProxy) (err error) {
go func() { go func() {
_ = ctl.pluginManager.CloseProxy(notifyContent) _ = ctl.pluginManager.CloseProxy(notifyContent)
}() }()
return return
} }

View File

@ -1,99 +0,0 @@
// Copyright 2017 fatedier, fatedier@gmail.com
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package server
import (
"crypto/tls"
"net"
"net/http"
"net/http/pprof"
"time"
"github.com/gorilla/mux"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/fatedier/frp/assets"
utilnet "github.com/fatedier/frp/pkg/util/net"
)
var (
httpServerReadTimeout = 60 * time.Second
httpServerWriteTimeout = 60 * time.Second
)
func (svr *Service) RunDashboardServer(address string) (err error) {
// url router
router := mux.NewRouter()
router.HandleFunc("/healthz", svr.Healthz)
// debug
if svr.cfg.WebServer.PprofEnable {
router.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
router.HandleFunc("/debug/pprof/profile", pprof.Profile)
router.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
router.HandleFunc("/debug/pprof/trace", pprof.Trace)
router.PathPrefix("/debug/pprof/").HandlerFunc(pprof.Index)
}
subRouter := router.NewRoute().Subrouter()
user, passwd := svr.cfg.WebServer.User, svr.cfg.WebServer.Password
subRouter.Use(utilnet.NewHTTPAuthMiddleware(user, passwd).SetAuthFailDelay(200 * time.Millisecond).Middleware)
// metrics
if svr.cfg.EnablePrometheus {
subRouter.Handle("/metrics", promhttp.Handler())
}
// api, see dashboard_api.go
subRouter.HandleFunc("/api/serverinfo", svr.APIServerInfo).Methods("GET")
subRouter.HandleFunc("/api/proxy/{type}", svr.APIProxyByType).Methods("GET")
subRouter.HandleFunc("/api/proxy/{type}/{name}", svr.APIProxyByTypeAndName).Methods("GET")
subRouter.HandleFunc("/api/traffic/{name}", svr.APIProxyTraffic).Methods("GET")
// view
subRouter.Handle("/favicon.ico", http.FileServer(assets.FileSystem)).Methods("GET")
subRouter.PathPrefix("/static/").Handler(utilnet.MakeHTTPGzipHandler(http.StripPrefix("/static/", http.FileServer(assets.FileSystem)))).Methods("GET")
subRouter.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "/static/", http.StatusMovedPermanently)
})
server := &http.Server{
Addr: address,
Handler: router,
ReadTimeout: httpServerReadTimeout,
WriteTimeout: httpServerWriteTimeout,
}
ln, err := net.Listen("tcp", address)
if err != nil {
return err
}
if svr.cfg.WebServer.TLS != nil {
cert, err := tls.LoadX509KeyPair(svr.cfg.WebServer.TLS.CertFile, svr.cfg.WebServer.TLS.KeyFile)
if err != nil {
return err
}
tlsCfg := &tls.Config{
Certificates: []tls.Certificate{cert},
}
ln = tls.NewListener(ln, tlsCfg)
}
go func() {
_ = server.Serve(ln)
}()
return
}

View File

@ -19,19 +19,52 @@ import (
"net/http" "net/http"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/fatedier/frp/pkg/config/types" "github.com/fatedier/frp/pkg/config/types"
v1 "github.com/fatedier/frp/pkg/config/v1" v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/metrics/mem" "github.com/fatedier/frp/pkg/metrics/mem"
httppkg "github.com/fatedier/frp/pkg/util/http"
"github.com/fatedier/frp/pkg/util/log" "github.com/fatedier/frp/pkg/util/log"
netpkg "github.com/fatedier/frp/pkg/util/net"
"github.com/fatedier/frp/pkg/util/version" "github.com/fatedier/frp/pkg/util/version"
) )
// TODO(fatedier): add an API to clean status of all offline proxies.
type GeneralResponse struct { type GeneralResponse struct {
Code int Code int
Msg string Msg string
} }
func (svr *Service) registerRouteHandlers(helper *httppkg.RouterRegisterHelper) {
helper.Router.HandleFunc("/healthz", svr.healthz)
subRouter := helper.Router.NewRoute().Subrouter()
subRouter.Use(helper.AuthMiddleware.Middleware)
// metrics
if svr.cfg.EnablePrometheus {
subRouter.Handle("/metrics", promhttp.Handler())
}
// apis
subRouter.HandleFunc("/api/serverinfo", svr.apiServerInfo).Methods("GET")
subRouter.HandleFunc("/api/proxy/{type}", svr.apiProxyByType).Methods("GET")
subRouter.HandleFunc("/api/proxy/{type}/{name}", svr.apiProxyByTypeAndName).Methods("GET")
subRouter.HandleFunc("/api/traffic/{name}", svr.apiProxyTraffic).Methods("GET")
// view
subRouter.Handle("/favicon.ico", http.FileServer(helper.AssetsFS)).Methods("GET")
subRouter.PathPrefix("/static/").Handler(
netpkg.MakeHTTPGzipHandler(http.StripPrefix("/static/", http.FileServer(helper.AssetsFS))),
).Methods("GET")
subRouter.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "/static/", http.StatusMovedPermanently)
})
}
type serverInfoResp struct { type serverInfoResp struct {
Version string `json:"version"` Version string `json:"version"`
BindPort int `json:"bindPort"` BindPort int `json:"bindPort"`
@ -55,12 +88,12 @@ type serverInfoResp struct {
} }
// /healthz // /healthz
func (svr *Service) Healthz(w http.ResponseWriter, _ *http.Request) { func (svr *Service) healthz(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(200) w.WriteHeader(200)
} }
// /api/serverinfo // /api/serverinfo
func (svr *Service) APIServerInfo(w http.ResponseWriter, r *http.Request) { func (svr *Service) apiServerInfo(w http.ResponseWriter, r *http.Request) {
res := GeneralResponse{Code: 200} res := GeneralResponse{Code: 200}
defer func() { defer func() {
log.Info("Http response [%s]: code [%d]", r.URL.Path, res.Code) log.Info("Http response [%s]: code [%d]", r.URL.Path, res.Code)
@ -177,7 +210,7 @@ type GetProxyInfoResp struct {
} }
// /api/proxy/:type // /api/proxy/:type
func (svr *Service) APIProxyByType(w http.ResponseWriter, r *http.Request) { func (svr *Service) apiProxyByType(w http.ResponseWriter, r *http.Request) {
res := GeneralResponse{Code: 200} res := GeneralResponse{Code: 200}
params := mux.Vars(r) params := mux.Vars(r)
proxyType := params["type"] proxyType := params["type"]
@ -245,7 +278,7 @@ type GetProxyStatsResp struct {
} }
// /api/proxy/:type/:name // /api/proxy/:type/:name
func (svr *Service) APIProxyByTypeAndName(w http.ResponseWriter, r *http.Request) { func (svr *Service) apiProxyByTypeAndName(w http.ResponseWriter, r *http.Request) {
res := GeneralResponse{Code: 200} res := GeneralResponse{Code: 200}
params := mux.Vars(r) params := mux.Vars(r)
proxyType := params["type"] proxyType := params["type"]
@ -314,7 +347,7 @@ type GetProxyTrafficResp struct {
TrafficOut []int64 `json:"trafficOut"` TrafficOut []int64 `json:"trafficOut"`
} }
func (svr *Service) APIProxyTraffic(w http.ResponseWriter, r *http.Request) { func (svr *Service) apiProxyTraffic(w http.ResponseWriter, r *http.Request) {
res := GeneralResponse{Code: 200} res := GeneralResponse{Code: 200}
params := mux.Vars(r) params := mux.Vars(r)
name := params["name"] name := params["name"]

View File

@ -24,7 +24,7 @@ import (
v1 "github.com/fatedier/frp/pkg/config/v1" v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/util/limit" "github.com/fatedier/frp/pkg/util/limit"
utilnet "github.com/fatedier/frp/pkg/util/net" netpkg "github.com/fatedier/frp/pkg/util/net"
"github.com/fatedier/frp/pkg/util/util" "github.com/fatedier/frp/pkg/util/util"
"github.com/fatedier/frp/pkg/util/vhost" "github.com/fatedier/frp/pkg/util/vhost"
"github.com/fatedier/frp/server/metrics" "github.com/fatedier/frp/server/metrics"
@ -180,8 +180,8 @@ func (pxy *HTTPProxy) GetRealConn(remoteAddr string) (workConn net.Conn, err err
}) })
} }
workConn = utilnet.WrapReadWriteCloserToConn(rwc, tmpConn) workConn = netpkg.WrapReadWriteCloserToConn(rwc, tmpConn)
workConn = utilnet.WrapStatsConn(workConn, pxy.updateStatsAfterClosedConn) workConn = netpkg.WrapStatsConn(workConn, pxy.updateStatsAfterClosedConn)
metrics.Server.OpenConnection(pxy.GetName(), pxy.GetConfigurer().GetBaseConfig().Type) metrics.Server.OpenConnection(pxy.GetName(), pxy.GetConfigurer().GetBaseConfig().Type)
return return
} }

View File

@ -32,7 +32,7 @@ import (
"github.com/fatedier/frp/pkg/msg" "github.com/fatedier/frp/pkg/msg"
plugin "github.com/fatedier/frp/pkg/plugin/server" plugin "github.com/fatedier/frp/pkg/plugin/server"
"github.com/fatedier/frp/pkg/util/limit" "github.com/fatedier/frp/pkg/util/limit"
utilnet "github.com/fatedier/frp/pkg/util/net" netpkg "github.com/fatedier/frp/pkg/util/net"
"github.com/fatedier/frp/pkg/util/xlog" "github.com/fatedier/frp/pkg/util/xlog"
"github.com/fatedier/frp/server/controller" "github.com/fatedier/frp/server/controller"
"github.com/fatedier/frp/server/metrics" "github.com/fatedier/frp/server/metrics"
@ -130,7 +130,7 @@ func (pxy *BaseProxy) GetWorkConnFromPool(src, dst net.Addr) (workConn net.Conn,
} }
xl.Debug("get a new work connection: [%s]", workConn.RemoteAddr().String()) xl.Debug("get a new work connection: [%s]", workConn.RemoteAddr().String())
xl.Spawn().AppendPrefix(pxy.GetName()) xl.Spawn().AppendPrefix(pxy.GetName())
workConn = utilnet.NewContextConn(pxy.ctx, workConn) workConn = netpkg.NewContextConn(pxy.ctx, workConn)
var ( var (
srcAddr string srcAddr string

View File

@ -30,7 +30,7 @@ import (
"github.com/fatedier/frp/pkg/msg" "github.com/fatedier/frp/pkg/msg"
"github.com/fatedier/frp/pkg/proto/udp" "github.com/fatedier/frp/pkg/proto/udp"
"github.com/fatedier/frp/pkg/util/limit" "github.com/fatedier/frp/pkg/util/limit"
utilnet "github.com/fatedier/frp/pkg/util/net" netpkg "github.com/fatedier/frp/pkg/util/net"
"github.com/fatedier/frp/server/metrics" "github.com/fatedier/frp/server/metrics"
) )
@ -222,7 +222,7 @@ func (pxy *UDPProxy) Run() (remoteAddr string, err error) {
}) })
} }
pxy.workConn = utilnet.WrapReadWriteCloserToConn(rwc, workConn) pxy.workConn = netpkg.WrapReadWriteCloserToConn(rwc, workConn)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
go workConnReaderFn(pxy.workConn) go workConnReaderFn(pxy.workConn)
go workConnSenderFn(pxy.workConn, ctx) go workConnSenderFn(pxy.workConn, ctx)

View File

@ -30,7 +30,6 @@ import (
quic "github.com/quic-go/quic-go" quic "github.com/quic-go/quic-go"
"github.com/samber/lo" "github.com/samber/lo"
"github.com/fatedier/frp/assets"
"github.com/fatedier/frp/pkg/auth" "github.com/fatedier/frp/pkg/auth"
v1 "github.com/fatedier/frp/pkg/config/v1" v1 "github.com/fatedier/frp/pkg/config/v1"
modelmetrics "github.com/fatedier/frp/pkg/metrics" modelmetrics "github.com/fatedier/frp/pkg/metrics"
@ -39,8 +38,9 @@ import (
plugin "github.com/fatedier/frp/pkg/plugin/server" plugin "github.com/fatedier/frp/pkg/plugin/server"
"github.com/fatedier/frp/pkg/ssh" "github.com/fatedier/frp/pkg/ssh"
"github.com/fatedier/frp/pkg/transport" "github.com/fatedier/frp/pkg/transport"
httppkg "github.com/fatedier/frp/pkg/util/http"
"github.com/fatedier/frp/pkg/util/log" "github.com/fatedier/frp/pkg/util/log"
utilnet "github.com/fatedier/frp/pkg/util/net" netpkg "github.com/fatedier/frp/pkg/util/net"
"github.com/fatedier/frp/pkg/util/tcpmux" "github.com/fatedier/frp/pkg/util/tcpmux"
"github.com/fatedier/frp/pkg/util/util" "github.com/fatedier/frp/pkg/util/util"
"github.com/fatedier/frp/pkg/util/version" "github.com/fatedier/frp/pkg/util/version"
@ -79,7 +79,8 @@ type Service struct {
// Accept frp tls connections // Accept frp tls connections
tlsListener net.Listener tlsListener net.Listener
virtualListener *utilnet.InternalListener // Accept pipe connections from ssh tunnel gateway
sshTunnelListener *netpkg.InternalListener
// Manage all controllers // Manage all controllers
ctlManager *ControlManager ctlManager *ControlManager
@ -96,6 +97,9 @@ type Service struct {
// All resource managers and controllers // All resource managers and controllers
rc *controller.ResourceController rc *controller.ResourceController
// web server for dashboard UI and apis
webServer *httppkg.Server
sshTunnelGateway *ssh.Gateway sshTunnelGateway *ssh.Gateway
// Verifies authentication based on selected method // Verifies authentication based on selected method
@ -111,16 +115,30 @@ type Service struct {
cancel context.CancelFunc cancel context.CancelFunc
} }
func NewService(cfg *v1.ServerConfig) (svr *Service, err error) { func NewService(cfg *v1.ServerConfig) (*Service, error) {
tlsConfig, err := transport.NewServerTLSConfig( tlsConfig, err := transport.NewServerTLSConfig(
cfg.Transport.TLS.CertFile, cfg.Transport.TLS.CertFile,
cfg.Transport.TLS.KeyFile, cfg.Transport.TLS.KeyFile,
cfg.Transport.TLS.TrustedCaFile) cfg.Transport.TLS.TrustedCaFile)
if err != nil { if err != nil {
return return nil, err
} }
svr = &Service{ var webServer *httppkg.Server
if cfg.WebServer.Port > 0 {
ws, err := httppkg.NewServer(cfg.WebServer)
if err != nil {
return nil, err
}
webServer = ws
modelmetrics.EnableMem()
if cfg.EnablePrometheus {
modelmetrics.EnablePrometheus()
}
}
svr := &Service{
ctlManager: NewControlManager(), ctlManager: NewControlManager(),
pxyManager: proxy.NewManager(), pxyManager: proxy.NewManager(),
pluginManager: plugin.NewManager(), pluginManager: plugin.NewManager(),
@ -129,12 +147,16 @@ func NewService(cfg *v1.ServerConfig) (svr *Service, err error) {
TCPPortManager: ports.NewManager("tcp", cfg.ProxyBindAddr, cfg.AllowPorts), TCPPortManager: ports.NewManager("tcp", cfg.ProxyBindAddr, cfg.AllowPorts),
UDPPortManager: ports.NewManager("udp", cfg.ProxyBindAddr, cfg.AllowPorts), UDPPortManager: ports.NewManager("udp", cfg.ProxyBindAddr, cfg.AllowPorts),
}, },
virtualListener: utilnet.NewInternalListener(), sshTunnelListener: netpkg.NewInternalListener(),
httpVhostRouter: vhost.NewRouters(), httpVhostRouter: vhost.NewRouters(),
authVerifier: auth.NewAuthVerifier(cfg.Auth), authVerifier: auth.NewAuthVerifier(cfg.Auth),
tlsConfig: tlsConfig, webServer: webServer,
cfg: cfg, tlsConfig: tlsConfig,
ctx: context.Background(), cfg: cfg,
ctx: context.Background(),
}
if webServer != nil {
webServer.RouteRegister(svr.registerRouteHandlers)
} }
// Create tcpmux httpconnect multiplexer. // Create tcpmux httpconnect multiplexer.
@ -143,14 +165,12 @@ func NewService(cfg *v1.ServerConfig) (svr *Service, err error) {
address := net.JoinHostPort(cfg.ProxyBindAddr, strconv.Itoa(cfg.TCPMuxHTTPConnectPort)) address := net.JoinHostPort(cfg.ProxyBindAddr, strconv.Itoa(cfg.TCPMuxHTTPConnectPort))
l, err = net.Listen("tcp", address) l, err = net.Listen("tcp", address)
if err != nil { if err != nil {
err = fmt.Errorf("create server listener error, %v", err) return nil, fmt.Errorf("create server listener error, %v", err)
return
} }
svr.rc.TCPMuxHTTPConnectMuxer, err = tcpmux.NewHTTPConnectTCPMuxer(l, cfg.TCPMuxPassthrough, vhostReadWriteTimeout) svr.rc.TCPMuxHTTPConnectMuxer, err = tcpmux.NewHTTPConnectTCPMuxer(l, cfg.TCPMuxPassthrough, vhostReadWriteTimeout)
if err != nil { if err != nil {
err = fmt.Errorf("create vhost tcpMuxer error, %v", err) return nil, fmt.Errorf("create vhost tcpMuxer error, %v", err)
return
} }
log.Info("tcpmux httpconnect multiplexer listen on %s, passthough: %v", address, cfg.TCPMuxPassthrough) log.Info("tcpmux httpconnect multiplexer listen on %s, passthough: %v", address, cfg.TCPMuxPassthrough)
} }
@ -191,8 +211,7 @@ func NewService(cfg *v1.ServerConfig) (svr *Service, err error) {
address := net.JoinHostPort(cfg.BindAddr, strconv.Itoa(cfg.BindPort)) address := net.JoinHostPort(cfg.BindAddr, strconv.Itoa(cfg.BindPort))
ln, err := net.Listen("tcp", address) ln, err := net.Listen("tcp", address)
if err != nil { if err != nil {
err = fmt.Errorf("create server listener error, %v", err) return nil, fmt.Errorf("create server listener error, %v", err)
return
} }
svr.muxer = mux.NewMux(ln) svr.muxer = mux.NewMux(ln)
@ -208,10 +227,9 @@ func NewService(cfg *v1.ServerConfig) (svr *Service, err error) {
// Listen for accepting connections from client using kcp protocol. // Listen for accepting connections from client using kcp protocol.
if cfg.KCPBindPort > 0 { if cfg.KCPBindPort > 0 {
address := net.JoinHostPort(cfg.BindAddr, strconv.Itoa(cfg.KCPBindPort)) address := net.JoinHostPort(cfg.BindAddr, strconv.Itoa(cfg.KCPBindPort))
svr.kcpListener, err = utilnet.ListenKcp(address) svr.kcpListener, err = netpkg.ListenKcp(address)
if err != nil { if err != nil {
err = fmt.Errorf("listen on kcp udp address %s error: %v", address, err) return nil, fmt.Errorf("listen on kcp udp address %s error: %v", address, err)
return
} }
log.Info("frps kcp listen on udp %s", address) log.Info("frps kcp listen on udp %s", address)
} }
@ -226,28 +244,26 @@ func NewService(cfg *v1.ServerConfig) (svr *Service, err error) {
KeepAlivePeriod: time.Duration(cfg.Transport.QUIC.KeepalivePeriod) * time.Second, KeepAlivePeriod: time.Duration(cfg.Transport.QUIC.KeepalivePeriod) * time.Second,
}) })
if err != nil { if err != nil {
err = fmt.Errorf("listen on quic udp address %s error: %v", address, err) return nil, fmt.Errorf("listen on quic udp address %s error: %v", address, err)
return
} }
log.Info("frps quic listen on %s", address) log.Info("frps quic listen on %s", address)
} }
if cfg.SSHTunnelGateway.BindPort > 0 { if cfg.SSHTunnelGateway.BindPort > 0 {
sshGateway, err := ssh.NewGateway(cfg.SSHTunnelGateway, cfg.ProxyBindAddr, svr.virtualListener) sshGateway, err := ssh.NewGateway(cfg.SSHTunnelGateway, cfg.ProxyBindAddr, svr.sshTunnelListener)
if err != nil { if err != nil {
err = fmt.Errorf("create ssh gateway error: %v", err) return nil, fmt.Errorf("create ssh gateway error: %v", err)
return nil, err
} }
svr.sshTunnelGateway = sshGateway svr.sshTunnelGateway = sshGateway
log.Info("frps sshTunnelGateway listen on port %d", cfg.SSHTunnelGateway.BindPort) log.Info("frps sshTunnelGateway listen on port %d", cfg.SSHTunnelGateway.BindPort)
} }
// Listen for accepting connections from client using websocket protocol. // Listen for accepting connections from client using websocket protocol.
websocketPrefix := []byte("GET " + utilnet.FrpWebsocketPath) websocketPrefix := []byte("GET " + netpkg.FrpWebsocketPath)
websocketLn := svr.muxer.Listen(0, uint32(len(websocketPrefix)), func(data []byte) bool { websocketLn := svr.muxer.Listen(0, uint32(len(websocketPrefix)), func(data []byte) bool {
return bytes.Equal(data, websocketPrefix) return bytes.Equal(data, websocketPrefix)
}) })
svr.websocketListener = utilnet.NewWebsocketListener(websocketLn) svr.websocketListener = netpkg.NewWebsocketListener(websocketLn)
// Create http vhost muxer. // Create http vhost muxer.
if cfg.VhostHTTPPort > 0 { if cfg.VhostHTTPPort > 0 {
@ -267,8 +283,7 @@ func NewService(cfg *v1.ServerConfig) (svr *Service, err error) {
} else { } else {
l, err = net.Listen("tcp", address) l, err = net.Listen("tcp", address)
if err != nil { if err != nil {
err = fmt.Errorf("create vhost http listener error, %v", err) return nil, fmt.Errorf("create vhost http listener error, %v", err)
return
} }
} }
go func() { go func() {
@ -286,55 +301,30 @@ func NewService(cfg *v1.ServerConfig) (svr *Service, err error) {
address := net.JoinHostPort(cfg.ProxyBindAddr, strconv.Itoa(cfg.VhostHTTPSPort)) address := net.JoinHostPort(cfg.ProxyBindAddr, strconv.Itoa(cfg.VhostHTTPSPort))
l, err = net.Listen("tcp", address) l, err = net.Listen("tcp", address)
if err != nil { if err != nil {
err = fmt.Errorf("create server listener error, %v", err) return nil, fmt.Errorf("create server listener error, %v", err)
return
} }
log.Info("https service listen on %s", address) log.Info("https service listen on %s", address)
} }
svr.rc.VhostHTTPSMuxer, err = vhost.NewHTTPSMuxer(l, vhostReadWriteTimeout) svr.rc.VhostHTTPSMuxer, err = vhost.NewHTTPSMuxer(l, vhostReadWriteTimeout)
if err != nil { if err != nil {
err = fmt.Errorf("create vhost httpsMuxer error, %v", err) return nil, fmt.Errorf("create vhost httpsMuxer error, %v", err)
return
} }
} }
// frp tls listener // frp tls listener
svr.tlsListener = svr.muxer.Listen(2, 1, func(data []byte) bool { svr.tlsListener = svr.muxer.Listen(2, 1, func(data []byte) bool {
// tls first byte can be 0x16 only when vhost https port is not same with bind port // tls first byte can be 0x16 only when vhost https port is not same with bind port
return int(data[0]) == utilnet.FRPTLSHeadByte || int(data[0]) == 0x16 return int(data[0]) == netpkg.FRPTLSHeadByte || int(data[0]) == 0x16
}) })
// Create nat hole controller. // Create nat hole controller.
nc, err := nathole.NewController(time.Duration(cfg.NatHoleAnalysisDataReserveHours) * time.Hour) nc, err := nathole.NewController(time.Duration(cfg.NatHoleAnalysisDataReserveHours) * time.Hour)
if err != nil { if err != nil {
err = fmt.Errorf("create nat hole controller error, %v", err) return nil, fmt.Errorf("create nat hole controller error, %v", err)
return
} }
svr.rc.NatHoleController = nc svr.rc.NatHoleController = nc
return svr, nil
var statsEnable bool
// Create dashboard web server.
if cfg.WebServer.Port > 0 {
// Init dashboard assets
assets.Load(cfg.WebServer.AssetsDir)
address := net.JoinHostPort(cfg.WebServer.Addr, strconv.Itoa(cfg.WebServer.Port))
err = svr.RunDashboardServer(address)
if err != nil {
err = fmt.Errorf("create dashboard web server error, %v", err)
return
}
log.Info("Dashboard listen on %s", address)
statsEnable = true
}
if statsEnable {
modelmetrics.EnableMem()
if cfg.EnablePrometheus {
modelmetrics.EnablePrometheus()
}
}
return
} }
func (svr *Service) Run(ctx context.Context) { func (svr *Service) Run(ctx context.Context) {
@ -342,7 +332,17 @@ func (svr *Service) Run(ctx context.Context) {
svr.ctx = ctx svr.ctx = ctx
svr.cancel = cancel svr.cancel = cancel
go svr.HandleListener(svr.virtualListener, true) // run dashboard web server.
if svr.webServer != nil {
go func() {
log.Info("dashboard listen on %s", svr.webServer.Address())
if err := svr.webServer.Run(); err != nil {
log.Warn("dashboard server exit with error: %v", err)
}
}()
}
go svr.HandleListener(svr.sshTunnelListener, true)
if svr.kcpListener != nil { if svr.kcpListener != nil {
go svr.HandleListener(svr.kcpListener, false) go svr.HandleListener(svr.kcpListener, false)
@ -398,7 +398,7 @@ func (svr *Service) Close() error {
return nil return nil
} }
func (svr *Service) handleConnection(ctx context.Context, conn net.Conn) { func (svr *Service) handleConnection(ctx context.Context, conn net.Conn, internal bool) {
xl := xlog.FromContextSafe(ctx) xl := xlog.FromContextSafe(ctx)
var ( var (
@ -424,7 +424,7 @@ func (svr *Service) handleConnection(ctx context.Context, conn net.Conn) {
retContent, err := svr.pluginManager.Login(content) retContent, err := svr.pluginManager.Login(content)
if err == nil { if err == nil {
m = &retContent.Login m = &retContent.Login
err = svr.RegisterControl(conn, m) err = svr.RegisterControl(conn, m, internal)
} }
// If login failed, send error message there. // If login failed, send error message there.
@ -461,6 +461,9 @@ func (svr *Service) handleConnection(ctx context.Context, conn net.Conn) {
} }
} }
// HandleListener accepts connections from client and call handleConnection to handle them.
// If internal is true, it means that this listener is used for internal communication like ssh tunnel gateway.
// TODO(fatedier): Pass some parameters of listener/connection through context to avoid passing too many parameters.
func (svr *Service) HandleListener(l net.Listener, internal bool) { func (svr *Service) HandleListener(l net.Listener, internal bool) {
// Listen for incoming connections from client. // Listen for incoming connections from client.
for { for {
@ -473,19 +476,21 @@ func (svr *Service) HandleListener(l net.Listener, internal bool) {
xl := xlog.New() xl := xlog.New()
ctx := context.Background() ctx := context.Background()
c = utilnet.NewContextConn(xlog.NewContext(ctx, xl), c) c = netpkg.NewContextConn(xlog.NewContext(ctx, xl), c)
log.Trace("start check TLS connection...") if !internal {
originConn := c log.Trace("start check TLS connection...")
forceTLS := svr.cfg.Transport.TLS.Force && !internal originConn := c
var isTLS, custom bool forceTLS := svr.cfg.Transport.TLS.Force
c, isTLS, custom, err = utilnet.CheckAndEnableTLSServerConnWithTimeout(c, svr.tlsConfig, forceTLS, connReadTimeout) var isTLS, custom bool
if err != nil { c, isTLS, custom, err = netpkg.CheckAndEnableTLSServerConnWithTimeout(c, svr.tlsConfig, forceTLS, connReadTimeout)
log.Warn("CheckAndEnableTLSServerConnWithTimeout error: %v", err) if err != nil {
originConn.Close() log.Warn("CheckAndEnableTLSServerConnWithTimeout error: %v", err)
continue originConn.Close()
continue
}
log.Trace("check TLS connection success, isTLS: %v custom: %v internal: %v", isTLS, custom, internal)
} }
log.Trace("check TLS connection success, isTLS: %v custom: %v", isTLS, custom)
// Start a new goroutine to handle connection. // Start a new goroutine to handle connection.
go func(ctx context.Context, frpConn net.Conn) { go func(ctx context.Context, frpConn net.Conn) {
@ -508,10 +513,10 @@ func (svr *Service) HandleListener(l net.Listener, internal bool) {
session.Close() session.Close()
return return
} }
go svr.handleConnection(ctx, stream) go svr.handleConnection(ctx, stream, internal)
} }
} else { } else {
svr.handleConnection(ctx, frpConn) svr.handleConnection(ctx, frpConn, internal)
} }
}(ctx, c) }(ctx, c)
} }
@ -534,13 +539,13 @@ func (svr *Service) HandleQUICListener(l *quic.Listener) {
_ = frpConn.CloseWithError(0, "") _ = frpConn.CloseWithError(0, "")
return return
} }
go svr.handleConnection(ctx, utilnet.QuicStreamToNetConn(stream, frpConn)) go svr.handleConnection(ctx, netpkg.QuicStreamToNetConn(stream, frpConn), false)
} }
}(context.Background(), c) }(context.Background(), c)
} }
} }
func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login) error { func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login, internal bool) error {
// If client's RunID is empty, it's a new client, we just create a new controller. // If client's RunID is empty, it's a new client, we just create a new controller.
// Otherwise, we check if there is one controller has the same run id. If so, we release previous controller and start new one. // Otherwise, we check if there is one controller has the same run id. If so, we release previous controller and start new one.
var err error var err error
@ -551,7 +556,7 @@ func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login) error
} }
} }
ctx := utilnet.NewContextFromConn(ctlConn) ctx := netpkg.NewContextFromConn(ctlConn)
xl := xlog.FromContextSafe(ctx) xl := xlog.FromContextSafe(ctx)
xl.AppendPrefix(loginMsg.RunID) xl.AppendPrefix(loginMsg.RunID)
ctx = xlog.NewContext(ctx, xl) ctx = xlog.NewContext(ctx, xl)
@ -559,11 +564,16 @@ func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login) error
ctlConn.RemoteAddr().String(), loginMsg.Version, loginMsg.Hostname, loginMsg.Os, loginMsg.Arch) ctlConn.RemoteAddr().String(), loginMsg.Version, loginMsg.Hostname, loginMsg.Os, loginMsg.Arch)
// Check auth. // Check auth.
if err := svr.authVerifier.VerifyLogin(loginMsg); err != nil { authVerifier := svr.authVerifier
if internal && loginMsg.ClientSpec.AlwaysAuthPass {
authVerifier = auth.AlwaysPassVerifier
}
if err := authVerifier.VerifyLogin(loginMsg); err != nil {
return err return err
} }
ctl, err := NewControl(ctx, svr.rc, svr.pxyManager, svr.pluginManager, svr.authVerifier, ctlConn, loginMsg, svr.cfg) // TODO(fatedier): use SessionContext
ctl, err := NewControl(ctx, svr.rc, svr.pxyManager, svr.pluginManager, authVerifier, ctlConn, !internal, loginMsg, svr.cfg)
if err != nil { if err != nil {
xl.Warn("create new controller error: %v", err) xl.Warn("create new controller error: %v", err)
// don't return detailed errors to client // don't return detailed errors to client
@ -588,7 +598,7 @@ func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login) error
// RegisterWorkConn register a new work connection to control and proxies need it. // RegisterWorkConn register a new work connection to control and proxies need it.
func (svr *Service) RegisterWorkConn(workConn net.Conn, newMsg *msg.NewWorkConn) error { func (svr *Service) RegisterWorkConn(workConn net.Conn, newMsg *msg.NewWorkConn) error {
xl := utilnet.NewLogFromConn(workConn) xl := netpkg.NewLogFromConn(workConn)
ctl, exist := svr.ctlManager.GetByID(newMsg.RunID) ctl, exist := svr.ctlManager.GetByID(newMsg.RunID)
if !exist { if !exist {
xl.Warn("No client control found for run id [%s]", newMsg.RunID) xl.Warn("No client control found for run id [%s]", newMsg.RunID)
@ -607,7 +617,7 @@ func (svr *Service) RegisterWorkConn(workConn net.Conn, newMsg *msg.NewWorkConn)
if err == nil { if err == nil {
newMsg = &retContent.NewWorkConn newMsg = &retContent.NewWorkConn
// Check auth. // Check auth.
err = svr.authVerifier.VerifyNewWorkConn(newMsg) err = ctl.authVerifier.VerifyNewWorkConn(newMsg)
} }
if err != nil { if err != nil {
xl.Warn("invalid NewWorkConn with run id [%s]", newMsg.RunID) xl.Warn("invalid NewWorkConn with run id [%s]", newMsg.RunID)

View File

@ -23,12 +23,12 @@ import (
libio "github.com/fatedier/golib/io" libio "github.com/fatedier/golib/io"
"github.com/samber/lo" "github.com/samber/lo"
utilnet "github.com/fatedier/frp/pkg/util/net" netpkg "github.com/fatedier/frp/pkg/util/net"
"github.com/fatedier/frp/pkg/util/util" "github.com/fatedier/frp/pkg/util/util"
) )
type listenerBundle struct { type listenerBundle struct {
l *utilnet.InternalListener l *netpkg.InternalListener
sk string sk string
allowUsers []string allowUsers []string
} }
@ -46,22 +46,21 @@ func NewManager() *Manager {
} }
} }
func (vm *Manager) Listen(name string, sk string, allowUsers []string) (l *utilnet.InternalListener, err error) { func (vm *Manager) Listen(name string, sk string, allowUsers []string) (*netpkg.InternalListener, error) {
vm.mu.Lock() vm.mu.Lock()
defer vm.mu.Unlock() defer vm.mu.Unlock()
if _, ok := vm.listeners[name]; ok { if _, ok := vm.listeners[name]; ok {
err = fmt.Errorf("custom listener for [%s] is repeated", name) return nil, fmt.Errorf("custom listener for [%s] is repeated", name)
return
} }
l = utilnet.NewInternalListener() l := netpkg.NewInternalListener()
vm.listeners[name] = &listenerBundle{ vm.listeners[name] = &listenerBundle{
l: l, l: l,
sk: sk, sk: sk,
allowUsers: allowUsers, allowUsers: allowUsers,
} }
return return l, nil
} }
func (vm *Manager) NewConn(name string, conn net.Conn, timestamp int64, signKey string, func (vm *Manager) NewConn(name string, conn net.Conn, timestamp int64, signKey string,
@ -91,7 +90,7 @@ func (vm *Manager) NewConn(name string, conn net.Conn, timestamp int64, signKey
if useCompression { if useCompression {
rwc = libio.WithCompression(rwc) rwc = libio.WithCompression(rwc)
} }
err = l.l.PutConn(utilnet.WrapReadWriteCloserToConn(rwc, conn)) err = l.l.PutConn(netpkg.WrapReadWriteCloserToConn(rwc, conn))
} else { } else {
err = fmt.Errorf("custom listener for [%s] doesn't exist", name) err = fmt.Errorf("custom listener for [%s] doesn't exist", name)
return return

View File

@ -8,7 +8,7 @@ import (
"github.com/onsi/ginkgo/v2" "github.com/onsi/ginkgo/v2"
"github.com/fatedier/frp/pkg/util/util" httppkg "github.com/fatedier/frp/pkg/util/http"
"github.com/fatedier/frp/test/e2e/framework" "github.com/fatedier/frp/test/e2e/framework"
"github.com/fatedier/frp/test/e2e/framework/consts" "github.com/fatedier/frp/test/e2e/framework/consts"
"github.com/fatedier/frp/test/e2e/mock/server/streamserver" "github.com/fatedier/frp/test/e2e/mock/server/streamserver"
@ -176,7 +176,7 @@ var _ = ginkgo.Describe("[Feature: TCPMUX httpconnect]", func() {
connectRequestHost = req.Host connectRequestHost = req.Host
// return ok response // return ok response
res := util.OkResponse() res := httppkg.OkResponse()
if res.Body != nil { if res.Body != nil {
defer res.Body.Close() defer res.Body.Close()
} }

View File

@ -14,7 +14,7 @@ import (
libdial "github.com/fatedier/golib/net/dial" libdial "github.com/fatedier/golib/net/dial"
"github.com/fatedier/frp/pkg/util/util" httppkg "github.com/fatedier/frp/pkg/util/http"
"github.com/fatedier/frp/test/e2e/pkg/rpc" "github.com/fatedier/frp/test/e2e/pkg/rpc"
) )
@ -115,7 +115,7 @@ func (r *Request) HTTPHeaders(headers map[string]string) *Request {
} }
func (r *Request) HTTPAuth(user, password string) *Request { func (r *Request) HTTPAuth(user, password string) *Request {
r.authValue = util.BasicAuth(user, password) r.authValue = httppkg.BasicAuth(user, password)
return r return r
} }

View File

@ -8,7 +8,7 @@ import (
"github.com/onsi/ginkgo/v2" "github.com/onsi/ginkgo/v2"
"github.com/fatedier/frp/pkg/util/util" httppkg "github.com/fatedier/frp/pkg/util/http"
"github.com/fatedier/frp/test/e2e/framework" "github.com/fatedier/frp/test/e2e/framework"
"github.com/fatedier/frp/test/e2e/framework/consts" "github.com/fatedier/frp/test/e2e/framework/consts"
"github.com/fatedier/frp/test/e2e/mock/server/streamserver" "github.com/fatedier/frp/test/e2e/mock/server/streamserver"
@ -180,7 +180,7 @@ var _ = ginkgo.Describe("[Feature: TCPMUX httpconnect]", func() {
connectRequestHost = req.Host connectRequestHost = req.Host
// return ok response // return ok response
res := util.OkResponse() res := httppkg.OkResponse()
if res.Body != nil { if res.Body != nil {
defer res.Body.Close() defer res.Body.Close()
} }