Merge pull request #1121 from fatedier/new

new feature
This commit is contained in:
fatedier 2019-03-11 16:05:18 +08:00 committed by GitHub
commit 8b216b0ca9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 494 additions and 112 deletions

View File

@ -2,8 +2,8 @@ sudo: false
language: go language: go
go: go:
- 1.10.x
- 1.11.x - 1.11.x
- 1.12.x
install: install:
- make - make

View File

@ -15,6 +15,7 @@
package client package client
import ( import (
"crypto/tls"
"fmt" "fmt"
"io" "io"
"runtime/debug" "runtime/debug"
@ -166,8 +167,14 @@ func (ctl *Control) connectServer() (conn frpNet.Conn, err error) {
} }
conn = frpNet.WrapConn(stream) conn = frpNet.WrapConn(stream)
} else { } else {
conn, err = frpNet.ConnectServerByProxy(g.GlbClientCfg.HttpProxy, g.GlbClientCfg.Protocol, var tlsConfig *tls.Config
fmt.Sprintf("%s:%d", g.GlbClientCfg.ServerAddr, g.GlbClientCfg.ServerPort)) if g.GlbClientCfg.TLSEnable {
tlsConfig = &tls.Config{
InsecureSkipVerify: true,
}
}
conn, err = frpNet.ConnectServerByProxyWithTLS(g.GlbClientCfg.HttpProxy, g.GlbClientCfg.Protocol,
fmt.Sprintf("%s:%d", g.GlbClientCfg.ServerAddr, g.GlbClientCfg.ServerPort), tlsConfig)
if err != nil { if err != nil {
ctl.Warn("start new connection to server error: %v", err) ctl.Warn("start new connection to server error: %v", err)
return return

View File

@ -20,6 +20,8 @@ import (
"io" "io"
"io/ioutil" "io/ioutil"
"net" "net"
"strconv"
"strings"
"sync" "sync"
"time" "time"
@ -280,25 +282,56 @@ func (pxy *XtcpProxy) InWorkConn(conn frpNet.Conn) {
return return
} }
pxy.Trace("get natHoleRespMsg, sid [%s], client address [%s]", natHoleRespMsg.Sid, natHoleRespMsg.ClientAddr) pxy.Trace("get natHoleRespMsg, sid [%s], client address [%s] visitor address [%s]", natHoleRespMsg.Sid, natHoleRespMsg.ClientAddr, natHoleRespMsg.VisitorAddr)
// Send sid to visitor udp address. // Send detect message
time.Sleep(time.Second) array := strings.Split(natHoleRespMsg.VisitorAddr, ":")
if len(array) <= 1 {
pxy.Error("get NatHoleResp visitor address error: %v", natHoleRespMsg.VisitorAddr)
}
laddr, _ := net.ResolveUDPAddr("udp", clientConn.LocalAddr().String()) laddr, _ := net.ResolveUDPAddr("udp", clientConn.LocalAddr().String())
daddr, err := net.ResolveUDPAddr("udp", natHoleRespMsg.VisitorAddr) /*
for i := 1000; i < 65000; i++ {
pxy.sendDetectMsg(array[0], int64(i), laddr, "a")
}
*/
port, err := strconv.ParseInt(array[1], 10, 64)
if err != nil { if err != nil {
pxy.Error("resolve visitor udp address error: %v", err) pxy.Error("get natHoleResp visitor address error: %v", natHoleRespMsg.VisitorAddr)
return return
} }
pxy.sendDetectMsg(array[0], int(port), laddr, []byte(natHoleRespMsg.Sid))
pxy.Trace("send all detect msg done")
lConn, err := net.DialUDP("udp", laddr, daddr) msg.WriteMsg(conn, &msg.NatHoleClientDetectOK{})
// Listen for clientConn's address and wait for visitor connection
lConn, err := net.ListenUDP("udp", laddr)
if err != nil { if err != nil {
pxy.Error("dial visitor udp address error: %v", err) pxy.Error("listen on visitorConn's local adress error: %v", err)
return return
} }
lConn.Write([]byte(natHoleRespMsg.Sid)) defer lConn.Close()
kcpConn, err := frpNet.NewKcpConnFromUdp(lConn, true, natHoleRespMsg.VisitorAddr) lConn.SetReadDeadline(time.Now().Add(8 * time.Second))
sidBuf := pool.GetBuf(1024)
var uAddr *net.UDPAddr
n, uAddr, err = lConn.ReadFromUDP(sidBuf)
if err != nil {
pxy.Warn("get sid from visitor error: %v", err)
return
}
lConn.SetReadDeadline(time.Time{})
if string(sidBuf[:n]) != natHoleRespMsg.Sid {
pxy.Warn("incorrect sid from visitor")
return
}
pool.PutBuf(sidBuf)
pxy.Info("nat hole connection make success, sid [%s]", natHoleRespMsg.Sid)
lConn.WriteToUDP(sidBuf[:n], uAddr)
kcpConn, err := frpNet.NewKcpConnFromUdp(lConn, false, natHoleRespMsg.VisitorAddr)
if err != nil { if err != nil {
pxy.Error("create kcp connection from udp connection error: %v", err) pxy.Error("create kcp connection from udp connection error: %v", err)
return return
@ -323,6 +356,25 @@ func (pxy *XtcpProxy) InWorkConn(conn frpNet.Conn) {
frpNet.WrapConn(muxConn), []byte(pxy.cfg.Sk)) frpNet.WrapConn(muxConn), []byte(pxy.cfg.Sk))
} }
func (pxy *XtcpProxy) sendDetectMsg(addr string, port int, laddr *net.UDPAddr, content []byte) (err error) {
daddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", addr, port))
if err != nil {
return err
}
tConn, err := net.DialUDP("udp", laddr, daddr)
if err != nil {
return err
}
//uConn := ipv4.NewConn(tConn)
//uConn.SetTTL(3)
tConn.Write(content)
tConn.Close()
return nil
}
// UDP // UDP
type UdpProxy struct { type UdpProxy struct {
*BaseProxy *BaseProxy

View File

@ -15,6 +15,7 @@
package client package client
import ( import (
"crypto/tls"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"runtime" "runtime"
@ -151,8 +152,14 @@ func (svr *Service) keepControllerWorking() {
// conn: control connection // conn: control connection
// session: if it's not nil, using tcp mux // session: if it's not nil, using tcp mux
func (svr *Service) login() (conn frpNet.Conn, session *fmux.Session, err error) { func (svr *Service) login() (conn frpNet.Conn, session *fmux.Session, err error) {
conn, err = frpNet.ConnectServerByProxy(g.GlbClientCfg.HttpProxy, g.GlbClientCfg.Protocol, var tlsConfig *tls.Config
fmt.Sprintf("%s:%d", g.GlbClientCfg.ServerAddr, g.GlbClientCfg.ServerPort)) if g.GlbClientCfg.TLSEnable {
tlsConfig = &tls.Config{
InsecureSkipVerify: true,
}
}
conn, err = frpNet.ConnectServerByProxyWithTLS(g.GlbClientCfg.HttpProxy, g.GlbClientCfg.Protocol,
fmt.Sprintf("%s:%d", g.GlbClientCfg.ServerAddr, g.GlbClientCfg.ServerPort), tlsConfig)
if err != nil { if err != nil {
return return
} }

View File

@ -20,13 +20,9 @@ import (
"io" "io"
"io/ioutil" "io/ioutil"
"net" "net"
"strconv"
"strings"
"sync" "sync"
"time" "time"
"golang.org/x/net/ipv4"
"github.com/fatedier/frp/g" "github.com/fatedier/frp/g"
"github.com/fatedier/frp/models/config" "github.com/fatedier/frp/models/config"
"github.com/fatedier/frp/models/msg" "github.com/fatedier/frp/models/msg"
@ -251,42 +247,31 @@ func (sv *XtcpVisitor) handleConn(userConn frpNet.Conn) {
return return
} }
sv.Trace("get natHoleRespMsg, sid [%s], client address [%s]", natHoleRespMsg.Sid, natHoleRespMsg.ClientAddr) sv.Trace("get natHoleRespMsg, sid [%s], client address [%s], visitor address [%s]", natHoleRespMsg.Sid, natHoleRespMsg.ClientAddr, natHoleRespMsg.VisitorAddr)
// Close visitorConn, so we can use it's local address. // Close visitorConn, so we can use it's local address.
visitorConn.Close() visitorConn.Close()
// Send detect message. // send sid message to client
array := strings.Split(natHoleRespMsg.ClientAddr, ":")
if len(array) <= 1 {
sv.Error("get natHoleResp client address error: %s", natHoleRespMsg.ClientAddr)
return
}
laddr, _ := net.ResolveUDPAddr("udp", visitorConn.LocalAddr().String()) laddr, _ := net.ResolveUDPAddr("udp", visitorConn.LocalAddr().String())
/* daddr, err := net.ResolveUDPAddr("udp", natHoleRespMsg.ClientAddr)
for i := 1000; i < 65000; i++ {
sv.sendDetectMsg(array[0], int64(i), laddr, "a")
}
*/
port, err := strconv.ParseInt(array[1], 10, 64)
if err != nil { if err != nil {
sv.Error("get natHoleResp client address error: %s", natHoleRespMsg.ClientAddr) sv.Error("resolve client udp address error: %v", err)
return return
} }
sv.sendDetectMsg(array[0], int(port), laddr, []byte(natHoleRespMsg.Sid)) lConn, err := net.DialUDP("udp", laddr, daddr)
sv.Trace("send all detect msg done")
// Listen for visitorConn's address and wait for client connection.
lConn, err := net.ListenUDP("udp", laddr)
if err != nil { if err != nil {
sv.Error("listen on visitorConn's local adress error: %v", err) sv.Error("dial client udp address error: %v", err)
return return
} }
defer lConn.Close() defer lConn.Close()
lConn.SetReadDeadline(time.Now().Add(5 * time.Second)) lConn.Write([]byte(natHoleRespMsg.Sid))
// read ack sid from client
sidBuf := pool.GetBuf(1024) sidBuf := pool.GetBuf(1024)
n, _, err = lConn.ReadFromUDP(sidBuf) lConn.SetReadDeadline(time.Now().Add(8 * time.Second))
n, err = lConn.Read(sidBuf)
if err != nil { if err != nil {
sv.Warn("get sid from client error: %v", err) sv.Warn("get sid from client error: %v", err)
return return
@ -296,11 +281,13 @@ func (sv *XtcpVisitor) handleConn(userConn frpNet.Conn) {
sv.Warn("incorrect sid from client") sv.Warn("incorrect sid from client")
return return
} }
sv.Info("nat hole connection make success, sid [%s]", string(sidBuf[:n]))
pool.PutBuf(sidBuf) pool.PutBuf(sidBuf)
sv.Info("nat hole connection make success, sid [%s]", natHoleRespMsg.Sid)
// wrap kcp connection
var remote io.ReadWriteCloser var remote io.ReadWriteCloser
remote, err = frpNet.NewKcpConnFromUdp(lConn, false, natHoleRespMsg.ClientAddr) remote, err = frpNet.NewKcpConnFromUdp(lConn, true, natHoleRespMsg.ClientAddr)
if err != nil { if err != nil {
sv.Error("create kcp connection from udp connection error: %v", err) sv.Error("create kcp connection from udp connection error: %v", err)
return return
@ -336,22 +323,3 @@ func (sv *XtcpVisitor) handleConn(userConn frpNet.Conn) {
frpIo.Join(userConn, muxConn) frpIo.Join(userConn, muxConn)
sv.Debug("join connections closed") sv.Debug("join connections closed")
} }
func (sv *XtcpVisitor) sendDetectMsg(addr string, port int, laddr *net.UDPAddr, content []byte) (err error) {
daddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", addr, port))
if err != nil {
return err
}
tConn, err := net.DialUDP("udp", laddr, daddr)
if err != nil {
return err
}
uConn := ipv4.NewConn(tConn)
uConn.SetTTL(3)
tConn.Write(content)
tConn.Close()
return nil
}

View File

@ -44,6 +44,9 @@ login_fail_exit = true
# now it supports tcp and kcp and websocket, default is tcp # now it supports tcp and kcp and websocket, default is tcp
protocol = tcp protocol = tcp
# if tls_enable is true, frpc will connect frps by tls
tls_enable = true
# specify a dns server, so frpc will use this instead of default one # specify a dns server, so frpc will use this instead of default one
# dns_server = 8.8.8.8 # dns_server = 8.8.8.8

View File

@ -44,6 +44,7 @@ type ClientCommonConf struct {
LoginFailExit bool `json:"login_fail_exit"` LoginFailExit bool `json:"login_fail_exit"`
Start map[string]struct{} `json:"start"` Start map[string]struct{} `json:"start"`
Protocol string `json:"protocol"` Protocol string `json:"protocol"`
TLSEnable bool `json:"tls_enable"`
HeartBeatInterval int64 `json:"heartbeat_interval"` HeartBeatInterval int64 `json:"heartbeat_interval"`
HeartBeatTimeout int64 `json:"heartbeat_timeout"` HeartBeatTimeout int64 `json:"heartbeat_timeout"`
} }
@ -69,6 +70,7 @@ func GetDefaultClientConf() *ClientCommonConf {
LoginFailExit: true, LoginFailExit: true,
Start: make(map[string]struct{}), Start: make(map[string]struct{}),
Protocol: "tcp", Protocol: "tcp",
TLSEnable: false,
HeartBeatInterval: 30, HeartBeatInterval: 30,
HeartBeatTimeout: 90, HeartBeatTimeout: 90,
} }
@ -194,6 +196,12 @@ func UnmarshalClientConfFromIni(defaultCfg *ClientCommonConf, content string) (c
cfg.Protocol = tmpStr cfg.Protocol = tmpStr
} }
if tmpStr, ok = conf.Get("common", "tls_enable"); ok && tmpStr == "true" {
cfg.TLSEnable = true
} else {
cfg.TLSEnable = false
}
if tmpStr, ok = conf.Get("common", "heartbeat_timeout"); ok { if tmpStr, ok = conf.Get("common", "heartbeat_timeout"); ok {
if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil { if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil {
err = fmt.Errorf("Parse conf error: invalid heartbeat_timeout") err = fmt.Errorf("Parse conf error: invalid heartbeat_timeout")

View File

@ -17,44 +17,46 @@ package msg
import "net" import "net"
const ( const (
TypeLogin = 'o' TypeLogin = 'o'
TypeLoginResp = '1' TypeLoginResp = '1'
TypeNewProxy = 'p' TypeNewProxy = 'p'
TypeNewProxyResp = '2' TypeNewProxyResp = '2'
TypeCloseProxy = 'c' TypeCloseProxy = 'c'
TypeNewWorkConn = 'w' TypeNewWorkConn = 'w'
TypeReqWorkConn = 'r' TypeReqWorkConn = 'r'
TypeStartWorkConn = 's' TypeStartWorkConn = 's'
TypeNewVisitorConn = 'v' TypeNewVisitorConn = 'v'
TypeNewVisitorConnResp = '3' TypeNewVisitorConnResp = '3'
TypePing = 'h' TypePing = 'h'
TypePong = '4' TypePong = '4'
TypeUdpPacket = 'u' TypeUdpPacket = 'u'
TypeNatHoleVisitor = 'i' TypeNatHoleVisitor = 'i'
TypeNatHoleClient = 'n' TypeNatHoleClient = 'n'
TypeNatHoleResp = 'm' TypeNatHoleResp = 'm'
TypeNatHoleSid = '5' TypeNatHoleClientDetectOK = 'd'
TypeNatHoleSid = '5'
) )
var ( var (
msgTypeMap = map[byte]interface{}{ msgTypeMap = map[byte]interface{}{
TypeLogin: Login{}, TypeLogin: Login{},
TypeLoginResp: LoginResp{}, TypeLoginResp: LoginResp{},
TypeNewProxy: NewProxy{}, TypeNewProxy: NewProxy{},
TypeNewProxyResp: NewProxyResp{}, TypeNewProxyResp: NewProxyResp{},
TypeCloseProxy: CloseProxy{}, TypeCloseProxy: CloseProxy{},
TypeNewWorkConn: NewWorkConn{}, TypeNewWorkConn: NewWorkConn{},
TypeReqWorkConn: ReqWorkConn{}, TypeReqWorkConn: ReqWorkConn{},
TypeStartWorkConn: StartWorkConn{}, TypeStartWorkConn: StartWorkConn{},
TypeNewVisitorConn: NewVisitorConn{}, TypeNewVisitorConn: NewVisitorConn{},
TypeNewVisitorConnResp: NewVisitorConnResp{}, TypeNewVisitorConnResp: NewVisitorConnResp{},
TypePing: Ping{}, TypePing: Ping{},
TypePong: Pong{}, TypePong: Pong{},
TypeUdpPacket: UdpPacket{}, TypeUdpPacket: UdpPacket{},
TypeNatHoleVisitor: NatHoleVisitor{}, TypeNatHoleVisitor: NatHoleVisitor{},
TypeNatHoleClient: NatHoleClient{}, TypeNatHoleClient: NatHoleClient{},
TypeNatHoleResp: NatHoleResp{}, TypeNatHoleResp: NatHoleResp{},
TypeNatHoleSid: NatHoleSid{}, TypeNatHoleClientDetectOK: NatHoleClientDetectOK{},
TypeNatHoleSid: NatHoleSid{},
} }
) )
@ -169,6 +171,9 @@ type NatHoleResp struct {
Error string `json:"error"` Error string `json:"error"`
} }
type NatHoleClientDetectOK struct {
}
type NatHoleSid struct { type NatHoleSid struct {
Sid string `json:"sid"` Sid string `json:"sid"`
} }

View File

@ -18,6 +18,11 @@ import (
// Timeout seconds. // Timeout seconds.
var NatHoleTimeout int64 = 10 var NatHoleTimeout int64 = 10
type SidRequest struct {
Sid string
NotifyCh chan struct{}
}
type NatHoleController struct { type NatHoleController struct {
listener *net.UDPConn listener *net.UDPConn
@ -44,11 +49,11 @@ func NewNatHoleController(udpBindAddr string) (nc *NatHoleController, err error)
return nc, nil return nc, nil
} }
func (nc *NatHoleController) ListenClient(name string, sk string) (sidCh chan string) { func (nc *NatHoleController) ListenClient(name string, sk string) (sidCh chan *SidRequest) {
clientCfg := &NatHoleClientCfg{ clientCfg := &NatHoleClientCfg{
Name: name, Name: name,
Sk: sk, Sk: sk,
SidCh: make(chan string), SidCh: make(chan *SidRequest),
} }
nc.mu.Lock() nc.mu.Lock()
nc.clientCfgs[name] = clientCfg nc.clientCfgs[name] = clientCfg
@ -132,7 +137,10 @@ func (nc *NatHoleController) HandleVisitor(m *msg.NatHoleVisitor, raddr *net.UDP
}() }()
err := errors.PanicToError(func() { err := errors.PanicToError(func() {
clientCfg.SidCh <- sid clientCfg.SidCh <- &SidRequest{
Sid: sid,
NotifyCh: session.NotifyCh,
}
}) })
if err != nil { if err != nil {
return return
@ -158,7 +166,6 @@ func (nc *NatHoleController) HandleClient(m *msg.NatHoleClient, raddr *net.UDPAd
} }
log.Trace("handle client message, sid [%s]", session.Sid) log.Trace("handle client message, sid [%s]", session.Sid)
session.ClientAddr = raddr session.ClientAddr = raddr
session.NotifyCh <- struct{}{}
resp := nc.GenNatHoleResponse(session, "") resp := nc.GenNatHoleResponse(session, "")
log.Trace("send nat hole response to client") log.Trace("send nat hole response to client")
@ -201,5 +208,5 @@ type NatHoleSession struct {
type NatHoleClientCfg struct { type NatHoleClientCfg struct {
Name string Name string
Sk string Sk string
SidCh chan string SidCh chan *SidRequest
} }

View File

@ -42,18 +42,40 @@ func (pxy *XtcpProxy) Run() (remoteAddr string, err error) {
select { select {
case <-pxy.closeCh: case <-pxy.closeCh:
break break
case sid := <-sidCh: case sidRequest := <-sidCh:
sr := sidRequest
workConn, errRet := pxy.GetWorkConnFromPool() workConn, errRet := pxy.GetWorkConnFromPool()
if errRet != nil { if errRet != nil {
continue continue
} }
m := &msg.NatHoleSid{ m := &msg.NatHoleSid{
Sid: sid, Sid: sr.Sid,
} }
errRet = msg.WriteMsg(workConn, m) errRet = msg.WriteMsg(workConn, m)
if errRet != nil { if errRet != nil {
pxy.Warn("write nat hole sid package error, %v", errRet) pxy.Warn("write nat hole sid package error, %v", errRet)
workConn.Close()
break
} }
go func() {
raw, errRet := msg.ReadMsg(workConn)
if errRet != nil {
pxy.Warn("read nat hole client ok package error: %v", errRet)
workConn.Close()
return
}
if _, ok := raw.(*msg.NatHoleClientDetectOK); !ok {
pxy.Warn("read nat hole client ok package format error")
workConn.Close()
return
}
select {
case sr.NotifyCh <- struct{}{}:
default:
}
}()
} }
} }
}() }()

View File

@ -16,8 +16,14 @@ package server
import ( import (
"bytes" "bytes"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"math/big"
"net" "net"
"net/http" "net/http"
"time" "time"
@ -61,6 +67,9 @@ type Service struct {
// Accept connections using websocket // Accept connections using websocket
websocketListener frpNet.Listener websocketListener frpNet.Listener
// Accept frp tls connections
tlsListener frpNet.Listener
// Manage all controllers // Manage all controllers
ctlManager *ControlManager ctlManager *ControlManager
@ -72,6 +81,8 @@ type Service struct {
// stats collector to store server and proxies stats info // stats collector to store server and proxies stats info
statsCollector stats.Collector statsCollector stats.Collector
tlsConfig *tls.Config
} }
func NewService() (svr *Service, err error) { func NewService() (svr *Service, err error) {
@ -84,6 +95,7 @@ func NewService() (svr *Service, err error) {
TcpPortManager: ports.NewPortManager("tcp", cfg.ProxyBindAddr, cfg.AllowPorts), TcpPortManager: ports.NewPortManager("tcp", cfg.ProxyBindAddr, cfg.AllowPorts),
UdpPortManager: ports.NewPortManager("udp", cfg.ProxyBindAddr, cfg.AllowPorts), UdpPortManager: ports.NewPortManager("udp", cfg.ProxyBindAddr, cfg.AllowPorts),
}, },
tlsConfig: generateTLSConfig(),
} }
// Init group controller // Init group controller
@ -187,6 +199,12 @@ func NewService() (svr *Service, err error) {
log.Info("https service listen on %s:%d", cfg.ProxyBindAddr, cfg.VhostHttpsPort) log.Info("https service listen on %s:%d", cfg.ProxyBindAddr, cfg.VhostHttpsPort)
} }
// frp tls listener
tlsListener := svr.muxer.Listen(1, 1, func(data []byte) bool {
return int(data[0]) == frpNet.FRP_TLS_HEAD_BYTE
})
svr.tlsListener = frpNet.WrapLogListener(tlsListener)
// Create nat hole controller. // Create nat hole controller.
if cfg.BindUdpPort > 0 { if cfg.BindUdpPort > 0 {
var nc *nathole.NatHoleController var nc *nathole.NatHoleController
@ -225,6 +243,7 @@ func (svr *Service) Run() {
} }
go svr.HandleListener(svr.websocketListener) go svr.HandleListener(svr.websocketListener)
go svr.HandleListener(svr.tlsListener)
svr.HandleListener(svr.listener) svr.HandleListener(svr.listener)
} }
@ -237,6 +256,7 @@ func (svr *Service) HandleListener(l frpNet.Listener) {
log.Warn("Listener for incoming connections from client closed") log.Warn("Listener for incoming connections from client closed")
return return
} }
c = frpNet.CheckAndEnableTLSServerConn(c, svr.tlsConfig)
// Start a new goroutine for dealing connections. // Start a new goroutine for dealing connections.
go func(frpConn frpNet.Conn) { go func(frpConn frpNet.Conn) {
@ -373,3 +393,24 @@ func (svr *Service) RegisterVisitorConn(visitorConn frpNet.Conn, newMsg *msg.New
return svr.rc.VisitorManager.NewConn(newMsg.ProxyName, visitorConn, newMsg.Timestamp, newMsg.SignKey, return svr.rc.VisitorManager.NewConn(newMsg.ProxyName, visitorConn, newMsg.Timestamp, newMsg.SignKey,
newMsg.UseEncryption, newMsg.UseCompression) newMsg.UseEncryption, newMsg.UseCompression)
} }
// Setup a bare-bones TLS config for the server
func generateTLSConfig() *tls.Config {
key, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil {
panic(err)
}
template := x509.Certificate{SerialNumber: big.NewInt(1)}
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key)
if err != nil {
panic(err)
}
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
tlsCert, err := tls.X509KeyPair(certPEM, keyPEM)
if err != nil {
panic(err)
}
return &tls.Config{Certificates: []tls.Certificate{tlsCert}}
}

188
tests/ci/tls_test.go Normal file
View File

@ -0,0 +1,188 @@
package ci
import (
"os"
"testing"
"time"
"github.com/fatedier/frp/tests/config"
"github.com/fatedier/frp/tests/consts"
"github.com/fatedier/frp/tests/util"
"github.com/stretchr/testify/assert"
)
const FRPS_TLS_TCP_CONF = `
[common]
bind_addr = 0.0.0.0
bind_port = 20000
log_file = console
log_level = debug
token = 123456
`
const FRPC_TLS_TCP_CONF = `
[common]
server_addr = 127.0.0.1
server_port = 20000
log_file = console
log_level = debug
token = 123456
protocol = tcp
tls_enable = true
[tcp]
type = tcp
local_port = 10701
remote_port = 20801
`
func TestTlsOverTCP(t *testing.T) {
assert := assert.New(t)
frpsCfgPath, err := config.GenerateConfigFile(consts.FRPS_NORMAL_CONFIG, FRPS_TLS_TCP_CONF)
if assert.NoError(err) {
defer os.Remove(frpsCfgPath)
}
frpcCfgPath, err := config.GenerateConfigFile(consts.FRPC_NORMAL_CONFIG, FRPC_TLS_TCP_CONF)
if assert.NoError(err) {
defer os.Remove(frpcCfgPath)
}
frpsProcess := util.NewProcess(consts.FRPS_BIN_PATH, []string{"-c", frpsCfgPath})
err = frpsProcess.Start()
if assert.NoError(err) {
defer frpsProcess.Stop()
}
time.Sleep(100 * time.Millisecond)
frpcProcess := util.NewProcess(consts.FRPC_BIN_PATH, []string{"-c", frpcCfgPath})
err = frpcProcess.Start()
if assert.NoError(err) {
defer frpcProcess.Stop()
}
time.Sleep(250 * time.Millisecond)
// test tcp
res, err := util.SendTcpMsg("127.0.0.1:20801", consts.TEST_TCP_ECHO_STR)
assert.NoError(err)
assert.Equal(consts.TEST_TCP_ECHO_STR, res)
}
const FRPS_TLS_KCP_CONF = `
[common]
bind_addr = 0.0.0.0
bind_port = 20000
kcp_bind_port = 20000
log_file = console
log_level = debug
token = 123456
`
const FRPC_TLS_KCP_CONF = `
[common]
server_addr = 127.0.0.1
server_port = 20000
log_file = console
log_level = debug
token = 123456
protocol = kcp
tls_enable = true
[tcp]
type = tcp
local_port = 10701
remote_port = 20801
`
func TestTLSOverKCP(t *testing.T) {
assert := assert.New(t)
frpsCfgPath, err := config.GenerateConfigFile(consts.FRPS_NORMAL_CONFIG, FRPS_TLS_KCP_CONF)
if assert.NoError(err) {
defer os.Remove(frpsCfgPath)
}
frpcCfgPath, err := config.GenerateConfigFile(consts.FRPC_NORMAL_CONFIG, FRPC_TLS_KCP_CONF)
if assert.NoError(err) {
defer os.Remove(frpcCfgPath)
}
frpsProcess := util.NewProcess(consts.FRPS_BIN_PATH, []string{"-c", frpsCfgPath})
err = frpsProcess.Start()
if assert.NoError(err) {
defer frpsProcess.Stop()
}
time.Sleep(200 * time.Millisecond)
frpcProcess := util.NewProcess(consts.FRPC_BIN_PATH, []string{"-c", frpcCfgPath})
err = frpcProcess.Start()
if assert.NoError(err) {
defer frpcProcess.Stop()
}
time.Sleep(500 * time.Millisecond)
// test tcp
res, err := util.SendTcpMsg("127.0.0.1:20801", consts.TEST_TCP_ECHO_STR)
assert.NoError(err)
assert.Equal(consts.TEST_TCP_ECHO_STR, res)
}
const FRPS_TLS_WS_CONF = `
[common]
bind_addr = 0.0.0.0
bind_port = 20000
log_file = console
log_level = debug
token = 123456
`
const FRPC_TLS_WS_CONF = `
[common]
server_addr = 127.0.0.1
server_port = 20000
log_file = console
log_level = debug
token = 123456
protocol = websocket
tls_enable = true
[tcp]
type = tcp
local_port = 10701
remote_port = 20801
`
func TestTLSOverWebsocket(t *testing.T) {
assert := assert.New(t)
frpsCfgPath, err := config.GenerateConfigFile(consts.FRPS_NORMAL_CONFIG, FRPS_TLS_WS_CONF)
if assert.NoError(err) {
defer os.Remove(frpsCfgPath)
}
frpcCfgPath, err := config.GenerateConfigFile(consts.FRPC_NORMAL_CONFIG, FRPC_TLS_WS_CONF)
if assert.NoError(err) {
defer os.Remove(frpcCfgPath)
}
frpsProcess := util.NewProcess(consts.FRPS_BIN_PATH, []string{"-c", frpsCfgPath})
err = frpsProcess.Start()
if assert.NoError(err) {
defer frpsProcess.Stop()
}
time.Sleep(200 * time.Millisecond)
frpcProcess := util.NewProcess(consts.FRPC_BIN_PATH, []string{"-c", frpcCfgPath})
err = frpcProcess.Start()
if assert.NoError(err) {
defer frpcProcess.Stop()
}
time.Sleep(500 * time.Millisecond)
// test tcp
res, err := util.SendTcpMsg("127.0.0.1:20801", consts.TEST_TCP_ECHO_STR)
assert.NoError(err)
assert.Equal(consts.TEST_TCP_ECHO_STR, res)
}

View File

@ -15,6 +15,7 @@
package net package net
import ( import (
"crypto/tls"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@ -207,3 +208,13 @@ func ConnectServerByProxy(proxyUrl string, protocol string, addr string) (c Conn
return nil, fmt.Errorf("unsupport protocol: %s", protocol) return nil, fmt.Errorf("unsupport protocol: %s", protocol)
} }
} }
func ConnectServerByProxyWithTLS(proxyUrl string, protocol string, addr string, tlsConfig *tls.Config) (c Conn, err error) {
c, err = ConnectServerByProxy(proxyUrl, protocol, addr)
if tlsConfig == nil {
return
}
c = WrapTLSClientConn(c, tlsConfig)
return
}

44
utils/net/tls.go Normal file
View File

@ -0,0 +1,44 @@
// Copyright 2019 fatedier, fatedier@gmail.com
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package net
import (
"crypto/tls"
"net"
gnet "github.com/fatedier/golib/net"
)
var (
FRP_TLS_HEAD_BYTE = 0x17
)
func WrapTLSClientConn(c net.Conn, tlsConfig *tls.Config) (out Conn) {
c.Write([]byte{byte(FRP_TLS_HEAD_BYTE)})
out = WrapConn(tls.Client(c, tlsConfig))
return
}
func CheckAndEnableTLSServerConn(c net.Conn, tlsConfig *tls.Config) (out Conn) {
sc, r := gnet.NewSharedConnSize(c, 1)
buf := make([]byte, 1)
n, _ := r.Read(buf)
if n == 1 && int(buf[0]) == FRP_TLS_HEAD_BYTE {
out = WrapConn(tls.Server(c, tlsConfig))
} else {
out = WrapConn(sc)
}
return
}

View File

@ -19,7 +19,7 @@ import (
"strings" "strings"
) )
var version string = "0.24.1" var version string = "0.25.0"
func Full() string { func Full() string {
return version return version

1
vendor/github.com/hashicorp/yamux/go.mod generated vendored Normal file
View File

@ -0,0 +1 @@
module github.com/hashicorp/yamux

View File

@ -3,6 +3,7 @@ package yamux
import ( import (
"fmt" "fmt"
"io" "io"
"log"
"os" "os"
"time" "time"
) )
@ -30,8 +31,13 @@ type Config struct {
// window size that we allow for a stream. // window size that we allow for a stream.
MaxStreamWindowSize uint32 MaxStreamWindowSize uint32
// LogOutput is used to control the log destination // LogOutput is used to control the log destination. Either Logger or
// LogOutput can be set, not both.
LogOutput io.Writer LogOutput io.Writer
// Logger is used to pass in the logger to be used. Either Logger or
// LogOutput can be set, not both.
Logger *log.Logger
} }
// DefaultConfig is used to return a default configuration // DefaultConfig is used to return a default configuration
@ -57,6 +63,11 @@ func VerifyConfig(config *Config) error {
if config.MaxStreamWindowSize < initialStreamWindow { if config.MaxStreamWindowSize < initialStreamWindow {
return fmt.Errorf("MaxStreamWindowSize must be larger than %d", initialStreamWindow) return fmt.Errorf("MaxStreamWindowSize must be larger than %d", initialStreamWindow)
} }
if config.LogOutput != nil && config.Logger != nil {
return fmt.Errorf("both Logger and LogOutput may not be set, select one")
} else if config.LogOutput == nil && config.Logger == nil {
return fmt.Errorf("one of Logger or LogOutput must be set, select one")
}
return nil return nil
} }

View File

@ -86,9 +86,14 @@ type sendReady struct {
// newSession is used to construct a new session // newSession is used to construct a new session
func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session { func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
logger := config.Logger
if logger == nil {
logger = log.New(config.LogOutput, "", log.LstdFlags)
}
s := &Session{ s := &Session{
config: config, config: config,
logger: log.New(config.LogOutput, "", log.LstdFlags), logger: logger,
conn: conn, conn: conn,
bufRead: bufio.NewReader(conn), bufRead: bufio.NewReader(conn),
pings: make(map[uint32]chan struct{}), pings: make(map[uint32]chan struct{}),
@ -309,8 +314,10 @@ func (s *Session) keepalive() {
case <-time.After(s.config.KeepAliveInterval): case <-time.After(s.config.KeepAliveInterval):
_, err := s.Ping() _, err := s.Ping()
if err != nil { if err != nil {
s.logger.Printf("[ERR] yamux: keepalive failed: %v", err) if err != ErrSessionShutdown {
s.exitErr(ErrKeepAliveTimeout) s.logger.Printf("[ERR] yamux: keepalive failed: %v", err)
s.exitErr(ErrKeepAliveTimeout)
}
return return
} }
case <-s.shutdownCh: case <-s.shutdownCh:

10
vendor/modules.txt vendored
View File

@ -23,7 +23,7 @@ github.com/gorilla/context
github.com/gorilla/mux github.com/gorilla/mux
# github.com/gorilla/websocket v1.2.0 # github.com/gorilla/websocket v1.2.0
github.com/gorilla/websocket github.com/gorilla/websocket
# github.com/hashicorp/yamux v0.0.0-20180314200745-2658be15c5f0 # github.com/hashicorp/yamux v0.0.0-20181012175058-2f1d1f20f75d
github.com/hashicorp/yamux github.com/hashicorp/yamux
# github.com/inconshreveable/mousetrap v1.0.0 # github.com/inconshreveable/mousetrap v1.0.0
github.com/inconshreveable/mousetrap github.com/inconshreveable/mousetrap
@ -61,11 +61,11 @@ golang.org/x/crypto/twofish
golang.org/x/crypto/xtea golang.org/x/crypto/xtea
golang.org/x/crypto/salsa20/salsa golang.org/x/crypto/salsa20/salsa
# golang.org/x/net v0.0.0-20180524181706-dfa909b99c79 # golang.org/x/net v0.0.0-20180524181706-dfa909b99c79
golang.org/x/net/ipv4
golang.org/x/net/websocket golang.org/x/net/websocket
golang.org/x/net/context
golang.org/x/net/proxy
golang.org/x/net/ipv4
golang.org/x/net/internal/socks
golang.org/x/net/bpf golang.org/x/net/bpf
golang.org/x/net/internal/iana golang.org/x/net/internal/iana
golang.org/x/net/internal/socket golang.org/x/net/internal/socket
golang.org/x/net/context
golang.org/x/net/proxy
golang.org/x/net/internal/socks