frps: vhost_http_port and vhost_https_port can be same with frps bind

port
This commit is contained in:
fatedier 2018-05-06 04:27:30 +08:00
parent f45283dbdb
commit 5db605ca02
7 changed files with 423 additions and 25 deletions

View File

@ -16,6 +16,7 @@ kcp_bind_port = 7000
# proxy_bind_addr = 127.0.0.1 # proxy_bind_addr = 127.0.0.1
# if you want to support virtual host, you must set the http port for listening (optional) # if you want to support virtual host, you must set the http port for listening (optional)
# Note: http port and https port can be same with bind_port
vhost_http_port = 80 vhost_http_port = 80
vhost_https_port = 443 vhost_https_port = 443

View File

@ -26,6 +26,7 @@ import (
"github.com/fatedier/frp/models/msg" "github.com/fatedier/frp/models/msg"
"github.com/fatedier/frp/utils/log" "github.com/fatedier/frp/utils/log"
frpNet "github.com/fatedier/frp/utils/net" frpNet "github.com/fatedier/frp/utils/net"
"github.com/fatedier/frp/utils/net/mux"
"github.com/fatedier/frp/utils/util" "github.com/fatedier/frp/utils/util"
"github.com/fatedier/frp/utils/version" "github.com/fatedier/frp/utils/version"
"github.com/fatedier/frp/utils/vhost" "github.com/fatedier/frp/utils/vhost"
@ -41,6 +42,9 @@ var ServerService *Service
// Server service. // Server service.
type Service struct { type Service struct {
// Dispatch connections to different handlers listen on same port.
muxer *mux.Mux
// Accept connections from client. // Accept connections from client.
listener frpNet.Listener listener frpNet.Listener
@ -88,12 +92,33 @@ func NewService() (svr *Service, err error) {
return return
} }
var (
httpMuxOn bool
httpsMuxOn bool
)
if cfg.BindAddr == cfg.ProxyBindAddr {
if cfg.BindPort == cfg.VhostHttpPort {
httpMuxOn = true
}
if cfg.BindPort == cfg.VhostHttpsPort {
httpsMuxOn = true
}
if httpMuxOn || httpsMuxOn {
svr.muxer = mux.NewMux()
}
}
// Listen for accepting connections from client. // Listen for accepting connections from client.
svr.listener, err = frpNet.ListenTcp(cfg.BindAddr, cfg.BindPort) ln, err := net.Listen("tcp", fmt.Sprintf("%s:%d", cfg.BindAddr, cfg.BindPort))
if err != nil { if err != nil {
err = fmt.Errorf("Create server listener error, %v", err) err = fmt.Errorf("Create server listener error, %v", err)
return return
} }
if svr.muxer != nil {
go svr.muxer.Serve(ln)
ln = svr.muxer.DefaultListener()
}
svr.listener = frpNet.WrapLogListener(ln)
log.Info("frps tcp listen on %s:%d", cfg.BindAddr, cfg.BindPort) log.Info("frps tcp listen on %s:%d", cfg.BindAddr, cfg.BindPort)
// Listen for accepting connections from client using kcp protocol. // Listen for accepting connections from client using kcp protocol.
@ -117,24 +142,33 @@ func NewService() (svr *Service, err error) {
Handler: rp, Handler: rp,
} }
var l net.Listener var l net.Listener
if httpMuxOn {
l = svr.muxer.ListenHttp(0)
} else {
l, err = net.Listen("tcp", address) l, err = net.Listen("tcp", address)
if err != nil { if err != nil {
err = fmt.Errorf("Create vhost http listener error, %v", err) err = fmt.Errorf("Create vhost http listener error, %v", err)
return return
} }
}
go server.Serve(l) go server.Serve(l)
log.Info("http service listen on %s:%d", cfg.ProxyBindAddr, cfg.VhostHttpPort) log.Info("http service listen on %s:%d", cfg.ProxyBindAddr, cfg.VhostHttpPort)
} }
// Create https vhost muxer. // Create https vhost muxer.
if cfg.VhostHttpsPort > 0 { if cfg.VhostHttpsPort > 0 {
var l frpNet.Listener var l net.Listener
l, err = frpNet.ListenTcp(cfg.ProxyBindAddr, cfg.VhostHttpsPort) if httpsMuxOn {
l = svr.muxer.ListenHttps(0)
} else {
l, err = net.Listen("tcp", fmt.Sprintf("%s:%d", cfg.ProxyBindAddr, cfg.VhostHttpsPort))
if err != nil { if err != nil {
err = fmt.Errorf("Create vhost https listener error, %v", err) err = fmt.Errorf("Create server listener error, %v", err)
return return
} }
svr.VhostHttpsMuxer, err = vhost.NewHttpsMuxer(l, 30*time.Second) }
svr.VhostHttpsMuxer, err = vhost.NewHttpsMuxer(frpNet.WrapLogListener(l), 30*time.Second)
if err != nil { if err != nil {
err = fmt.Errorf("Create vhost httpsMuxer error, %v", err) err = fmt.Errorf("Create vhost httpsMuxer error, %v", err)
return return

View File

@ -20,7 +20,6 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"sync"
"sync/atomic" "sync/atomic"
"time" "time"
@ -136,7 +135,6 @@ func ConnectServerByProxy(proxyUrl string, protocol string, addr string) (c Conn
type SharedConn struct { type SharedConn struct {
Conn Conn
sync.Mutex
buf *bytes.Buffer buf *bytes.Buffer
} }
@ -149,22 +147,24 @@ func NewShareConn(conn Conn) (*SharedConn, io.Reader) {
return sc, io.TeeReader(conn, sc.buf) return sc, io.TeeReader(conn, sc.buf)
} }
func NewShareConnSize(conn Conn, bufSize int) (*SharedConn, io.Reader) {
sc := &SharedConn{
Conn: conn,
buf: bytes.NewBuffer(make([]byte, 0, bufSize)),
}
return sc, io.TeeReader(conn, sc.buf)
}
// Not thread safety.
func (sc *SharedConn) Read(p []byte) (n int, err error) { func (sc *SharedConn) Read(p []byte) (n int, err error) {
sc.Lock()
if sc.buf == nil { if sc.buf == nil {
sc.Unlock()
return sc.Conn.Read(p) return sc.Conn.Read(p)
} }
sc.Unlock()
n, err = sc.buf.Read(p) n, err = sc.buf.Read(p)
if err == io.EOF { if err == io.EOF {
sc.Lock()
sc.buf = nil sc.buf = nil
sc.Unlock()
var n2 int var n2 int
n2, err = sc.Conn.Read(p[n:]) n2, err = sc.Conn.Read(p[n:])
n += n2 n += n2
} }
return return

210
utils/net/mux/mux.go Normal file
View File

@ -0,0 +1,210 @@
package mux
import (
"fmt"
"io"
"net"
"sort"
"sync"
"time"
"github.com/fatedier/frp/utils/errors"
frpNet "github.com/fatedier/frp/utils/net"
)
const (
// DefaultTimeout is the default length of time to wait for bytes we need.
DefaultTimeout = 10 * time.Second
)
type Mux struct {
ln net.Listener
defaultLn *listener
lns []*listener
maxNeedBytesNum uint32
mu sync.RWMutex
}
func NewMux() (mux *Mux) {
mux = &Mux{
lns: make([]*listener, 0),
}
return
}
func (mux *Mux) Listen(priority int, needBytesNum uint32, fn MatchFunc) net.Listener {
ln := &listener{
c: make(chan net.Conn),
mux: mux,
needBytesNum: needBytesNum,
matchFn: fn,
}
mux.mu.Lock()
defer mux.mu.Unlock()
if needBytesNum > mux.maxNeedBytesNum {
mux.maxNeedBytesNum = needBytesNum
}
newlns := append(mux.copyLns(), ln)
sort.Slice(newlns, func(i, j int) bool {
return newlns[i].needBytesNum < newlns[j].needBytesNum
})
mux.lns = newlns
return ln
}
func (mux *Mux) ListenHttp(priority int) net.Listener {
return mux.Listen(priority, HttpNeedBytesNum, HttpMatchFunc)
}
func (mux *Mux) ListenHttps(priority int) net.Listener {
return mux.Listen(priority, HttpsNeedBytesNum, HttpsMatchFunc)
}
func (mux *Mux) DefaultListener() net.Listener {
mux.mu.Lock()
defer mux.mu.Unlock()
if mux.defaultLn == nil {
mux.defaultLn = &listener{
c: make(chan net.Conn),
mux: mux,
}
}
return mux.defaultLn
}
func (mux *Mux) release(ln *listener) bool {
result := false
mux.mu.Lock()
defer mux.mu.Unlock()
lns := mux.copyLns()
for i, l := range lns {
if l == ln {
lns = append(lns[:i], lns[i+1:]...)
result = true
}
}
mux.lns = lns
return result
}
func (mux *Mux) copyLns() []*listener {
lns := make([]*listener, 0, len(mux.lns))
for _, l := range mux.lns {
lns = append(lns, l)
}
return lns
}
// Serve handles connections from ln and multiplexes then across registered listeners.
func (mux *Mux) Serve(ln net.Listener) error {
mux.mu.Lock()
mux.ln = ln
mux.mu.Unlock()
for {
// Wait for the next connection.
// If it returns a temporary error then simply retry.
// If it returns any other error then exit immediately.
conn, err := ln.Accept()
if err, ok := err.(interface {
Temporary() bool
}); ok && err.Temporary() {
continue
}
if err != nil {
return err
}
go mux.handleConn(conn)
}
}
func (mux *Mux) handleConn(conn net.Conn) {
mux.mu.RLock()
maxNeedBytesNum := mux.maxNeedBytesNum
lns := mux.lns
defaultLn := mux.defaultLn
mux.mu.RUnlock()
shareConn, rd := frpNet.NewShareConnSize(frpNet.WrapConn(conn), int(maxNeedBytesNum))
data := make([]byte, maxNeedBytesNum)
conn.SetReadDeadline(time.Now().Add(DefaultTimeout))
_, err := io.ReadFull(rd, data)
if err != nil {
conn.Close()
return
}
conn.SetReadDeadline(time.Time{})
for _, ln := range lns {
if match := ln.matchFn(data); match {
err = errors.PanicToError(func() {
ln.c <- shareConn
})
if err != nil {
conn.Close()
}
return
}
}
// No match listeners
if defaultLn != nil {
err = errors.PanicToError(func() {
defaultLn.c <- shareConn
})
if err != nil {
conn.Close()
}
return
}
// No listeners for this connection, close it.
conn.Close()
return
}
type listener struct {
mux *Mux
needBytesNum uint32
matchFn MatchFunc
c chan net.Conn
mu sync.RWMutex
}
// Accept waits for and returns the next connection to the listener.
func (ln *listener) Accept() (net.Conn, error) {
conn, ok := <-ln.c
if !ok {
return nil, fmt.Errorf("network connection closed")
}
return conn, nil
}
// Close removes this listener from the parent mux and closes the channel.
func (ln *listener) Close() error {
if ok := ln.mux.release(ln); ok {
// Close done to signal to any RLock holders to release their lock.
close(ln.c)
}
return nil
}
func (ln *listener) Addr() net.Addr {
if ln.mux == nil {
return nil
}
ln.mux.mu.RLock()
defer ln.mux.mu.RUnlock()
if ln.mux.ln == nil {
return nil
}
return ln.mux.ln.Addr()
}

95
utils/net/mux/mux_test.go Normal file
View File

@ -0,0 +1,95 @@
package mux
import (
"bufio"
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func runHttpSvr(ln net.Listener) *httptest.Server {
svr := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("http service"))
}))
svr.Listener = ln
svr.Start()
return svr
}
func runHttpsSvr(ln net.Listener) *httptest.Server {
svr := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("https service"))
}))
svr.Listener = ln
svr.StartTLS()
return svr
}
func runEchoSvr(ln net.Listener) {
go func() {
for {
conn, err := ln.Accept()
if err != nil {
return
}
rd := bufio.NewReader(conn)
data, err := rd.ReadString('\n')
if err != nil {
return
}
conn.Write([]byte(data))
conn.Close()
}
}()
}
func TestMux(t *testing.T) {
assert := assert.New(t)
ln, err := net.Listen("tcp", "127.0.0.1:")
assert.NoError(err)
mux := NewMux()
httpLn := mux.ListenHttp(0)
httpsLn := mux.ListenHttps(0)
defaultLn := mux.DefaultListener()
go mux.Serve(ln)
time.Sleep(100 * time.Millisecond)
httpSvr := runHttpSvr(httpLn)
defer httpSvr.Close()
httpsSvr := runHttpsSvr(httpsLn)
defer httpsSvr.Close()
runEchoSvr(defaultLn)
defer ln.Close()
// test http service
resp, err := http.Get(httpSvr.URL)
assert.NoError(err)
data, err := ioutil.ReadAll(resp.Body)
assert.NoError(err)
assert.Equal("http service", string(data))
// test https service
client := httpsSvr.Client()
resp, err = client.Get(httpsSvr.URL)
assert.NoError(err)
data, err = ioutil.ReadAll(resp.Body)
assert.NoError(err)
assert.Equal("https service", string(data))
// test echo service
conn, err := net.Dial("tcp", ln.Addr().String())
assert.NoError(err)
_, err = conn.Write([]byte("test echo\n"))
assert.NoError(err)
data = make([]byte, 1024)
n, err := conn.Read(data)
assert.NoError(err)
assert.Equal("test echo\n", string(data[:n]))
}

55
utils/net/mux/rule.go Normal file
View File

@ -0,0 +1,55 @@
package mux
type MatchFunc func(data []byte) (match bool)
var (
HttpsNeedBytesNum uint32 = 1
HttpNeedBytesNum uint32 = 3
YamuxNeedBytesNum uint32 = 2
)
var HttpsMatchFunc MatchFunc = func(data []byte) bool {
if len(data) < int(HttpsNeedBytesNum) {
return false
}
if data[0] == 0x16 {
return true
} else {
return false
}
}
// From https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods
var httpHeadBytes = map[string]struct{}{
"GET": struct{}{},
"HEA": struct{}{},
"POS": struct{}{},
"PUT": struct{}{},
"DEL": struct{}{},
"CON": struct{}{},
"OPT": struct{}{},
"TRA": struct{}{},
"PAT": struct{}{},
}
var HttpMatchFunc MatchFunc = func(data []byte) bool {
if len(data) < int(HttpNeedBytesNum) {
return false
}
_, ok := httpHeadBytes[string(data[:3])]
return ok
}
// From https://github.com/hashicorp/yamux/blob/master/spec.md
var YamuxMatchFunc MatchFunc = func(data []byte) bool {
if len(data) < int(YamuxNeedBytesNum) {
return false
}
if data[0] == 0 && data[1] >= 0x0 && data[1] <= 0x3 {
return true
}
return false
}

View File

@ -55,14 +55,17 @@ func readHandshake(rd io.Reader) (host string, err error) {
data := pool.GetBuf(1024) data := pool.GetBuf(1024)
origin := data origin := data
defer pool.PutBuf(origin) defer pool.PutBuf(origin)
length, err := rd.Read(data)
_, err = io.ReadFull(rd, data[:47])
if err != nil {
return
}
length, err := rd.Read(data[47:])
if err != nil { if err != nil {
return return
} else { } else {
if length < 47 { length += 47
err = fmt.Errorf("readHandshake: proto length[%d] is too short", length)
return
}
} }
data = data[:length] data = data[:length]
if uint8(data[5]) != typeClientHello { if uint8(data[5]) != typeClientHello {