diff --git a/go.mod b/go.mod index 8d27e52..8d0055e 100644 --- a/go.mod +++ b/go.mod @@ -23,6 +23,7 @@ require ( github.com/samber/lo v1.38.1 github.com/spf13/cobra v1.7.0 github.com/stretchr/testify v1.8.4 + golang.org/x/crypto v0.15.0 golang.org/x/net v0.17.0 golang.org/x/oauth2 v0.10.0 golang.org/x/sync v0.3.0 @@ -64,11 +65,10 @@ require ( github.com/templexxx/cpufeat v0.0.0-20180724012125-cef66df7f161 // indirect github.com/templexxx/xor v0.0.0-20191217153810-f85b25db303b // indirect github.com/tjfoc/gmsm v1.4.1 // indirect - golang.org/x/crypto v0.14.0 // indirect golang.org/x/exp v0.0.0-20221205204356-47842c84f3db // indirect golang.org/x/mod v0.10.0 // indirect - golang.org/x/sys v0.13.0 // indirect - golang.org/x/text v0.13.0 // indirect + golang.org/x/sys v0.14.0 // indirect + golang.org/x/text v0.14.0 // indirect golang.org/x/tools v0.9.3 // indirect google.golang.org/appengine v1.6.7 // indirect google.golang.org/protobuf v1.31.0 // indirect diff --git a/go.sum b/go.sum index af509c3..49cef0b 100644 --- a/go.sum +++ b/go.sum @@ -157,8 +157,8 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20201012173705-84dcc777aaee/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE= -golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= -golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= +golang.org/x/crypto v0.15.0 h1:frVn1TEaCEaZcn3Tmd7Y2b5KKPaZ+I32Q2OA3kYp5TA= +golang.org/x/crypto v0.15.0/go.mod h1:4ChreQoLWfG3xLDer1WdlH5NdlQ3+mwnQq1YTKY+72g= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20221205204356-47842c84f3db h1:D/cFflL63o2KSLJIwjlcIt8PR064j/xsmdEJL/YvY/o= golang.org/x/exp v0.0.0-20221205204356-47842c84f3db/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= @@ -210,20 +210,21 @@ golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= -golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q= +golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY= +golang.org/x/term v0.14.0 h1:LGK9IlZ8T9jvdy6cTdfKUCltatMFOehAQo9SRC46UQ8= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= -golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= -golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/pkg/config/v1/server.go b/pkg/config/v1/server.go index e49921e..f562be8 100644 --- a/pkg/config/v1/server.go +++ b/pkg/config/v1/server.go @@ -16,11 +16,21 @@ package v1 import ( "github.com/samber/lo" + "golang.org/x/crypto/ssh" "github.com/fatedier/frp/pkg/config/types" "github.com/fatedier/frp/pkg/util/util" ) +type SSHTunnelGateway struct { + BindPort int `json:"bindPort,omitempty" validate:"gte=0,lte=65535"` + PrivateKeyFilePath string `json:"privateKeyFilePath,omitempty"` + PublicKeyFilesPath string `json:"publicKeyFilesPath,omitempty"` + + // store all public key file. load all when init + PublicKeyFilesMap map[string]ssh.PublicKey +} + type ServerConfig struct { APIMetadata @@ -31,6 +41,9 @@ type ServerConfig struct { // BindPort specifies the port that the server listens on. By default, this // value is 7000. BindPort int `json:"bindPort,omitempty"` + + SSHTunnelGateway SSHTunnelGateway `json:"sshGatewayConfig,omitempty"` + // KCPBindPort specifies the KCP port that the server listens on. If this // value is 0, the server will not listen for KCP connections. KCPBindPort int `json:"kcpBindPort,omitempty"` diff --git a/pkg/config/v1/ssh.go b/pkg/config/v1/ssh.go new file mode 100644 index 0000000..440305d --- /dev/null +++ b/pkg/config/v1/ssh.go @@ -0,0 +1,72 @@ +package v1 + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "errors" + "os" + "path/filepath" + + "golang.org/x/crypto/ssh" +) + +const ( + // custom define + SSHClientLoginUserPrefix = "_frpc_ssh_client_" +) + +// encodePrivateKeyToPEM encodes Private Key from RSA to PEM format +func GeneratePrivateKey() ([]byte, error) { + privateKey, err := generatePrivateKey() + if err != nil { + return nil, errors.New("gen private key error") + } + + privBlock := pem.Block{ + Type: "RSA PRIVATE KEY", + Headers: nil, + Bytes: x509.MarshalPKCS1PrivateKey(privateKey), + } + + return pem.EncodeToMemory(&privBlock), nil +} + +// generatePrivateKey creates a RSA Private Key of specified byte size +func generatePrivateKey() (*rsa.PrivateKey, error) { + privateKey, err := rsa.GenerateKey(rand.Reader, 4096) + if err != nil { + return nil, err + } + + err = privateKey.Validate() + if err != nil { + return nil, err + } + return privateKey, nil +} + +func LoadSSHPublicKeyFilesInDir(dirPath string) (map[string]ssh.PublicKey, error) { + fileMap := make(map[string]ssh.PublicKey) + files, err := os.ReadDir(dirPath) + if err != nil { + return nil, err + } + + for _, file := range files { + filePath := filepath.Join(dirPath, file.Name()) + content, err := os.ReadFile(filePath) + if err != nil { + return nil, err + } + + parsedAuthorizedKey, _, _, _, err := ssh.ParseAuthorizedKey(content) + if err != nil { + continue + } + fileMap[ssh.FingerprintSHA256(parsedAuthorizedKey)] = parsedAuthorizedKey + } + + return fileMap, nil +} diff --git a/pkg/ssh/service.go b/pkg/ssh/service.go new file mode 100644 index 0000000..ce0bc52 --- /dev/null +++ b/pkg/ssh/service.go @@ -0,0 +1,497 @@ +package ssh + +import ( + "encoding/binary" + "errors" + "flag" + "fmt" + "io" + "net" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + gerror "github.com/fatedier/golib/errors" + "golang.org/x/crypto/ssh" + + v1 "github.com/fatedier/frp/pkg/config/v1" + "github.com/fatedier/frp/pkg/util/log" +) + +const ( + // ssh protocol define + // https://datatracker.ietf.org/doc/html/rfc4254#page-16 + ChannelTypeServerOpenChannel = "forwarded-tcpip" + RequestTypeForward = "tcpip-forward" + + // golang ssh package define. + // https://pkg.go.dev/golang.org/x/crypto/ssh + RequestTypeHeartbeat = "keepalive@openssh.com" +) + +// 当 proxy 失败会返回该错误 +type VProxyError struct{} + +// ssh protocol define +// https://datatracker.ietf.org/doc/html/rfc4254#page-16 +// parse ssh client cmds input +type forwardedTCPPayload struct { + Addr string + Port uint32 + + // can be default empty value but do not delete it + // because ssh protocol shoule be reserved + OriginAddr string + OriginPort uint32 +} + +// custom define +// parse ssh client cmds input +type CmdPayload struct { + Address string + Port uint32 +} + +// custom define +// with frp control cmds +type ExtraPayload struct { + Type string + + // TODO port can be set by extra message and priority to ssh raw cmd + Address string + Port uint32 +} + +type Service struct { + tcpConn net.Conn + cfg *ssh.ServerConfig + + sshConn *ssh.ServerConn + gChannel <-chan ssh.NewChannel + gReq <-chan *ssh.Request + + addrPayloadCh chan CmdPayload + extraPayloadCh chan ExtraPayload + + proxyPayloadCh chan v1.ProxyConfigurer + replyCh chan interface{} + + closeCh chan struct{} + exit int32 +} + +func NewSSHService( + tcpConn net.Conn, + cfg *ssh.ServerConfig, + proxyPayloadCh chan v1.ProxyConfigurer, + replyCh chan interface{}, +) (ss *Service, err error) { + ss = &Service{ + tcpConn: tcpConn, + cfg: cfg, + + addrPayloadCh: make(chan CmdPayload), + extraPayloadCh: make(chan ExtraPayload), + + proxyPayloadCh: proxyPayloadCh, + replyCh: replyCh, + + closeCh: make(chan struct{}), + exit: 0, + } + + ss.sshConn, ss.gChannel, ss.gReq, err = ssh.NewServerConn(tcpConn, cfg) + if err != nil { + log.Error("ssh handshake error: %v", err) + return nil, err + } + + log.Info("ssh connection success") + + return ss, nil +} + +func (ss *Service) Run() { + go ss.loopGenerateProxy() + go ss.loopParseCmdPayload() + go ss.loopParseExtraPayload() + go ss.loopReply() +} + +func (ss *Service) Exit() <-chan struct{} { + return ss.closeCh +} + +func (ss *Service) Close() { + if atomic.LoadInt32(&ss.exit) == 1 { + return + } + + select { + case <-ss.closeCh: + return + default: + } + + close(ss.closeCh) + close(ss.addrPayloadCh) + close(ss.extraPayloadCh) + + _ = ss.sshConn.Wait() + + ss.sshConn.Close() + ss.tcpConn.Close() + + atomic.StoreInt32(&ss.exit, 1) + + log.Info("ssh service close") +} + +func (ss *Service) loopParseCmdPayload() { + for { + select { + case req, ok := <-ss.gReq: + if !ok { + log.Info("global request is close") + ss.Close() + return + } + + switch req.Type { + case RequestTypeForward: + var addrPayload CmdPayload + if err := ssh.Unmarshal(req.Payload, &addrPayload); err != nil { + log.Error("ssh unmarshal error: %v", err) + return + } + _ = gerror.PanicToError(func() { + ss.addrPayloadCh <- addrPayload + }) + default: + if req.Type == RequestTypeHeartbeat { + log.Debug("ssh heartbeat data") + } else { + log.Info("default req, data: %v", req) + } + } + if req.WantReply { + err := req.Reply(true, nil) + if err != nil { + log.Error("reply to ssh client error: %v", err) + } + } + case <-ss.closeCh: + log.Info("loop parse cmd payload close") + return + } + } +} + +func (ss *Service) loopSendHeartbeat(ch ssh.Channel) { + tk := time.NewTicker(time.Second * 60) + defer tk.Stop() + + for { + select { + case <-tk.C: + ok, err := ch.SendRequest("heartbeat", false, nil) + if err != nil { + log.Error("channel send req error: %v", err) + if err == io.EOF { + ss.Close() + return + } + continue + } + log.Debug("heartbeat send success, ok: %v", ok) + case <-ss.closeCh: + return + } + } +} + +func (ss *Service) loopParseExtraPayload() { + log.Info("loop parse extra payload start") + + for newChannel := range ss.gChannel { + ch, req, err := newChannel.Accept() + if err != nil { + log.Error("channel accept error: %v", err) + return + } + + go ss.loopSendHeartbeat(ch) + + go func(req <-chan *ssh.Request) { + for r := range req { + if len(r.Payload) <= 4 { + log.Info("r.payload is less than 4") + continue + } + if !strings.Contains(string(r.Payload), "tcp") && !strings.Contains(string(r.Payload), "http") { + log.Info("ssh protocol exchange data") + continue + } + + // [4byte data_len|data] + end := 4 + binary.BigEndian.Uint32(r.Payload[:4]) + if end > uint32(len(r.Payload)) { + end = uint32(len(r.Payload)) + } + p := string(r.Payload[4:end]) + + msg, err := parseSSHExtraMessage(p) + if err != nil { + log.Error("parse ssh extra message error: %v, payload: %v", err, r.Payload) + continue + } + _ = gerror.PanicToError(func() { + ss.extraPayloadCh <- msg + }) + return + } + }(req) + } +} + +func (ss *Service) SSHConn() *ssh.ServerConn { + return ss.sshConn +} + +func (ss *Service) TCPConn() net.Conn { + return ss.tcpConn +} + +func (ss *Service) loopReply() { + for { + select { + case <-ss.closeCh: + log.Info("loop reply close") + return + case req := <-ss.replyCh: + switch req.(type) { + case *VProxyError: + log.Error("run frp proxy error, close ssh service") + ss.Close() + default: + // TODO + } + } + } +} + +func (ss *Service) loopGenerateProxy() { + log.Info("loop generate proxy start") + + for { + if atomic.LoadInt32(&ss.exit) == 1 { + return + } + + wg := new(sync.WaitGroup) + wg.Add(2) + + var p1 CmdPayload + var p2 ExtraPayload + + go func() { + defer wg.Done() + for { + select { + case <-ss.closeCh: + return + case p1 = <-ss.addrPayloadCh: + return + } + } + }() + + go func() { + defer wg.Done() + for { + select { + case <-ss.closeCh: + return + case p2 = <-ss.extraPayloadCh: + return + } + } + }() + + wg.Wait() + + if atomic.LoadInt32(&ss.exit) == 1 { + return + } + + switch p2.Type { + case "http": + case "tcp": + ss.proxyPayloadCh <- &v1.TCPProxyConfig{ + ProxyBaseConfig: v1.ProxyBaseConfig{ + Name: fmt.Sprintf("ssh-proxy-%v-%v", ss.tcpConn.RemoteAddr().String(), time.Now().UnixNano()), + Type: p2.Type, + + ProxyBackend: v1.ProxyBackend{ + LocalIP: p1.Address, + }, + }, + RemotePort: int(p1.Port), + } + default: + log.Warn("invalid frp proxy type: %v", p2.Type) + } + } +} + +func parseSSHExtraMessage(s string) (p ExtraPayload, err error) { + sn := len(s) + + log.Info("parse ssh extra message: %v", s) + + ss := strings.Fields(s) + if len(ss) == 0 { + if sn != 0 { + ss = append(ss, s) + } else { + return p, fmt.Errorf("invalid ssh input, args: %v", ss) + } + } + + for i, v := range ss { + ss[i] = strings.TrimSpace(v) + } + + if ss[0] != "tcp" && ss[0] != "http" { + return p, fmt.Errorf("only support tcp/http now") + } + + switch ss[0] { + case "tcp": + tcpCmd, err := ParseTCPCommand(ss) + if err != nil { + return ExtraPayload{}, fmt.Errorf("invalid ssh input: %v", err) + } + + port, _ := strconv.Atoi(tcpCmd.Port) + + p = ExtraPayload{ + Type: "tcp", + Address: tcpCmd.Address, + Port: uint32(port), + } + case "http": + httpCmd, err := ParseHTTPCommand(ss) + if err != nil { + return ExtraPayload{}, fmt.Errorf("invalid ssh input: %v", err) + } + + _ = httpCmd + + p = ExtraPayload{ + Type: "http", + } + } + + return p, nil +} + +type HTTPCommand struct { + Domain string + BasicAuthUser string + BasicAuthPass string +} + +func ParseHTTPCommand(params []string) (*HTTPCommand, error) { + if len(params) < 2 { + return nil, errors.New("invalid HTTP command") + } + + var ( + basicAuth string + domainURL string + basicAuthUser string + basicAuthPass string + ) + + fs := flag.NewFlagSet("http", flag.ContinueOnError) + fs.StringVar(&basicAuth, "basic-auth", "", "") + fs.StringVar(&domainURL, "domain", "", "") + + fs.SetOutput(&nullWriter{}) // Disables usage output + + err := fs.Parse(params[2:]) + if err != nil { + if !errors.Is(err, flag.ErrHelp) { + return nil, err + } + } + + if basicAuth != "" { + authParts := strings.SplitN(basicAuth, ":", 2) + basicAuthUser = authParts[0] + if len(authParts) > 1 { + basicAuthPass = authParts[1] + } + } + + httpCmd := &HTTPCommand{ + Domain: domainURL, + BasicAuthUser: basicAuthUser, + BasicAuthPass: basicAuthPass, + } + return httpCmd, nil +} + +type TCPCommand struct { + Address string + Port string +} + +func ParseTCPCommand(params []string) (*TCPCommand, error) { + if len(params) == 0 || params[0] != "tcp" { + return nil, errors.New("invalid TCP command") + } + + if len(params) == 1 { + return &TCPCommand{}, nil + } + + var ( + address string + port string + ) + + fs := flag.NewFlagSet("tcp", flag.ContinueOnError) + fs.StringVar(&address, "address", "", "The IP address to listen on") + fs.StringVar(&port, "port", "", "The port to listen on") + fs.SetOutput(&nullWriter{}) // Disables usage output + + args := params[1:] + err := fs.Parse(args) + if err != nil { + if !errors.Is(err, flag.ErrHelp) { + return nil, err + } + } + + parsedAddr, err := net.ResolveIPAddr("ip", address) + if err != nil { + return nil, err + } + if _, err := net.LookupPort("tcp", port); err != nil { + return nil, err + } + + tcpCmd := &TCPCommand{ + Address: parsedAddr.String(), + Port: port, + } + return tcpCmd, nil +} + +type nullWriter struct{} + +func (w *nullWriter) Write(p []byte) (n int, err error) { return len(p), nil } diff --git a/pkg/ssh/vclient.go b/pkg/ssh/vclient.go new file mode 100644 index 0000000..e78c828 --- /dev/null +++ b/pkg/ssh/vclient.go @@ -0,0 +1,185 @@ +package ssh + +import ( + "context" + "fmt" + "net" + "sync/atomic" + "time" + + "golang.org/x/crypto/ssh" + + "github.com/fatedier/frp/pkg/config" + v1 "github.com/fatedier/frp/pkg/config/v1" + "github.com/fatedier/frp/pkg/msg" + plugin "github.com/fatedier/frp/pkg/plugin/server" + "github.com/fatedier/frp/pkg/util/log" + frp_net "github.com/fatedier/frp/pkg/util/net" + "github.com/fatedier/frp/pkg/util/util" + "github.com/fatedier/frp/pkg/util/xlog" + "github.com/fatedier/frp/server/controller" + "github.com/fatedier/frp/server/proxy" +) + +// VirtualService is a client VirtualService run in frps +type VirtualService struct { + clientCfg v1.ClientCommonConfig + pxyCfg v1.ProxyConfigurer + serverCfg v1.ServerConfig + + sshSvc *Service + + // uniq id got from frps, attach it in loginMsg + runID string + loginMsg *msg.Login + + // All resource managers and controllers + rc *controller.ResourceController + + exit uint32 // 0 means not exit + // SSHService context + ctx context.Context + // call cancel to stop SSHService + cancel context.CancelFunc + + replyCh chan interface{} + pxy proxy.Proxy +} + +func NewVirtualService( + ctx context.Context, + clientCfg v1.ClientCommonConfig, + serverCfg v1.ServerConfig, + logMsg msg.Login, + rc *controller.ResourceController, + pxyCfg v1.ProxyConfigurer, + sshSvc *Service, + replyCh chan interface{}, +) (svr *VirtualService, err error) { + svr = &VirtualService{ + clientCfg: clientCfg, + serverCfg: serverCfg, + rc: rc, + + loginMsg: &logMsg, + + sshSvc: sshSvc, + pxyCfg: pxyCfg, + + ctx: ctx, + exit: 0, + + replyCh: replyCh, + } + + svr.runID, err = util.RandID() + if err != nil { + return nil, err + } + + go svr.loopCheck() + + return +} + +func (svr *VirtualService) Run(ctx context.Context) (err error) { + ctx, cancel := context.WithCancel(ctx) + svr.ctx = xlog.NewContext(ctx, xlog.New()) + svr.cancel = cancel + + remoteAddr, err := svr.RegisterProxy(&msg.NewProxy{ + ProxyName: svr.pxyCfg.(*v1.TCPProxyConfig).Name, + ProxyType: svr.pxyCfg.(*v1.TCPProxyConfig).Type, + RemotePort: svr.pxyCfg.(*v1.TCPProxyConfig).RemotePort, + }) + if err != nil { + return err + } + + log.Info("run a reverse proxy on port: %v", remoteAddr) + + return nil +} + +func (svr *VirtualService) Close() { + svr.GracefulClose(time.Duration(0)) +} + +func (svr *VirtualService) GracefulClose(d time.Duration) { + atomic.StoreUint32(&svr.exit, 1) + svr.pxy.Close() + + if svr.cancel != nil { + svr.cancel() + } + + svr.replyCh <- &VProxyError{} +} + +func (svr *VirtualService) loopCheck() { + <-svr.sshSvc.Exit() + svr.pxy.Close() + log.Info("virtual client service close") +} + +func (svr *VirtualService) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr string, err error) { + var pxyConf v1.ProxyConfigurer + pxyConf, err = config.NewProxyConfigurerFromMsg(pxyMsg, &svr.serverCfg) + if err != nil { + return + } + + // User info + userInfo := plugin.UserInfo{ + User: svr.loginMsg.User, + Metas: svr.loginMsg.Metas, + RunID: svr.runID, + } + + svr.pxy, err = proxy.NewProxy(svr.ctx, &proxy.Options{ + LoginMsg: svr.loginMsg, + UserInfo: userInfo, + Configurer: pxyConf, + ResourceController: svr.rc, + + GetWorkConnFn: svr.GetWorkConn, + PoolCount: 10, + + ServerCfg: &svr.serverCfg, + }) + if err != nil { + return remoteAddr, err + } + + remoteAddr, err = svr.pxy.Run() + if err != nil { + log.Warn("proxy run error: %v", err) + return + } + + defer func() { + if err != nil { + log.Warn("proxy close") + svr.pxy.Close() + } + }() + + return +} + +func (svr *VirtualService) GetWorkConn() (workConn net.Conn, err error) { + // tell ssh client open a new stream for work + payload := forwardedTCPPayload{ + Addr: svr.serverCfg.BindAddr, // TODO refine + Port: uint32(svr.pxyCfg.(*v1.TCPProxyConfig).RemotePort), + } + + channel, reqs, err := svr.sshSvc.SSHConn().OpenChannel(ChannelTypeServerOpenChannel, ssh.Marshal(payload)) + if err != nil { + return nil, fmt.Errorf("open ssh channel error: %v", err) + } + go ssh.DiscardRequests(reqs) + + workConn = frp_net.WrapReadWriteCloserToConn(channel, svr.sshSvc.tcpConn) + return workConn, nil +} diff --git a/server/proxy/proxy.go b/server/proxy/proxy.go index fe6f781..5ea99f1 100644 --- a/server/proxy/proxy.go +++ b/server/proxy/proxy.go @@ -21,6 +21,7 @@ import ( "net" "reflect" "strconv" + "strings" "sync" "time" @@ -229,8 +230,14 @@ func (pxy *BaseProxy) handleUserTCPConnection(userConn net.Conn) { return } + var workConn net.Conn + // try all connections from the pool - workConn, err := pxy.GetWorkConnFromPool(userConn.RemoteAddr(), userConn.LocalAddr()) + if strings.HasPrefix(pxy.GetLoginMsg().User, v1.SSHClientLoginUserPrefix) { + workConn, err = pxy.getWorkConnFn() + } else { + workConn, err = pxy.GetWorkConnFromPool(userConn.RemoteAddr(), userConn.LocalAddr()) + } if err != nil { return } diff --git a/server/service.go b/server/service.go index 7478b97..2ca501b 100644 --- a/server/service.go +++ b/server/service.go @@ -18,10 +18,13 @@ import ( "bytes" "context" "crypto/tls" + "errors" "fmt" "io" "net" "net/http" + "os" + "reflect" "strconv" "time" @@ -29,6 +32,7 @@ import ( fmux "github.com/hashicorp/yamux" quic "github.com/quic-go/quic-go" "github.com/samber/lo" + "golang.org/x/crypto/ssh" "github.com/fatedier/frp/assets" "github.com/fatedier/frp/pkg/auth" @@ -37,6 +41,7 @@ import ( "github.com/fatedier/frp/pkg/msg" "github.com/fatedier/frp/pkg/nathole" plugin "github.com/fatedier/frp/pkg/plugin/server" + frpssh "github.com/fatedier/frp/pkg/ssh" "github.com/fatedier/frp/pkg/transport" "github.com/fatedier/frp/pkg/util/log" utilnet "github.com/fatedier/frp/pkg/util/net" @@ -66,6 +71,10 @@ type Service struct { // Accept connections from client listener net.Listener + // Accept connections using ssh + sshListener net.Listener + sshConfig *ssh.ServerConfig + // Accept connections using kcp kcpListener net.Listener @@ -199,6 +208,67 @@ func NewService(cfg *v1.ServerConfig) (svr *Service, err error) { svr.listener = ln log.Info("frps tcp listen on %s", address) + if cfg.SSHTunnelGateway.BindPort > 0 { + + if cfg.SSHTunnelGateway.PublicKeyFilesPath != "" { + cfg.SSHTunnelGateway.PublicKeyFilesMap, err = v1.LoadSSHPublicKeyFilesInDir(cfg.SSHTunnelGateway.PublicKeyFilesPath) + if err != nil { + return nil, fmt.Errorf("load ssh all public key files error: %v", err) + } + log.Info("load %v public key files success", cfg.SSHTunnelGateway.PublicKeyFilesPath) + } + + svr.sshConfig = &ssh.ServerConfig{ + NoClientAuth: lo.If(cfg.SSHTunnelGateway.PublicKeyFilesPath == "", true).Else(false), + + PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { + parsedAuthorizedKey, ok := cfg.SSHTunnelGateway.PublicKeyFilesMap[ssh.FingerprintSHA256(key)] + if !ok { + return nil, errors.New("cannot find public key file") + } + + if key.Type() == parsedAuthorizedKey.Type() && reflect.DeepEqual(parsedAuthorizedKey, key) { + return &ssh.Permissions{ + Extensions: map[string]string{}, + }, nil + } + return nil, fmt.Errorf("unknown public key for %q", conn.User()) + }, + } + + var privateBytes []byte + if cfg.SSHTunnelGateway.PrivateKeyFilePath != "" { + privateBytes, err = os.ReadFile(cfg.SSHTunnelGateway.PrivateKeyFilePath) + if err != nil { + log.Error("Failed to load private key") + return nil, err + } + log.Info("load %v private key file success", cfg.SSHTunnelGateway.PrivateKeyFilePath) + } else { + privateBytes, err = v1.GeneratePrivateKey() + if err != nil { + log.Error("Failed to load private key") + return nil, err + } + log.Info("auto gen private key file success") + } + private, err := ssh.ParsePrivateKey(privateBytes) + if err != nil { + log.Error("Failed to parse private key, error: %v", err) + return nil, err + } + + svr.sshConfig.AddHostKey(private) + + sshAddr := net.JoinHostPort(cfg.BindAddr, strconv.Itoa(cfg.SSHTunnelGateway.BindPort)) + svr.sshListener, err = net.Listen("tcp", sshAddr) + if err != nil { + log.Error("Failed to listen on %v, error: %v", sshAddr, err) + return nil, err + } + log.Info("ssh server listening on %v", sshAddr) + } + // Listen for accepting connections from client using kcp protocol. if cfg.KCPBindPort > 0 { address := net.JoinHostPort(cfg.BindAddr, strconv.Itoa(cfg.KCPBindPort)) @@ -326,6 +396,10 @@ func (svr *Service) Run(ctx context.Context) { svr.ctx = ctx svr.cancel = cancel + if svr.sshListener != nil { + go svr.HandleSSHListener(svr.sshListener) + } + if svr.kcpListener != nil { go svr.HandleListener(svr.kcpListener) } @@ -348,6 +422,10 @@ func (svr *Service) Run(ctx context.Context) { } func (svr *Service) Close() error { + if svr.sshListener != nil { + svr.sshListener.Close() + svr.sshListener = nil + } if svr.kcpListener != nil { svr.kcpListener.Close() svr.kcpListener = nil @@ -493,6 +571,52 @@ func (svr *Service) HandleListener(l net.Listener) { } } +func (svr *Service) HandleSSHListener(listener net.Listener) { + for { + tcpConn, err := listener.Accept() + if err != nil { + log.Error("failed to accept incoming ssh connection (%s)", err) + return + } + log.Info("new tcp conn connected: %v", tcpConn.RemoteAddr().String()) + + pxyPayloadCh := make(chan v1.ProxyConfigurer) + replyCh := make(chan interface{}) + + ss, err := frpssh.NewSSHService(tcpConn, svr.sshConfig, pxyPayloadCh, replyCh) + if err != nil { + log.Error("new ssh service error: %v", err) + continue + } + ss.Run() + + go func() { + for { + pxyCfg := <-pxyPayloadCh + + ctx := context.Background() + + // TODO fill client common config and login msg + vs, err := frpssh.NewVirtualService(ctx, v1.ClientCommonConfig{}, *svr.cfg, + msg.Login{User: v1.SSHClientLoginUserPrefix + tcpConn.RemoteAddr().String()}, + svr.rc, pxyCfg, ss, replyCh) + if err != nil { + log.Error("new virtual service error: %v", err) + ss.Close() + return + } + + err = vs.Run(ctx) + if err != nil { + log.Error("proxy run error: %v", err) + vs.Close() + return + } + } + }() + } +} + func (svr *Service) HandleQUICListener(l *quic.Listener) { // Listen for incoming connections from client. for {