diff --git a/conf/frps_full.ini b/conf/frps_full.ini index 34a48db..641f388 100644 --- a/conf/frps_full.ini +++ b/conf/frps_full.ini @@ -16,6 +16,7 @@ kcp_bind_port = 7000 # proxy_bind_addr = 127.0.0.1 # if you want to support virtual host, you must set the http port for listening (optional) +# Note: http port and https port can be same with bind_port vhost_http_port = 80 vhost_https_port = 443 diff --git a/server/service.go b/server/service.go index 736502b..496fd55 100644 --- a/server/service.go +++ b/server/service.go @@ -26,6 +26,7 @@ import ( "github.com/fatedier/frp/models/msg" "github.com/fatedier/frp/utils/log" frpNet "github.com/fatedier/frp/utils/net" + "github.com/fatedier/frp/utils/net/mux" "github.com/fatedier/frp/utils/util" "github.com/fatedier/frp/utils/version" "github.com/fatedier/frp/utils/vhost" @@ -41,6 +42,9 @@ var ServerService *Service // Server service. type Service struct { + // Dispatch connections to different handlers listen on same port. + muxer *mux.Mux + // Accept connections from client. listener frpNet.Listener @@ -88,12 +92,33 @@ func NewService() (svr *Service, err error) { return } + var ( + httpMuxOn bool + httpsMuxOn bool + ) + if cfg.BindAddr == cfg.ProxyBindAddr { + if cfg.BindPort == cfg.VhostHttpPort { + httpMuxOn = true + } + if cfg.BindPort == cfg.VhostHttpsPort { + httpsMuxOn = true + } + if httpMuxOn || httpsMuxOn { + svr.muxer = mux.NewMux() + } + } + // Listen for accepting connections from client. - svr.listener, err = frpNet.ListenTcp(cfg.BindAddr, cfg.BindPort) + ln, err := net.Listen("tcp", fmt.Sprintf("%s:%d", cfg.BindAddr, cfg.BindPort)) if err != nil { err = fmt.Errorf("Create server listener error, %v", err) return } + if svr.muxer != nil { + go svr.muxer.Serve(ln) + ln = svr.muxer.DefaultListener() + } + svr.listener = frpNet.WrapLogListener(ln) log.Info("frps tcp listen on %s:%d", cfg.BindAddr, cfg.BindPort) // Listen for accepting connections from client using kcp protocol. @@ -117,10 +142,14 @@ func NewService() (svr *Service, err error) { Handler: rp, } var l net.Listener - l, err = net.Listen("tcp", address) - if err != nil { - err = fmt.Errorf("Create vhost http listener error, %v", err) - return + if httpMuxOn { + l = svr.muxer.ListenHttp(0) + } else { + l, err = net.Listen("tcp", address) + if err != nil { + err = fmt.Errorf("Create vhost http listener error, %v", err) + return + } } go server.Serve(l) log.Info("http service listen on %s:%d", cfg.ProxyBindAddr, cfg.VhostHttpPort) @@ -128,13 +157,18 @@ func NewService() (svr *Service, err error) { // Create https vhost muxer. if cfg.VhostHttpsPort > 0 { - var l frpNet.Listener - l, err = frpNet.ListenTcp(cfg.ProxyBindAddr, cfg.VhostHttpsPort) - if err != nil { - err = fmt.Errorf("Create vhost https listener error, %v", err) - return + var l net.Listener + if httpsMuxOn { + l = svr.muxer.ListenHttps(0) + } else { + l, err = net.Listen("tcp", fmt.Sprintf("%s:%d", cfg.ProxyBindAddr, cfg.VhostHttpsPort)) + if err != nil { + err = fmt.Errorf("Create server listener error, %v", err) + return + } } - svr.VhostHttpsMuxer, err = vhost.NewHttpsMuxer(l, 30*time.Second) + + svr.VhostHttpsMuxer, err = vhost.NewHttpsMuxer(frpNet.WrapLogListener(l), 30*time.Second) if err != nil { err = fmt.Errorf("Create vhost httpsMuxer error, %v", err) return diff --git a/utils/net/conn.go b/utils/net/conn.go index e3473a8..e18d9c4 100644 --- a/utils/net/conn.go +++ b/utils/net/conn.go @@ -20,7 +20,6 @@ import ( "fmt" "io" "net" - "sync" "sync/atomic" "time" @@ -136,7 +135,6 @@ func ConnectServerByProxy(proxyUrl string, protocol string, addr string) (c Conn type SharedConn struct { Conn - sync.Mutex buf *bytes.Buffer } @@ -149,22 +147,24 @@ func NewShareConn(conn Conn) (*SharedConn, io.Reader) { return sc, io.TeeReader(conn, sc.buf) } +func NewShareConnSize(conn Conn, bufSize int) (*SharedConn, io.Reader) { + sc := &SharedConn{ + Conn: conn, + buf: bytes.NewBuffer(make([]byte, 0, bufSize)), + } + return sc, io.TeeReader(conn, sc.buf) +} + +// Not thread safety. func (sc *SharedConn) Read(p []byte) (n int, err error) { - sc.Lock() if sc.buf == nil { - sc.Unlock() return sc.Conn.Read(p) } - sc.Unlock() n, err = sc.buf.Read(p) - if err == io.EOF { - sc.Lock() sc.buf = nil - sc.Unlock() var n2 int n2, err = sc.Conn.Read(p[n:]) - n += n2 } return diff --git a/utils/net/mux/mux.go b/utils/net/mux/mux.go new file mode 100644 index 0000000..4934f19 --- /dev/null +++ b/utils/net/mux/mux.go @@ -0,0 +1,210 @@ +package mux + +import ( + "fmt" + "io" + "net" + "sort" + "sync" + "time" + + "github.com/fatedier/frp/utils/errors" + frpNet "github.com/fatedier/frp/utils/net" +) + +const ( + // DefaultTimeout is the default length of time to wait for bytes we need. + DefaultTimeout = 10 * time.Second +) + +type Mux struct { + ln net.Listener + + defaultLn *listener + lns []*listener + maxNeedBytesNum uint32 + mu sync.RWMutex +} + +func NewMux() (mux *Mux) { + mux = &Mux{ + lns: make([]*listener, 0), + } + return +} + +func (mux *Mux) Listen(priority int, needBytesNum uint32, fn MatchFunc) net.Listener { + ln := &listener{ + c: make(chan net.Conn), + mux: mux, + needBytesNum: needBytesNum, + matchFn: fn, + } + + mux.mu.Lock() + defer mux.mu.Unlock() + if needBytesNum > mux.maxNeedBytesNum { + mux.maxNeedBytesNum = needBytesNum + } + + newlns := append(mux.copyLns(), ln) + sort.Slice(newlns, func(i, j int) bool { + return newlns[i].needBytesNum < newlns[j].needBytesNum + }) + mux.lns = newlns + return ln +} + +func (mux *Mux) ListenHttp(priority int) net.Listener { + return mux.Listen(priority, HttpNeedBytesNum, HttpMatchFunc) +} + +func (mux *Mux) ListenHttps(priority int) net.Listener { + return mux.Listen(priority, HttpsNeedBytesNum, HttpsMatchFunc) +} + +func (mux *Mux) DefaultListener() net.Listener { + mux.mu.Lock() + defer mux.mu.Unlock() + if mux.defaultLn == nil { + mux.defaultLn = &listener{ + c: make(chan net.Conn), + mux: mux, + } + } + return mux.defaultLn +} + +func (mux *Mux) release(ln *listener) bool { + result := false + mux.mu.Lock() + defer mux.mu.Unlock() + lns := mux.copyLns() + + for i, l := range lns { + if l == ln { + lns = append(lns[:i], lns[i+1:]...) + result = true + } + } + mux.lns = lns + return result +} + +func (mux *Mux) copyLns() []*listener { + lns := make([]*listener, 0, len(mux.lns)) + for _, l := range mux.lns { + lns = append(lns, l) + } + return lns +} + +// Serve handles connections from ln and multiplexes then across registered listeners. +func (mux *Mux) Serve(ln net.Listener) error { + mux.mu.Lock() + mux.ln = ln + mux.mu.Unlock() + for { + // Wait for the next connection. + // If it returns a temporary error then simply retry. + // If it returns any other error then exit immediately. + conn, err := ln.Accept() + if err, ok := err.(interface { + Temporary() bool + }); ok && err.Temporary() { + continue + } + + if err != nil { + return err + } + + go mux.handleConn(conn) + } +} + +func (mux *Mux) handleConn(conn net.Conn) { + mux.mu.RLock() + maxNeedBytesNum := mux.maxNeedBytesNum + lns := mux.lns + defaultLn := mux.defaultLn + mux.mu.RUnlock() + + shareConn, rd := frpNet.NewShareConnSize(frpNet.WrapConn(conn), int(maxNeedBytesNum)) + data := make([]byte, maxNeedBytesNum) + + conn.SetReadDeadline(time.Now().Add(DefaultTimeout)) + _, err := io.ReadFull(rd, data) + if err != nil { + conn.Close() + return + } + conn.SetReadDeadline(time.Time{}) + + for _, ln := range lns { + if match := ln.matchFn(data); match { + err = errors.PanicToError(func() { + ln.c <- shareConn + }) + if err != nil { + conn.Close() + } + return + } + } + + // No match listeners + if defaultLn != nil { + err = errors.PanicToError(func() { + defaultLn.c <- shareConn + }) + if err != nil { + conn.Close() + } + return + } + + // No listeners for this connection, close it. + conn.Close() + return +} + +type listener struct { + mux *Mux + + needBytesNum uint32 + matchFn MatchFunc + + c chan net.Conn + mu sync.RWMutex +} + +// Accept waits for and returns the next connection to the listener. +func (ln *listener) Accept() (net.Conn, error) { + conn, ok := <-ln.c + if !ok { + return nil, fmt.Errorf("network connection closed") + } + return conn, nil +} + +// Close removes this listener from the parent mux and closes the channel. +func (ln *listener) Close() error { + if ok := ln.mux.release(ln); ok { + // Close done to signal to any RLock holders to release their lock. + close(ln.c) + } + return nil +} + +func (ln *listener) Addr() net.Addr { + if ln.mux == nil { + return nil + } + ln.mux.mu.RLock() + defer ln.mux.mu.RUnlock() + if ln.mux.ln == nil { + return nil + } + return ln.mux.ln.Addr() +} diff --git a/utils/net/mux/mux_test.go b/utils/net/mux/mux_test.go new file mode 100644 index 0000000..fd3a9e2 --- /dev/null +++ b/utils/net/mux/mux_test.go @@ -0,0 +1,95 @@ +package mux + +import ( + "bufio" + "io/ioutil" + "net" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func runHttpSvr(ln net.Listener) *httptest.Server { + svr := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("http service")) + })) + svr.Listener = ln + svr.Start() + return svr +} + +func runHttpsSvr(ln net.Listener) *httptest.Server { + svr := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("https service")) + })) + svr.Listener = ln + svr.StartTLS() + return svr +} + +func runEchoSvr(ln net.Listener) { + go func() { + for { + conn, err := ln.Accept() + if err != nil { + return + } + rd := bufio.NewReader(conn) + data, err := rd.ReadString('\n') + if err != nil { + return + } + conn.Write([]byte(data)) + conn.Close() + } + }() +} + +func TestMux(t *testing.T) { + assert := assert.New(t) + + ln, err := net.Listen("tcp", "127.0.0.1:") + assert.NoError(err) + + mux := NewMux() + httpLn := mux.ListenHttp(0) + httpsLn := mux.ListenHttps(0) + defaultLn := mux.DefaultListener() + go mux.Serve(ln) + time.Sleep(100 * time.Millisecond) + + httpSvr := runHttpSvr(httpLn) + defer httpSvr.Close() + httpsSvr := runHttpsSvr(httpsLn) + defer httpsSvr.Close() + runEchoSvr(defaultLn) + defer ln.Close() + + // test http service + resp, err := http.Get(httpSvr.URL) + assert.NoError(err) + data, err := ioutil.ReadAll(resp.Body) + assert.NoError(err) + assert.Equal("http service", string(data)) + + // test https service + client := httpsSvr.Client() + resp, err = client.Get(httpsSvr.URL) + assert.NoError(err) + data, err = ioutil.ReadAll(resp.Body) + assert.NoError(err) + assert.Equal("https service", string(data)) + + // test echo service + conn, err := net.Dial("tcp", ln.Addr().String()) + assert.NoError(err) + _, err = conn.Write([]byte("test echo\n")) + assert.NoError(err) + data = make([]byte, 1024) + n, err := conn.Read(data) + assert.NoError(err) + assert.Equal("test echo\n", string(data[:n])) +} diff --git a/utils/net/mux/rule.go b/utils/net/mux/rule.go new file mode 100644 index 0000000..f01b058 --- /dev/null +++ b/utils/net/mux/rule.go @@ -0,0 +1,55 @@ +package mux + +type MatchFunc func(data []byte) (match bool) + +var ( + HttpsNeedBytesNum uint32 = 1 + HttpNeedBytesNum uint32 = 3 + YamuxNeedBytesNum uint32 = 2 +) + +var HttpsMatchFunc MatchFunc = func(data []byte) bool { + if len(data) < int(HttpsNeedBytesNum) { + return false + } + + if data[0] == 0x16 { + return true + } else { + return false + } +} + +// From https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods +var httpHeadBytes = map[string]struct{}{ + "GET": struct{}{}, + "HEA": struct{}{}, + "POS": struct{}{}, + "PUT": struct{}{}, + "DEL": struct{}{}, + "CON": struct{}{}, + "OPT": struct{}{}, + "TRA": struct{}{}, + "PAT": struct{}{}, +} + +var HttpMatchFunc MatchFunc = func(data []byte) bool { + if len(data) < int(HttpNeedBytesNum) { + return false + } + + _, ok := httpHeadBytes[string(data[:3])] + return ok +} + +// From https://github.com/hashicorp/yamux/blob/master/spec.md +var YamuxMatchFunc MatchFunc = func(data []byte) bool { + if len(data) < int(YamuxNeedBytesNum) { + return false + } + + if data[0] == 0 && data[1] >= 0x0 && data[1] <= 0x3 { + return true + } + return false +} diff --git a/utils/vhost/https.go b/utils/vhost/https.go index a6ef55d..a37c2e3 100644 --- a/utils/vhost/https.go +++ b/utils/vhost/https.go @@ -55,14 +55,17 @@ func readHandshake(rd io.Reader) (host string, err error) { data := pool.GetBuf(1024) origin := data defer pool.PutBuf(origin) - length, err := rd.Read(data) + + _, err = io.ReadFull(rd, data[:47]) + if err != nil { + return + } + + length, err := rd.Read(data[47:]) if err != nil { return } else { - if length < 47 { - err = fmt.Errorf("readHandshake: proto length[%d] is too short", length) - return - } + length += 47 } data = data[:length] if uint8(data[5]) != typeClientHello {