support more proxy type

This commit is contained in:
fatedier 2017-03-10 02:01:17 +08:00
parent f90028cf96
commit b02e1007fb
14 changed files with 223 additions and 229 deletions

View File

@ -1,3 +1,5 @@
Issue is only used for submiting bug report and documents typo. If there are same issues or answers can be found in documents, we will close it directly.
Please answer these questions before submitting your issue. Thanks!
1. What did you do?
@ -13,3 +15,7 @@ If possible, provide a recipe for reproducing the error.
4. What version of frp are you using (./frpc -v or ./frps -v)?
5. Can you point out what caused this issue (optional)?

View File

@ -208,11 +208,7 @@ func (ctl *Control) reader() {
}()
defer close(ctl.closedCh)
encReader, err := crypto.NewReader(ctl.conn, []byte(config.ClientCommonCfg.PrivilegeToken))
if err != nil {
ctl.conn.Error("crypto new reader error: %v", err)
return
}
encReader := crypto.NewReader(ctl.conn, []byte(config.ClientCommonCfg.PrivilegeToken))
for {
if m, err := msg.ReadMsg(encReader); err != nil {
if err == io.EOF {

View File

@ -72,27 +72,41 @@ func (pxy *TcpProxy) Close() {
func (pxy *TcpProxy) InWorkConn(conn net.Conn) {
defer conn.Close()
localConn, err := net.ConnectTcpServer(fmt.Sprintf("%s:%d", pxy.cfg.LocalIp, pxy.cfg.LocalPort))
if err != nil {
conn.Error("connect to local service [%s:%d] error: %v", pxy.cfg.LocalIp, pxy.cfg.LocalPort, err)
return
}
HandleTcpWorkConnection(&pxy.cfg.LocalSvrConf, &pxy.cfg.BaseProxyConf, conn)
}
var remote io.ReadWriteCloser
remote = conn
if pxy.cfg.UseEncryption {
remote, err = tcp.WithEncryption(remote, []byte(config.ClientCommonCfg.PrivilegeToken))
if err != nil {
conn.Error("create encryption stream error: %v", err)
return
}
}
if pxy.cfg.UseCompression {
remote = tcp.WithCompression(remote)
}
conn.Debug("join connections")
tcp.Join(localConn, remote)
conn.Debug("join connections closed")
// HTTP
type HttpProxy struct {
cfg *config.HttpProxyConf
ctl *Control
}
func (pxy *HttpProxy) Run() {
}
func (pxy *HttpProxy) Close() {
}
func (pxy *HttpProxy) InWorkConn(conn net.Conn) {
defer conn.Close()
HandleTcpWorkConnection(&pxy.cfg.LocalSvrConf, &pxy.cfg.BaseProxyConf, conn)
}
// HTTPS
type HttpsProxy struct {
cfg *config.HttpsProxyConf
ctl *Control
}
func (pxy *HttpsProxy) Run() {
}
func (pxy *HttpsProxy) Close() {
}
func (pxy *HttpsProxy) InWorkConn(conn net.Conn) {
defer conn.Close()
HandleTcpWorkConnection(&pxy.cfg.LocalSvrConf, &pxy.cfg.BaseProxyConf, conn)
}
// UDP
@ -111,34 +125,28 @@ func (pxy *UdpProxy) InWorkConn(conn net.Conn) {
defer conn.Close()
}
// HTTP
type HttpProxy struct {
cfg *config.HttpProxyConf
ctl *Control
}
// Common handler for tcp work connections.
func HandleTcpWorkConnection(localInfo *config.LocalSvrConf, baseInfo *config.BaseProxyConf, workConn net.Conn) {
localConn, err := net.ConnectTcpServer(fmt.Sprintf("%s:%d", localInfo.LocalIp, localInfo.LocalPort))
if err != nil {
workConn.Error("connect to local service [%s:%d] error: %v", localInfo.LocalIp, localInfo.LocalPort, err)
return
}
func (pxy *HttpProxy) Run() {
}
func (pxy *HttpProxy) Close() {
}
func (pxy *HttpProxy) InWorkConn(conn net.Conn) {
defer conn.Close()
}
// HTTPS
type HttpsProxy struct {
cfg *config.HttpsProxyConf
ctl *Control
}
func (pxy *HttpsProxy) Run() {
}
func (pxy *HttpsProxy) Close() {
}
func (pxy *HttpsProxy) InWorkConn(conn net.Conn) {
defer conn.Close()
var remote io.ReadWriteCloser
remote = workConn
if baseInfo.UseEncryption {
remote, err = tcp.WithEncryption(remote, []byte(config.ClientCommonCfg.PrivilegeToken))
if err != nil {
workConn.Error("create encryption stream error: %v", err)
return
}
}
if baseInfo.UseCompression {
remote = tcp.WithCompression(remote)
}
workConn.Debug("join connections, localConn(l[%s] r[%s]) workConn(l[%s] r[%s])", localConn.LocalAddr().String(),
localConn.RemoteAddr().String(), workConn.LocalAddr().String(), workConn.RemoteAddr().String())
tcp.Join(localConn, remote)
workConn.Debug("join connections closed")
}

View File

@ -16,6 +16,7 @@ package config
import (
"fmt"
"reflect"
"strconv"
"strings"
@ -25,6 +26,27 @@ import (
ini "github.com/vaughan0/go-ini"
)
var proxyConfTypeMap map[string]reflect.Type
func init() {
proxyConfTypeMap = make(map[string]reflect.Type)
proxyConfTypeMap[consts.TcpProxy] = reflect.TypeOf(TcpProxyConf{})
proxyConfTypeMap[consts.UdpProxy] = reflect.TypeOf(UdpProxyConf{})
proxyConfTypeMap[consts.HttpProxy] = reflect.TypeOf(HttpProxyConf{})
proxyConfTypeMap[consts.HttpsProxy] = reflect.TypeOf(HttpsProxyConf{})
}
// NewConfByType creates a empty ProxyConf object by proxyType.
// If proxyType isn't exist, return nil.
func NewConfByType(proxyType string) ProxyConf {
v, ok := proxyConfTypeMap[proxyType]
if !ok {
return nil
}
cfg := reflect.New(v).Interface().(ProxyConf)
return cfg
}
type ProxyConf interface {
GetName() string
GetBaseInfo() *BaseProxyConf
@ -38,16 +60,9 @@ func NewProxyConf(pMsg *msg.NewProxy) (cfg ProxyConf, err error) {
if pMsg.ProxyType == "" {
pMsg.ProxyType = consts.TcpProxy
}
switch pMsg.ProxyType {
case consts.TcpProxy:
cfg = &TcpProxyConf{}
case consts.UdpProxy:
cfg = &UdpProxyConf{}
case consts.HttpProxy:
cfg = &HttpProxyConf{}
case consts.HttpsProxy:
cfg = &HttpsProxyConf{}
default:
cfg = NewConfByType(pMsg.ProxyType)
if cfg == nil {
err = fmt.Errorf("proxy [%s] type [%s] error", pMsg.ProxyName, pMsg.ProxyType)
return
}
@ -62,16 +77,8 @@ func NewProxyConfFromFile(name string, section ini.Section) (cfg ProxyConf, err
proxyType = consts.TcpProxy
section["type"] = consts.TcpProxy
}
switch proxyType {
case consts.TcpProxy:
cfg = &TcpProxyConf{}
case consts.UdpProxy:
cfg = &UdpProxyConf{}
case consts.HttpProxy:
cfg = &HttpProxyConf{}
case consts.HttpsProxy:
cfg = &HttpsProxyConf{}
default:
cfg = NewConfByType(proxyType)
if cfg == nil {
err = fmt.Errorf("proxy [%s] type [%s] error", name, proxyType)
return
}
@ -223,7 +230,6 @@ func (cfg *DomainConf) check() (err error) {
if strings.Contains(cfg.SubDomain, ".") || strings.Contains(cfg.SubDomain, "*") {
return fmt.Errorf("'.' and '*' is not supported in subdomain")
}
cfg.SubDomain += "." + ServerCommonCfg.SubDomainHost
}
return nil
}

View File

@ -37,25 +37,21 @@ func init() {
TypeMap = make(map[byte]reflect.Type)
TypeStringMap = make(map[reflect.Type]byte)
TypeMap[TypeLogin] = getTypeFn((*Login)(nil))
TypeMap[TypeLoginResp] = getTypeFn((*LoginResp)(nil))
TypeMap[TypeNewProxy] = getTypeFn((*NewProxy)(nil))
TypeMap[TypeNewProxyResp] = getTypeFn((*NewProxyResp)(nil))
TypeMap[TypeNewWorkConn] = getTypeFn((*NewWorkConn)(nil))
TypeMap[TypeReqWorkConn] = getTypeFn((*ReqWorkConn)(nil))
TypeMap[TypeStartWorkConn] = getTypeFn((*StartWorkConn)(nil))
TypeMap[TypePing] = getTypeFn((*Ping)(nil))
TypeMap[TypePong] = getTypeFn((*Pong)(nil))
TypeMap[TypeLogin] = reflect.TypeOf(Login{})
TypeMap[TypeLoginResp] = reflect.TypeOf(LoginResp{})
TypeMap[TypeNewProxy] = reflect.TypeOf(NewProxy{})
TypeMap[TypeNewProxyResp] = reflect.TypeOf(NewProxyResp{})
TypeMap[TypeNewWorkConn] = reflect.TypeOf(NewWorkConn{})
TypeMap[TypeReqWorkConn] = reflect.TypeOf(ReqWorkConn{})
TypeMap[TypeStartWorkConn] = reflect.TypeOf(StartWorkConn{})
TypeMap[TypePing] = reflect.TypeOf(Ping{})
TypeMap[TypePong] = reflect.TypeOf(Pong{})
for k, v := range TypeMap {
TypeStringMap[v] = k
}
}
func getTypeFn(obj interface{}) reflect.Type {
return reflect.TypeOf(obj).Elem()
}
// Message wraps socket packages for communicating between frpc and frps.
type Message interface{}

View File

@ -17,6 +17,7 @@ package msg
import (
"bytes"
"encoding/binary"
"reflect"
"testing"
"github.com/stretchr/testify/assert"
@ -66,7 +67,7 @@ func TestUnPack(t *testing.T) {
// correct
msg, err = UnPack(TypePong, []byte("{}"))
assert.NoError(err)
assert.Equal(getTypeFn(msg), getTypeFn((*Pong)(nil)))
assert.Equal(reflect.TypeOf(msg).Elem(), reflect.TypeOf(Pong{}))
}
func TestUnPackInto(t *testing.T) {

View File

@ -19,8 +19,6 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/fatedier/frp/utils/crypto"
)
func TestJoin(t *testing.T) {
@ -67,63 +65,3 @@ func TestJoin(t *testing.T) {
conn3.Close()
conn4.Close()
}
func TestJoinEncrypt(t *testing.T) {
assert := assert.New(t)
var (
n int
err error
)
text1 := "1234567890"
text2 := "abcdefghij"
key := "authkey"
// Forward enrypted bytes.
pr, pw := io.Pipe()
pr2, pw2 := io.Pipe()
pr3, pw3 := io.Pipe()
pr4, pw4 := io.Pipe()
pr5, pw5 := io.Pipe()
pr6, pw6 := io.Pipe()
conn1 := WrapReadWriteCloser(pr, pw2)
conn2 := WrapReadWriteCloser(pr2, pw)
conn3 := WrapReadWriteCloser(pr3, pw4)
conn4 := WrapReadWriteCloser(pr4, pw3)
conn5 := WrapReadWriteCloser(pr5, pw6)
conn6 := WrapReadWriteCloser(pr6, pw5)
r1, err := crypto.NewReader(conn3, []byte(key))
assert.NoError(err)
w1, err := crypto.NewWriter(conn3, []byte(key))
assert.NoError(err)
r2, err := crypto.NewReader(conn4, []byte(key))
assert.NoError(err)
w2, err := crypto.NewWriter(conn4, []byte(key))
assert.NoError(err)
go Join(conn2, WrapReadWriteCloser(r1, w1))
go Join(WrapReadWriteCloser(r2, w2), conn5)
buf := make([]byte, 128)
conn1.Write([]byte(text1))
conn6.Write([]byte(text2))
n, err = conn6.Read(buf)
assert.NoError(err)
assert.Equal(text1, string(buf[:n]))
n, err = conn1.Read(buf)
assert.NoError(err)
assert.Equal(text2, string(buf[:n]))
conn1.Close()
conn2.Close()
conn3.Close()
conn4.Close()
conn5.Close()
conn6.Close()
}

View File

@ -22,32 +22,16 @@ import (
"github.com/fatedier/frp/utils/crypto"
)
func WithEncryption(rwc io.ReadWriteCloser, key []byte) (res io.ReadWriteCloser, err error) {
var (
r io.Reader
w io.Writer
)
r, err = crypto.NewReader(rwc, key)
func WithEncryption(rwc io.ReadWriteCloser, key []byte) (io.ReadWriteCloser, error) {
w, err := crypto.NewWriter(rwc, key)
if err != nil {
return
return nil, err
}
w, err = crypto.NewWriter(rwc, key)
if err != nil {
return
}
res = WrapReadWriteCloser(r, w)
return
return WrapReadWriteCloser(crypto.NewReader(rwc, key), w), nil
}
func WithCompression(rwc io.ReadWriteCloser) (res io.ReadWriteCloser) {
var (
r io.Reader
w io.Writer
)
r = snappy.NewReader(rwc)
w = snappy.NewWriter(rwc)
res = WrapReadWriteCloser(r, w)
return
func WithCompression(rwc io.ReadWriteCloser) io.ReadWriteCloser {
return WrapReadWriteCloser(snappy.NewReader(rwc), snappy.NewWriter(rwc))
}
func WrapReadWriteCloser(r io.Reader, w io.Writer) io.ReadWriteCloser {

View File

@ -212,11 +212,7 @@ func (ctl *Control) reader() {
defer ctl.allShutdown.Start()
defer ctl.readerShutdown.Done()
encReader, err := crypto.NewReader(ctl.conn, []byte(config.ServerCommonCfg.PrivilegeToken))
if err != nil {
ctl.conn.Error("crypto new reader error: %v", err)
return
}
encReader := crypto.NewReader(ctl.conn, []byte(config.ServerCommonCfg.PrivilegeToken))
for {
if m, err := msg.ReadMsg(encReader); err != nil {
if err == io.EOF {
@ -260,7 +256,7 @@ func (ctl *Control) stoper() {
}
ctl.allShutdown.Done()
ctl.conn.Info("all shutdown success")
ctl.conn.Info("client exit success")
}
func (ctl *Control) manager() {

View File

@ -9,6 +9,7 @@ import (
"github.com/fatedier/frp/models/proto/tcp"
"github.com/fatedier/frp/utils/log"
"github.com/fatedier/frp/utils/net"
"github.com/fatedier/frp/utils/vhost"
)
type Proxy interface {
@ -42,6 +43,27 @@ func (pxy *BaseProxy) Close() {
}
}
// startListenHandler start a goroutine handler for each listener.
// p: p will just be passed to handler(Proxy, net.Conn).
// handler: each proxy type can set different handler function to deal with connections accepted from listeners.
func (pxy *BaseProxy) startListenHandler(p Proxy, handler func(Proxy, net.Conn)) {
for _, listener := range pxy.listeners {
go func(l net.Listener) {
for {
// block
// if listener is closed, err returned
c, err := l.Accept()
if err != nil {
pxy.Info("listener is closed")
return
}
pxy.Debug("get a user connection [%s]", c.RemoteAddr().String())
go handler(p, c)
}
}(listener)
}
}
func NewProxy(ctl *Control, pxyConf config.ProxyConf) (pxy Proxy, err error) {
basePxy := BaseProxy{
name: pxyConf.GetName(),
@ -83,25 +105,14 @@ type TcpProxy struct {
}
func (pxy *TcpProxy) Run() error {
listener, err := net.ListenTcp(config.ServerCommonCfg.BindAddr, int64(pxy.cfg.RemotePort))
listener, err := net.ListenTcp(config.ServerCommonCfg.BindAddr, pxy.cfg.RemotePort)
if err != nil {
return err
}
pxy.listeners = append(pxy.listeners, listener)
pxy.Info("tcp proxy listen port [%d]", pxy.cfg.RemotePort)
go func(l net.Listener) {
for {
// block
// if listener is closed, err returned
c, err := l.Accept()
if err != nil {
pxy.Info("listener is closed")
return
}
pxy.Debug("got one user connection [%s]", c.RemoteAddr().String())
go HandleUserTcpConnection(pxy, c)
}
}(listener)
pxy.startListenHandler(pxy, HandleUserTcpConnection)
return nil
}
@ -119,6 +130,43 @@ type HttpProxy struct {
}
func (pxy *HttpProxy) Run() (err error) {
routeConfig := &vhost.VhostRouteConfig{
RewriteHost: pxy.cfg.HostHeaderRewrite,
Username: pxy.cfg.HttpUser,
Password: pxy.cfg.HttpPwd,
}
locations := pxy.cfg.Locations
if len(locations) == 0 {
locations = []string{""}
}
for _, domain := range pxy.cfg.CustomDomains {
routeConfig.Domain = domain
for _, location := range locations {
routeConfig.Location = location
l, err := pxy.ctl.svr.VhostHttpMuxer.Listen(routeConfig)
if err != nil {
return err
}
pxy.Info("http proxy listen for host [%s] location [%s]", routeConfig.Domain, routeConfig.Location)
pxy.listeners = append(pxy.listeners, l)
}
}
if pxy.cfg.SubDomain != "" {
routeConfig.Domain = pxy.cfg.SubDomain + "." + config.ServerCommonCfg.SubDomainHost
for _, location := range locations {
routeConfig.Location = location
l, err := pxy.ctl.svr.VhostHttpMuxer.Listen(routeConfig)
if err != nil {
return err
}
pxy.Info("http proxy listen for host [%s] location [%s]", routeConfig.Domain, routeConfig.Location)
pxy.listeners = append(pxy.listeners, l)
}
}
pxy.startListenHandler(pxy, HandleUserTcpConnection)
return
}
@ -136,6 +184,29 @@ type HttpsProxy struct {
}
func (pxy *HttpsProxy) Run() (err error) {
routeConfig := &vhost.VhostRouteConfig{}
for _, domain := range pxy.cfg.CustomDomains {
routeConfig.Domain = domain
l, err := pxy.ctl.svr.VhostHttpsMuxer.Listen(routeConfig)
if err != nil {
return err
}
pxy.Info("https proxy listen for host [%s]", routeConfig.Domain)
pxy.listeners = append(pxy.listeners, l)
}
if pxy.cfg.SubDomain != "" {
routeConfig.Domain = pxy.cfg.SubDomain + "." + config.ServerCommonCfg.SubDomainHost
l, err := pxy.ctl.svr.VhostHttpsMuxer.Listen(routeConfig)
if err != nil {
return err
}
pxy.Info("https proxy listen for host [%s]", routeConfig.Domain)
pxy.listeners = append(pxy.listeners, l)
}
pxy.startListenHandler(pxy, HandleUserTcpConnection)
return
}
@ -180,7 +251,7 @@ func HandleUserTcpConnection(pxy Proxy, userConn net.Conn) {
return
}
defer workConn.Close()
pxy.Info("get one new work connection: %s", workConn.RemoteAddr().String())
pxy.Info("get a new work connection: [%s]", workConn.RemoteAddr().String())
workConn.AddLogPrefix(pxy.GetName())
err := msg.WriteMsg(workConn, &msg.StartWorkConn{
@ -199,12 +270,7 @@ func HandleUserTcpConnection(pxy Proxy, userConn net.Conn) {
return
}
var (
local io.ReadWriteCloser
remote io.ReadWriteCloser
)
local = workConn
remote = userConn
var local io.ReadWriteCloser = workConn
cfg := pxy.GetConf().GetBaseInfo()
if cfg.UseEncryption {
local, err = tcp.WithEncryption(local, []byte(config.ServerCommonCfg.PrivilegeToken))
@ -216,5 +282,8 @@ func HandleUserTcpConnection(pxy Proxy, userConn net.Conn) {
if cfg.UseCompression {
local = tcp.WithCompression(local)
}
tcp.Join(local, remote)
pxy.Debug("join connections, workConn(l[%s] r[%s]) userConn(l[%s] r[%s])", workConn.LocalAddr().String(),
workConn.RemoteAddr().String(), userConn.LocalAddr().String(), userConn.RemoteAddr().String())
tcp.Join(local, userConn)
pxy.Debug("join connections closed")
}

View File

@ -22,15 +22,6 @@ import (
"github.com/stretchr/testify/assert"
)
func TestWriter(t *testing.T) {
// Empty key.
assert := assert.New(t)
key := ""
buffer := bytes.NewBuffer(nil)
_, err := NewWriter(buffer, []byte(key))
assert.NoError(err)
}
func TestCrypto(t *testing.T) {
assert := assert.New(t)
@ -40,12 +31,10 @@ func TestCrypto(t *testing.T) {
buffer := bytes.NewBuffer(nil)
encWriter, err := NewWriter(buffer, []byte(key))
assert.NoError(err)
decReader := NewReader(buffer, []byte(key))
encWriter.Write([]byte(text))
decReader, err := NewReader(buffer, []byte(key))
assert.NoError(err)
c := bytes.NewBuffer(nil)
io.Copy(c, decReader)
assert.Equal(text, string(c.Bytes()))

View File

@ -24,13 +24,13 @@ import (
)
// NewReader returns a new Reader that decrypts bytes from r
func NewReader(r io.Reader, key []byte) (*Reader, error) {
func NewReader(r io.Reader, key []byte) *Reader {
key = pbkdf2.Key(key, []byte(salt), 64, aes.BlockSize, sha1.New)
return &Reader{
r: r,
key: key,
}, nil
}
}
// Reader is an io.Reader that can read encrypted bytes.

View File

@ -67,10 +67,6 @@ type Writer struct {
// Write satisfies the io.Writer interface.
func (w *Writer) Write(p []byte) (nRet int, errRet error) {
return w.write(p)
}
func (w *Writer) write(p []byte) (nRet int, errRet error) {
if w.err != nil {
return 0, w.err
}

View File

@ -50,26 +50,35 @@ func NewVhostMuxer(listener frpNet.Listener, vhostFunc muxFunc, authFunc httpAut
return mux, nil
}
// listen for a new domain name, if rewriteHost is not empty and rewriteFunc is not nil, then rewrite the host header to rewriteHost
func (v *VhostMuxer) Listen(name, location, rewriteHost, userName, passWord string) (l *Listener, err error) {
type VhostRouteConfig struct {
Domain string
Location string
RewriteHost string
Username string
Password string
}
// listen for a new domain name, if rewriteHost is not empty and rewriteFunc is not nil
// then rewrite the host header to rewriteHost
func (v *VhostMuxer) Listen(cfg *VhostRouteConfig) (l *Listener, err error) {
v.mutex.Lock()
defer v.mutex.Unlock()
_, ok := v.registryRouter.Exist(name, location)
_, ok := v.registryRouter.Exist(cfg.Domain, cfg.Location)
if ok {
return nil, fmt.Errorf("hostname [%s] location [%s] is already registered", name, location)
return nil, fmt.Errorf("hostname [%s] location [%s] is already registered", cfg.Domain, cfg.Location)
}
l = &Listener{
name: name,
location: location,
rewriteHost: rewriteHost,
userName: userName,
passWord: passWord,
name: cfg.Domain,
location: cfg.Location,
rewriteHost: cfg.RewriteHost,
userName: cfg.Username,
passWord: cfg.Password,
mux: v,
accept: make(chan frpNet.Conn),
}
v.registryRouter.Add(name, location, l)
v.registryRouter.Add(cfg.Domain, cfg.Location, l)
return l, nil
}