diff --git a/pkg/config/proxy.go b/pkg/config/proxy.go index 19add6e..de4953c 100644 --- a/pkg/config/proxy.go +++ b/pkg/config/proxy.go @@ -187,6 +187,8 @@ type TCPProxyConf struct { type TCPMuxProxyConf struct { BaseProxyConf `ini:",extends"` DomainConf `ini:",extends"` + HTTPUser string `ini:"http_user" json:"http_user,omitempty"` + HTTPPwd string `ini:"http_pwd" json:"http_pwd,omitempty"` RouteByHTTPUser string `ini:"route_by_http_user" json:"route_by_http_user"` Multiplexer string `ini:"multiplexer"` @@ -607,7 +609,10 @@ func (cfg *TCPMuxProxyConf) Compare(cmp ProxyConf) bool { return false } - if cfg.Multiplexer != cmpConf.Multiplexer || cfg.RouteByHTTPUser != cmpConf.RouteByHTTPUser { + if cfg.Multiplexer != cmpConf.Multiplexer || + cfg.HTTPUser != cmpConf.HTTPUser || + cfg.HTTPPwd != cmpConf.HTTPPwd || + cfg.RouteByHTTPUser != cmpConf.RouteByHTTPUser { return false } @@ -632,6 +637,8 @@ func (cfg *TCPMuxProxyConf) UnmarshalFromMsg(pMsg *msg.NewProxy) { cfg.CustomDomains = pMsg.CustomDomains cfg.SubDomain = pMsg.SubDomain cfg.Multiplexer = pMsg.Multiplexer + cfg.HTTPUser = pMsg.HTTPUser + cfg.HTTPPwd = pMsg.HTTPPwd cfg.RouteByHTTPUser = pMsg.RouteByHTTPUser } @@ -642,6 +649,8 @@ func (cfg *TCPMuxProxyConf) MarshalToMsg(pMsg *msg.NewProxy) { pMsg.CustomDomains = cfg.CustomDomains pMsg.SubDomain = cfg.SubDomain pMsg.Multiplexer = cfg.Multiplexer + pMsg.HTTPUser = cfg.HTTPUser + pMsg.HTTPPwd = cfg.HTTPPwd pMsg.RouteByHTTPUser = cfg.RouteByHTTPUser } diff --git a/pkg/util/tcpmux/httpconnect.go b/pkg/util/tcpmux/httpconnect.go index e94ff4b..970897a 100644 --- a/pkg/util/tcpmux/httpconnect.go +++ b/pkg/util/tcpmux/httpconnect.go @@ -31,18 +31,21 @@ import ( type HTTPConnectTCPMuxer struct { *vhost.Muxer - passthrough bool - authRequired bool // Not supported until we really need this. + // If passthrough is set to true, the CONNECT request will be forwarded to the backend service. + // Otherwise, it will return an OK response to the client and forward the remaining content to the backend service. + passthrough bool } func NewHTTPConnectTCPMuxer(listener net.Listener, passthrough bool, timeout time.Duration) (*HTTPConnectTCPMuxer, error) { - ret := &HTTPConnectTCPMuxer{passthrough: passthrough, authRequired: false} - mux, err := vhost.NewMuxer(listener, ret.getHostFromHTTPConnect, nil, ret.sendConnectResponse, nil, timeout) + ret := &HTTPConnectTCPMuxer{passthrough: passthrough} + mux, err := vhost.NewMuxer(listener, ret.getHostFromHTTPConnect, timeout) + mux.SetCheckAuthFunc(ret.auth). + SetSuccessHookFunc(ret.sendConnectResponse) ret.Muxer = mux return ret, err } -func (muxer *HTTPConnectTCPMuxer) readHTTPConnectRequest(rd io.Reader) (host string, httpUser string, err error) { +func (muxer *HTTPConnectTCPMuxer) readHTTPConnectRequest(rd io.Reader) (host, httpUser, httpPwd string, err error) { bufioReader := bufio.NewReader(rd) req, err := http.ReadRequest(bufioReader) @@ -58,7 +61,7 @@ func (muxer *HTTPConnectTCPMuxer) readHTTPConnectRequest(rd io.Reader) (host str host, _ = util.CanonicalHost(req.Host) proxyAuth := req.Header.Get("Proxy-Authorization") if proxyAuth != "" { - httpUser, _, _ = util.ParseBasicAuth(proxyAuth) + httpUser, httpPwd, _ = util.ParseBasicAuth(proxyAuth) } return } @@ -74,11 +77,26 @@ func (muxer *HTTPConnectTCPMuxer) sendConnectResponse(c net.Conn, reqInfo map[st return res.Write(c) } +func (muxer *HTTPConnectTCPMuxer) auth(c net.Conn, username, password string, reqInfo map[string]string) (bool, error) { + reqUsername := reqInfo["HTTPUser"] + reqPassword := reqInfo["HTTPPwd"] + if username == reqUsername && password == reqPassword { + return true, nil + } + + resp := util.ProxyUnauthorizedResponse() + if resp.Body != nil { + defer resp.Body.Close() + } + _ = resp.Write(c) + return false, nil +} + func (muxer *HTTPConnectTCPMuxer) getHostFromHTTPConnect(c net.Conn) (net.Conn, map[string]string, error) { reqInfoMap := make(map[string]string, 0) sc, rd := gnet.NewSharedConn(c) - host, httpUser, err := muxer.readHTTPConnectRequest(rd) + host, httpUser, httpPwd, err := muxer.readHTTPConnectRequest(rd) if err != nil { return nil, reqInfoMap, err } @@ -86,18 +104,11 @@ func (muxer *HTTPConnectTCPMuxer) getHostFromHTTPConnect(c net.Conn) (net.Conn, reqInfoMap["Host"] = host reqInfoMap["Scheme"] = "tcp" reqInfoMap["HTTPUser"] = httpUser + reqInfoMap["HTTPPwd"] = httpPwd outConn := c if muxer.passthrough { outConn = sc - if muxer.authRequired && httpUser == "" { - resp := util.ProxyUnauthorizedResponse() - if resp.Body != nil { - defer resp.Body.Close() - } - _ = resp.Write(c) - outConn = c - } } return outConn, reqInfoMap, nil } diff --git a/pkg/util/vhost/https.go b/pkg/util/vhost/https.go index dd20739..2e94b4a 100644 --- a/pkg/util/vhost/https.go +++ b/pkg/util/vhost/https.go @@ -28,7 +28,10 @@ type HTTPSMuxer struct { } func NewHTTPSMuxer(listener net.Listener, timeout time.Duration) (*HTTPSMuxer, error) { - mux, err := NewMuxer(listener, GetHTTPSHostname, nil, nil, nil, timeout) + mux, err := NewMuxer(listener, GetHTTPSHostname, timeout) + if err != nil { + return nil, err + } return &HTTPSMuxer{mux}, err } diff --git a/pkg/util/vhost/resource.go b/pkg/util/vhost/resource.go index 65bdbcb..e09edf2 100644 --- a/pkg/util/vhost/resource.go +++ b/pkg/util/vhost/resource.go @@ -85,17 +85,3 @@ func notFoundResponse() *http.Response { } return res } - -func noAuthResponse() *http.Response { - header := make(map[string][]string) - header["WWW-Authenticate"] = []string{`Basic realm="Restricted"`} - res := &http.Response{ - Status: "401 Not authorized", - StatusCode: 401, - Proto: "HTTP/1.1", - ProtoMajor: 1, - ProtoMinor: 1, - Header: header, - } - return res -} diff --git a/pkg/util/vhost/vhost.go b/pkg/util/vhost/vhost.go index 651d960..acb9cb1 100644 --- a/pkg/util/vhost/vhost.go +++ b/pkg/util/vhost/vhost.go @@ -43,43 +43,55 @@ type RequestRouteInfo struct { type ( muxFunc func(net.Conn) (net.Conn, map[string]string, error) - httpAuthFunc func(net.Conn, string, string, string) (bool, error) + authFunc func(conn net.Conn, username, password string, reqInfoMap map[string]string) (bool, error) hostRewriteFunc func(net.Conn, string) (net.Conn, error) - successFunc func(net.Conn, map[string]string) error + successHookFunc func(net.Conn, map[string]string) error ) -// Muxer is only used for https and tcpmux proxy. +// Muxer is a functional component used for https and tcpmux proxies. +// It accepts connections and extracts vhost information from the beginning of the connection data. +// It then routes the connection to its appropriate listener. type Muxer struct { - listener net.Listener - timeout time.Duration + listener net.Listener + timeout time.Duration + vhostFunc muxFunc - authFunc httpAuthFunc - successFunc successFunc - rewriteFunc hostRewriteFunc + checkAuth authFunc + successHook successHookFunc + rewriteHost hostRewriteFunc registryRouter *Routers } func NewMuxer( listener net.Listener, vhostFunc muxFunc, - authFunc httpAuthFunc, - successFunc successFunc, - rewriteFunc hostRewriteFunc, timeout time.Duration, ) (mux *Muxer, err error) { mux = &Muxer{ listener: listener, timeout: timeout, vhostFunc: vhostFunc, - authFunc: authFunc, - successFunc: successFunc, - rewriteFunc: rewriteFunc, registryRouter: NewRouters(), } go mux.run() return mux, nil } +func (v *Muxer) SetCheckAuthFunc(f authFunc) *Muxer { + v.checkAuth = f + return v +} + +func (v *Muxer) SetSuccessHookFunc(f successHookFunc) *Muxer { + v.successHook = f + return v +} + +func (v *Muxer) SetRewriteHostFunc(f hostRewriteFunc) *Muxer { + v.rewriteHost = f + return v +} + type ChooseEndpointFunc func() (string, error) type CreateConnFunc func(remoteAddr string) (net.Conn, error) @@ -101,7 +113,7 @@ type RouteConfig struct { CreateConnByEndpointFn CreateConnByEndpointFunc } -// listen for a new domain name, if rewriteHost is not empty and rewriteFunc is not nil +// listen for a new domain name, if rewriteHost is not empty and rewriteHost func is not nil, // then rewrite the host header to rewriteHost func (v *Muxer) Listen(ctx context.Context, cfg *RouteConfig) (l *Listener, err error) { l = &Listener{ @@ -109,8 +121,8 @@ func (v *Muxer) Listen(ctx context.Context, cfg *RouteConfig) (l *Listener, err location: cfg.Location, routeByHTTPUser: cfg.RouteByHTTPUser, rewriteHost: cfg.RewriteHost, - userName: cfg.Username, - passWord: cfg.Password, + username: cfg.Username, + password: cfg.Password, mux: v, accept: make(chan net.Conn), ctx: ctx, @@ -205,25 +217,20 @@ func (v *Muxer) handle(c net.Conn) { } xl := xlog.FromContextSafe(l.ctx) - if v.successFunc != nil { - if err := v.successFunc(c, reqInfoMap); err != nil { + if v.successHook != nil { + if err := v.successHook(c, reqInfoMap); err != nil { xl.Info("success func failure on vhost connection: %v", err) _ = c.Close() return } } - // if authFunc is exist and username/password is set + // if checkAuth func is exist and username/password is set // then verify user access - if l.mux.authFunc != nil && l.userName != "" && l.passWord != "" { - bAccess, err := l.mux.authFunc(c, l.userName, l.passWord, reqInfoMap["Authorization"]) - if !bAccess || err != nil { - xl.Debug("check http Authorization failed") - res := noAuthResponse() - if res.Body != nil { - defer res.Body.Close() - } - _ = res.Write(c) + if l.mux.checkAuth != nil && l.username != "" { + ok, err := l.mux.checkAuth(c, l.username, l.password, reqInfoMap) + if !ok || err != nil { + xl.Debug("auth failed for user: %s", l.username) _ = c.Close() return } @@ -249,8 +256,8 @@ type Listener struct { location string routeByHTTPUser string rewriteHost string - userName string - passWord string + username string + password string mux *Muxer // for closing Muxer accept chan net.Conn ctx context.Context @@ -263,11 +270,11 @@ func (l *Listener) Accept() (net.Conn, error) { return nil, fmt.Errorf("Listener closed") } - // if rewriteFunc is exist + // if rewriteHost func is exist // rewrite http requests with a modified host header // if l.rewriteHost is empty, nothing to do - if l.mux.rewriteFunc != nil { - sConn, err := l.mux.rewriteFunc(conn, l.rewriteHost) + if l.mux.rewriteHost != nil { + sConn, err := l.mux.rewriteHost(conn, l.rewriteHost) if err != nil { xl.Warn("host header rewrite failed: %v", err) return nil, fmt.Errorf("host header rewrite failed") diff --git a/server/group/tcpmux.go b/server/group/tcpmux.go index 2da85ce..0d9790e 100644 --- a/server/group/tcpmux.go +++ b/server/group/tcpmux.go @@ -81,6 +81,8 @@ type TCPMuxGroup struct { groupKey string domain string routeByHTTPUser string + username string + password string acceptCh chan net.Conn tcpMuxLn net.Listener @@ -120,6 +122,8 @@ func (tmg *TCPMuxGroup) HTTPConnectListen( tmg.groupKey = groupKey tmg.domain = routeConfig.Domain tmg.routeByHTTPUser = routeConfig.RouteByHTTPUser + tmg.username = routeConfig.Username + tmg.password = routeConfig.Password tmg.tcpMuxLn = tcpMuxLn tmg.lns = append(tmg.lns, ln) if tmg.acceptCh == nil { @@ -128,7 +132,10 @@ func (tmg *TCPMuxGroup) HTTPConnectListen( go tmg.worker() } else { // route config in the same group must be equal - if tmg.group != group || tmg.domain != routeConfig.Domain || tmg.routeByHTTPUser != routeConfig.RouteByHTTPUser { + if tmg.group != group || tmg.domain != routeConfig.Domain || + tmg.routeByHTTPUser != routeConfig.RouteByHTTPUser || + tmg.username != routeConfig.Username || + tmg.password != routeConfig.Password { return nil, ErrGroupParamsInvalid } if tmg.groupKey != groupKey { diff --git a/server/proxy/tcpmux.go b/server/proxy/tcpmux.go index 4b413c3..23e833b 100644 --- a/server/proxy/tcpmux.go +++ b/server/proxy/tcpmux.go @@ -32,12 +32,16 @@ type TCPMuxProxy struct { cfg *config.TCPMuxProxyConf } -func (pxy *TCPMuxProxy) httpConnectListen(domain, routeByHTTPUser string, addrs []string) ([]string, error) { +func (pxy *TCPMuxProxy) httpConnectListen( + domain, routeByHTTPUser, httpUser, httpPwd string, addrs []string) ([]string, error, +) { var l net.Listener var err error routeConfig := &vhost.RouteConfig{ Domain: domain, RouteByHTTPUser: routeByHTTPUser, + Username: httpUser, + Password: httpPwd, } if pxy.cfg.Group != "" { l, err = pxy.rc.TCPMuxGroupCtl.Listen(pxy.ctx, pxy.cfg.Multiplexer, pxy.cfg.Group, pxy.cfg.GroupKey, *routeConfig) @@ -60,14 +64,15 @@ func (pxy *TCPMuxProxy) httpConnectRun() (remoteAddr string, err error) { continue } - addrs, err = pxy.httpConnectListen(domain, pxy.cfg.RouteByHTTPUser, addrs) + addrs, err = pxy.httpConnectListen(domain, pxy.cfg.RouteByHTTPUser, pxy.cfg.HTTPUser, pxy.cfg.HTTPPwd, addrs) if err != nil { return "", err } } if pxy.cfg.SubDomain != "" { - addrs, err = pxy.httpConnectListen(pxy.cfg.SubDomain+"."+pxy.serverCfg.SubDomainHost, pxy.cfg.RouteByHTTPUser, addrs) + addrs, err = pxy.httpConnectListen(pxy.cfg.SubDomain+"."+pxy.serverCfg.SubDomainHost, + pxy.cfg.RouteByHTTPUser, pxy.cfg.HTTPUser, pxy.cfg.HTTPPwd, addrs) if err != nil { return "", err } diff --git a/test/e2e/basic/tcpmux.go b/test/e2e/basic/tcpmux.go new file mode 100644 index 0000000..a1106b9 --- /dev/null +++ b/test/e2e/basic/tcpmux.go @@ -0,0 +1,218 @@ +package basic + +import ( + "bufio" + "fmt" + "net" + "net/http" + + "github.com/onsi/ginkgo/v2" + + "github.com/fatedier/frp/pkg/util/util" + "github.com/fatedier/frp/test/e2e/framework" + "github.com/fatedier/frp/test/e2e/framework/consts" + "github.com/fatedier/frp/test/e2e/mock/server/streamserver" + "github.com/fatedier/frp/test/e2e/pkg/request" + "github.com/fatedier/frp/test/e2e/pkg/rpc" +) + +var _ = ginkgo.Describe("[Feature: TCPMUX httpconnect]", func() { + f := framework.NewDefaultFramework() + + getDefaultServerConf := func(httpconnectPort int) string { + conf := consts.DefaultServerConfig + ` + tcpmux_httpconnect_port = %d + ` + return fmt.Sprintf(conf, httpconnectPort) + } + newServer := func(port int, respContent string) *streamserver.Server { + return streamserver.New( + streamserver.TCP, + streamserver.WithBindPort(port), + streamserver.WithRespContent([]byte(respContent)), + ) + } + + proxyURLWithAuth := func(username, password string, port int) string { + if username == "" { + return fmt.Sprintf("http://127.0.0.1:%d", port) + } + return fmt.Sprintf("http://%s:%s@127.0.0.1:%d", username, password, port) + } + + ginkgo.It("Route by HTTP user", func() { + vhostPort := f.AllocPort() + serverConf := getDefaultServerConf(vhostPort) + + fooPort := f.AllocPort() + f.RunServer("", newServer(fooPort, "foo")) + + barPort := f.AllocPort() + f.RunServer("", newServer(barPort, "bar")) + + otherPort := f.AllocPort() + f.RunServer("", newServer(otherPort, "other")) + + clientConf := consts.DefaultClientConfig + clientConf += fmt.Sprintf(` + [foo] + type = tcpmux + multiplexer = httpconnect + local_port = %d + custom_domains = normal.example.com + route_by_http_user = user1 + + [bar] + type = tcpmux + multiplexer = httpconnect + local_port = %d + custom_domains = normal.example.com + route_by_http_user = user2 + + [catchAll] + type = tcpmux + multiplexer = httpconnect + local_port = %d + custom_domains = normal.example.com + `, fooPort, barPort, otherPort) + + f.RunProcesses([]string{serverConf}, []string{clientConf}) + + // user1 + framework.NewRequestExpect(f).Explain("user1"). + RequestModify(func(r *request.Request) { + r.Addr("normal.example.com").Proxy(proxyURLWithAuth("user1", "", vhostPort)) + }). + ExpectResp([]byte("foo")). + Ensure() + + // user2 + framework.NewRequestExpect(f).Explain("user2"). + RequestModify(func(r *request.Request) { + r.Addr("normal.example.com").Proxy(proxyURLWithAuth("user2", "", vhostPort)) + }). + ExpectResp([]byte("bar")). + Ensure() + + // other user + framework.NewRequestExpect(f).Explain("other user"). + RequestModify(func(r *request.Request) { + r.Addr("normal.example.com").Proxy(proxyURLWithAuth("user3", "", vhostPort)) + }). + ExpectResp([]byte("other")). + Ensure() + }) + + ginkgo.It("Proxy auth", func() { + vhostPort := f.AllocPort() + serverConf := getDefaultServerConf(vhostPort) + + fooPort := f.AllocPort() + f.RunServer("", newServer(fooPort, "foo")) + + clientConf := consts.DefaultClientConfig + clientConf += fmt.Sprintf(` + [test] + type = tcpmux + multiplexer = httpconnect + local_port = %d + custom_domains = normal.example.com + http_user = test + http_pwd = test + `, fooPort) + + f.RunProcesses([]string{serverConf}, []string{clientConf}) + + // not set auth header + framework.NewRequestExpect(f).Explain("no auth"). + RequestModify(func(r *request.Request) { + r.Addr("normal.example.com").Proxy(proxyURLWithAuth("", "", vhostPort)) + }). + ExpectError(true). + Ensure() + + // set incorrect auth header + framework.NewRequestExpect(f).Explain("incorrect auth"). + RequestModify(func(r *request.Request) { + r.Addr("normal.example.com").Proxy(proxyURLWithAuth("test", "invalid", vhostPort)) + }). + ExpectError(true). + Ensure() + + // set correct auth header + framework.NewRequestExpect(f).Explain("correct auth"). + RequestModify(func(r *request.Request) { + r.Addr("normal.example.com").Proxy(proxyURLWithAuth("test", "test", vhostPort)) + }). + ExpectResp([]byte("foo")). + Ensure() + }) + + ginkgo.It("TCPMux Passthrough", func() { + vhostPort := f.AllocPort() + serverConf := getDefaultServerConf(vhostPort) + serverConf += ` + tcpmux_passthrough = true + ` + + var ( + respErr error + connectRequestHost string + ) + newServer := func(port int) *streamserver.Server { + return streamserver.New( + streamserver.TCP, + streamserver.WithBindPort(port), + streamserver.WithCustomHandler(func(conn net.Conn) { + defer conn.Close() + + // read HTTP CONNECT request + bufioReader := bufio.NewReader(conn) + req, err := http.ReadRequest(bufioReader) + if err != nil { + respErr = err + return + } + connectRequestHost = req.Host + + // return ok response + res := util.OkResponse() + if res.Body != nil { + defer res.Body.Close() + } + _ = res.Write(conn) + + buf, err := rpc.ReadBytes(conn) + if err != nil { + respErr = err + return + } + _, _ = rpc.WriteBytes(conn, buf) + }), + ) + } + + localPort := f.AllocPort() + f.RunServer("", newServer(localPort)) + + clientConf := consts.DefaultClientConfig + clientConf += fmt.Sprintf(` + [test] + type = tcpmux + multiplexer = httpconnect + local_port = %d + custom_domains = normal.example.com + `, localPort) + + f.RunProcesses([]string{serverConf}, []string{clientConf}) + + framework.NewRequestExpect(f). + RequestModify(func(r *request.Request) { + r.Addr("normal.example.com").Proxy(proxyURLWithAuth("", "", vhostPort)).Body([]byte("frp")) + }). + ExpectResp([]byte("frp")). + Ensure() + framework.ExpectNoError(respErr) + framework.ExpectEqualValues(connectRequestHost, "normal.example.com") + }) +}) diff --git a/test/e2e/framework/process.go b/test/e2e/framework/process.go index 814ce04..6c0eeea 100644 --- a/test/e2e/framework/process.go +++ b/test/e2e/framework/process.go @@ -56,7 +56,7 @@ func (f *Framework) RunProcesses(serverTemplates []string, clientTemplates []str ExpectNoError(err) time.Sleep(500 * time.Millisecond) } - time.Sleep(2 * time.Second) + time.Sleep(5 * time.Second) return currentServerProcesses, currentClientProcesses } diff --git a/test/e2e/pkg/request/request.go b/test/e2e/pkg/request/request.go index 96e714f..44bc0d0 100644 --- a/test/e2e/pkg/request/request.go +++ b/test/e2e/pkg/request/request.go @@ -145,7 +145,10 @@ func (r *Request) Do() (*Response, error) { err error ) - addr := net.JoinHostPort(r.addr, strconv.Itoa(r.port)) + addr := r.addr + if r.port > 0 { + addr = net.JoinHostPort(r.addr, strconv.Itoa(r.port)) + } // for protocol http and https if r.protocol == "http" || r.protocol == "https" { return r.sendHTTPRequest(r.method, fmt.Sprintf("%s://%s%s", r.protocol, addr, r.path), diff --git a/test/e2e/pkg/rpc/rpc.go b/test/e2e/pkg/rpc/rpc.go index d602644..48b240d 100644 --- a/test/e2e/pkg/rpc/rpc.go +++ b/test/e2e/pkg/rpc/rpc.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/binary" "errors" + "fmt" "io" ) @@ -22,6 +23,9 @@ func ReadBytes(r io.Reader) ([]byte, error) { if err := binary.Read(r, binary.BigEndian, &length); err != nil { return nil, err } + if length < 0 || length > 10*1024*1024 { + return nil, fmt.Errorf("invalid length") + } buffer := make([]byte, length) n, err := io.ReadFull(r, buffer) if err != nil {