frp/pkg/ssh/server.go

386 lines
9.9 KiB
Go
Raw Normal View History

2023-11-21 11:19:35 +08:00
// Copyright 2023 The frp Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ssh
import (
"context"
"encoding/binary"
2023-11-27 15:47:49 +08:00
"errors"
2023-11-21 11:19:35 +08:00
"fmt"
"net"
"strings"
2023-11-27 15:47:49 +08:00
"sync"
2023-11-21 11:19:35 +08:00
"time"
libio "github.com/fatedier/golib/io"
"github.com/samber/lo"
"github.com/spf13/cobra"
flag "github.com/spf13/pflag"
2023-11-21 11:19:35 +08:00
"golang.org/x/crypto/ssh"
2023-11-27 15:47:49 +08:00
"github.com/fatedier/frp/client/proxy"
2023-11-21 11:19:35 +08:00
"github.com/fatedier/frp/pkg/config"
v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/msg"
2023-11-27 15:47:49 +08:00
"github.com/fatedier/frp/pkg/util/log"
netpkg "github.com/fatedier/frp/pkg/util/net"
2023-11-21 11:19:35 +08:00
"github.com/fatedier/frp/pkg/util/util"
"github.com/fatedier/frp/pkg/util/xlog"
"github.com/fatedier/frp/pkg/virtual"
)
const (
// https://datatracker.ietf.org/doc/html/rfc4254#page-16
ChannelTypeServerOpenChannel = "forwarded-tcpip"
RequestTypeForward = "tcpip-forward"
)
type tcpipForward struct {
Host string
Port uint32
}
// https://datatracker.ietf.org/doc/html/rfc4254#page-16
type forwardedTCPPayload struct {
Addr string
Port uint32
OriginAddr string
OriginPort uint32
}
type TunnelServer struct {
underlyingConn net.Conn
sshConn *ssh.ServerConn
sc *ssh.ServerConfig
firstChannel ssh.Channel
2023-11-21 11:19:35 +08:00
vc *virtual.Client
2023-11-27 15:47:49 +08:00
peerServerListener *netpkg.InternalListener
2023-11-21 11:19:35 +08:00
doneCh chan struct{}
2023-11-27 15:47:49 +08:00
closeDoneChOnce sync.Once
2023-11-21 11:19:35 +08:00
}
2023-11-27 15:47:49 +08:00
func NewTunnelServer(conn net.Conn, sc *ssh.ServerConfig, peerServerListener *netpkg.InternalListener) (*TunnelServer, error) {
2023-11-21 11:19:35 +08:00
s := &TunnelServer{
underlyingConn: conn,
sc: sc,
2023-11-27 15:47:49 +08:00
peerServerListener: peerServerListener,
2023-11-21 11:19:35 +08:00
doneCh: make(chan struct{}),
}
return s, nil
}
func (s *TunnelServer) Run() error {
sshConn, channels, requests, err := ssh.NewServerConn(s.underlyingConn, s.sc)
if err != nil {
return err
}
2023-11-21 11:19:35 +08:00
s.sshConn = sshConn
addr, extraPayload, err := s.waitForwardAddrAndExtraPayload(channels, requests, 3*time.Second)
if err != nil {
return err
}
clientCfg, pc, helpMessage, err := s.parseClientAndProxyConfigurer(addr, extraPayload)
2023-11-21 11:19:35 +08:00
if err != nil {
if errors.Is(err, flag.ErrHelp) {
s.writeToClient(helpMessage)
return nil
}
s.writeToClient(err.Error())
return fmt.Errorf("parse flags from ssh client error: %v", err)
2023-11-21 11:19:35 +08:00
}
2023-11-27 15:47:49 +08:00
clientCfg.Complete()
if sshConn.Permissions != nil {
clientCfg.User = util.EmptyOr(sshConn.Permissions.Extensions["user"], clientCfg.User)
}
2023-11-21 11:19:35 +08:00
pc.Complete(clientCfg.User)
2023-11-27 15:47:49 +08:00
vc, err := virtual.NewClient(virtual.ClientOptions{
Common: clientCfg,
Spec: &msg.ClientSpec{
Type: "ssh-tunnel",
// If ssh does not require authentication, then the virtual client needs to authenticate through a token.
// Otherwise, once ssh authentication is passed, the virtual client does not need to authenticate again.
AlwaysAuthPass: !s.sc.NoClientAuth,
},
HandleWorkConnCb: func(base *v1.ProxyBaseConfig, workConn net.Conn, m *msg.StartWorkConn) bool {
// join workConn and ssh channel
c, err := s.openConn(addr)
if err != nil {
2023-11-28 13:48:32 +08:00
log.Trace("open conn error: %v", err)
workConn.Close()
2023-11-27 15:47:49 +08:00
return false
}
libio.Join(c, workConn)
2023-11-21 11:19:35 +08:00
return false
2023-11-27 15:47:49 +08:00
},
2023-11-21 11:19:35 +08:00
})
2023-11-27 15:47:49 +08:00
if err != nil {
return err
}
s.vc = vc
2023-11-21 11:19:35 +08:00
// transfer connection from virtual client to server peer listener
go func() {
l := s.vc.PeerListener()
for {
conn, err := l.Accept()
if err != nil {
return
}
2023-11-27 15:47:49 +08:00
_ = s.peerServerListener.PutConn(conn)
2023-11-21 11:19:35 +08:00
}
}()
xl := xlog.New().AddPrefix(xlog.LogPrefix{Name: "sshVirtualClient", Value: "sshVirtualClient", Priority: 100})
ctx := xlog.NewContext(context.Background(), xl)
go func() {
vcErr := s.vc.Run(ctx)
if vcErr != nil {
s.writeToClient(vcErr.Error())
}
2023-11-27 15:47:49 +08:00
// If vc.Run returns, it means that the virtual client has been closed, and the ssh tunnel connection should be closed.
// One scenario is that the virtual client exits due to login failure.
s.closeDoneChOnce.Do(func() {
_ = sshConn.Close()
close(s.doneCh)
})
2023-11-21 11:19:35 +08:00
}()
s.vc.UpdateProxyConfigurer([]v1.ProxyConfigurer{pc})
if ps, err := s.waitProxyStatusReady(pc.GetBaseConfig().Name, time.Second); err != nil {
s.writeToClient(err.Error())
2023-11-27 15:47:49 +08:00
log.Warn("wait proxy status ready error: %v", err)
} else {
// success
s.writeToClient(createSuccessInfo(clientCfg.User, pc, ps))
2023-11-27 15:47:49 +08:00
_ = sshConn.Wait()
}
2023-11-21 11:19:35 +08:00
s.vc.Close()
2023-11-27 15:47:49 +08:00
log.Trace("ssh tunnel connection from %v closed", sshConn.RemoteAddr())
s.closeDoneChOnce.Do(func() {
_ = sshConn.Close()
close(s.doneCh)
})
2023-11-21 11:19:35 +08:00
return nil
}
func (s *TunnelServer) writeToClient(data string) {
if s.firstChannel == nil {
return
}
_, _ = s.firstChannel.Write([]byte(data + "\n"))
}
2023-11-21 11:19:35 +08:00
func (s *TunnelServer) waitForwardAddrAndExtraPayload(
channels <-chan ssh.NewChannel,
requests <-chan *ssh.Request,
timeout time.Duration,
) (*tcpipForward, string, error) {
addrCh := make(chan *tcpipForward, 1)
extraPayloadCh := make(chan string, 1)
// get forward address
go func() {
addrGot := false
for req := range requests {
2023-11-28 13:48:32 +08:00
if req.Type == RequestTypeForward && !addrGot {
payload := tcpipForward{}
if err := ssh.Unmarshal(req.Payload, &payload); err != nil {
return
2023-11-21 11:19:35 +08:00
}
2023-11-28 13:48:32 +08:00
addrGot = true
addrCh <- &payload
}
if req.WantReply {
_ = req.Reply(true, nil)
2023-11-21 11:19:35 +08:00
}
}
}()
// get extra payload
go func() {
for newChannel := range channels {
// extraPayload will send to extraPayloadCh
go s.handleNewChannel(newChannel, extraPayloadCh)
}
}()
var (
addr *tcpipForward
extraPayload string
)
timer := time.NewTimer(timeout)
defer timer.Stop()
for {
select {
case v := <-addrCh:
addr = v
case extra := <-extraPayloadCh:
extraPayload = extra
case <-timer.C:
return nil, "", fmt.Errorf("get addr and extra payload timeout")
}
if addr != nil && extraPayload != "" {
break
}
}
return addr, extraPayload, nil
}
func (s *TunnelServer) parseClientAndProxyConfigurer(_ *tcpipForward, extraPayload string) (*v1.ClientCommonConfig, v1.ProxyConfigurer, string, error) {
helpMessage := ""
cmd := &cobra.Command{
Use: "ssh v0@{address} [command]",
Short: "ssh v0@{address} [command]",
Run: func(*cobra.Command, []string) {},
}
cmd.SetGlobalNormalizationFunc(config.WordSepNormalizeFunc)
2023-11-21 11:19:35 +08:00
args := strings.Split(extraPayload, " ")
if len(args) < 1 {
return nil, nil, helpMessage, fmt.Errorf("invalid extra payload")
2023-11-21 11:19:35 +08:00
}
proxyType := strings.TrimSpace(args[0])
supportTypes := []string{"tcp", "http", "https", "tcpmux", "stcp"}
if !lo.Contains(supportTypes, proxyType) {
return nil, nil, helpMessage, fmt.Errorf("invalid proxy type: %s, support types: %v", proxyType, supportTypes)
2023-11-21 11:19:35 +08:00
}
pc := v1.NewProxyConfigurerByType(v1.ProxyType(proxyType))
if pc == nil {
return nil, nil, helpMessage, fmt.Errorf("new proxy configurer error")
2023-11-21 11:19:35 +08:00
}
config.RegisterProxyFlags(cmd, pc, config.WithSSHMode())
2023-11-21 11:19:35 +08:00
clientCfg := v1.ClientCommonConfig{}
config.RegisterClientCommonConfigFlags(cmd, &clientCfg, config.WithSSHMode())
2023-11-21 11:19:35 +08:00
cmd.InitDefaultHelpCmd()
2023-11-21 11:19:35 +08:00
if err := cmd.ParseFlags(args); err != nil {
if errors.Is(err, flag.ErrHelp) {
helpMessage = cmd.UsageString()
}
return nil, nil, helpMessage, err
2023-11-21 11:19:35 +08:00
}
2023-11-27 15:47:49 +08:00
// if name is not set, generate a random one
if pc.GetBaseConfig().Name == "" {
id, err := util.RandIDWithLen(8)
if err != nil {
return nil, nil, helpMessage, fmt.Errorf("generate random id error: %v", err)
2023-11-27 15:47:49 +08:00
}
pc.GetBaseConfig().Name = fmt.Sprintf("sshtunnel-%s-%s", proxyType, id)
}
return &clientCfg, pc, helpMessage, nil
2023-11-21 11:19:35 +08:00
}
func (s *TunnelServer) handleNewChannel(channel ssh.NewChannel, extraPayloadCh chan string) {
ch, reqs, err := channel.Accept()
if err != nil {
return
}
if s.firstChannel == nil {
s.firstChannel = ch
}
2023-11-21 11:19:35 +08:00
go s.keepAlive(ch)
for req := range reqs {
2023-11-28 13:48:32 +08:00
if req.WantReply {
_ = req.Reply(true, nil)
2023-11-21 11:19:35 +08:00
}
2023-11-28 13:48:32 +08:00
if req.Type != "exec" || len(req.Payload) <= 4 {
2023-11-21 11:19:35 +08:00
continue
}
end := 4 + binary.BigEndian.Uint32(req.Payload[:4])
if len(req.Payload) < int(end) {
continue
}
extraPayload := string(req.Payload[4:end])
select {
case extraPayloadCh <- extraPayload:
default:
}
}
}
func (s *TunnelServer) keepAlive(ch ssh.Channel) {
tk := time.NewTicker(time.Second * 30)
defer tk.Stop()
for {
select {
case <-tk.C:
_, err := ch.SendRequest("heartbeat", false, nil)
if err != nil {
return
}
case <-s.doneCh:
return
}
}
}
func (s *TunnelServer) openConn(addr *tcpipForward) (net.Conn, error) {
payload := forwardedTCPPayload{
Addr: addr.Host,
Port: addr.Port,
2023-11-28 13:48:32 +08:00
// Note: Here is just for compatibility, not the real source address.
OriginAddr: addr.Host,
OriginPort: addr.Port,
2023-11-21 11:19:35 +08:00
}
channel, reqs, err := s.sshConn.OpenChannel(ChannelTypeServerOpenChannel, ssh.Marshal(&payload))
if err != nil {
return nil, fmt.Errorf("open ssh channel error: %v", err)
}
go ssh.DiscardRequests(reqs)
2023-11-27 15:47:49 +08:00
conn := netpkg.WrapReadWriteCloserToConn(channel, s.underlyingConn)
2023-11-21 11:19:35 +08:00
return conn, nil
}
2023-11-27 15:47:49 +08:00
func (s *TunnelServer) waitProxyStatusReady(name string, timeout time.Duration) (*proxy.WorkingStatus, error) {
2023-11-27 15:47:49 +08:00
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
timer := time.NewTimer(timeout)
defer timer.Stop()
for {
select {
case <-ticker.C:
ps, err := s.vc.Service().GetProxyStatus(name)
if err != nil {
continue
}
switch ps.Phase {
case proxy.ProxyPhaseRunning:
return ps, nil
2023-11-27 15:47:49 +08:00
case proxy.ProxyPhaseStartErr, proxy.ProxyPhaseClosed:
return ps, errors.New(ps.Err)
2023-11-27 15:47:49 +08:00
}
case <-timer.C:
return nil, fmt.Errorf("wait proxy status ready timeout")
2023-11-27 15:47:49 +08:00
case <-s.doneCh:
return nil, fmt.Errorf("ssh tunnel server closed")
2023-11-27 15:47:49 +08:00
}
}
}