add e2e tests for ssh tunnel (#3805)

This commit is contained in:
fatedier 2023-11-28 13:48:32 +08:00 committed by GitHub
parent 69ae2b0b69
commit 7c799ee921
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 300 additions and 18 deletions

1
.gitignore vendored
View File

@ -33,6 +33,7 @@ lastversion/
dist/
.idea/
.vscode/
.autogen_ssh_key
# Cache
*.swp

View File

@ -56,8 +56,6 @@ type forwardedTCPPayload struct {
Addr string
Port uint32
// can be default empty value but do not delete it
// because ssh protocol shoule be reserved
OriginAddr string
OriginPort uint32
}
@ -117,6 +115,8 @@ func (s *TunnelServer) Run() error {
// join workConn and ssh channel
c, err := s.openConn(addr)
if err != nil {
log.Trace("open conn error: %v", err)
workConn.Close()
return false
}
libio.Join(c, workConn)
@ -180,9 +180,7 @@ func (s *TunnelServer) waitForwardAddrAndExtraPayload(
go func() {
addrGot := false
for req := range requests {
switch req.Type {
case RequestTypeForward:
if !addrGot {
if req.Type == RequestTypeForward && !addrGot {
payload := tcpipForward{}
if err := ssh.Unmarshal(req.Payload, &payload); err != nil {
return
@ -190,12 +188,10 @@ func (s *TunnelServer) waitForwardAddrAndExtraPayload(
addrGot = true
addrCh <- &payload
}
default:
if req.WantReply {
_ = req.Reply(true, nil)
}
}
}
}()
// get extra payload
@ -271,10 +267,10 @@ func (s *TunnelServer) handleNewChannel(channel ssh.NewChannel, extraPayloadCh c
go s.keepAlive(ch)
for req := range reqs {
if req.Type != "exec" {
continue
if req.WantReply {
_ = req.Reply(true, nil)
}
if len(req.Payload) <= 4 {
if req.Type != "exec" || len(req.Payload) <= 4 {
continue
}
end := 4 + binary.BigEndian.Uint32(req.Payload[:4])
@ -310,6 +306,9 @@ func (s *TunnelServer) openConn(addr *tcpipForward) (net.Conn, error) {
payload := forwardedTCPPayload{
Addr: addr.Host,
Port: addr.Port,
// Note: Here is just for compatibility, not the real source address.
OriginAddr: addr.Host,
OriginPort: addr.Port,
}
channel, reqs, err := s.sshConn.OpenChannel(ChannelTypeServerOpenChannel, ssh.Marshal(&payload))
if err != nil {

View File

@ -0,0 +1,89 @@
package ssh
import (
"net"
libio "github.com/fatedier/golib/io"
"golang.org/x/crypto/ssh"
)
type TunnelClient struct {
localAddr string
sshServer string
commands string
sshConn *ssh.Client
ln net.Listener
}
func NewTunnelClient(localAddr string, sshServer string, commands string) *TunnelClient {
return &TunnelClient{
localAddr: localAddr,
sshServer: sshServer,
commands: commands,
}
}
func (c *TunnelClient) Start() error {
config := &ssh.ClientConfig{
User: "v0",
HostKeyCallback: func(string, net.Addr, ssh.PublicKey) error { return nil },
}
conn, err := ssh.Dial("tcp", c.sshServer, config)
if err != nil {
return err
}
c.sshConn = conn
l, err := conn.Listen("tcp", "0.0.0.0:80")
if err != nil {
return err
}
c.ln = l
ch, req, err := conn.OpenChannel("direct", []byte(""))
if err != nil {
return err
}
defer ch.Close()
go ssh.DiscardRequests(req)
type command struct {
Cmd string
}
_, err = ch.SendRequest("exec", false, ssh.Marshal(command{Cmd: c.commands}))
if err != nil {
return err
}
go c.serveListener()
return nil
}
func (c *TunnelClient) Close() {
if c.sshConn != nil {
_ = c.sshConn.Close()
}
if c.ln != nil {
_ = c.ln.Close()
}
}
func (c *TunnelClient) serveListener() {
for {
conn, err := c.ln.Accept()
if err != nil {
return
}
go c.hanldeConn(conn)
}
}
func (c *TunnelClient) hanldeConn(conn net.Conn) {
defer conn.Close()
local, err := net.Dial("tcp", c.localAddr)
if err != nil {
return
}
_, _, _ = libio.Join(local, conn)
}

View File

@ -0,0 +1,193 @@
package features
import (
"crypto/tls"
"fmt"
"time"
"github.com/onsi/ginkgo/v2"
"github.com/fatedier/frp/pkg/transport"
"github.com/fatedier/frp/test/e2e/framework"
"github.com/fatedier/frp/test/e2e/framework/consts"
"github.com/fatedier/frp/test/e2e/mock/server/httpserver"
"github.com/fatedier/frp/test/e2e/mock/server/streamserver"
"github.com/fatedier/frp/test/e2e/pkg/request"
"github.com/fatedier/frp/test/e2e/pkg/ssh"
)
var _ = ginkgo.Describe("[Feature: SSH Tunnel]", func() {
f := framework.NewDefaultFramework()
ginkgo.It("tcp", func() {
sshPort := f.AllocPort()
serverConf := consts.DefaultServerConfig + fmt.Sprintf(`
sshTunnelGateway.bindPort = %d
`, sshPort)
f.RunProcesses([]string{serverConf}, nil)
localPort := f.PortByName(framework.TCPEchoServerPort)
remotePort := f.AllocPort()
tc := ssh.NewTunnelClient(
fmt.Sprintf("127.0.0.1:%d", localPort),
fmt.Sprintf("127.0.0.1:%d", sshPort),
fmt.Sprintf("tcp --remote_port %d", remotePort),
)
framework.ExpectNoError(tc.Start())
defer tc.Close()
time.Sleep(time.Second)
framework.NewRequestExpect(f).Port(remotePort).Ensure()
})
ginkgo.It("http", func() {
sshPort := f.AllocPort()
vhostPort := f.AllocPort()
serverConf := consts.DefaultServerConfig + fmt.Sprintf(`
vhostHTTPPort = %d
sshTunnelGateway.bindPort = %d
`, vhostPort, sshPort)
f.RunProcesses([]string{serverConf}, nil)
localPort := f.PortByName(framework.HTTPSimpleServerPort)
tc := ssh.NewTunnelClient(
fmt.Sprintf("127.0.0.1:%d", localPort),
fmt.Sprintf("127.0.0.1:%d", sshPort),
"http --custom_domain test.example.com",
)
framework.ExpectNoError(tc.Start())
defer tc.Close()
time.Sleep(time.Second)
framework.NewRequestExpect(f).Port(vhostPort).
RequestModify(func(r *request.Request) {
r.HTTP().HTTPHost("test.example.com")
}).
Ensure()
})
ginkgo.It("https", func() {
sshPort := f.AllocPort()
vhostPort := f.AllocPort()
serverConf := consts.DefaultServerConfig + fmt.Sprintf(`
vhostHTTPSPort = %d
sshTunnelGateway.bindPort = %d
`, vhostPort, sshPort)
f.RunProcesses([]string{serverConf}, nil)
localPort := f.AllocPort()
testDomain := "test.example.com"
tc := ssh.NewTunnelClient(
fmt.Sprintf("127.0.0.1:%d", localPort),
fmt.Sprintf("127.0.0.1:%d", sshPort),
fmt.Sprintf("https --custom_domain %s", testDomain),
)
framework.ExpectNoError(tc.Start())
defer tc.Close()
tlsConfig, err := transport.NewServerTLSConfig("", "", "")
framework.ExpectNoError(err)
localServer := httpserver.New(
httpserver.WithBindPort(localPort),
httpserver.WithTLSConfig(tlsConfig),
httpserver.WithResponse([]byte("test")),
)
f.RunServer("", localServer)
time.Sleep(time.Second)
framework.NewRequestExpect(f).
Port(vhostPort).
RequestModify(func(r *request.Request) {
r.HTTPS().HTTPHost(testDomain).TLSConfig(&tls.Config{
ServerName: testDomain,
InsecureSkipVerify: true,
})
}).
ExpectResp([]byte("test")).
Ensure()
})
ginkgo.It("tcpmux", func() {
sshPort := f.AllocPort()
tcpmuxPort := f.AllocPort()
serverConf := consts.DefaultServerConfig + fmt.Sprintf(`
tcpmuxHTTPConnectPort = %d
sshTunnelGateway.bindPort = %d
`, tcpmuxPort, sshPort)
f.RunProcesses([]string{serverConf}, nil)
localPort := f.AllocPort()
testDomain := "test.example.com"
tc := ssh.NewTunnelClient(
fmt.Sprintf("127.0.0.1:%d", localPort),
fmt.Sprintf("127.0.0.1:%d", sshPort),
fmt.Sprintf("tcpmux --mux=httpconnect --custom_domain %s", testDomain),
)
framework.ExpectNoError(tc.Start())
defer tc.Close()
localServer := streamserver.New(
streamserver.TCP,
streamserver.WithBindPort(localPort),
streamserver.WithRespContent([]byte("test")),
)
f.RunServer("", localServer)
time.Sleep(time.Second)
// Request without HTTP connect should get error
framework.NewRequestExpect(f).
Port(tcpmuxPort).
ExpectError(true).
Explain("request without HTTP connect expect error").
Ensure()
proxyURL := fmt.Sprintf("http://127.0.0.1:%d", tcpmuxPort)
// Request with incorrect connect hostname
framework.NewRequestExpect(f).RequestModify(func(r *request.Request) {
r.Addr("invalid").Proxy(proxyURL)
}).ExpectError(true).Explain("request without HTTP connect expect error").Ensure()
// Request with correct connect hostname
framework.NewRequestExpect(f).RequestModify(func(r *request.Request) {
r.Addr(testDomain).Proxy(proxyURL)
}).ExpectResp([]byte("test")).Ensure()
})
ginkgo.It("stcp", func() {
sshPort := f.AllocPort()
serverConf := consts.DefaultServerConfig + fmt.Sprintf(`
sshTunnelGateway.bindPort = %d
`, sshPort)
bindPort := f.AllocPort()
visitorConf := consts.DefaultClientConfig + fmt.Sprintf(`
[[visitors]]
name = "stcp-test-visitor"
type = "stcp"
serverName = "stcp-test"
secretKey = "abcdefg"
bindPort = %d
`, bindPort)
f.RunProcesses([]string{serverConf}, []string{visitorConf})
localPort := f.PortByName(framework.TCPEchoServerPort)
tc := ssh.NewTunnelClient(
fmt.Sprintf("127.0.0.1:%d", localPort),
fmt.Sprintf("127.0.0.1:%d", sshPort),
"stcp -n stcp-test --sk=abcdefg --allow_users=\"*\"",
)
framework.ExpectNoError(tc.Start())
defer tc.Close()
time.Sleep(time.Second)
framework.NewRequestExpect(f).
Port(bindPort).
Ensure()
})
})