tcpmux: support authentication (#3345)

This commit is contained in:
fatedier 2023-03-07 19:53:32 +08:00 committed by GitHub
parent 54eb704650
commit 862b1642ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 324 additions and 71 deletions

View File

@ -187,6 +187,8 @@ type TCPProxyConf struct {
type TCPMuxProxyConf struct {
BaseProxyConf `ini:",extends"`
DomainConf `ini:",extends"`
HTTPUser string `ini:"http_user" json:"http_user,omitempty"`
HTTPPwd string `ini:"http_pwd" json:"http_pwd,omitempty"`
RouteByHTTPUser string `ini:"route_by_http_user" json:"route_by_http_user"`
Multiplexer string `ini:"multiplexer"`
@ -607,7 +609,10 @@ func (cfg *TCPMuxProxyConf) Compare(cmp ProxyConf) bool {
return false
}
if cfg.Multiplexer != cmpConf.Multiplexer || cfg.RouteByHTTPUser != cmpConf.RouteByHTTPUser {
if cfg.Multiplexer != cmpConf.Multiplexer ||
cfg.HTTPUser != cmpConf.HTTPUser ||
cfg.HTTPPwd != cmpConf.HTTPPwd ||
cfg.RouteByHTTPUser != cmpConf.RouteByHTTPUser {
return false
}
@ -632,6 +637,8 @@ func (cfg *TCPMuxProxyConf) UnmarshalFromMsg(pMsg *msg.NewProxy) {
cfg.CustomDomains = pMsg.CustomDomains
cfg.SubDomain = pMsg.SubDomain
cfg.Multiplexer = pMsg.Multiplexer
cfg.HTTPUser = pMsg.HTTPUser
cfg.HTTPPwd = pMsg.HTTPPwd
cfg.RouteByHTTPUser = pMsg.RouteByHTTPUser
}
@ -642,6 +649,8 @@ func (cfg *TCPMuxProxyConf) MarshalToMsg(pMsg *msg.NewProxy) {
pMsg.CustomDomains = cfg.CustomDomains
pMsg.SubDomain = cfg.SubDomain
pMsg.Multiplexer = cfg.Multiplexer
pMsg.HTTPUser = cfg.HTTPUser
pMsg.HTTPPwd = cfg.HTTPPwd
pMsg.RouteByHTTPUser = cfg.RouteByHTTPUser
}

View File

@ -31,18 +31,21 @@ import (
type HTTPConnectTCPMuxer struct {
*vhost.Muxer
// If passthrough is set to true, the CONNECT request will be forwarded to the backend service.
// Otherwise, it will return an OK response to the client and forward the remaining content to the backend service.
passthrough bool
authRequired bool // Not supported until we really need this.
}
func NewHTTPConnectTCPMuxer(listener net.Listener, passthrough bool, timeout time.Duration) (*HTTPConnectTCPMuxer, error) {
ret := &HTTPConnectTCPMuxer{passthrough: passthrough, authRequired: false}
mux, err := vhost.NewMuxer(listener, ret.getHostFromHTTPConnect, nil, ret.sendConnectResponse, nil, timeout)
ret := &HTTPConnectTCPMuxer{passthrough: passthrough}
mux, err := vhost.NewMuxer(listener, ret.getHostFromHTTPConnect, timeout)
mux.SetCheckAuthFunc(ret.auth).
SetSuccessHookFunc(ret.sendConnectResponse)
ret.Muxer = mux
return ret, err
}
func (muxer *HTTPConnectTCPMuxer) readHTTPConnectRequest(rd io.Reader) (host string, httpUser string, err error) {
func (muxer *HTTPConnectTCPMuxer) readHTTPConnectRequest(rd io.Reader) (host, httpUser, httpPwd string, err error) {
bufioReader := bufio.NewReader(rd)
req, err := http.ReadRequest(bufioReader)
@ -58,7 +61,7 @@ func (muxer *HTTPConnectTCPMuxer) readHTTPConnectRequest(rd io.Reader) (host str
host, _ = util.CanonicalHost(req.Host)
proxyAuth := req.Header.Get("Proxy-Authorization")
if proxyAuth != "" {
httpUser, _, _ = util.ParseBasicAuth(proxyAuth)
httpUser, httpPwd, _ = util.ParseBasicAuth(proxyAuth)
}
return
}
@ -74,11 +77,26 @@ func (muxer *HTTPConnectTCPMuxer) sendConnectResponse(c net.Conn, reqInfo map[st
return res.Write(c)
}
func (muxer *HTTPConnectTCPMuxer) auth(c net.Conn, username, password string, reqInfo map[string]string) (bool, error) {
reqUsername := reqInfo["HTTPUser"]
reqPassword := reqInfo["HTTPPwd"]
if username == reqUsername && password == reqPassword {
return true, nil
}
resp := util.ProxyUnauthorizedResponse()
if resp.Body != nil {
defer resp.Body.Close()
}
_ = resp.Write(c)
return false, nil
}
func (muxer *HTTPConnectTCPMuxer) getHostFromHTTPConnect(c net.Conn) (net.Conn, map[string]string, error) {
reqInfoMap := make(map[string]string, 0)
sc, rd := gnet.NewSharedConn(c)
host, httpUser, err := muxer.readHTTPConnectRequest(rd)
host, httpUser, httpPwd, err := muxer.readHTTPConnectRequest(rd)
if err != nil {
return nil, reqInfoMap, err
}
@ -86,18 +104,11 @@ func (muxer *HTTPConnectTCPMuxer) getHostFromHTTPConnect(c net.Conn) (net.Conn,
reqInfoMap["Host"] = host
reqInfoMap["Scheme"] = "tcp"
reqInfoMap["HTTPUser"] = httpUser
reqInfoMap["HTTPPwd"] = httpPwd
outConn := c
if muxer.passthrough {
outConn = sc
if muxer.authRequired && httpUser == "" {
resp := util.ProxyUnauthorizedResponse()
if resp.Body != nil {
defer resp.Body.Close()
}
_ = resp.Write(c)
outConn = c
}
}
return outConn, reqInfoMap, nil
}

View File

@ -28,7 +28,10 @@ type HTTPSMuxer struct {
}
func NewHTTPSMuxer(listener net.Listener, timeout time.Duration) (*HTTPSMuxer, error) {
mux, err := NewMuxer(listener, GetHTTPSHostname, nil, nil, nil, timeout)
mux, err := NewMuxer(listener, GetHTTPSHostname, timeout)
if err != nil {
return nil, err
}
return &HTTPSMuxer{mux}, err
}

View File

@ -85,17 +85,3 @@ func notFoundResponse() *http.Response {
}
return res
}
func noAuthResponse() *http.Response {
header := make(map[string][]string)
header["WWW-Authenticate"] = []string{`Basic realm="Restricted"`}
res := &http.Response{
Status: "401 Not authorized",
StatusCode: 401,
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Header: header,
}
return res
}

View File

@ -43,43 +43,55 @@ type RequestRouteInfo struct {
type (
muxFunc func(net.Conn) (net.Conn, map[string]string, error)
httpAuthFunc func(net.Conn, string, string, string) (bool, error)
authFunc func(conn net.Conn, username, password string, reqInfoMap map[string]string) (bool, error)
hostRewriteFunc func(net.Conn, string) (net.Conn, error)
successFunc func(net.Conn, map[string]string) error
successHookFunc func(net.Conn, map[string]string) error
)
// Muxer is only used for https and tcpmux proxy.
// Muxer is a functional component used for https and tcpmux proxies.
// It accepts connections and extracts vhost information from the beginning of the connection data.
// It then routes the connection to its appropriate listener.
type Muxer struct {
listener net.Listener
timeout time.Duration
vhostFunc muxFunc
authFunc httpAuthFunc
successFunc successFunc
rewriteFunc hostRewriteFunc
checkAuth authFunc
successHook successHookFunc
rewriteHost hostRewriteFunc
registryRouter *Routers
}
func NewMuxer(
listener net.Listener,
vhostFunc muxFunc,
authFunc httpAuthFunc,
successFunc successFunc,
rewriteFunc hostRewriteFunc,
timeout time.Duration,
) (mux *Muxer, err error) {
mux = &Muxer{
listener: listener,
timeout: timeout,
vhostFunc: vhostFunc,
authFunc: authFunc,
successFunc: successFunc,
rewriteFunc: rewriteFunc,
registryRouter: NewRouters(),
}
go mux.run()
return mux, nil
}
func (v *Muxer) SetCheckAuthFunc(f authFunc) *Muxer {
v.checkAuth = f
return v
}
func (v *Muxer) SetSuccessHookFunc(f successHookFunc) *Muxer {
v.successHook = f
return v
}
func (v *Muxer) SetRewriteHostFunc(f hostRewriteFunc) *Muxer {
v.rewriteHost = f
return v
}
type ChooseEndpointFunc func() (string, error)
type CreateConnFunc func(remoteAddr string) (net.Conn, error)
@ -101,7 +113,7 @@ type RouteConfig struct {
CreateConnByEndpointFn CreateConnByEndpointFunc
}
// listen for a new domain name, if rewriteHost is not empty and rewriteFunc is not nil
// listen for a new domain name, if rewriteHost is not empty and rewriteHost func is not nil,
// then rewrite the host header to rewriteHost
func (v *Muxer) Listen(ctx context.Context, cfg *RouteConfig) (l *Listener, err error) {
l = &Listener{
@ -109,8 +121,8 @@ func (v *Muxer) Listen(ctx context.Context, cfg *RouteConfig) (l *Listener, err
location: cfg.Location,
routeByHTTPUser: cfg.RouteByHTTPUser,
rewriteHost: cfg.RewriteHost,
userName: cfg.Username,
passWord: cfg.Password,
username: cfg.Username,
password: cfg.Password,
mux: v,
accept: make(chan net.Conn),
ctx: ctx,
@ -205,25 +217,20 @@ func (v *Muxer) handle(c net.Conn) {
}
xl := xlog.FromContextSafe(l.ctx)
if v.successFunc != nil {
if err := v.successFunc(c, reqInfoMap); err != nil {
if v.successHook != nil {
if err := v.successHook(c, reqInfoMap); err != nil {
xl.Info("success func failure on vhost connection: %v", err)
_ = c.Close()
return
}
}
// if authFunc is exist and username/password is set
// if checkAuth func is exist and username/password is set
// then verify user access
if l.mux.authFunc != nil && l.userName != "" && l.passWord != "" {
bAccess, err := l.mux.authFunc(c, l.userName, l.passWord, reqInfoMap["Authorization"])
if !bAccess || err != nil {
xl.Debug("check http Authorization failed")
res := noAuthResponse()
if res.Body != nil {
defer res.Body.Close()
}
_ = res.Write(c)
if l.mux.checkAuth != nil && l.username != "" {
ok, err := l.mux.checkAuth(c, l.username, l.password, reqInfoMap)
if !ok || err != nil {
xl.Debug("auth failed for user: %s", l.username)
_ = c.Close()
return
}
@ -249,8 +256,8 @@ type Listener struct {
location string
routeByHTTPUser string
rewriteHost string
userName string
passWord string
username string
password string
mux *Muxer // for closing Muxer
accept chan net.Conn
ctx context.Context
@ -263,11 +270,11 @@ func (l *Listener) Accept() (net.Conn, error) {
return nil, fmt.Errorf("Listener closed")
}
// if rewriteFunc is exist
// if rewriteHost func is exist
// rewrite http requests with a modified host header
// if l.rewriteHost is empty, nothing to do
if l.mux.rewriteFunc != nil {
sConn, err := l.mux.rewriteFunc(conn, l.rewriteHost)
if l.mux.rewriteHost != nil {
sConn, err := l.mux.rewriteHost(conn, l.rewriteHost)
if err != nil {
xl.Warn("host header rewrite failed: %v", err)
return nil, fmt.Errorf("host header rewrite failed")

View File

@ -81,6 +81,8 @@ type TCPMuxGroup struct {
groupKey string
domain string
routeByHTTPUser string
username string
password string
acceptCh chan net.Conn
tcpMuxLn net.Listener
@ -120,6 +122,8 @@ func (tmg *TCPMuxGroup) HTTPConnectListen(
tmg.groupKey = groupKey
tmg.domain = routeConfig.Domain
tmg.routeByHTTPUser = routeConfig.RouteByHTTPUser
tmg.username = routeConfig.Username
tmg.password = routeConfig.Password
tmg.tcpMuxLn = tcpMuxLn
tmg.lns = append(tmg.lns, ln)
if tmg.acceptCh == nil {
@ -128,7 +132,10 @@ func (tmg *TCPMuxGroup) HTTPConnectListen(
go tmg.worker()
} else {
// route config in the same group must be equal
if tmg.group != group || tmg.domain != routeConfig.Domain || tmg.routeByHTTPUser != routeConfig.RouteByHTTPUser {
if tmg.group != group || tmg.domain != routeConfig.Domain ||
tmg.routeByHTTPUser != routeConfig.RouteByHTTPUser ||
tmg.username != routeConfig.Username ||
tmg.password != routeConfig.Password {
return nil, ErrGroupParamsInvalid
}
if tmg.groupKey != groupKey {

View File

@ -32,12 +32,16 @@ type TCPMuxProxy struct {
cfg *config.TCPMuxProxyConf
}
func (pxy *TCPMuxProxy) httpConnectListen(domain, routeByHTTPUser string, addrs []string) ([]string, error) {
func (pxy *TCPMuxProxy) httpConnectListen(
domain, routeByHTTPUser, httpUser, httpPwd string, addrs []string) ([]string, error,
) {
var l net.Listener
var err error
routeConfig := &vhost.RouteConfig{
Domain: domain,
RouteByHTTPUser: routeByHTTPUser,
Username: httpUser,
Password: httpPwd,
}
if pxy.cfg.Group != "" {
l, err = pxy.rc.TCPMuxGroupCtl.Listen(pxy.ctx, pxy.cfg.Multiplexer, pxy.cfg.Group, pxy.cfg.GroupKey, *routeConfig)
@ -60,14 +64,15 @@ func (pxy *TCPMuxProxy) httpConnectRun() (remoteAddr string, err error) {
continue
}
addrs, err = pxy.httpConnectListen(domain, pxy.cfg.RouteByHTTPUser, addrs)
addrs, err = pxy.httpConnectListen(domain, pxy.cfg.RouteByHTTPUser, pxy.cfg.HTTPUser, pxy.cfg.HTTPPwd, addrs)
if err != nil {
return "", err
}
}
if pxy.cfg.SubDomain != "" {
addrs, err = pxy.httpConnectListen(pxy.cfg.SubDomain+"."+pxy.serverCfg.SubDomainHost, pxy.cfg.RouteByHTTPUser, addrs)
addrs, err = pxy.httpConnectListen(pxy.cfg.SubDomain+"."+pxy.serverCfg.SubDomainHost,
pxy.cfg.RouteByHTTPUser, pxy.cfg.HTTPUser, pxy.cfg.HTTPPwd, addrs)
if err != nil {
return "", err
}

218
test/e2e/basic/tcpmux.go Normal file
View File

@ -0,0 +1,218 @@
package basic
import (
"bufio"
"fmt"
"net"
"net/http"
"github.com/onsi/ginkgo/v2"
"github.com/fatedier/frp/pkg/util/util"
"github.com/fatedier/frp/test/e2e/framework"
"github.com/fatedier/frp/test/e2e/framework/consts"
"github.com/fatedier/frp/test/e2e/mock/server/streamserver"
"github.com/fatedier/frp/test/e2e/pkg/request"
"github.com/fatedier/frp/test/e2e/pkg/rpc"
)
var _ = ginkgo.Describe("[Feature: TCPMUX httpconnect]", func() {
f := framework.NewDefaultFramework()
getDefaultServerConf := func(httpconnectPort int) string {
conf := consts.DefaultServerConfig + `
tcpmux_httpconnect_port = %d
`
return fmt.Sprintf(conf, httpconnectPort)
}
newServer := func(port int, respContent string) *streamserver.Server {
return streamserver.New(
streamserver.TCP,
streamserver.WithBindPort(port),
streamserver.WithRespContent([]byte(respContent)),
)
}
proxyURLWithAuth := func(username, password string, port int) string {
if username == "" {
return fmt.Sprintf("http://127.0.0.1:%d", port)
}
return fmt.Sprintf("http://%s:%s@127.0.0.1:%d", username, password, port)
}
ginkgo.It("Route by HTTP user", func() {
vhostPort := f.AllocPort()
serverConf := getDefaultServerConf(vhostPort)
fooPort := f.AllocPort()
f.RunServer("", newServer(fooPort, "foo"))
barPort := f.AllocPort()
f.RunServer("", newServer(barPort, "bar"))
otherPort := f.AllocPort()
f.RunServer("", newServer(otherPort, "other"))
clientConf := consts.DefaultClientConfig
clientConf += fmt.Sprintf(`
[foo]
type = tcpmux
multiplexer = httpconnect
local_port = %d
custom_domains = normal.example.com
route_by_http_user = user1
[bar]
type = tcpmux
multiplexer = httpconnect
local_port = %d
custom_domains = normal.example.com
route_by_http_user = user2
[catchAll]
type = tcpmux
multiplexer = httpconnect
local_port = %d
custom_domains = normal.example.com
`, fooPort, barPort, otherPort)
f.RunProcesses([]string{serverConf}, []string{clientConf})
// user1
framework.NewRequestExpect(f).Explain("user1").
RequestModify(func(r *request.Request) {
r.Addr("normal.example.com").Proxy(proxyURLWithAuth("user1", "", vhostPort))
}).
ExpectResp([]byte("foo")).
Ensure()
// user2
framework.NewRequestExpect(f).Explain("user2").
RequestModify(func(r *request.Request) {
r.Addr("normal.example.com").Proxy(proxyURLWithAuth("user2", "", vhostPort))
}).
ExpectResp([]byte("bar")).
Ensure()
// other user
framework.NewRequestExpect(f).Explain("other user").
RequestModify(func(r *request.Request) {
r.Addr("normal.example.com").Proxy(proxyURLWithAuth("user3", "", vhostPort))
}).
ExpectResp([]byte("other")).
Ensure()
})
ginkgo.It("Proxy auth", func() {
vhostPort := f.AllocPort()
serverConf := getDefaultServerConf(vhostPort)
fooPort := f.AllocPort()
f.RunServer("", newServer(fooPort, "foo"))
clientConf := consts.DefaultClientConfig
clientConf += fmt.Sprintf(`
[test]
type = tcpmux
multiplexer = httpconnect
local_port = %d
custom_domains = normal.example.com
http_user = test
http_pwd = test
`, fooPort)
f.RunProcesses([]string{serverConf}, []string{clientConf})
// not set auth header
framework.NewRequestExpect(f).Explain("no auth").
RequestModify(func(r *request.Request) {
r.Addr("normal.example.com").Proxy(proxyURLWithAuth("", "", vhostPort))
}).
ExpectError(true).
Ensure()
// set incorrect auth header
framework.NewRequestExpect(f).Explain("incorrect auth").
RequestModify(func(r *request.Request) {
r.Addr("normal.example.com").Proxy(proxyURLWithAuth("test", "invalid", vhostPort))
}).
ExpectError(true).
Ensure()
// set correct auth header
framework.NewRequestExpect(f).Explain("correct auth").
RequestModify(func(r *request.Request) {
r.Addr("normal.example.com").Proxy(proxyURLWithAuth("test", "test", vhostPort))
}).
ExpectResp([]byte("foo")).
Ensure()
})
ginkgo.It("TCPMux Passthrough", func() {
vhostPort := f.AllocPort()
serverConf := getDefaultServerConf(vhostPort)
serverConf += `
tcpmux_passthrough = true
`
var (
respErr error
connectRequestHost string
)
newServer := func(port int) *streamserver.Server {
return streamserver.New(
streamserver.TCP,
streamserver.WithBindPort(port),
streamserver.WithCustomHandler(func(conn net.Conn) {
defer conn.Close()
// read HTTP CONNECT request
bufioReader := bufio.NewReader(conn)
req, err := http.ReadRequest(bufioReader)
if err != nil {
respErr = err
return
}
connectRequestHost = req.Host
// return ok response
res := util.OkResponse()
if res.Body != nil {
defer res.Body.Close()
}
_ = res.Write(conn)
buf, err := rpc.ReadBytes(conn)
if err != nil {
respErr = err
return
}
_, _ = rpc.WriteBytes(conn, buf)
}),
)
}
localPort := f.AllocPort()
f.RunServer("", newServer(localPort))
clientConf := consts.DefaultClientConfig
clientConf += fmt.Sprintf(`
[test]
type = tcpmux
multiplexer = httpconnect
local_port = %d
custom_domains = normal.example.com
`, localPort)
f.RunProcesses([]string{serverConf}, []string{clientConf})
framework.NewRequestExpect(f).
RequestModify(func(r *request.Request) {
r.Addr("normal.example.com").Proxy(proxyURLWithAuth("", "", vhostPort)).Body([]byte("frp"))
}).
ExpectResp([]byte("frp")).
Ensure()
framework.ExpectNoError(respErr)
framework.ExpectEqualValues(connectRequestHost, "normal.example.com")
})
})

View File

@ -56,7 +56,7 @@ func (f *Framework) RunProcesses(serverTemplates []string, clientTemplates []str
ExpectNoError(err)
time.Sleep(500 * time.Millisecond)
}
time.Sleep(2 * time.Second)
time.Sleep(5 * time.Second)
return currentServerProcesses, currentClientProcesses
}

View File

@ -145,7 +145,10 @@ func (r *Request) Do() (*Response, error) {
err error
)
addr := net.JoinHostPort(r.addr, strconv.Itoa(r.port))
addr := r.addr
if r.port > 0 {
addr = net.JoinHostPort(r.addr, strconv.Itoa(r.port))
}
// for protocol http and https
if r.protocol == "http" || r.protocol == "https" {
return r.sendHTTPRequest(r.method, fmt.Sprintf("%s://%s%s", r.protocol, addr, r.path),

View File

@ -4,6 +4,7 @@ import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
)
@ -22,6 +23,9 @@ func ReadBytes(r io.Reader) ([]byte, error) {
if err := binary.Read(r, binary.BigEndian, &length); err != nil {
return nil, err
}
if length < 0 || length > 10*1024*1024 {
return nil, fmt.Errorf("invalid length")
}
buffer := make([]byte, length)
n, err := io.ReadFull(r, buffer)
if err != nil {