frp/client/proxy/proxy.go
2023-05-28 16:50:43 +08:00

369 lines
8.9 KiB
Go

// Copyright 2017 fatedier, fatedier@gmail.com
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package proxy
import (
"bytes"
"context"
"io"
"net"
"strconv"
"strings"
"sync"
"time"
frpIo "github.com/fatedier/golib/io"
libdial "github.com/fatedier/golib/net/dial"
pp "github.com/pires/go-proxyproto"
"golang.org/x/time/rate"
"github.com/fatedier/frp/pkg/config"
"github.com/fatedier/frp/pkg/msg"
plugin "github.com/fatedier/frp/pkg/plugin/client"
"github.com/fatedier/frp/pkg/transport"
"github.com/fatedier/frp/pkg/util/limit"
"github.com/fatedier/frp/pkg/util/xlog"
)
// Proxy defines how to handle work connections for different proxy type.
type Proxy interface {
Run() error
// InWorkConn accept work connections registered to server.
InWorkConn(net.Conn, *msg.StartWorkConn)
Close()
}
func NewProxy(
ctx context.Context,
pxyConf config.ProxyConf,
clientCfg config.ClientCommonConf,
msgTransporter transport.MessageTransporter,
) (pxy Proxy) {
var limiter *rate.Limiter
limitBytes := pxyConf.GetBaseInfo().BandwidthLimit.Bytes()
if limitBytes > 0 && pxyConf.GetBaseInfo().BandwidthLimitMode == config.BandwidthLimitModeClient {
limiter = rate.NewLimiter(rate.Limit(float64(limitBytes)), int(limitBytes))
}
baseProxy := BaseProxy{
clientCfg: clientCfg,
limiter: limiter,
msgTransporter: msgTransporter,
xl: xlog.FromContextSafe(ctx),
ctx: ctx,
}
switch cfg := pxyConf.(type) {
case *config.TCPProxyConf:
pxy = &TCPProxy{
BaseProxy: &baseProxy,
cfg: cfg,
}
case *config.TCPMuxProxyConf:
pxy = &TCPMuxProxy{
BaseProxy: &baseProxy,
cfg: cfg,
}
case *config.UDPProxyConf:
pxy = &UDPProxy{
BaseProxy: &baseProxy,
cfg: cfg,
}
case *config.HTTPProxyConf:
pxy = &HTTPProxy{
BaseProxy: &baseProxy,
cfg: cfg,
}
case *config.HTTPSProxyConf:
pxy = &HTTPSProxy{
BaseProxy: &baseProxy,
cfg: cfg,
}
case *config.STCPProxyConf:
pxy = &STCPProxy{
BaseProxy: &baseProxy,
cfg: cfg,
}
case *config.XTCPProxyConf:
pxy = &XTCPProxy{
BaseProxy: &baseProxy,
cfg: cfg,
}
case *config.SUDPProxyConf:
pxy = &SUDPProxy{
BaseProxy: &baseProxy,
cfg: cfg,
closeCh: make(chan struct{}),
}
}
return
}
type BaseProxy struct {
closed bool
clientCfg config.ClientCommonConf
msgTransporter transport.MessageTransporter
limiter *rate.Limiter
mu sync.RWMutex
xl *xlog.Logger
ctx context.Context
}
// TCP
type TCPProxy struct {
*BaseProxy
cfg *config.TCPProxyConf
proxyPlugin plugin.Plugin
}
func (pxy *TCPProxy) Run() (err error) {
if pxy.cfg.Plugin != "" {
pxy.proxyPlugin, err = plugin.Create(pxy.cfg.Plugin, pxy.cfg.PluginParams)
if err != nil {
return
}
}
return
}
func (pxy *TCPProxy) Close() {
if pxy.proxyPlugin != nil {
pxy.proxyPlugin.Close()
}
}
func (pxy *TCPProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) {
HandleTCPWorkConnection(pxy.ctx, &pxy.cfg.LocalSvrConf, pxy.proxyPlugin, pxy.cfg.GetBaseInfo(), pxy.limiter,
conn, []byte(pxy.clientCfg.Token), m)
}
// TCP Multiplexer
type TCPMuxProxy struct {
*BaseProxy
cfg *config.TCPMuxProxyConf
proxyPlugin plugin.Plugin
}
func (pxy *TCPMuxProxy) Run() (err error) {
if pxy.cfg.Plugin != "" {
pxy.proxyPlugin, err = plugin.Create(pxy.cfg.Plugin, pxy.cfg.PluginParams)
if err != nil {
return
}
}
return
}
func (pxy *TCPMuxProxy) Close() {
if pxy.proxyPlugin != nil {
pxy.proxyPlugin.Close()
}
}
func (pxy *TCPMuxProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) {
HandleTCPWorkConnection(pxy.ctx, &pxy.cfg.LocalSvrConf, pxy.proxyPlugin, pxy.cfg.GetBaseInfo(), pxy.limiter,
conn, []byte(pxy.clientCfg.Token), m)
}
// HTTP
type HTTPProxy struct {
*BaseProxy
cfg *config.HTTPProxyConf
proxyPlugin plugin.Plugin
}
func (pxy *HTTPProxy) Run() (err error) {
if pxy.cfg.Plugin != "" {
pxy.proxyPlugin, err = plugin.Create(pxy.cfg.Plugin, pxy.cfg.PluginParams)
if err != nil {
return
}
}
return
}
func (pxy *HTTPProxy) Close() {
if pxy.proxyPlugin != nil {
pxy.proxyPlugin.Close()
}
}
func (pxy *HTTPProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) {
HandleTCPWorkConnection(pxy.ctx, &pxy.cfg.LocalSvrConf, pxy.proxyPlugin, pxy.cfg.GetBaseInfo(), pxy.limiter,
conn, []byte(pxy.clientCfg.Token), m)
}
// HTTPS
type HTTPSProxy struct {
*BaseProxy
cfg *config.HTTPSProxyConf
proxyPlugin plugin.Plugin
}
func (pxy *HTTPSProxy) Run() (err error) {
if pxy.cfg.Plugin != "" {
pxy.proxyPlugin, err = plugin.Create(pxy.cfg.Plugin, pxy.cfg.PluginParams)
if err != nil {
return
}
}
return
}
func (pxy *HTTPSProxy) Close() {
if pxy.proxyPlugin != nil {
pxy.proxyPlugin.Close()
}
}
func (pxy *HTTPSProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) {
HandleTCPWorkConnection(pxy.ctx, &pxy.cfg.LocalSvrConf, pxy.proxyPlugin, pxy.cfg.GetBaseInfo(), pxy.limiter,
conn, []byte(pxy.clientCfg.Token), m)
}
// STCP
type STCPProxy struct {
*BaseProxy
cfg *config.STCPProxyConf
proxyPlugin plugin.Plugin
}
func (pxy *STCPProxy) Run() (err error) {
if pxy.cfg.Plugin != "" {
pxy.proxyPlugin, err = plugin.Create(pxy.cfg.Plugin, pxy.cfg.PluginParams)
if err != nil {
return
}
}
return
}
func (pxy *STCPProxy) Close() {
if pxy.proxyPlugin != nil {
pxy.proxyPlugin.Close()
}
}
func (pxy *STCPProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) {
HandleTCPWorkConnection(pxy.ctx, &pxy.cfg.LocalSvrConf, pxy.proxyPlugin, pxy.cfg.GetBaseInfo(), pxy.limiter,
conn, []byte(pxy.clientCfg.Token), m)
}
// Common handler for tcp work connections.
func HandleTCPWorkConnection(ctx context.Context, localInfo *config.LocalSvrConf, proxyPlugin plugin.Plugin,
baseInfo *config.BaseProxyConf, limiter *rate.Limiter, workConn net.Conn, encKey []byte, m *msg.StartWorkConn,
) {
xl := xlog.FromContextSafe(ctx)
var (
remote io.ReadWriteCloser
err error
)
remote = workConn
if limiter != nil {
remote = frpIo.WrapReadWriteCloser(limit.NewReader(workConn, limiter), limit.NewWriter(workConn, limiter), func() error {
return workConn.Close()
})
}
xl.Trace("handle tcp work connection, use_encryption: %t, use_compression: %t",
baseInfo.UseEncryption, baseInfo.UseCompression)
if baseInfo.UseEncryption {
remote, err = frpIo.WithEncryption(remote, encKey)
if err != nil {
workConn.Close()
xl.Error("create encryption stream error: %v", err)
return
}
}
if baseInfo.UseCompression {
remote = frpIo.WithCompression(remote)
}
// check if we need to send proxy protocol info
var extraInfo []byte
if baseInfo.ProxyProtocolVersion != "" {
if m.SrcAddr != "" && m.SrcPort != 0 {
if m.DstAddr == "" {
m.DstAddr = "127.0.0.1"
}
srcAddr, _ := net.ResolveTCPAddr("tcp", net.JoinHostPort(m.SrcAddr, strconv.Itoa(int(m.SrcPort))))
dstAddr, _ := net.ResolveTCPAddr("tcp", net.JoinHostPort(m.DstAddr, strconv.Itoa(int(m.DstPort))))
h := &pp.Header{
Command: pp.PROXY,
SourceAddr: srcAddr,
DestinationAddr: dstAddr,
}
if strings.Contains(m.SrcAddr, ".") {
h.TransportProtocol = pp.TCPv4
} else {
h.TransportProtocol = pp.TCPv6
}
if baseInfo.ProxyProtocolVersion == "v1" {
h.Version = 1
} else if baseInfo.ProxyProtocolVersion == "v2" {
h.Version = 2
}
buf := bytes.NewBuffer(nil)
_, _ = h.WriteTo(buf)
extraInfo = buf.Bytes()
}
}
if proxyPlugin != nil {
// if plugin is set, let plugin handle connections first
xl.Debug("handle by plugin: %s", proxyPlugin.Name())
proxyPlugin.Handle(remote, workConn, extraInfo)
xl.Debug("handle by plugin finished")
return
}
localConn, err := libdial.Dial(
net.JoinHostPort(localInfo.LocalIP, strconv.Itoa(localInfo.LocalPort)),
libdial.WithTimeout(10*time.Second),
)
if err != nil {
workConn.Close()
xl.Error("connect to local service [%s:%d] error: %v", localInfo.LocalIP, localInfo.LocalPort, err)
return
}
xl.Debug("join connections, localConn(l[%s] r[%s]) workConn(l[%s] r[%s])", localConn.LocalAddr().String(),
localConn.RemoteAddr().String(), workConn.LocalAddr().String(), workConn.RemoteAddr().String())
if len(extraInfo) > 0 {
if _, err := localConn.Write(extraInfo); err != nil {
workConn.Close()
xl.Error("write extraInfo to local conn error: %v", err)
return
}
}
_, _, errs := frpIo.Join(localConn, remote)
xl.Debug("join connections closed")
if len(errs) > 0 {
xl.Trace("join connections errors: %v", errs)
}
}