diff --git a/.gitignore b/.gitignore index f6df315..c9480d5 100644 --- a/.gitignore +++ b/.gitignore @@ -33,6 +33,7 @@ lastversion/ dist/ .idea/ .vscode/ +.autogen_ssh_key # Cache *.swp diff --git a/pkg/ssh/server.go b/pkg/ssh/server.go index 042f676..30e79c6 100644 --- a/pkg/ssh/server.go +++ b/pkg/ssh/server.go @@ -56,8 +56,6 @@ type forwardedTCPPayload struct { Addr string Port uint32 - // can be default empty value but do not delete it - // because ssh protocol shoule be reserved OriginAddr string OriginPort uint32 } @@ -117,6 +115,8 @@ func (s *TunnelServer) Run() error { // join workConn and ssh channel c, err := s.openConn(addr) if err != nil { + log.Trace("open conn error: %v", err) + workConn.Close() return false } libio.Join(c, workConn) @@ -180,20 +180,16 @@ func (s *TunnelServer) waitForwardAddrAndExtraPayload( go func() { addrGot := false for req := range requests { - switch req.Type { - case RequestTypeForward: - if !addrGot { - payload := tcpipForward{} - if err := ssh.Unmarshal(req.Payload, &payload); err != nil { - return - } - addrGot = true - addrCh <- &payload - } - default: - if req.WantReply { - _ = req.Reply(true, nil) + if req.Type == RequestTypeForward && !addrGot { + payload := tcpipForward{} + if err := ssh.Unmarshal(req.Payload, &payload); err != nil { + return } + addrGot = true + addrCh <- &payload + } + if req.WantReply { + _ = req.Reply(true, nil) } } }() @@ -271,10 +267,10 @@ func (s *TunnelServer) handleNewChannel(channel ssh.NewChannel, extraPayloadCh c go s.keepAlive(ch) for req := range reqs { - if req.Type != "exec" { - continue + if req.WantReply { + _ = req.Reply(true, nil) } - if len(req.Payload) <= 4 { + if req.Type != "exec" || len(req.Payload) <= 4 { continue } end := 4 + binary.BigEndian.Uint32(req.Payload[:4]) @@ -310,6 +306,9 @@ func (s *TunnelServer) openConn(addr *tcpipForward) (net.Conn, error) { payload := forwardedTCPPayload{ Addr: addr.Host, Port: addr.Port, + // Note: Here is just for compatibility, not the real source address. + OriginAddr: addr.Host, + OriginPort: addr.Port, } channel, reqs, err := s.sshConn.OpenChannel(ChannelTypeServerOpenChannel, ssh.Marshal(&payload)) if err != nil { diff --git a/test/e2e/pkg/ssh/client.go b/test/e2e/pkg/ssh/client.go new file mode 100644 index 0000000..1a923e9 --- /dev/null +++ b/test/e2e/pkg/ssh/client.go @@ -0,0 +1,89 @@ +package ssh + +import ( + "net" + + libio "github.com/fatedier/golib/io" + "golang.org/x/crypto/ssh" +) + +type TunnelClient struct { + localAddr string + sshServer string + commands string + + sshConn *ssh.Client + ln net.Listener +} + +func NewTunnelClient(localAddr string, sshServer string, commands string) *TunnelClient { + return &TunnelClient{ + localAddr: localAddr, + sshServer: sshServer, + commands: commands, + } +} + +func (c *TunnelClient) Start() error { + config := &ssh.ClientConfig{ + User: "v0", + HostKeyCallback: func(string, net.Addr, ssh.PublicKey) error { return nil }, + } + + conn, err := ssh.Dial("tcp", c.sshServer, config) + if err != nil { + return err + } + c.sshConn = conn + + l, err := conn.Listen("tcp", "0.0.0.0:80") + if err != nil { + return err + } + c.ln = l + ch, req, err := conn.OpenChannel("direct", []byte("")) + if err != nil { + return err + } + defer ch.Close() + go ssh.DiscardRequests(req) + + type command struct { + Cmd string + } + _, err = ch.SendRequest("exec", false, ssh.Marshal(command{Cmd: c.commands})) + if err != nil { + return err + } + + go c.serveListener() + return nil +} + +func (c *TunnelClient) Close() { + if c.sshConn != nil { + _ = c.sshConn.Close() + } + if c.ln != nil { + _ = c.ln.Close() + } +} + +func (c *TunnelClient) serveListener() { + for { + conn, err := c.ln.Accept() + if err != nil { + return + } + go c.hanldeConn(conn) + } +} + +func (c *TunnelClient) hanldeConn(conn net.Conn) { + defer conn.Close() + local, err := net.Dial("tcp", c.localAddr) + if err != nil { + return + } + _, _, _ = libio.Join(local, conn) +} diff --git a/test/e2e/v1/features/ssh_tunnel.go b/test/e2e/v1/features/ssh_tunnel.go new file mode 100644 index 0000000..f67d87a --- /dev/null +++ b/test/e2e/v1/features/ssh_tunnel.go @@ -0,0 +1,193 @@ +package features + +import ( + "crypto/tls" + "fmt" + "time" + + "github.com/onsi/ginkgo/v2" + + "github.com/fatedier/frp/pkg/transport" + "github.com/fatedier/frp/test/e2e/framework" + "github.com/fatedier/frp/test/e2e/framework/consts" + "github.com/fatedier/frp/test/e2e/mock/server/httpserver" + "github.com/fatedier/frp/test/e2e/mock/server/streamserver" + "github.com/fatedier/frp/test/e2e/pkg/request" + "github.com/fatedier/frp/test/e2e/pkg/ssh" +) + +var _ = ginkgo.Describe("[Feature: SSH Tunnel]", func() { + f := framework.NewDefaultFramework() + + ginkgo.It("tcp", func() { + sshPort := f.AllocPort() + serverConf := consts.DefaultServerConfig + fmt.Sprintf(` + sshTunnelGateway.bindPort = %d + `, sshPort) + + f.RunProcesses([]string{serverConf}, nil) + + localPort := f.PortByName(framework.TCPEchoServerPort) + remotePort := f.AllocPort() + tc := ssh.NewTunnelClient( + fmt.Sprintf("127.0.0.1:%d", localPort), + fmt.Sprintf("127.0.0.1:%d", sshPort), + fmt.Sprintf("tcp --remote_port %d", remotePort), + ) + framework.ExpectNoError(tc.Start()) + defer tc.Close() + + time.Sleep(time.Second) + framework.NewRequestExpect(f).Port(remotePort).Ensure() + }) + + ginkgo.It("http", func() { + sshPort := f.AllocPort() + vhostPort := f.AllocPort() + serverConf := consts.DefaultServerConfig + fmt.Sprintf(` + vhostHTTPPort = %d + sshTunnelGateway.bindPort = %d + `, vhostPort, sshPort) + + f.RunProcesses([]string{serverConf}, nil) + + localPort := f.PortByName(framework.HTTPSimpleServerPort) + tc := ssh.NewTunnelClient( + fmt.Sprintf("127.0.0.1:%d", localPort), + fmt.Sprintf("127.0.0.1:%d", sshPort), + "http --custom_domain test.example.com", + ) + framework.ExpectNoError(tc.Start()) + defer tc.Close() + + time.Sleep(time.Second) + framework.NewRequestExpect(f).Port(vhostPort). + RequestModify(func(r *request.Request) { + r.HTTP().HTTPHost("test.example.com") + }). + Ensure() + }) + + ginkgo.It("https", func() { + sshPort := f.AllocPort() + vhostPort := f.AllocPort() + serverConf := consts.DefaultServerConfig + fmt.Sprintf(` + vhostHTTPSPort = %d + sshTunnelGateway.bindPort = %d + `, vhostPort, sshPort) + + f.RunProcesses([]string{serverConf}, nil) + + localPort := f.AllocPort() + testDomain := "test.example.com" + tc := ssh.NewTunnelClient( + fmt.Sprintf("127.0.0.1:%d", localPort), + fmt.Sprintf("127.0.0.1:%d", sshPort), + fmt.Sprintf("https --custom_domain %s", testDomain), + ) + framework.ExpectNoError(tc.Start()) + defer tc.Close() + + tlsConfig, err := transport.NewServerTLSConfig("", "", "") + framework.ExpectNoError(err) + localServer := httpserver.New( + httpserver.WithBindPort(localPort), + httpserver.WithTLSConfig(tlsConfig), + httpserver.WithResponse([]byte("test")), + ) + f.RunServer("", localServer) + + time.Sleep(time.Second) + framework.NewRequestExpect(f). + Port(vhostPort). + RequestModify(func(r *request.Request) { + r.HTTPS().HTTPHost(testDomain).TLSConfig(&tls.Config{ + ServerName: testDomain, + InsecureSkipVerify: true, + }) + }). + ExpectResp([]byte("test")). + Ensure() + }) + + ginkgo.It("tcpmux", func() { + sshPort := f.AllocPort() + tcpmuxPort := f.AllocPort() + serverConf := consts.DefaultServerConfig + fmt.Sprintf(` + tcpmuxHTTPConnectPort = %d + sshTunnelGateway.bindPort = %d + `, tcpmuxPort, sshPort) + + f.RunProcesses([]string{serverConf}, nil) + + localPort := f.AllocPort() + testDomain := "test.example.com" + tc := ssh.NewTunnelClient( + fmt.Sprintf("127.0.0.1:%d", localPort), + fmt.Sprintf("127.0.0.1:%d", sshPort), + fmt.Sprintf("tcpmux --mux=httpconnect --custom_domain %s", testDomain), + ) + framework.ExpectNoError(tc.Start()) + defer tc.Close() + + localServer := streamserver.New( + streamserver.TCP, + streamserver.WithBindPort(localPort), + streamserver.WithRespContent([]byte("test")), + ) + f.RunServer("", localServer) + + time.Sleep(time.Second) + // Request without HTTP connect should get error + framework.NewRequestExpect(f). + Port(tcpmuxPort). + ExpectError(true). + Explain("request without HTTP connect expect error"). + Ensure() + + proxyURL := fmt.Sprintf("http://127.0.0.1:%d", tcpmuxPort) + // Request with incorrect connect hostname + framework.NewRequestExpect(f).RequestModify(func(r *request.Request) { + r.Addr("invalid").Proxy(proxyURL) + }).ExpectError(true).Explain("request without HTTP connect expect error").Ensure() + + // Request with correct connect hostname + framework.NewRequestExpect(f).RequestModify(func(r *request.Request) { + r.Addr(testDomain).Proxy(proxyURL) + }).ExpectResp([]byte("test")).Ensure() + }) + + ginkgo.It("stcp", func() { + sshPort := f.AllocPort() + serverConf := consts.DefaultServerConfig + fmt.Sprintf(` + sshTunnelGateway.bindPort = %d + `, sshPort) + + bindPort := f.AllocPort() + visitorConf := consts.DefaultClientConfig + fmt.Sprintf(` + [[visitors]] + name = "stcp-test-visitor" + type = "stcp" + serverName = "stcp-test" + secretKey = "abcdefg" + bindPort = %d + `, bindPort) + + f.RunProcesses([]string{serverConf}, []string{visitorConf}) + + localPort := f.PortByName(framework.TCPEchoServerPort) + tc := ssh.NewTunnelClient( + fmt.Sprintf("127.0.0.1:%d", localPort), + fmt.Sprintf("127.0.0.1:%d", sshPort), + "stcp -n stcp-test --sk=abcdefg --allow_users=\"*\"", + ) + framework.ExpectNoError(tc.Start()) + defer tc.Close() + + time.Sleep(time.Second) + + framework.NewRequestExpect(f). + Port(bindPort). + Ensure() + }) +})