mirror of
https://gitee.com/IrisVega/frp.git
synced 2024-11-01 22:31:29 +08:00
Merge pull request #746 from fatedier/mux
http port and https port can be same with frps bind_port
This commit is contained in:
commit
178efd67f1
@ -16,6 +16,7 @@ kcp_bind_port = 7000
|
||||
# proxy_bind_addr = 127.0.0.1
|
||||
|
||||
# if you want to support virtual host, you must set the http port for listening (optional)
|
||||
# Note: http port and https port can be same with bind_port
|
||||
vhost_http_port = 80
|
||||
vhost_https_port = 443
|
||||
|
||||
|
@ -66,14 +66,21 @@ func (hp *HttpProxy) Handle(conn io.ReadWriteCloser, realConn frpNet.Conn) {
|
||||
wrapConn := frpNet.WrapReadWriteCloserToConn(conn, realConn)
|
||||
|
||||
sc, rd := frpNet.NewShareConn(wrapConn)
|
||||
request, err := http.ReadRequest(bufio.NewReader(rd))
|
||||
firstBytes := make([]byte, 7)
|
||||
_, err := rd.Read(firstBytes)
|
||||
if err != nil {
|
||||
wrapConn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
if request.Method == http.MethodConnect {
|
||||
hp.handleConnectReq(request, frpIo.WrapReadWriteCloser(rd, wrapConn, nil))
|
||||
if strings.ToUpper(string(firstBytes)) == "CONNECT" {
|
||||
bufRd := bufio.NewReader(sc)
|
||||
request, err := http.ReadRequest(bufRd)
|
||||
if err != nil {
|
||||
wrapConn.Close()
|
||||
return
|
||||
}
|
||||
hp.handleConnectReq(request, frpIo.WrapReadWriteCloser(bufRd, wrapConn, wrapConn.Close))
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -26,6 +26,7 @@ import (
|
||||
"github.com/fatedier/frp/models/msg"
|
||||
"github.com/fatedier/frp/utils/log"
|
||||
frpNet "github.com/fatedier/frp/utils/net"
|
||||
"github.com/fatedier/frp/utils/net/mux"
|
||||
"github.com/fatedier/frp/utils/util"
|
||||
"github.com/fatedier/frp/utils/version"
|
||||
"github.com/fatedier/frp/utils/vhost"
|
||||
@ -41,6 +42,9 @@ var ServerService *Service
|
||||
|
||||
// Server service.
|
||||
type Service struct {
|
||||
// Dispatch connections to different handlers listen on same port.
|
||||
muxer *mux.Mux
|
||||
|
||||
// Accept connections from client.
|
||||
listener frpNet.Listener
|
||||
|
||||
@ -88,12 +92,33 @@ func NewService() (svr *Service, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
var (
|
||||
httpMuxOn bool
|
||||
httpsMuxOn bool
|
||||
)
|
||||
if cfg.BindAddr == cfg.ProxyBindAddr {
|
||||
if cfg.BindPort == cfg.VhostHttpPort {
|
||||
httpMuxOn = true
|
||||
}
|
||||
if cfg.BindPort == cfg.VhostHttpsPort {
|
||||
httpsMuxOn = true
|
||||
}
|
||||
if httpMuxOn || httpsMuxOn {
|
||||
svr.muxer = mux.NewMux()
|
||||
}
|
||||
}
|
||||
|
||||
// Listen for accepting connections from client.
|
||||
svr.listener, err = frpNet.ListenTcp(cfg.BindAddr, cfg.BindPort)
|
||||
ln, err := net.Listen("tcp", fmt.Sprintf("%s:%d", cfg.BindAddr, cfg.BindPort))
|
||||
if err != nil {
|
||||
err = fmt.Errorf("Create server listener error, %v", err)
|
||||
return
|
||||
}
|
||||
if svr.muxer != nil {
|
||||
go svr.muxer.Serve(ln)
|
||||
ln = svr.muxer.DefaultListener()
|
||||
}
|
||||
svr.listener = frpNet.WrapLogListener(ln)
|
||||
log.Info("frps tcp listen on %s:%d", cfg.BindAddr, cfg.BindPort)
|
||||
|
||||
// Listen for accepting connections from client using kcp protocol.
|
||||
@ -117,24 +142,33 @@ func NewService() (svr *Service, err error) {
|
||||
Handler: rp,
|
||||
}
|
||||
var l net.Listener
|
||||
if httpMuxOn {
|
||||
l = svr.muxer.ListenHttp(0)
|
||||
} else {
|
||||
l, err = net.Listen("tcp", address)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("Create vhost http listener error, %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
go server.Serve(l)
|
||||
log.Info("http service listen on %s:%d", cfg.ProxyBindAddr, cfg.VhostHttpPort)
|
||||
}
|
||||
|
||||
// Create https vhost muxer.
|
||||
if cfg.VhostHttpsPort > 0 {
|
||||
var l frpNet.Listener
|
||||
l, err = frpNet.ListenTcp(cfg.ProxyBindAddr, cfg.VhostHttpsPort)
|
||||
var l net.Listener
|
||||
if httpsMuxOn {
|
||||
l = svr.muxer.ListenHttps(0)
|
||||
} else {
|
||||
l, err = net.Listen("tcp", fmt.Sprintf("%s:%d", cfg.ProxyBindAddr, cfg.VhostHttpsPort))
|
||||
if err != nil {
|
||||
err = fmt.Errorf("Create vhost https listener error, %v", err)
|
||||
err = fmt.Errorf("Create server listener error, %v", err)
|
||||
return
|
||||
}
|
||||
svr.VhostHttpsMuxer, err = vhost.NewHttpsMuxer(l, 30*time.Second)
|
||||
}
|
||||
|
||||
svr.VhostHttpsMuxer, err = vhost.NewHttpsMuxer(frpNet.WrapLogListener(l), 30*time.Second)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("Create vhost httpsMuxer error, %v", err)
|
||||
return
|
||||
|
@ -20,7 +20,6 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
@ -136,7 +135,6 @@ func ConnectServerByProxy(proxyUrl string, protocol string, addr string) (c Conn
|
||||
|
||||
type SharedConn struct {
|
||||
Conn
|
||||
sync.Mutex
|
||||
buf *bytes.Buffer
|
||||
}
|
||||
|
||||
@ -149,22 +147,24 @@ func NewShareConn(conn Conn) (*SharedConn, io.Reader) {
|
||||
return sc, io.TeeReader(conn, sc.buf)
|
||||
}
|
||||
|
||||
func NewShareConnSize(conn Conn, bufSize int) (*SharedConn, io.Reader) {
|
||||
sc := &SharedConn{
|
||||
Conn: conn,
|
||||
buf: bytes.NewBuffer(make([]byte, 0, bufSize)),
|
||||
}
|
||||
return sc, io.TeeReader(conn, sc.buf)
|
||||
}
|
||||
|
||||
// Not thread safety.
|
||||
func (sc *SharedConn) Read(p []byte) (n int, err error) {
|
||||
sc.Lock()
|
||||
if sc.buf == nil {
|
||||
sc.Unlock()
|
||||
return sc.Conn.Read(p)
|
||||
}
|
||||
sc.Unlock()
|
||||
n, err = sc.buf.Read(p)
|
||||
|
||||
if err == io.EOF {
|
||||
sc.Lock()
|
||||
sc.buf = nil
|
||||
sc.Unlock()
|
||||
var n2 int
|
||||
n2, err = sc.Conn.Read(p[n:])
|
||||
|
||||
n += n2
|
||||
}
|
||||
return
|
||||
|
210
utils/net/mux/mux.go
Normal file
210
utils/net/mux/mux.go
Normal file
@ -0,0 +1,210 @@
|
||||
package mux
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fatedier/frp/utils/errors"
|
||||
frpNet "github.com/fatedier/frp/utils/net"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultTimeout is the default length of time to wait for bytes we need.
|
||||
DefaultTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
type Mux struct {
|
||||
ln net.Listener
|
||||
|
||||
defaultLn *listener
|
||||
lns []*listener
|
||||
maxNeedBytesNum uint32
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewMux() (mux *Mux) {
|
||||
mux = &Mux{
|
||||
lns: make([]*listener, 0),
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (mux *Mux) Listen(priority int, needBytesNum uint32, fn MatchFunc) net.Listener {
|
||||
ln := &listener{
|
||||
c: make(chan net.Conn),
|
||||
mux: mux,
|
||||
needBytesNum: needBytesNum,
|
||||
matchFn: fn,
|
||||
}
|
||||
|
||||
mux.mu.Lock()
|
||||
defer mux.mu.Unlock()
|
||||
if needBytesNum > mux.maxNeedBytesNum {
|
||||
mux.maxNeedBytesNum = needBytesNum
|
||||
}
|
||||
|
||||
newlns := append(mux.copyLns(), ln)
|
||||
sort.Slice(newlns, func(i, j int) bool {
|
||||
return newlns[i].needBytesNum < newlns[j].needBytesNum
|
||||
})
|
||||
mux.lns = newlns
|
||||
return ln
|
||||
}
|
||||
|
||||
func (mux *Mux) ListenHttp(priority int) net.Listener {
|
||||
return mux.Listen(priority, HttpNeedBytesNum, HttpMatchFunc)
|
||||
}
|
||||
|
||||
func (mux *Mux) ListenHttps(priority int) net.Listener {
|
||||
return mux.Listen(priority, HttpsNeedBytesNum, HttpsMatchFunc)
|
||||
}
|
||||
|
||||
func (mux *Mux) DefaultListener() net.Listener {
|
||||
mux.mu.Lock()
|
||||
defer mux.mu.Unlock()
|
||||
if mux.defaultLn == nil {
|
||||
mux.defaultLn = &listener{
|
||||
c: make(chan net.Conn),
|
||||
mux: mux,
|
||||
}
|
||||
}
|
||||
return mux.defaultLn
|
||||
}
|
||||
|
||||
func (mux *Mux) release(ln *listener) bool {
|
||||
result := false
|
||||
mux.mu.Lock()
|
||||
defer mux.mu.Unlock()
|
||||
lns := mux.copyLns()
|
||||
|
||||
for i, l := range lns {
|
||||
if l == ln {
|
||||
lns = append(lns[:i], lns[i+1:]...)
|
||||
result = true
|
||||
}
|
||||
}
|
||||
mux.lns = lns
|
||||
return result
|
||||
}
|
||||
|
||||
func (mux *Mux) copyLns() []*listener {
|
||||
lns := make([]*listener, 0, len(mux.lns))
|
||||
for _, l := range mux.lns {
|
||||
lns = append(lns, l)
|
||||
}
|
||||
return lns
|
||||
}
|
||||
|
||||
// Serve handles connections from ln and multiplexes then across registered listeners.
|
||||
func (mux *Mux) Serve(ln net.Listener) error {
|
||||
mux.mu.Lock()
|
||||
mux.ln = ln
|
||||
mux.mu.Unlock()
|
||||
for {
|
||||
// Wait for the next connection.
|
||||
// If it returns a temporary error then simply retry.
|
||||
// If it returns any other error then exit immediately.
|
||||
conn, err := ln.Accept()
|
||||
if err, ok := err.(interface {
|
||||
Temporary() bool
|
||||
}); ok && err.Temporary() {
|
||||
continue
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
go mux.handleConn(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (mux *Mux) handleConn(conn net.Conn) {
|
||||
mux.mu.RLock()
|
||||
maxNeedBytesNum := mux.maxNeedBytesNum
|
||||
lns := mux.lns
|
||||
defaultLn := mux.defaultLn
|
||||
mux.mu.RUnlock()
|
||||
|
||||
shareConn, rd := frpNet.NewShareConnSize(frpNet.WrapConn(conn), int(maxNeedBytesNum))
|
||||
data := make([]byte, maxNeedBytesNum)
|
||||
|
||||
conn.SetReadDeadline(time.Now().Add(DefaultTimeout))
|
||||
_, err := io.ReadFull(rd, data)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
conn.SetReadDeadline(time.Time{})
|
||||
|
||||
for _, ln := range lns {
|
||||
if match := ln.matchFn(data); match {
|
||||
err = errors.PanicToError(func() {
|
||||
ln.c <- shareConn
|
||||
})
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// No match listeners
|
||||
if defaultLn != nil {
|
||||
err = errors.PanicToError(func() {
|
||||
defaultLn.c <- shareConn
|
||||
})
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// No listeners for this connection, close it.
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
type listener struct {
|
||||
mux *Mux
|
||||
|
||||
needBytesNum uint32
|
||||
matchFn MatchFunc
|
||||
|
||||
c chan net.Conn
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// Accept waits for and returns the next connection to the listener.
|
||||
func (ln *listener) Accept() (net.Conn, error) {
|
||||
conn, ok := <-ln.c
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("network connection closed")
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// Close removes this listener from the parent mux and closes the channel.
|
||||
func (ln *listener) Close() error {
|
||||
if ok := ln.mux.release(ln); ok {
|
||||
// Close done to signal to any RLock holders to release their lock.
|
||||
close(ln.c)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ln *listener) Addr() net.Addr {
|
||||
if ln.mux == nil {
|
||||
return nil
|
||||
}
|
||||
ln.mux.mu.RLock()
|
||||
defer ln.mux.mu.RUnlock()
|
||||
if ln.mux.ln == nil {
|
||||
return nil
|
||||
}
|
||||
return ln.mux.ln.Addr()
|
||||
}
|
95
utils/net/mux/mux_test.go
Normal file
95
utils/net/mux/mux_test.go
Normal file
@ -0,0 +1,95 @@
|
||||
package mux
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func runHttpSvr(ln net.Listener) *httptest.Server {
|
||||
svr := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("http service"))
|
||||
}))
|
||||
svr.Listener = ln
|
||||
svr.Start()
|
||||
return svr
|
||||
}
|
||||
|
||||
func runHttpsSvr(ln net.Listener) *httptest.Server {
|
||||
svr := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("https service"))
|
||||
}))
|
||||
svr.Listener = ln
|
||||
svr.StartTLS()
|
||||
return svr
|
||||
}
|
||||
|
||||
func runEchoSvr(ln net.Listener) {
|
||||
go func() {
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
rd := bufio.NewReader(conn)
|
||||
data, err := rd.ReadString('\n')
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
conn.Write([]byte(data))
|
||||
conn.Close()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func TestMux(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:")
|
||||
assert.NoError(err)
|
||||
|
||||
mux := NewMux()
|
||||
httpLn := mux.ListenHttp(0)
|
||||
httpsLn := mux.ListenHttps(0)
|
||||
defaultLn := mux.DefaultListener()
|
||||
go mux.Serve(ln)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
httpSvr := runHttpSvr(httpLn)
|
||||
defer httpSvr.Close()
|
||||
httpsSvr := runHttpsSvr(httpsLn)
|
||||
defer httpsSvr.Close()
|
||||
runEchoSvr(defaultLn)
|
||||
defer ln.Close()
|
||||
|
||||
// test http service
|
||||
resp, err := http.Get(httpSvr.URL)
|
||||
assert.NoError(err)
|
||||
data, err := ioutil.ReadAll(resp.Body)
|
||||
assert.NoError(err)
|
||||
assert.Equal("http service", string(data))
|
||||
|
||||
// test https service
|
||||
client := httpsSvr.Client()
|
||||
resp, err = client.Get(httpsSvr.URL)
|
||||
assert.NoError(err)
|
||||
data, err = ioutil.ReadAll(resp.Body)
|
||||
assert.NoError(err)
|
||||
assert.Equal("https service", string(data))
|
||||
|
||||
// test echo service
|
||||
conn, err := net.Dial("tcp", ln.Addr().String())
|
||||
assert.NoError(err)
|
||||
_, err = conn.Write([]byte("test echo\n"))
|
||||
assert.NoError(err)
|
||||
data = make([]byte, 1024)
|
||||
n, err := conn.Read(data)
|
||||
assert.NoError(err)
|
||||
assert.Equal("test echo\n", string(data[:n]))
|
||||
}
|
55
utils/net/mux/rule.go
Normal file
55
utils/net/mux/rule.go
Normal file
@ -0,0 +1,55 @@
|
||||
package mux
|
||||
|
||||
type MatchFunc func(data []byte) (match bool)
|
||||
|
||||
var (
|
||||
HttpsNeedBytesNum uint32 = 1
|
||||
HttpNeedBytesNum uint32 = 3
|
||||
YamuxNeedBytesNum uint32 = 2
|
||||
)
|
||||
|
||||
var HttpsMatchFunc MatchFunc = func(data []byte) bool {
|
||||
if len(data) < int(HttpsNeedBytesNum) {
|
||||
return false
|
||||
}
|
||||
|
||||
if data[0] == 0x16 {
|
||||
return true
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// From https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods
|
||||
var httpHeadBytes = map[string]struct{}{
|
||||
"GET": struct{}{},
|
||||
"HEA": struct{}{},
|
||||
"POS": struct{}{},
|
||||
"PUT": struct{}{},
|
||||
"DEL": struct{}{},
|
||||
"CON": struct{}{},
|
||||
"OPT": struct{}{},
|
||||
"TRA": struct{}{},
|
||||
"PAT": struct{}{},
|
||||
}
|
||||
|
||||
var HttpMatchFunc MatchFunc = func(data []byte) bool {
|
||||
if len(data) < int(HttpNeedBytesNum) {
|
||||
return false
|
||||
}
|
||||
|
||||
_, ok := httpHeadBytes[string(data[:3])]
|
||||
return ok
|
||||
}
|
||||
|
||||
// From https://github.com/hashicorp/yamux/blob/master/spec.md
|
||||
var YamuxMatchFunc MatchFunc = func(data []byte) bool {
|
||||
if len(data) < int(YamuxNeedBytesNum) {
|
||||
return false
|
||||
}
|
||||
|
||||
if data[0] == 0 && data[1] >= 0x0 && data[1] <= 0x3 {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
@ -55,14 +55,17 @@ func readHandshake(rd io.Reader) (host string, err error) {
|
||||
data := pool.GetBuf(1024)
|
||||
origin := data
|
||||
defer pool.PutBuf(origin)
|
||||
length, err := rd.Read(data)
|
||||
|
||||
_, err = io.ReadFull(rd, data[:47])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
length, err := rd.Read(data[47:])
|
||||
if err != nil {
|
||||
return
|
||||
} else {
|
||||
if length < 47 {
|
||||
err = fmt.Errorf("readHandshake: proto length[%d] is too short", length)
|
||||
return
|
||||
}
|
||||
length += 47
|
||||
}
|
||||
data = data[:length]
|
||||
if uint8(data[5]) != typeClientHello {
|
||||
|
Loading…
Reference in New Issue
Block a user