mirror of
https://gitee.com/IrisVega/frp.git
synced 2024-11-01 22:31:29 +08:00
add e2e tests for ssh tunnel (#3805)
This commit is contained in:
parent
69ae2b0b69
commit
7c799ee921
1
.gitignore
vendored
1
.gitignore
vendored
@ -33,6 +33,7 @@ lastversion/
|
|||||||
dist/
|
dist/
|
||||||
.idea/
|
.idea/
|
||||||
.vscode/
|
.vscode/
|
||||||
|
.autogen_ssh_key
|
||||||
|
|
||||||
# Cache
|
# Cache
|
||||||
*.swp
|
*.swp
|
||||||
|
@ -56,8 +56,6 @@ type forwardedTCPPayload struct {
|
|||||||
Addr string
|
Addr string
|
||||||
Port uint32
|
Port uint32
|
||||||
|
|
||||||
// can be default empty value but do not delete it
|
|
||||||
// because ssh protocol shoule be reserved
|
|
||||||
OriginAddr string
|
OriginAddr string
|
||||||
OriginPort uint32
|
OriginPort uint32
|
||||||
}
|
}
|
||||||
@ -117,6 +115,8 @@ func (s *TunnelServer) Run() error {
|
|||||||
// join workConn and ssh channel
|
// join workConn and ssh channel
|
||||||
c, err := s.openConn(addr)
|
c, err := s.openConn(addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Trace("open conn error: %v", err)
|
||||||
|
workConn.Close()
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
libio.Join(c, workConn)
|
libio.Join(c, workConn)
|
||||||
@ -180,9 +180,7 @@ func (s *TunnelServer) waitForwardAddrAndExtraPayload(
|
|||||||
go func() {
|
go func() {
|
||||||
addrGot := false
|
addrGot := false
|
||||||
for req := range requests {
|
for req := range requests {
|
||||||
switch req.Type {
|
if req.Type == RequestTypeForward && !addrGot {
|
||||||
case RequestTypeForward:
|
|
||||||
if !addrGot {
|
|
||||||
payload := tcpipForward{}
|
payload := tcpipForward{}
|
||||||
if err := ssh.Unmarshal(req.Payload, &payload); err != nil {
|
if err := ssh.Unmarshal(req.Payload, &payload); err != nil {
|
||||||
return
|
return
|
||||||
@ -190,12 +188,10 @@ func (s *TunnelServer) waitForwardAddrAndExtraPayload(
|
|||||||
addrGot = true
|
addrGot = true
|
||||||
addrCh <- &payload
|
addrCh <- &payload
|
||||||
}
|
}
|
||||||
default:
|
|
||||||
if req.WantReply {
|
if req.WantReply {
|
||||||
_ = req.Reply(true, nil)
|
_ = req.Reply(true, nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// get extra payload
|
// get extra payload
|
||||||
@ -271,10 +267,10 @@ func (s *TunnelServer) handleNewChannel(channel ssh.NewChannel, extraPayloadCh c
|
|||||||
go s.keepAlive(ch)
|
go s.keepAlive(ch)
|
||||||
|
|
||||||
for req := range reqs {
|
for req := range reqs {
|
||||||
if req.Type != "exec" {
|
if req.WantReply {
|
||||||
continue
|
_ = req.Reply(true, nil)
|
||||||
}
|
}
|
||||||
if len(req.Payload) <= 4 {
|
if req.Type != "exec" || len(req.Payload) <= 4 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
end := 4 + binary.BigEndian.Uint32(req.Payload[:4])
|
end := 4 + binary.BigEndian.Uint32(req.Payload[:4])
|
||||||
@ -310,6 +306,9 @@ func (s *TunnelServer) openConn(addr *tcpipForward) (net.Conn, error) {
|
|||||||
payload := forwardedTCPPayload{
|
payload := forwardedTCPPayload{
|
||||||
Addr: addr.Host,
|
Addr: addr.Host,
|
||||||
Port: addr.Port,
|
Port: addr.Port,
|
||||||
|
// Note: Here is just for compatibility, not the real source address.
|
||||||
|
OriginAddr: addr.Host,
|
||||||
|
OriginPort: addr.Port,
|
||||||
}
|
}
|
||||||
channel, reqs, err := s.sshConn.OpenChannel(ChannelTypeServerOpenChannel, ssh.Marshal(&payload))
|
channel, reqs, err := s.sshConn.OpenChannel(ChannelTypeServerOpenChannel, ssh.Marshal(&payload))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
89
test/e2e/pkg/ssh/client.go
Normal file
89
test/e2e/pkg/ssh/client.go
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
package ssh
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
|
||||||
|
libio "github.com/fatedier/golib/io"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TunnelClient struct {
|
||||||
|
localAddr string
|
||||||
|
sshServer string
|
||||||
|
commands string
|
||||||
|
|
||||||
|
sshConn *ssh.Client
|
||||||
|
ln net.Listener
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTunnelClient(localAddr string, sshServer string, commands string) *TunnelClient {
|
||||||
|
return &TunnelClient{
|
||||||
|
localAddr: localAddr,
|
||||||
|
sshServer: sshServer,
|
||||||
|
commands: commands,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *TunnelClient) Start() error {
|
||||||
|
config := &ssh.ClientConfig{
|
||||||
|
User: "v0",
|
||||||
|
HostKeyCallback: func(string, net.Addr, ssh.PublicKey) error { return nil },
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := ssh.Dial("tcp", c.sshServer, config)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.sshConn = conn
|
||||||
|
|
||||||
|
l, err := conn.Listen("tcp", "0.0.0.0:80")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.ln = l
|
||||||
|
ch, req, err := conn.OpenChannel("direct", []byte(""))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer ch.Close()
|
||||||
|
go ssh.DiscardRequests(req)
|
||||||
|
|
||||||
|
type command struct {
|
||||||
|
Cmd string
|
||||||
|
}
|
||||||
|
_, err = ch.SendRequest("exec", false, ssh.Marshal(command{Cmd: c.commands}))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
go c.serveListener()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *TunnelClient) Close() {
|
||||||
|
if c.sshConn != nil {
|
||||||
|
_ = c.sshConn.Close()
|
||||||
|
}
|
||||||
|
if c.ln != nil {
|
||||||
|
_ = c.ln.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *TunnelClient) serveListener() {
|
||||||
|
for {
|
||||||
|
conn, err := c.ln.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go c.hanldeConn(conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *TunnelClient) hanldeConn(conn net.Conn) {
|
||||||
|
defer conn.Close()
|
||||||
|
local, err := net.Dial("tcp", c.localAddr)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, _, _ = libio.Join(local, conn)
|
||||||
|
}
|
193
test/e2e/v1/features/ssh_tunnel.go
Normal file
193
test/e2e/v1/features/ssh_tunnel.go
Normal file
@ -0,0 +1,193 @@
|
|||||||
|
package features
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/onsi/ginkgo/v2"
|
||||||
|
|
||||||
|
"github.com/fatedier/frp/pkg/transport"
|
||||||
|
"github.com/fatedier/frp/test/e2e/framework"
|
||||||
|
"github.com/fatedier/frp/test/e2e/framework/consts"
|
||||||
|
"github.com/fatedier/frp/test/e2e/mock/server/httpserver"
|
||||||
|
"github.com/fatedier/frp/test/e2e/mock/server/streamserver"
|
||||||
|
"github.com/fatedier/frp/test/e2e/pkg/request"
|
||||||
|
"github.com/fatedier/frp/test/e2e/pkg/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ = ginkgo.Describe("[Feature: SSH Tunnel]", func() {
|
||||||
|
f := framework.NewDefaultFramework()
|
||||||
|
|
||||||
|
ginkgo.It("tcp", func() {
|
||||||
|
sshPort := f.AllocPort()
|
||||||
|
serverConf := consts.DefaultServerConfig + fmt.Sprintf(`
|
||||||
|
sshTunnelGateway.bindPort = %d
|
||||||
|
`, sshPort)
|
||||||
|
|
||||||
|
f.RunProcesses([]string{serverConf}, nil)
|
||||||
|
|
||||||
|
localPort := f.PortByName(framework.TCPEchoServerPort)
|
||||||
|
remotePort := f.AllocPort()
|
||||||
|
tc := ssh.NewTunnelClient(
|
||||||
|
fmt.Sprintf("127.0.0.1:%d", localPort),
|
||||||
|
fmt.Sprintf("127.0.0.1:%d", sshPort),
|
||||||
|
fmt.Sprintf("tcp --remote_port %d", remotePort),
|
||||||
|
)
|
||||||
|
framework.ExpectNoError(tc.Start())
|
||||||
|
defer tc.Close()
|
||||||
|
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
framework.NewRequestExpect(f).Port(remotePort).Ensure()
|
||||||
|
})
|
||||||
|
|
||||||
|
ginkgo.It("http", func() {
|
||||||
|
sshPort := f.AllocPort()
|
||||||
|
vhostPort := f.AllocPort()
|
||||||
|
serverConf := consts.DefaultServerConfig + fmt.Sprintf(`
|
||||||
|
vhostHTTPPort = %d
|
||||||
|
sshTunnelGateway.bindPort = %d
|
||||||
|
`, vhostPort, sshPort)
|
||||||
|
|
||||||
|
f.RunProcesses([]string{serverConf}, nil)
|
||||||
|
|
||||||
|
localPort := f.PortByName(framework.HTTPSimpleServerPort)
|
||||||
|
tc := ssh.NewTunnelClient(
|
||||||
|
fmt.Sprintf("127.0.0.1:%d", localPort),
|
||||||
|
fmt.Sprintf("127.0.0.1:%d", sshPort),
|
||||||
|
"http --custom_domain test.example.com",
|
||||||
|
)
|
||||||
|
framework.ExpectNoError(tc.Start())
|
||||||
|
defer tc.Close()
|
||||||
|
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
framework.NewRequestExpect(f).Port(vhostPort).
|
||||||
|
RequestModify(func(r *request.Request) {
|
||||||
|
r.HTTP().HTTPHost("test.example.com")
|
||||||
|
}).
|
||||||
|
Ensure()
|
||||||
|
})
|
||||||
|
|
||||||
|
ginkgo.It("https", func() {
|
||||||
|
sshPort := f.AllocPort()
|
||||||
|
vhostPort := f.AllocPort()
|
||||||
|
serverConf := consts.DefaultServerConfig + fmt.Sprintf(`
|
||||||
|
vhostHTTPSPort = %d
|
||||||
|
sshTunnelGateway.bindPort = %d
|
||||||
|
`, vhostPort, sshPort)
|
||||||
|
|
||||||
|
f.RunProcesses([]string{serverConf}, nil)
|
||||||
|
|
||||||
|
localPort := f.AllocPort()
|
||||||
|
testDomain := "test.example.com"
|
||||||
|
tc := ssh.NewTunnelClient(
|
||||||
|
fmt.Sprintf("127.0.0.1:%d", localPort),
|
||||||
|
fmt.Sprintf("127.0.0.1:%d", sshPort),
|
||||||
|
fmt.Sprintf("https --custom_domain %s", testDomain),
|
||||||
|
)
|
||||||
|
framework.ExpectNoError(tc.Start())
|
||||||
|
defer tc.Close()
|
||||||
|
|
||||||
|
tlsConfig, err := transport.NewServerTLSConfig("", "", "")
|
||||||
|
framework.ExpectNoError(err)
|
||||||
|
localServer := httpserver.New(
|
||||||
|
httpserver.WithBindPort(localPort),
|
||||||
|
httpserver.WithTLSConfig(tlsConfig),
|
||||||
|
httpserver.WithResponse([]byte("test")),
|
||||||
|
)
|
||||||
|
f.RunServer("", localServer)
|
||||||
|
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
framework.NewRequestExpect(f).
|
||||||
|
Port(vhostPort).
|
||||||
|
RequestModify(func(r *request.Request) {
|
||||||
|
r.HTTPS().HTTPHost(testDomain).TLSConfig(&tls.Config{
|
||||||
|
ServerName: testDomain,
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
})
|
||||||
|
}).
|
||||||
|
ExpectResp([]byte("test")).
|
||||||
|
Ensure()
|
||||||
|
})
|
||||||
|
|
||||||
|
ginkgo.It("tcpmux", func() {
|
||||||
|
sshPort := f.AllocPort()
|
||||||
|
tcpmuxPort := f.AllocPort()
|
||||||
|
serverConf := consts.DefaultServerConfig + fmt.Sprintf(`
|
||||||
|
tcpmuxHTTPConnectPort = %d
|
||||||
|
sshTunnelGateway.bindPort = %d
|
||||||
|
`, tcpmuxPort, sshPort)
|
||||||
|
|
||||||
|
f.RunProcesses([]string{serverConf}, nil)
|
||||||
|
|
||||||
|
localPort := f.AllocPort()
|
||||||
|
testDomain := "test.example.com"
|
||||||
|
tc := ssh.NewTunnelClient(
|
||||||
|
fmt.Sprintf("127.0.0.1:%d", localPort),
|
||||||
|
fmt.Sprintf("127.0.0.1:%d", sshPort),
|
||||||
|
fmt.Sprintf("tcpmux --mux=httpconnect --custom_domain %s", testDomain),
|
||||||
|
)
|
||||||
|
framework.ExpectNoError(tc.Start())
|
||||||
|
defer tc.Close()
|
||||||
|
|
||||||
|
localServer := streamserver.New(
|
||||||
|
streamserver.TCP,
|
||||||
|
streamserver.WithBindPort(localPort),
|
||||||
|
streamserver.WithRespContent([]byte("test")),
|
||||||
|
)
|
||||||
|
f.RunServer("", localServer)
|
||||||
|
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
// Request without HTTP connect should get error
|
||||||
|
framework.NewRequestExpect(f).
|
||||||
|
Port(tcpmuxPort).
|
||||||
|
ExpectError(true).
|
||||||
|
Explain("request without HTTP connect expect error").
|
||||||
|
Ensure()
|
||||||
|
|
||||||
|
proxyURL := fmt.Sprintf("http://127.0.0.1:%d", tcpmuxPort)
|
||||||
|
// Request with incorrect connect hostname
|
||||||
|
framework.NewRequestExpect(f).RequestModify(func(r *request.Request) {
|
||||||
|
r.Addr("invalid").Proxy(proxyURL)
|
||||||
|
}).ExpectError(true).Explain("request without HTTP connect expect error").Ensure()
|
||||||
|
|
||||||
|
// Request with correct connect hostname
|
||||||
|
framework.NewRequestExpect(f).RequestModify(func(r *request.Request) {
|
||||||
|
r.Addr(testDomain).Proxy(proxyURL)
|
||||||
|
}).ExpectResp([]byte("test")).Ensure()
|
||||||
|
})
|
||||||
|
|
||||||
|
ginkgo.It("stcp", func() {
|
||||||
|
sshPort := f.AllocPort()
|
||||||
|
serverConf := consts.DefaultServerConfig + fmt.Sprintf(`
|
||||||
|
sshTunnelGateway.bindPort = %d
|
||||||
|
`, sshPort)
|
||||||
|
|
||||||
|
bindPort := f.AllocPort()
|
||||||
|
visitorConf := consts.DefaultClientConfig + fmt.Sprintf(`
|
||||||
|
[[visitors]]
|
||||||
|
name = "stcp-test-visitor"
|
||||||
|
type = "stcp"
|
||||||
|
serverName = "stcp-test"
|
||||||
|
secretKey = "abcdefg"
|
||||||
|
bindPort = %d
|
||||||
|
`, bindPort)
|
||||||
|
|
||||||
|
f.RunProcesses([]string{serverConf}, []string{visitorConf})
|
||||||
|
|
||||||
|
localPort := f.PortByName(framework.TCPEchoServerPort)
|
||||||
|
tc := ssh.NewTunnelClient(
|
||||||
|
fmt.Sprintf("127.0.0.1:%d", localPort),
|
||||||
|
fmt.Sprintf("127.0.0.1:%d", sshPort),
|
||||||
|
"stcp -n stcp-test --sk=abcdefg --allow_users=\"*\"",
|
||||||
|
)
|
||||||
|
framework.ExpectNoError(tc.Start())
|
||||||
|
defer tc.Close()
|
||||||
|
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
|
||||||
|
framework.NewRequestExpect(f).
|
||||||
|
Port(bindPort).
|
||||||
|
Ensure()
|
||||||
|
})
|
||||||
|
})
|
Loading…
Reference in New Issue
Block a user