type http/tcpmux proxy support route_by_http_user, tcpmux support passthourgh mode (#2932)

This commit is contained in:
fatedier 2022-05-26 23:57:30 +08:00 committed by GitHub
parent bd89eaba2f
commit 4af85da0c2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 606 additions and 283 deletions

View File

@ -1,7 +1 @@
### New
* Added new parameter `config_dir` in frpc to run multiple client instances in one process.
### Fix
* Equal sign in environment variables causes parsing error.

View File

@ -216,6 +216,8 @@ subdomain = web01
custom_domains = web01.yourdomain.com custom_domains = web01.yourdomain.com
# locations is only available for http type # locations is only available for http type
locations = /,/pic locations = /,/pic
# route requests to this service if http basic auto user is abc
# route_by_http_user = abc
host_header_rewrite = example.com host_header_rewrite = example.com
# params with prefix "header_" will be used to update http request headers # params with prefix "header_" will be used to update http request headers
header_X-From-Where = frp header_X-From-Where = frp
@ -348,3 +350,4 @@ multiplexer = httpconnect
local_ip = 127.0.0.1 local_ip = 127.0.0.1
local_port = 10701 local_port = 10701
custom_domains = tunnel1 custom_domains = tunnel1
# route_by_http_user = user1

View File

@ -30,6 +30,9 @@ vhost_https_port = 443
# HTTP CONNECT requests. By default, this value is 0. # HTTP CONNECT requests. By default, this value is 0.
# tcpmux_httpconnect_port = 1337 # tcpmux_httpconnect_port = 1337
# If tcpmux_passthrough is true, frps won't do any update on traffic.
# tcpmux_passthrough = false
# set dashboard_addr and dashboard_port to view dashboard of frps # set dashboard_addr and dashboard_port to view dashboard of frps
# dashboard_addr's default value is same with bind_addr # dashboard_addr's default value is same with bind_addr
# dashboard is available only if dashboard_port is set # dashboard is available only if dashboard_port is set

View File

@ -162,6 +162,7 @@ type HTTPProxyConf struct {
HTTPPwd string `ini:"http_pwd" json:"http_pwd"` HTTPPwd string `ini:"http_pwd" json:"http_pwd"`
HostHeaderRewrite string `ini:"host_header_rewrite" json:"host_header_rewrite"` HostHeaderRewrite string `ini:"host_header_rewrite" json:"host_header_rewrite"`
Headers map[string]string `ini:"-" json:"headers"` Headers map[string]string `ini:"-" json:"headers"`
RouteByHTTPUser string `ini:"route_by_http_user" json:"route_by_http_user"`
} }
// HTTPS // HTTPS
@ -178,8 +179,9 @@ type TCPProxyConf struct {
// TCPMux // TCPMux
type TCPMuxProxyConf struct { type TCPMuxProxyConf struct {
BaseProxyConf `ini:",extends"` BaseProxyConf `ini:",extends"`
DomainConf `ini:",extends"` DomainConf `ini:",extends"`
RouteByHTTPUser string `ini:"route_by_http_user" json:"route_by_http_user"`
Multiplexer string `ini:"multiplexer"` Multiplexer string `ini:"multiplexer"`
} }
@ -576,7 +578,7 @@ func (cfg *TCPMuxProxyConf) Compare(cmp ProxyConf) bool {
return false return false
} }
if cfg.Multiplexer != cmpConf.Multiplexer { if cfg.Multiplexer != cmpConf.Multiplexer || cfg.RouteByHTTPUser != cmpConf.RouteByHTTPUser {
return false return false
} }
@ -601,6 +603,7 @@ func (cfg *TCPMuxProxyConf) UnmarshalFromMsg(pMsg *msg.NewProxy) {
cfg.CustomDomains = pMsg.CustomDomains cfg.CustomDomains = pMsg.CustomDomains
cfg.SubDomain = pMsg.SubDomain cfg.SubDomain = pMsg.SubDomain
cfg.Multiplexer = pMsg.Multiplexer cfg.Multiplexer = pMsg.Multiplexer
cfg.RouteByHTTPUser = pMsg.RouteByHTTPUser
} }
func (cfg *TCPMuxProxyConf) MarshalToMsg(pMsg *msg.NewProxy) { func (cfg *TCPMuxProxyConf) MarshalToMsg(pMsg *msg.NewProxy) {
@ -610,6 +613,7 @@ func (cfg *TCPMuxProxyConf) MarshalToMsg(pMsg *msg.NewProxy) {
pMsg.CustomDomains = cfg.CustomDomains pMsg.CustomDomains = cfg.CustomDomains
pMsg.SubDomain = cfg.SubDomain pMsg.SubDomain = cfg.SubDomain
pMsg.Multiplexer = cfg.Multiplexer pMsg.Multiplexer = cfg.Multiplexer
pMsg.RouteByHTTPUser = cfg.RouteByHTTPUser
} }
func (cfg *TCPMuxProxyConf) CheckForCli() (err error) { func (cfg *TCPMuxProxyConf) CheckForCli() (err error) {
@ -724,6 +728,7 @@ func (cfg *HTTPProxyConf) Compare(cmp ProxyConf) bool {
cfg.HTTPUser != cmpConf.HTTPUser || cfg.HTTPUser != cmpConf.HTTPUser ||
cfg.HTTPPwd != cmpConf.HTTPPwd || cfg.HTTPPwd != cmpConf.HTTPPwd ||
cfg.HostHeaderRewrite != cmpConf.HostHeaderRewrite || cfg.HostHeaderRewrite != cmpConf.HostHeaderRewrite ||
cfg.RouteByHTTPUser != cmpConf.RouteByHTTPUser ||
!reflect.DeepEqual(cfg.Headers, cmpConf.Headers) { !reflect.DeepEqual(cfg.Headers, cmpConf.Headers) {
return false return false
} }
@ -754,6 +759,7 @@ func (cfg *HTTPProxyConf) UnmarshalFromMsg(pMsg *msg.NewProxy) {
cfg.HTTPUser = pMsg.HTTPUser cfg.HTTPUser = pMsg.HTTPUser
cfg.HTTPPwd = pMsg.HTTPPwd cfg.HTTPPwd = pMsg.HTTPPwd
cfg.Headers = pMsg.Headers cfg.Headers = pMsg.Headers
cfg.RouteByHTTPUser = pMsg.RouteByHTTPUser
} }
func (cfg *HTTPProxyConf) MarshalToMsg(pMsg *msg.NewProxy) { func (cfg *HTTPProxyConf) MarshalToMsg(pMsg *msg.NewProxy) {
@ -767,6 +773,7 @@ func (cfg *HTTPProxyConf) MarshalToMsg(pMsg *msg.NewProxy) {
pMsg.HTTPUser = cfg.HTTPUser pMsg.HTTPUser = cfg.HTTPUser
pMsg.HTTPPwd = cfg.HTTPPwd pMsg.HTTPPwd = cfg.HTTPPwd
pMsg.Headers = cfg.Headers pMsg.Headers = cfg.Headers
pMsg.RouteByHTTPUser = cfg.RouteByHTTPUser
} }
func (cfg *HTTPProxyConf) CheckForCli() (err error) { func (cfg *HTTPProxyConf) CheckForCli() (err error) {

View File

@ -62,6 +62,8 @@ type ServerCommonConf struct {
// requests on one single port. If it's not - it will listen on this value for // requests on one single port. If it's not - it will listen on this value for
// HTTP CONNECT requests. By default, this value is 0. // HTTP CONNECT requests. By default, this value is 0.
TCPMuxHTTPConnectPort int `ini:"tcpmux_httpconnect_port" json:"tcpmux_httpconnect_port" validate:"gte=0,lte=65535"` TCPMuxHTTPConnectPort int `ini:"tcpmux_httpconnect_port" json:"tcpmux_httpconnect_port" validate:"gte=0,lte=65535"`
// If TCPMuxPassthrough is true, frps won't do any update on traffic.
TCPMuxPassthrough bool `ini:"tcpmux_passthrough" json:"tcpmux_passthrough"`
// VhostHTTPTimeout specifies the response header timeout for the Vhost // VhostHTTPTimeout specifies the response header timeout for the Vhost
// HTTP server, in seconds. By default, this value is 60. // HTTP server, in seconds. By default, this value is 60.
VhostHTTPTimeout int64 `ini:"vhost_http_timeout" json:"vhost_http_timeout"` VhostHTTPTimeout int64 `ini:"vhost_http_timeout" json:"vhost_http_timeout"`
@ -188,6 +190,7 @@ func GetDefaultServerConf() ServerCommonConf {
VhostHTTPPort: 0, VhostHTTPPort: 0,
VhostHTTPSPort: 0, VhostHTTPSPort: 0,
TCPMuxHTTPConnectPort: 0, TCPMuxHTTPConnectPort: 0,
TCPMuxPassthrough: false,
VhostHTTPTimeout: 60, VhostHTTPTimeout: 60,
DashboardAddr: "0.0.0.0", DashboardAddr: "0.0.0.0",
DashboardPort: 0, DashboardPort: 0,

View File

@ -62,133 +62,134 @@ var (
// When frpc start, client send this message to login to server. // When frpc start, client send this message to login to server.
type Login struct { type Login struct {
Version string `json:"version"` Version string `json:"version,omitempty"`
Hostname string `json:"hostname"` Hostname string `json:"hostname,omitempty"`
Os string `json:"os"` Os string `json:"os,omitempty"`
Arch string `json:"arch"` Arch string `json:"arch,omitempty"`
User string `json:"user"` User string `json:"user,omitempty"`
PrivilegeKey string `json:"privilege_key"` PrivilegeKey string `json:"privilege_key,omitempty"`
Timestamp int64 `json:"timestamp"` Timestamp int64 `json:"timestamp,omitempty"`
RunID string `json:"run_id"` RunID string `json:"run_id,omitempty"`
Metas map[string]string `json:"metas"` Metas map[string]string `json:"metas,omitempty"`
// Some global configures. // Some global configures.
PoolCount int `json:"pool_count"` PoolCount int `json:"pool_count,omitempty"`
} }
type LoginResp struct { type LoginResp struct {
Version string `json:"version"` Version string `json:"version,omitempty"`
RunID string `json:"run_id"` RunID string `json:"run_id,omitempty"`
ServerUDPPort int `json:"server_udp_port"` ServerUDPPort int `json:"server_udp_port,omitempty"`
Error string `json:"error"` Error string `json:"error,omitempty"`
} }
// When frpc login success, send this message to frps for running a new proxy. // When frpc login success, send this message to frps for running a new proxy.
type NewProxy struct { type NewProxy struct {
ProxyName string `json:"proxy_name"` ProxyName string `json:"proxy_name,omitempty"`
ProxyType string `json:"proxy_type"` ProxyType string `json:"proxy_type,omitempty"`
UseEncryption bool `json:"use_encryption"` UseEncryption bool `json:"use_encryption,omitempty"`
UseCompression bool `json:"use_compression"` UseCompression bool `json:"use_compression,omitempty"`
Group string `json:"group"` Group string `json:"group,omitempty"`
GroupKey string `json:"group_key"` GroupKey string `json:"group_key,omitempty"`
Metas map[string]string `json:"metas"` Metas map[string]string `json:"metas,omitempty"`
// tcp and udp only // tcp and udp only
RemotePort int `json:"remote_port"` RemotePort int `json:"remote_port,omitempty"`
// http and https only // http and https only
CustomDomains []string `json:"custom_domains"` CustomDomains []string `json:"custom_domains,omitempty"`
SubDomain string `json:"subdomain"` SubDomain string `json:"subdomain,omitempty"`
Locations []string `json:"locations"` Locations []string `json:"locations,omitempty"`
HTTPUser string `json:"http_user"` HTTPUser string `json:"http_user,omitempty"`
HTTPPwd string `json:"http_pwd"` HTTPPwd string `json:"http_pwd,omitempty"`
HostHeaderRewrite string `json:"host_header_rewrite"` HostHeaderRewrite string `json:"host_header_rewrite,omitempty"`
Headers map[string]string `json:"headers"` Headers map[string]string `json:"headers,omitempty"`
RouteByHTTPUser string `json:"route_by_http_user,omitempty"`
// stcp // stcp
Sk string `json:"sk"` Sk string `json:"sk,omitempty"`
// tcpmux // tcpmux
Multiplexer string `json:"multiplexer"` Multiplexer string `json:"multiplexer,omitempty"`
} }
type NewProxyResp struct { type NewProxyResp struct {
ProxyName string `json:"proxy_name"` ProxyName string `json:"proxy_name,omitempty"`
RemoteAddr string `json:"remote_addr"` RemoteAddr string `json:"remote_addr,omitempty"`
Error string `json:"error"` Error string `json:"error,omitempty"`
} }
type CloseProxy struct { type CloseProxy struct {
ProxyName string `json:"proxy_name"` ProxyName string `json:"proxy_name,omitempty"`
} }
type NewWorkConn struct { type NewWorkConn struct {
RunID string `json:"run_id"` RunID string `json:"run_id,omitempty"`
PrivilegeKey string `json:"privilege_key"` PrivilegeKey string `json:"privilege_key,omitempty"`
Timestamp int64 `json:"timestamp"` Timestamp int64 `json:"timestamp,omitempty"`
} }
type ReqWorkConn struct { type ReqWorkConn struct {
} }
type StartWorkConn struct { type StartWorkConn struct {
ProxyName string `json:"proxy_name"` ProxyName string `json:"proxy_name,omitempty"`
SrcAddr string `json:"src_addr"` SrcAddr string `json:"src_addr,omitempty"`
DstAddr string `json:"dst_addr"` DstAddr string `json:"dst_addr,omitempty"`
SrcPort uint16 `json:"src_port"` SrcPort uint16 `json:"src_port,omitempty"`
DstPort uint16 `json:"dst_port"` DstPort uint16 `json:"dst_port,omitempty"`
Error string `json:"error"` Error string `json:"error,omitempty"`
} }
type NewVisitorConn struct { type NewVisitorConn struct {
ProxyName string `json:"proxy_name"` ProxyName string `json:"proxy_name,omitempty"`
SignKey string `json:"sign_key"` SignKey string `json:"sign_key,omitempty"`
Timestamp int64 `json:"timestamp"` Timestamp int64 `json:"timestamp,omitempty"`
UseEncryption bool `json:"use_encryption"` UseEncryption bool `json:"use_encryption,omitempty"`
UseCompression bool `json:"use_compression"` UseCompression bool `json:"use_compression,omitempty"`
} }
type NewVisitorConnResp struct { type NewVisitorConnResp struct {
ProxyName string `json:"proxy_name"` ProxyName string `json:"proxy_name,omitempty"`
Error string `json:"error"` Error string `json:"error,omitempty"`
} }
type Ping struct { type Ping struct {
PrivilegeKey string `json:"privilege_key"` PrivilegeKey string `json:"privilege_key,omitempty"`
Timestamp int64 `json:"timestamp"` Timestamp int64 `json:"timestamp,omitempty"`
} }
type Pong struct { type Pong struct {
Error string `json:"error"` Error string `json:"error,omitempty"`
} }
type UDPPacket struct { type UDPPacket struct {
Content string `json:"c"` Content string `json:"c,omitempty"`
LocalAddr *net.UDPAddr `json:"l"` LocalAddr *net.UDPAddr `json:"l,omitempty"`
RemoteAddr *net.UDPAddr `json:"r"` RemoteAddr *net.UDPAddr `json:"r,omitempty"`
} }
type NatHoleVisitor struct { type NatHoleVisitor struct {
ProxyName string `json:"proxy_name"` ProxyName string `json:"proxy_name,omitempty"`
SignKey string `json:"sign_key"` SignKey string `json:"sign_key,omitempty"`
Timestamp int64 `json:"timestamp"` Timestamp int64 `json:"timestamp,omitempty"`
} }
type NatHoleClient struct { type NatHoleClient struct {
ProxyName string `json:"proxy_name"` ProxyName string `json:"proxy_name,omitempty"`
Sid string `json:"sid"` Sid string `json:"sid,omitempty"`
} }
type NatHoleResp struct { type NatHoleResp struct {
Sid string `json:"sid"` Sid string `json:"sid,omitempty"`
VisitorAddr string `json:"visitor_addr"` VisitorAddr string `json:"visitor_addr,omitempty"`
ClientAddr string `json:"client_addr"` ClientAddr string `json:"client_addr,omitempty"`
Error string `json:"error"` Error string `json:"error,omitempty"`
} }
type NatHoleClientDetectOK struct { type NatHoleClientDetectOK struct {
} }
type NatHoleSid struct { type NatHoleSid struct {
Sid string `json:"sid"` Sid string `json:"sid,omitempty"`
} }

View File

@ -24,18 +24,24 @@ import (
"github.com/fatedier/frp/pkg/util/util" "github.com/fatedier/frp/pkg/util/util"
"github.com/fatedier/frp/pkg/util/vhost" "github.com/fatedier/frp/pkg/util/vhost"
gnet "github.com/fatedier/golib/net"
) )
type HTTPConnectTCPMuxer struct { type HTTPConnectTCPMuxer struct {
*vhost.Muxer *vhost.Muxer
passthrough bool
authRequired bool // Not supported until we really need this.
} }
func NewHTTPConnectTCPMuxer(listener net.Listener, timeout time.Duration) (*HTTPConnectTCPMuxer, error) { func NewHTTPConnectTCPMuxer(listener net.Listener, passthrough bool, timeout time.Duration) (*HTTPConnectTCPMuxer, error) {
mux, err := vhost.NewMuxer(listener, getHostFromHTTPConnect, nil, sendHTTPOk, nil, timeout) ret := &HTTPConnectTCPMuxer{passthrough: passthrough, authRequired: false}
return &HTTPConnectTCPMuxer{mux}, err mux, err := vhost.NewMuxer(listener, ret.getHostFromHTTPConnect, nil, ret.sendConnectResponse, nil, timeout)
ret.Muxer = mux
return ret, err
} }
func readHTTPConnectRequest(rd io.Reader) (host string, err error) { func (muxer *HTTPConnectTCPMuxer) readHTTPConnectRequest(rd io.Reader) (host string, httpUser string, err error) {
bufioReader := bufio.NewReader(rd) bufioReader := bufio.NewReader(rd)
req, err := http.ReadRequest(bufioReader) req, err := http.ReadRequest(bufioReader)
@ -49,20 +55,40 @@ func readHTTPConnectRequest(rd io.Reader) (host string, err error) {
} }
host, _ = util.CanonicalHost(req.Host) host, _ = util.CanonicalHost(req.Host)
proxyAuth := req.Header.Get("Proxy-Authorization")
if proxyAuth != "" {
httpUser, _, _ = util.ParseBasicAuth(proxyAuth)
}
return return
} }
func sendHTTPOk(c net.Conn) error { func (muxer *HTTPConnectTCPMuxer) sendConnectResponse(c net.Conn, reqInfo map[string]string) error {
if muxer.passthrough {
return nil
}
return util.OkResponse().Write(c) return util.OkResponse().Write(c)
} }
func getHostFromHTTPConnect(c net.Conn) (_ net.Conn, _ map[string]string, err error) { func (muxer *HTTPConnectTCPMuxer) getHostFromHTTPConnect(c net.Conn) (net.Conn, map[string]string, error) {
reqInfoMap := make(map[string]string, 0) reqInfoMap := make(map[string]string, 0)
host, err := readHTTPConnectRequest(c) sc, rd := gnet.NewSharedConn(c)
host, httpUser, err := muxer.readHTTPConnectRequest(rd)
if err != nil { if err != nil {
return nil, reqInfoMap, err return nil, reqInfoMap, err
} }
reqInfoMap["Host"] = host reqInfoMap["Host"] = host
reqInfoMap["Scheme"] = "tcp" reqInfoMap["Scheme"] = "tcp"
return c, reqInfoMap, nil reqInfoMap["HTTPUser"] = httpUser
var outConn net.Conn = c
if muxer.passthrough {
outConn = sc
if muxer.authRequired && httpUser == "" {
util.ProxyUnauthorizedResponse().Write(c)
outConn = c
}
}
return outConn, reqInfoMap, nil
} }

View File

@ -15,6 +15,7 @@
package util package util
import ( import (
"encoding/base64"
"net" "net"
"net/http" "net/http"
"strings" "strings"
@ -34,6 +35,20 @@ func OkResponse() *http.Response {
return res return res
} }
func ProxyUnauthorizedResponse() *http.Response {
header := make(http.Header)
header.Set("Proxy-Authenticate", `Basic realm="Restricted"`)
res := &http.Response{
Status: "Proxy Authentication Required",
StatusCode: 407,
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Header: header,
}
return res
}
// canonicalHost strips port from host if present and returns the canonicalized // canonicalHost strips port from host if present and returns the canonicalized
// host name. // host name.
func CanonicalHost(host string) (string, error) { func CanonicalHost(host string) (string, error) {
@ -64,3 +79,21 @@ func hasPort(host string) bool {
} }
return host[0] == '[' && strings.Contains(host, "]:") return host[0] == '[' && strings.Contains(host, "]:")
} }
func ParseBasicAuth(auth string) (username, password string, ok bool) {
const prefix = "Basic "
// Case insensitive prefix match. See Issue 22736.
if len(auth) < len(prefix) || !strings.EqualFold(auth[:len(prefix)], prefix) {
return
}
c, err := base64.StdEncoding.DecodeString(auth[len(prefix):])
if err != nil {
return
}
cs := string(c)
s := strings.IndexByte(cs, ':')
if s < 0 {
return
}
return cs[:s], cs[s+1:], true
}

View File

@ -23,17 +23,19 @@ import (
"log" "log"
"net" "net"
"net/http" "net/http"
"net/url"
"strings" "strings"
"time" "time"
frpLog "github.com/fatedier/frp/pkg/util/log" frpLog "github.com/fatedier/frp/pkg/util/log"
"github.com/fatedier/frp/pkg/util/util" "github.com/fatedier/frp/pkg/util/util"
frpIo "github.com/fatedier/golib/io"
"github.com/fatedier/golib/pool" "github.com/fatedier/golib/pool"
) )
var ( var (
ErrNoDomain = errors.New("no such domain") ErrNoRouteFound = errors.New("no route found")
) )
type HTTPReverseProxyOptions struct { type HTTPReverseProxyOptions struct {
@ -56,17 +58,22 @@ func NewHTTPReverseProxy(option HTTPReverseProxyOptions, vhostRouter *Routers) *
vhostRouter: vhostRouter, vhostRouter: vhostRouter,
} }
proxy := &ReverseProxy{ proxy := &ReverseProxy{
// Modify incoming requests by route policies.
Director: func(req *http.Request) { Director: func(req *http.Request) {
req.URL.Scheme = "http" req.URL.Scheme = "http"
url := req.Context().Value(RouteInfoURL).(string) url := req.Context().Value(RouteInfoURL).(string)
routeByHTTPUser := req.Context().Value(RouteInfoHTTPUser).(string)
oldHost, _ := util.CanonicalHost(req.Context().Value(RouteInfoHost).(string)) oldHost, _ := util.CanonicalHost(req.Context().Value(RouteInfoHost).(string))
rc := rp.GetRouteConfig(oldHost, url) rc := rp.GetRouteConfig(oldHost, url, routeByHTTPUser)
if rc != nil { if rc != nil {
if rc.RewriteHost != "" { if rc.RewriteHost != "" {
req.Host = rc.RewriteHost req.Host = rc.RewriteHost
} }
// Set {domain}.{location} as URL host here to let http transport reuse connections. // Set {domain}.{location}.{routeByHTTPUser} as URL host here to let http transport reuse connections.
req.URL.Host = rc.Domain + "." + base64.StdEncoding.EncodeToString([]byte(rc.Location)) // TODO(fatedier): use proxy name instead?
req.URL.Host = rc.Domain + "." +
base64.StdEncoding.EncodeToString([]byte(rc.Location)) + "." +
base64.StdEncoding.EncodeToString([]byte(rc.RouteByHTTPUser))
for k, v := range rc.Headers { for k, v := range rc.Headers {
req.Header.Set(k, v) req.Header.Set(k, v)
@ -76,14 +83,30 @@ func NewHTTPReverseProxy(option HTTPReverseProxyOptions, vhostRouter *Routers) *
} }
}, },
// Create a connection to one proxy routed by route policy.
Transport: &http.Transport{ Transport: &http.Transport{
ResponseHeaderTimeout: rp.responseHeaderTimeout, ResponseHeaderTimeout: rp.responseHeaderTimeout,
IdleConnTimeout: 60 * time.Second, IdleConnTimeout: 60 * time.Second,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
url := ctx.Value(RouteInfoURL).(string) url := ctx.Value(RouteInfoURL).(string)
host, _ := util.CanonicalHost(ctx.Value(RouteInfoHost).(string)) host, _ := util.CanonicalHost(ctx.Value(RouteInfoHost).(string))
routerByHTTPUser := ctx.Value(RouteInfoHTTPUser).(string)
remote := ctx.Value(RouteInfoRemote).(string) remote := ctx.Value(RouteInfoRemote).(string)
return rp.CreateConnection(host, url, remote) return rp.CreateConnection(host, url, routerByHTTPUser, remote)
},
Proxy: func(req *http.Request) (*url.URL, error) {
// Use proxy mode if there is host in HTTP first request line.
// GET http://example.com/ HTTP/1.1
// Host: example.com
//
// Normal:
// GET / HTTP/1.1
// Host: example.com
urlHost := req.Context().Value(RouteInfoURLHost).(string)
if urlHost != "" {
return req.URL, nil
}
return nil, nil
}, },
}, },
BufferPool: newWrapPool(), BufferPool: newWrapPool(),
@ -101,7 +124,7 @@ func NewHTTPReverseProxy(option HTTPReverseProxyOptions, vhostRouter *Routers) *
// Register register the route config to reverse proxy // Register register the route config to reverse proxy
// reverse proxy will use CreateConnFn from routeCfg to create a connection to the remote service // reverse proxy will use CreateConnFn from routeCfg to create a connection to the remote service
func (rp *HTTPReverseProxy) Register(routeCfg RouteConfig) error { func (rp *HTTPReverseProxy) Register(routeCfg RouteConfig) error {
err := rp.vhostRouter.Add(routeCfg.Domain, routeCfg.Location, &routeCfg) err := rp.vhostRouter.Add(routeCfg.Domain, routeCfg.Location, routeCfg.RouteByHTTPUser, &routeCfg)
if err != nil { if err != nil {
return err return err
} }
@ -109,28 +132,29 @@ func (rp *HTTPReverseProxy) Register(routeCfg RouteConfig) error {
} }
// UnRegister unregister route config by domain and location // UnRegister unregister route config by domain and location
func (rp *HTTPReverseProxy) UnRegister(domain string, location string) { func (rp *HTTPReverseProxy) UnRegister(routeCfg RouteConfig) {
rp.vhostRouter.Del(domain, location) rp.vhostRouter.Del(routeCfg.Domain, routeCfg.Location, routeCfg.RouteByHTTPUser)
} }
func (rp *HTTPReverseProxy) GetRouteConfig(domain string, location string) *RouteConfig { func (rp *HTTPReverseProxy) GetRouteConfig(domain, location, routeByHTTPUser string) *RouteConfig {
vr, ok := rp.getVhost(domain, location) vr, ok := rp.getVhost(domain, location, routeByHTTPUser)
if ok { if ok {
frpLog.Debug("get new HTTP request host [%s] path [%s] httpuser [%s]", domain, location, routeByHTTPUser)
return vr.payload.(*RouteConfig) return vr.payload.(*RouteConfig)
} }
return nil return nil
} }
func (rp *HTTPReverseProxy) GetRealHost(domain string, location string) (host string) { func (rp *HTTPReverseProxy) GetRealHost(domain, location, routeByHTTPUser string) (host string) {
vr, ok := rp.getVhost(domain, location) vr, ok := rp.getVhost(domain, location, routeByHTTPUser)
if ok { if ok {
host = vr.payload.(*RouteConfig).RewriteHost host = vr.payload.(*RouteConfig).RewriteHost
} }
return return
} }
func (rp *HTTPReverseProxy) GetHeaders(domain string, location string) (headers map[string]string) { func (rp *HTTPReverseProxy) GetHeaders(domain, location, routeByHTTPUser string) (headers map[string]string) {
vr, ok := rp.getVhost(domain, location) vr, ok := rp.getVhost(domain, location, routeByHTTPUser)
if ok { if ok {
headers = vr.payload.(*RouteConfig).Headers headers = vr.payload.(*RouteConfig).Headers
} }
@ -138,19 +162,19 @@ func (rp *HTTPReverseProxy) GetHeaders(domain string, location string) (headers
} }
// CreateConnection create a new connection by route config // CreateConnection create a new connection by route config
func (rp *HTTPReverseProxy) CreateConnection(domain string, location string, remoteAddr string) (net.Conn, error) { func (rp *HTTPReverseProxy) CreateConnection(domain, location, routeByHTTPUser string, remoteAddr string) (net.Conn, error) {
vr, ok := rp.getVhost(domain, location) vr, ok := rp.getVhost(domain, location, routeByHTTPUser)
if ok { if ok {
fn := vr.payload.(*RouteConfig).CreateConnFn fn := vr.payload.(*RouteConfig).CreateConnFn
if fn != nil { if fn != nil {
return fn(remoteAddr) return fn(remoteAddr)
} }
} }
return nil, fmt.Errorf("%v: %s %s", ErrNoDomain, domain, location) return nil, fmt.Errorf("%v: %s %s %s", ErrNoRouteFound, domain, location, routeByHTTPUser)
} }
func (rp *HTTPReverseProxy) CheckAuth(domain, location, user, passwd string) bool { func (rp *HTTPReverseProxy) CheckAuth(domain, location, routeByHTTPUser, user, passwd string) bool {
vr, ok := rp.getVhost(domain, location) vr, ok := rp.getVhost(domain, location, routeByHTTPUser)
if ok { if ok {
checkUser := vr.payload.(*RouteConfig).Username checkUser := vr.payload.(*RouteConfig).Username
checkPasswd := vr.payload.(*RouteConfig).Password checkPasswd := vr.payload.(*RouteConfig).Password
@ -161,45 +185,120 @@ func (rp *HTTPReverseProxy) CheckAuth(domain, location, user, passwd string) boo
return true return true
} }
// getVhost get vhost router by domain and location // getVhost trys to get vhost router by route policy.
func (rp *HTTPReverseProxy) getVhost(domain string, location string) (vr *Router, ok bool) { func (rp *HTTPReverseProxy) getVhost(domain, location, routeByHTTPUser string) (*Router, bool) {
// first we check the full hostname findRouter := func(inDomain, inLocation, inRouteByHTTPUser string) (*Router, bool) {
// if not exist, then check the wildcard_domain such as *.example.com vr, ok := rp.vhostRouter.Get(inDomain, inLocation, inRouteByHTTPUser)
vr, ok = rp.vhostRouter.Get(domain, location) if ok {
if ok { return vr, ok
return }
} // Try to check if there is one proxy that doesn't specify routerByHTTPUser, it means match all.
vr, ok = rp.vhostRouter.Get(inDomain, inLocation, "")
domainSplit := strings.Split(domain, ".") if ok {
if len(domainSplit) < 3 { return vr, ok
}
return nil, false return nil, false
} }
// First we check the full hostname
// if not exist, then check the wildcard_domain such as *.example.com
vr, ok := findRouter(domain, location, routeByHTTPUser)
if ok {
return vr, ok
}
// e.g. domain = test.example.com, try to match wildcard domains.
// *.example.com
// *.com
domainSplit := strings.Split(domain, ".")
for { for {
if len(domainSplit) < 3 { if len(domainSplit) < 3 {
return nil, false break
} }
domainSplit[0] = "*" domainSplit[0] = "*"
domain = strings.Join(domainSplit, ".") domain = strings.Join(domainSplit, ".")
vr, ok = rp.vhostRouter.Get(domain, location) vr, ok = findRouter(domain, location, routeByHTTPUser)
if ok { if ok {
return vr, true return vr, true
} }
domainSplit = domainSplit[1:] domainSplit = domainSplit[1:]
} }
// Finally, try to check if there is one proxy that domain is "*" means match all domains.
vr, ok = findRouter("*", location, routeByHTTPUser)
if ok {
return vr, true
}
return nil, false
}
func (rp *HTTPReverseProxy) connectHandler(rw http.ResponseWriter, req *http.Request) {
hj, ok := rw.(http.Hijacker)
if !ok {
rw.WriteHeader(http.StatusInternalServerError)
return
}
client, _, err := hj.Hijack()
if err != nil {
rw.WriteHeader(http.StatusInternalServerError)
return
}
url := req.Context().Value(RouteInfoURL).(string)
routeByHTTPUser := req.Context().Value(RouteInfoHTTPUser).(string)
domain, _ := util.CanonicalHost(req.Context().Value(RouteInfoHost).(string))
remoteAddr := req.Context().Value(RouteInfoRemote).(string)
remote, err := rp.CreateConnection(domain, url, routeByHTTPUser, remoteAddr)
if err != nil {
http.Error(rw, "Failed", http.StatusBadRequest)
client.Close()
return
}
req.Write(remote)
go frpIo.Join(remote, client)
}
func (rp *HTTPReverseProxy) injectRequestInfoToCtx(req *http.Request) *http.Request {
newctx := req.Context()
newctx = context.WithValue(newctx, RouteInfoURL, req.URL.Path)
newctx = context.WithValue(newctx, RouteInfoHost, req.Host)
newctx = context.WithValue(newctx, RouteInfoURLHost, req.URL.Host)
user := ""
// If url host isn't empty, it's a proxy request. Get http user from Proxy-Authorization header.
if req.URL.Host != "" {
proxyAuth := req.Header.Get("Proxy-Authorization")
if proxyAuth != "" {
user, _, _ = parseBasicAuth(proxyAuth)
}
}
if user == "" {
user, _, _ = req.BasicAuth()
}
newctx = context.WithValue(newctx, RouteInfoHTTPUser, user)
newctx = context.WithValue(newctx, RouteInfoRemote, req.RemoteAddr)
return req.Clone(newctx)
} }
func (rp *HTTPReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { func (rp *HTTPReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
domain, _ := util.CanonicalHost(req.Host) domain, _ := util.CanonicalHost(req.Host)
location := req.URL.Path location := req.URL.Path
user, passwd, _ := req.BasicAuth() user, passwd, _ := req.BasicAuth()
if !rp.CheckAuth(domain, location, user, passwd) { if !rp.CheckAuth(domain, location, user, user, passwd) {
rw.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) rw.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`)
http.Error(rw, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) http.Error(rw, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return return
} }
rp.proxy.ServeHTTP(rw, req)
newreq := rp.injectRequestInfoToCtx(req)
if req.Method == http.MethodConnect {
rp.connectHandler(rw, newreq)
} else {
rp.proxy.ServeHTTP(rw, newreq)
}
} }
type wrapPool struct{} type wrapPool struct{}

View File

@ -8,6 +8,7 @@ package vhost
import ( import (
"context" "context"
"encoding/base64"
"fmt" "fmt"
"io" "io"
"log" "log"
@ -209,6 +210,24 @@ func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response
return true return true
} }
func parseBasicAuth(auth string) (username, password string, ok bool) {
const prefix = "Basic "
// Case insensitive prefix match. See Issue 22736.
if len(auth) < len(prefix) || !strings.EqualFold(auth[:len(prefix)], prefix) {
return
}
c, err := base64.StdEncoding.DecodeString(auth[len(prefix):])
if err != nil {
return
}
cs := string(c)
s := strings.IndexByte(cs, ':')
if s < 0 {
return
}
return cs[:s], cs[s+1:], true
}
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
transport := p.Transport transport := p.Transport
if transport == nil { if transport == nil {
@ -238,13 +257,6 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
outreq.Header = make(http.Header) // Issue 33142: historical behavior was to always allocate outreq.Header = make(http.Header) // Issue 33142: historical behavior was to always allocate
} }
// =============================
// Modified for frp
outreq = outreq.Clone(context.WithValue(outreq.Context(), RouteInfoURL, req.URL.Path))
outreq = outreq.Clone(context.WithValue(outreq.Context(), RouteInfoHost, req.Host))
outreq = outreq.Clone(context.WithValue(outreq.Context(), RouteInfoRemote, req.RemoteAddr))
// =============================
p.Director(outreq) p.Director(outreq)
outreq.Close = false outreq.Close = false

View File

@ -11,33 +11,42 @@ var (
ErrRouterConfigConflict = errors.New("router config conflict") ErrRouterConfigConflict = errors.New("router config conflict")
) )
type routerByHTTPUser map[string][]*Router
type Routers struct { type Routers struct {
RouterByDomain map[string][]*Router indexByDomain map[string]routerByHTTPUser
mutex sync.RWMutex
mutex sync.RWMutex
} }
type Router struct { type Router struct {
domain string domain string
location string location string
httpUser string
// store any object here
payload interface{} payload interface{}
} }
func NewRouters() *Routers { func NewRouters() *Routers {
return &Routers{ return &Routers{
RouterByDomain: make(map[string][]*Router), indexByDomain: make(map[string]routerByHTTPUser),
} }
} }
func (r *Routers) Add(domain, location string, payload interface{}) error { func (r *Routers) Add(domain, location, httpUser string, payload interface{}) error {
r.mutex.Lock() r.mutex.Lock()
defer r.mutex.Unlock() defer r.mutex.Unlock()
if _, exist := r.exist(domain, location); exist { if _, exist := r.exist(domain, location, httpUser); exist {
return ErrRouterConfigConflict return ErrRouterConfigConflict
} }
vrs, found := r.RouterByDomain[domain] routersByHTTPUser, found := r.indexByDomain[domain]
if !found {
routersByHTTPUser = make(map[string][]*Router)
}
vrs, found := routersByHTTPUser[httpUser]
if !found { if !found {
vrs = make([]*Router, 0, 1) vrs = make([]*Router, 0, 1)
} }
@ -45,20 +54,27 @@ func (r *Routers) Add(domain, location string, payload interface{}) error {
vr := &Router{ vr := &Router{
domain: domain, domain: domain,
location: location, location: location,
httpUser: httpUser,
payload: payload, payload: payload,
} }
vrs = append(vrs, vr) vrs = append(vrs, vr)
sort.Sort(sort.Reverse(ByLocation(vrs))) sort.Sort(sort.Reverse(ByLocation(vrs)))
r.RouterByDomain[domain] = vrs
routersByHTTPUser[httpUser] = vrs
r.indexByDomain[domain] = routersByHTTPUser
return nil return nil
} }
func (r *Routers) Del(domain, location string) { func (r *Routers) Del(domain, location, httpUser string) {
r.mutex.Lock() r.mutex.Lock()
defer r.mutex.Unlock() defer r.mutex.Unlock()
vrs, found := r.RouterByDomain[domain] routersByHTTPUser, found := r.indexByDomain[domain]
if !found {
return
}
vrs, found := routersByHTTPUser[httpUser]
if !found { if !found {
return return
} }
@ -68,40 +84,46 @@ func (r *Routers) Del(domain, location string) {
newVrs = append(newVrs, vr) newVrs = append(newVrs, vr)
} }
} }
r.RouterByDomain[domain] = newVrs routersByHTTPUser[httpUser] = newVrs
} }
func (r *Routers) Get(host, path string) (vr *Router, exist bool) { func (r *Routers) Get(host, path, httpUser string) (vr *Router, exist bool) {
r.mutex.RLock() r.mutex.RLock()
defer r.mutex.RUnlock() defer r.mutex.RUnlock()
vrs, found := r.RouterByDomain[host] routersByHTTPUser, found := r.indexByDomain[host]
if !found {
return
}
vrs, found := routersByHTTPUser[httpUser]
if !found { if !found {
return return
} }
// can't support load balance, will to do
for _, vr = range vrs { for _, vr = range vrs {
if strings.HasPrefix(path, vr.location) { if strings.HasPrefix(path, vr.location) {
return vr, true return vr, true
} }
} }
return return
} }
func (r *Routers) exist(host, path string) (vr *Router, exist bool) { func (r *Routers) exist(host, path, httpUser string) (route *Router, exist bool) {
vrs, found := r.RouterByDomain[host] routersByHTTPUser, found := r.indexByDomain[host]
if !found {
return
}
routers, found := routersByHTTPUser[httpUser]
if !found { if !found {
return return
} }
for _, vr = range vrs { for _, route = range routers {
if path == vr.location { if path == route.location {
return vr, true return route, true
} }
} }
return return
} }

View File

@ -29,16 +29,19 @@ import (
type RouteInfo string type RouteInfo string
const ( const (
RouteInfoURL RouteInfo = "url" RouteInfoURL RouteInfo = "url"
RouteInfoHost RouteInfo = "host" RouteInfoHost RouteInfo = "host"
RouteInfoRemote RouteInfo = "remote" RouteInfoHTTPUser RouteInfo = "httpUser"
RouteInfoRemote RouteInfo = "remote"
RouteInfoURLHost RouteInfo = "urlHost"
) )
type muxFunc func(net.Conn) (net.Conn, map[string]string, error) type muxFunc func(net.Conn) (net.Conn, map[string]string, error)
type httpAuthFunc func(net.Conn, string, string, string) (bool, error) type httpAuthFunc func(net.Conn, string, string, string) (bool, error)
type hostRewriteFunc func(net.Conn, string) (net.Conn, error) type hostRewriteFunc func(net.Conn, string) (net.Conn, error)
type successFunc func(net.Conn) error type successFunc func(net.Conn, map[string]string) error
// Muxer is only used for https and tcpmux proxy.
type Muxer struct { type Muxer struct {
listener net.Listener listener net.Listener
timeout time.Duration timeout time.Duration
@ -49,7 +52,15 @@ type Muxer struct {
registryRouter *Routers registryRouter *Routers
} }
func NewMuxer(listener net.Listener, vhostFunc muxFunc, authFunc httpAuthFunc, successFunc successFunc, rewriteFunc hostRewriteFunc, timeout time.Duration) (mux *Muxer, err error) { func NewMuxer(
listener net.Listener,
vhostFunc muxFunc,
authFunc httpAuthFunc,
successFunc successFunc,
rewriteFunc hostRewriteFunc,
timeout time.Duration,
) (mux *Muxer, err error) {
mux = &Muxer{ mux = &Muxer{
listener: listener, listener: listener,
timeout: timeout, timeout: timeout,
@ -67,12 +78,13 @@ type CreateConnFunc func(remoteAddr string) (net.Conn, error)
// RouteConfig is the params used to match HTTP requests // RouteConfig is the params used to match HTTP requests
type RouteConfig struct { type RouteConfig struct {
Domain string Domain string
Location string Location string
RewriteHost string RewriteHost string
Username string Username string
Password string Password string
Headers map[string]string Headers map[string]string
RouteByHTTPUser string
CreateConnFn CreateConnFunc CreateConnFn CreateConnFunc
} }
@ -81,49 +93,66 @@ type RouteConfig struct {
// then rewrite the host header to rewriteHost // then rewrite the host header to rewriteHost
func (v *Muxer) Listen(ctx context.Context, cfg *RouteConfig) (l *Listener, err error) { func (v *Muxer) Listen(ctx context.Context, cfg *RouteConfig) (l *Listener, err error) {
l = &Listener{ l = &Listener{
name: cfg.Domain, name: cfg.Domain,
location: cfg.Location, location: cfg.Location,
rewriteHost: cfg.RewriteHost, routeByHTTPUser: cfg.RouteByHTTPUser,
userName: cfg.Username, rewriteHost: cfg.RewriteHost,
passWord: cfg.Password, userName: cfg.Username,
mux: v, passWord: cfg.Password,
accept: make(chan net.Conn), mux: v,
ctx: ctx, accept: make(chan net.Conn),
ctx: ctx,
} }
err = v.registryRouter.Add(cfg.Domain, cfg.Location, l) err = v.registryRouter.Add(cfg.Domain, cfg.Location, cfg.RouteByHTTPUser, l)
if err != nil { if err != nil {
return return
} }
return l, nil return l, nil
} }
func (v *Muxer) getListener(name, path string) (l *Listener, exist bool) { func (v *Muxer) getListener(name, path, httpUser string) (*Listener, bool) {
findRouter := func(inName, inPath, inHTTPUser string) (*Listener, bool) {
vr, ok := v.registryRouter.Get(inName, inPath, httpUser)
if ok {
return vr.payload.(*Listener), true
}
// Try to check if there is one proxy that doesn't specify routerByHTTPUser, it means match all.
vr, ok = v.registryRouter.Get(inName, inPath, "")
if ok {
return vr.payload.(*Listener), true
}
return nil, false
}
// first we check the full hostname // first we check the full hostname
// if not exist, then check the wildcard_domain such as *.example.com // if not exist, then check the wildcard_domain such as *.example.com
vr, found := v.registryRouter.Get(name, path) l, ok := findRouter(name, path, httpUser)
if found { if ok {
return vr.payload.(*Listener), true return l, true
} }
domainSplit := strings.Split(name, ".") domainSplit := strings.Split(name, ".")
if len(domainSplit) < 3 {
return
}
for { for {
if len(domainSplit) < 3 { if len(domainSplit) < 3 {
return break
} }
domainSplit[0] = "*" domainSplit[0] = "*"
name = strings.Join(domainSplit, ".") name = strings.Join(domainSplit, ".")
vr, found = v.registryRouter.Get(name, path) l, ok = findRouter(name, path, httpUser)
if found { if ok {
return vr.payload.(*Listener), true return l, true
} }
domainSplit = domainSplit[1:] domainSplit = domainSplit[1:]
} }
// Finally, try to check if there is one proxy that domain is "*" means match all domains.
l, ok = findRouter("*", path, httpUser)
if ok {
return l, true
}
return nil, false
} }
func (v *Muxer) run() { func (v *Muxer) run() {
@ -151,25 +180,26 @@ func (v *Muxer) handle(c net.Conn) {
name := strings.ToLower(reqInfoMap["Host"]) name := strings.ToLower(reqInfoMap["Host"])
path := strings.ToLower(reqInfoMap["Path"]) path := strings.ToLower(reqInfoMap["Path"])
l, ok := v.getListener(name, path) httpUser := reqInfoMap["HTTPUser"]
l, ok := v.getListener(name, path, httpUser)
if !ok { if !ok {
res := notFoundResponse() res := notFoundResponse()
res.Write(c) res.Write(c)
log.Debug("http request for host [%s] path [%s] not found", name, path) log.Debug("http request for host [%s] path [%s] httpUser [%s] not found", name, path, httpUser)
c.Close() c.Close()
return return
} }
xl := xlog.FromContextSafe(l.ctx) xl := xlog.FromContextSafe(l.ctx)
if v.successFunc != nil { if v.successFunc != nil {
if err := v.successFunc(c); err != nil { if err := v.successFunc(c, reqInfoMap); err != nil {
xl.Info("success func failure on vhost connection: %v", err) xl.Info("success func failure on vhost connection: %v", err)
c.Close() c.Close()
return return
} }
} }
// if authFunc is exist and userName/password is set // if authFunc is exist and username/password is set
// then verify user access // then verify user access
if l.mux.authFunc != nil && l.userName != "" && l.passWord != "" { if l.mux.authFunc != nil && l.userName != "" && l.passWord != "" {
bAccess, err := l.mux.authFunc(c, l.userName, l.passWord, reqInfoMap["Authorization"]) bAccess, err := l.mux.authFunc(c, l.userName, l.passWord, reqInfoMap["Authorization"])
@ -188,7 +218,7 @@ func (v *Muxer) handle(c net.Conn) {
} }
c = sConn c = sConn
xl.Debug("get new http request host [%s] path [%s]", name, path) xl.Debug("new request host [%s] path [%s] httpUser [%s]", name, path, httpUser)
err = errors.PanicToError(func() { err = errors.PanicToError(func() {
l.accept <- c l.accept <- c
}) })
@ -198,14 +228,15 @@ func (v *Muxer) handle(c net.Conn) {
} }
type Listener struct { type Listener struct {
name string name string
location string location string
rewriteHost string routeByHTTPUser string
userName string rewriteHost string
passWord string userName string
mux *Muxer // for closing Muxer passWord string
accept chan net.Conn mux *Muxer // for closing Muxer
ctx context.Context accept chan net.Conn
ctx context.Context
} }
func (l *Listener) Accept() (net.Conn, error) { func (l *Listener) Accept() (net.Conn, error) {
@ -231,7 +262,7 @@ func (l *Listener) Accept() (net.Conn, error) {
} }
func (l *Listener) Close() error { func (l *Listener) Close() error {
l.mux.registryRouter.Del(l.name, l.location) l.mux.registryRouter.Del(l.name, l.location, l.routeByHTTPUser)
close(l.accept) close(l.accept)
return nil return nil
} }

View File

@ -458,11 +458,11 @@ func (ctl *Control) manager() {
ProxyName: m.ProxyName, ProxyName: m.ProxyName,
} }
if err != nil { if err != nil {
xl.Warn("new proxy [%s] error: %v", m.ProxyName, err) xl.Warn("new proxy [%s] type [%s] error: %v", m.ProxyName, m.ProxyType, err)
resp.Error = util.GenerateResponseErrorString(fmt.Sprintf("new proxy [%s] error", m.ProxyName), err, ctl.serverCfg.DetailedErrorsToClient) resp.Error = util.GenerateResponseErrorString(fmt.Sprintf("new proxy [%s] error", m.ProxyName), err, ctl.serverCfg.DetailedErrorsToClient)
} else { } else {
resp.RemoteAddr = remoteAddr resp.RemoteAddr = remoteAddr
xl.Info("new proxy [%s] success", m.ProxyName) xl.Info("new proxy [%s] type [%s] success", m.ProxyName, m.ProxyType)
metrics.Server.NewProxy(m.ProxyName, m.ProxyType) metrics.Server.NewProxy(m.ProxyName, m.ProxyType)
} }
ctl.sendCh <- resp ctl.sendCh <- resp

View File

@ -10,8 +10,11 @@ import (
) )
type HTTPGroupController struct { type HTTPGroupController struct {
// groups by indexKey
groups map[string]*HTTPGroup groups map[string]*HTTPGroup
// register createConn for each group to vhostRouter.
// createConn will get a connection from one proxy of the group
vhostRouter *vhost.Routers vhostRouter *vhost.Routers
mu sync.Mutex mu sync.Mutex
@ -24,10 +27,12 @@ func NewHTTPGroupController(vhostRouter *vhost.Routers) *HTTPGroupController {
} }
} }
func (ctl *HTTPGroupController) Register(proxyName, group, groupKey string, func (ctl *HTTPGroupController) Register(
routeConfig vhost.RouteConfig) (err error) { proxyName, group, groupKey string,
routeConfig vhost.RouteConfig,
) (err error) {
indexKey := httpGroupIndex(group, routeConfig.Domain, routeConfig.Location) indexKey := group
ctl.mu.Lock() ctl.mu.Lock()
g, ok := ctl.groups[indexKey] g, ok := ctl.groups[indexKey]
if !ok { if !ok {
@ -39,8 +44,8 @@ func (ctl *HTTPGroupController) Register(proxyName, group, groupKey string,
return g.Register(proxyName, group, groupKey, routeConfig) return g.Register(proxyName, group, groupKey, routeConfig)
} }
func (ctl *HTTPGroupController) UnRegister(proxyName, group, domain, location string) { func (ctl *HTTPGroupController) UnRegister(proxyName, group string, routeConfig vhost.RouteConfig) {
indexKey := httpGroupIndex(group, domain, location) indexKey := group
ctl.mu.Lock() ctl.mu.Lock()
defer ctl.mu.Unlock() defer ctl.mu.Unlock()
g, ok := ctl.groups[indexKey] g, ok := ctl.groups[indexKey]
@ -55,11 +60,13 @@ func (ctl *HTTPGroupController) UnRegister(proxyName, group, domain, location st
} }
type HTTPGroup struct { type HTTPGroup struct {
group string group string
groupKey string groupKey string
domain string domain string
location string location string
routeByHTTPUser string
// CreateConnFuncs indexed by echo proxy name
createFuncs map[string]vhost.CreateConnFunc createFuncs map[string]vhost.CreateConnFunc
pxyNames []string pxyNames []string
index uint64 index uint64
@ -75,8 +82,10 @@ func NewHTTPGroup(ctl *HTTPGroupController) *HTTPGroup {
} }
} }
func (g *HTTPGroup) Register(proxyName, group, groupKey string, func (g *HTTPGroup) Register(
routeConfig vhost.RouteConfig) (err error) { proxyName, group, groupKey string,
routeConfig vhost.RouteConfig,
) (err error) {
g.mu.Lock() g.mu.Lock()
defer g.mu.Unlock() defer g.mu.Unlock()
@ -84,7 +93,7 @@ func (g *HTTPGroup) Register(proxyName, group, groupKey string,
// the first proxy in this group // the first proxy in this group
tmp := routeConfig // copy object tmp := routeConfig // copy object
tmp.CreateConnFn = g.createConn tmp.CreateConnFn = g.createConn
err = g.ctl.vhostRouter.Add(routeConfig.Domain, routeConfig.Location, &tmp) err = g.ctl.vhostRouter.Add(routeConfig.Domain, routeConfig.Location, routeConfig.RouteByHTTPUser, &tmp)
if err != nil { if err != nil {
return return
} }
@ -93,8 +102,10 @@ func (g *HTTPGroup) Register(proxyName, group, groupKey string,
g.groupKey = groupKey g.groupKey = groupKey
g.domain = routeConfig.Domain g.domain = routeConfig.Domain
g.location = routeConfig.Location g.location = routeConfig.Location
g.routeByHTTPUser = routeConfig.RouteByHTTPUser
} else { } else {
if g.group != group || g.domain != routeConfig.Domain || g.location != routeConfig.Location { if g.group != group || g.domain != routeConfig.Domain ||
g.location != routeConfig.Location || g.routeByHTTPUser != routeConfig.RouteByHTTPUser {
err = ErrGroupParamsInvalid err = ErrGroupParamsInvalid
return return
} }
@ -125,7 +136,7 @@ func (g *HTTPGroup) UnRegister(proxyName string) (isEmpty bool) {
if len(g.createFuncs) == 0 { if len(g.createFuncs) == 0 {
isEmpty = true isEmpty = true
g.ctl.vhostRouter.Del(g.domain, g.location) g.ctl.vhostRouter.Del(g.domain, g.location, g.routeByHTTPUser)
} }
return return
} }
@ -138,6 +149,7 @@ func (g *HTTPGroup) createConn(remoteAddr string) (net.Conn, error) {
group := g.group group := g.group
domain := g.domain domain := g.domain
location := g.location location := g.location
routeByHTTPUser := g.routeByHTTPUser
if len(g.pxyNames) > 0 { if len(g.pxyNames) > 0 {
name := g.pxyNames[int(newIndex)%len(g.pxyNames)] name := g.pxyNames[int(newIndex)%len(g.pxyNames)]
f, _ = g.createFuncs[name] f, _ = g.createFuncs[name]
@ -145,12 +157,9 @@ func (g *HTTPGroup) createConn(remoteAddr string) (net.Conn, error) {
g.mu.RUnlock() g.mu.RUnlock()
if f == nil { if f == nil {
return nil, fmt.Errorf("no CreateConnFunc for http group [%s], domain [%s], location [%s]", group, domain, location) return nil, fmt.Errorf("no CreateConnFunc for http group [%s], domain [%s], location [%s], routeByHTTPUser [%s]",
group, domain, location, routeByHTTPUser)
} }
return f(remoteAddr) return f(remoteAddr)
} }
func httpGroupIndex(group, domain, location string) string {
return fmt.Sprintf("%s_%s_%s", group, domain, location)
}

View File

@ -46,8 +46,11 @@ func NewTCPMuxGroupCtl(tcpMuxHTTPConnectMuxer *tcpmux.HTTPConnectTCPMuxer) *TCPM
// Listen is the wrapper for TCPMuxGroup's Listen // Listen is the wrapper for TCPMuxGroup's Listen
// If there are no group, we will create one here // If there are no group, we will create one here
func (tmgc *TCPMuxGroupCtl) Listen(ctx context.Context, multiplexer string, group string, groupKey string, func (tmgc *TCPMuxGroupCtl) Listen(
domain string) (l net.Listener, err error) { ctx context.Context,
multiplexer, group, groupKey string,
routeConfig vhost.RouteConfig,
) (l net.Listener, err error) {
tmgc.mu.Lock() tmgc.mu.Lock()
tcpMuxGroup, ok := tmgc.groups[group] tcpMuxGroup, ok := tmgc.groups[group]
@ -59,7 +62,7 @@ func (tmgc *TCPMuxGroupCtl) Listen(ctx context.Context, multiplexer string, grou
switch multiplexer { switch multiplexer {
case consts.HTTPConnectTCPMultiplexer: case consts.HTTPConnectTCPMultiplexer:
return tcpMuxGroup.HTTPConnectListen(ctx, group, groupKey, domain) return tcpMuxGroup.HTTPConnectListen(ctx, group, groupKey, routeConfig)
default: default:
err = fmt.Errorf("unknown multiplexer [%s]", multiplexer) err = fmt.Errorf("unknown multiplexer [%s]", multiplexer)
return return
@ -75,9 +78,10 @@ func (tmgc *TCPMuxGroupCtl) RemoveGroup(group string) {
// TCPMuxGroup route connections to different proxies // TCPMuxGroup route connections to different proxies
type TCPMuxGroup struct { type TCPMuxGroup struct {
group string group string
groupKey string groupKey string
domain string domain string
routeByHTTPUser string
acceptCh chan net.Conn acceptCh chan net.Conn
index uint64 index uint64
@ -99,15 +103,17 @@ func NewTCPMuxGroup(ctl *TCPMuxGroupCtl) *TCPMuxGroup {
// Listen will return a new TCPMuxGroupListener // Listen will return a new TCPMuxGroupListener
// if TCPMuxGroup already has a listener, just add a new TCPMuxGroupListener to the queues // if TCPMuxGroup already has a listener, just add a new TCPMuxGroupListener to the queues
// otherwise, listen on the real address // otherwise, listen on the real address
func (tmg *TCPMuxGroup) HTTPConnectListen(ctx context.Context, group string, groupKey string, domain string) (ln *TCPMuxGroupListener, err error) { func (tmg *TCPMuxGroup) HTTPConnectListen(
ctx context.Context,
group, groupKey string,
routeConfig vhost.RouteConfig,
) (ln *TCPMuxGroupListener, err error) {
tmg.mu.Lock() tmg.mu.Lock()
defer tmg.mu.Unlock() defer tmg.mu.Unlock()
if len(tmg.lns) == 0 { if len(tmg.lns) == 0 {
// the first listener, listen on the real address // the first listener, listen on the real address
routeConfig := &vhost.RouteConfig{ tcpMuxLn, errRet := tmg.ctl.tcpMuxHTTPConnectMuxer.Listen(ctx, &routeConfig)
Domain: domain,
}
tcpMuxLn, errRet := tmg.ctl.tcpMuxHTTPConnectMuxer.Listen(ctx, routeConfig)
if errRet != nil { if errRet != nil {
return nil, errRet return nil, errRet
} }
@ -115,7 +121,8 @@ func (tmg *TCPMuxGroup) HTTPConnectListen(ctx context.Context, group string, gro
tmg.group = group tmg.group = group
tmg.groupKey = groupKey tmg.groupKey = groupKey
tmg.domain = domain tmg.domain = routeConfig.Domain
tmg.routeByHTTPUser = routeConfig.RouteByHTTPUser
tmg.tcpMuxLn = tcpMuxLn tmg.tcpMuxLn = tcpMuxLn
tmg.lns = append(tmg.lns, ln) tmg.lns = append(tmg.lns, ln)
if tmg.acceptCh == nil { if tmg.acceptCh == nil {
@ -123,8 +130,8 @@ func (tmg *TCPMuxGroup) HTTPConnectListen(ctx context.Context, group string, gro
} }
go tmg.worker() go tmg.worker()
} else { } else {
// domain in the same group must be equal // route config in the same group must be equal
if tmg.group != group || tmg.domain != domain { if tmg.group != group || tmg.domain != routeConfig.Domain || tmg.routeByHTTPUser != routeConfig.RouteByHTTPUser {
return nil, ErrGroupParamsInvalid return nil, ErrGroupParamsInvalid
} }
if tmg.groupKey != groupKey { if tmg.groupKey != groupKey {

View File

@ -38,11 +38,12 @@ type HTTPProxy struct {
func (pxy *HTTPProxy) Run() (remoteAddr string, err error) { func (pxy *HTTPProxy) Run() (remoteAddr string, err error) {
xl := pxy.xl xl := pxy.xl
routeConfig := vhost.RouteConfig{ routeConfig := vhost.RouteConfig{
RewriteHost: pxy.cfg.HostHeaderRewrite, RewriteHost: pxy.cfg.HostHeaderRewrite,
Headers: pxy.cfg.Headers, RouteByHTTPUser: pxy.cfg.RouteByHTTPUser,
Username: pxy.cfg.HTTPUser, Headers: pxy.cfg.Headers,
Password: pxy.cfg.HTTPPwd, Username: pxy.cfg.HTTPUser,
CreateConnFn: pxy.GetRealConn, Password: pxy.cfg.HTTPPwd,
CreateConnFn: pxy.GetRealConn,
} }
locations := pxy.cfg.Locations locations := pxy.cfg.Locations
@ -65,8 +66,8 @@ func (pxy *HTTPProxy) Run() (remoteAddr string, err error) {
routeConfig.Domain = domain routeConfig.Domain = domain
for _, location := range locations { for _, location := range locations {
routeConfig.Location = location routeConfig.Location = location
tmpDomain := routeConfig.Domain
tmpLocation := routeConfig.Location tmpRouteConfig := routeConfig
// handle group // handle group
if pxy.cfg.Group != "" { if pxy.cfg.Group != "" {
@ -76,7 +77,7 @@ func (pxy *HTTPProxy) Run() (remoteAddr string, err error) {
} }
pxy.closeFuncs = append(pxy.closeFuncs, func() { pxy.closeFuncs = append(pxy.closeFuncs, func() {
pxy.rc.HTTPGroupCtl.UnRegister(pxy.name, pxy.cfg.Group, tmpDomain, tmpLocation) pxy.rc.HTTPGroupCtl.UnRegister(pxy.name, pxy.cfg.Group, tmpRouteConfig)
}) })
} else { } else {
// no group // no group
@ -85,11 +86,12 @@ func (pxy *HTTPProxy) Run() (remoteAddr string, err error) {
return return
} }
pxy.closeFuncs = append(pxy.closeFuncs, func() { pxy.closeFuncs = append(pxy.closeFuncs, func() {
pxy.rc.HTTPReverseProxy.UnRegister(tmpDomain, tmpLocation) pxy.rc.HTTPReverseProxy.UnRegister(tmpRouteConfig)
}) })
} }
addrs = append(addrs, util.CanonicalAddr(routeConfig.Domain, int(pxy.serverCfg.VhostHTTPPort))) addrs = append(addrs, util.CanonicalAddr(routeConfig.Domain, int(pxy.serverCfg.VhostHTTPPort)))
xl.Info("http proxy listen for host [%s] location [%s] group [%s]", routeConfig.Domain, routeConfig.Location, pxy.cfg.Group) xl.Info("http proxy listen for host [%s] location [%s] group [%s], routeByHTTPUser [%s]",
routeConfig.Domain, routeConfig.Location, pxy.cfg.Group, pxy.cfg.RouteByHTTPUser)
} }
} }
@ -97,8 +99,8 @@ func (pxy *HTTPProxy) Run() (remoteAddr string, err error) {
routeConfig.Domain = pxy.cfg.SubDomain + "." + pxy.serverCfg.SubDomainHost routeConfig.Domain = pxy.cfg.SubDomain + "." + pxy.serverCfg.SubDomainHost
for _, location := range locations { for _, location := range locations {
routeConfig.Location = location routeConfig.Location = location
tmpDomain := routeConfig.Domain
tmpLocation := routeConfig.Location tmpRouteConfig := routeConfig
// handle group // handle group
if pxy.cfg.Group != "" { if pxy.cfg.Group != "" {
@ -108,7 +110,7 @@ func (pxy *HTTPProxy) Run() (remoteAddr string, err error) {
} }
pxy.closeFuncs = append(pxy.closeFuncs, func() { pxy.closeFuncs = append(pxy.closeFuncs, func() {
pxy.rc.HTTPGroupCtl.UnRegister(pxy.name, pxy.cfg.Group, tmpDomain, tmpLocation) pxy.rc.HTTPGroupCtl.UnRegister(pxy.name, pxy.cfg.Group, tmpRouteConfig)
}) })
} else { } else {
err = pxy.rc.HTTPReverseProxy.Register(routeConfig) err = pxy.rc.HTTPReverseProxy.Register(routeConfig)
@ -116,12 +118,13 @@ func (pxy *HTTPProxy) Run() (remoteAddr string, err error) {
return return
} }
pxy.closeFuncs = append(pxy.closeFuncs, func() { pxy.closeFuncs = append(pxy.closeFuncs, func() {
pxy.rc.HTTPReverseProxy.UnRegister(tmpDomain, tmpLocation) pxy.rc.HTTPReverseProxy.UnRegister(tmpRouteConfig)
}) })
} }
addrs = append(addrs, util.CanonicalAddr(tmpDomain, pxy.serverCfg.VhostHTTPPort)) addrs = append(addrs, util.CanonicalAddr(tmpRouteConfig.Domain, pxy.serverCfg.VhostHTTPPort))
xl.Info("http proxy listen for host [%s] location [%s] group [%s]", routeConfig.Domain, routeConfig.Location, pxy.cfg.Group) xl.Info("http proxy listen for host [%s] location [%s] group [%s], routeByHTTPUser [%s]",
routeConfig.Domain, routeConfig.Location, pxy.cfg.Group, pxy.cfg.RouteByHTTPUser)
} }
} }
remoteAddr = strings.Join(addrs, ",") remoteAddr = strings.Join(addrs, ",")

View File

@ -30,20 +30,23 @@ type TCPMuxProxy struct {
cfg *config.TCPMuxProxyConf cfg *config.TCPMuxProxyConf
} }
func (pxy *TCPMuxProxy) httpConnectListen(domain string, addrs []string) (_ []string, err error) { func (pxy *TCPMuxProxy) httpConnectListen(domain, routeByHTTPUser string, addrs []string) ([]string, error) {
var l net.Listener var l net.Listener
var err error
routeConfig := &vhost.RouteConfig{
Domain: domain,
RouteByHTTPUser: routeByHTTPUser,
}
if pxy.cfg.Group != "" { if pxy.cfg.Group != "" {
l, err = pxy.rc.TCPMuxGroupCtl.Listen(pxy.ctx, pxy.cfg.Multiplexer, pxy.cfg.Group, pxy.cfg.GroupKey, domain) l, err = pxy.rc.TCPMuxGroupCtl.Listen(pxy.ctx, pxy.cfg.Multiplexer, pxy.cfg.Group, pxy.cfg.GroupKey, *routeConfig)
} else { } else {
routeConfig := &vhost.RouteConfig{
Domain: domain,
}
l, err = pxy.rc.TCPMuxHTTPConnectMuxer.Listen(pxy.ctx, routeConfig) l, err = pxy.rc.TCPMuxHTTPConnectMuxer.Listen(pxy.ctx, routeConfig)
} }
if err != nil { if err != nil {
return nil, err return nil, err
} }
pxy.xl.Info("tcpmux httpconnect multiplexer listens for host [%s]", domain) pxy.xl.Info("tcpmux httpconnect multiplexer listens for host [%s], group [%s] routeByHTTPUser [%s]",
domain, pxy.cfg.Group, pxy.cfg.RouteByHTTPUser)
pxy.listeners = append(pxy.listeners, l) pxy.listeners = append(pxy.listeners, l)
return append(addrs, util.CanonicalAddr(domain, pxy.serverCfg.TCPMuxHTTPConnectPort)), nil return append(addrs, util.CanonicalAddr(domain, pxy.serverCfg.TCPMuxHTTPConnectPort)), nil
} }
@ -55,14 +58,14 @@ func (pxy *TCPMuxProxy) httpConnectRun() (remoteAddr string, err error) {
continue continue
} }
addrs, err = pxy.httpConnectListen(domain, addrs) addrs, err = pxy.httpConnectListen(domain, pxy.cfg.RouteByHTTPUser, addrs)
if err != nil { if err != nil {
return "", err return "", err
} }
} }
if pxy.cfg.SubDomain != "" { if pxy.cfg.SubDomain != "" {
addrs, err = pxy.httpConnectListen(pxy.cfg.SubDomain+"."+pxy.serverCfg.SubDomainHost, addrs) addrs, err = pxy.httpConnectListen(pxy.cfg.SubDomain+"."+pxy.serverCfg.SubDomainHost, pxy.cfg.RouteByHTTPUser, addrs)
if err != nil { if err != nil {
return "", err return "", err
} }

View File

@ -131,12 +131,12 @@ func NewService(cfg config.ServerCommonConf) (svr *Service, err error) {
return return
} }
svr.rc.TCPMuxHTTPConnectMuxer, err = tcpmux.NewHTTPConnectTCPMuxer(l, vhostReadWriteTimeout) svr.rc.TCPMuxHTTPConnectMuxer, err = tcpmux.NewHTTPConnectTCPMuxer(l, cfg.TCPMuxPassthrough, vhostReadWriteTimeout)
if err != nil { if err != nil {
err = fmt.Errorf("Create vhost tcpMuxer error, %v", err) err = fmt.Errorf("Create vhost tcpMuxer error, %v", err)
return return
} }
log.Info("tcpmux httpconnect multiplexer listen on %s", address) log.Info("tcpmux httpconnect multiplexer listen on %s, passthough: %v", address, cfg.TCPMuxPassthrough)
} }
// Init all plugins // Init all plugins

View File

@ -10,7 +10,6 @@ import (
"github.com/fatedier/frp/test/e2e/framework/consts" "github.com/fatedier/frp/test/e2e/framework/consts"
"github.com/fatedier/frp/test/e2e/mock/server/httpserver" "github.com/fatedier/frp/test/e2e/mock/server/httpserver"
"github.com/fatedier/frp/test/e2e/pkg/request" "github.com/fatedier/frp/test/e2e/pkg/request"
"github.com/fatedier/frp/test/e2e/pkg/utils"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
@ -59,28 +58,83 @@ var _ = Describe("[Feature: HTTP]", func() {
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses([]string{serverConf}, []string{clientConf})
// foo path tests := []struct {
framework.NewRequestExpect(f).Explain("foo path").Port(vhostHTTPPort). path string
expectResp string
desc string
}{
{path: "/foo", expectResp: "foo", desc: "foo path"},
{path: "/bar", expectResp: "bar", desc: "bar path"},
{path: "/other", expectResp: "foo", desc: "other path"},
}
for _, test := range tests {
framework.NewRequestExpect(f).Explain(test.desc).Port(vhostHTTPPort).
RequestModify(func(r *request.Request) {
r.HTTP().HTTPHost("normal.example.com").HTTPPath(test.path)
}).
ExpectResp([]byte(test.expectResp)).
Ensure()
}
})
It("HTTP route by HTTP user", func() {
vhostHTTPPort := f.AllocPort()
serverConf := getDefaultServerConf(vhostHTTPPort)
fooPort := f.AllocPort()
f.RunServer("", newHTTPServer(fooPort, "foo"))
barPort := f.AllocPort()
f.RunServer("", newHTTPServer(barPort, "bar"))
otherPort := f.AllocPort()
f.RunServer("", newHTTPServer(otherPort, "other"))
clientConf := consts.DefaultClientConfig
clientConf += fmt.Sprintf(`
[foo]
type = http
local_port = %d
custom_domains = normal.example.com
route_by_http_user = user1
[bar]
type = http
local_port = %d
custom_domains = normal.example.com
route_by_http_user = user2
[catchAll]
type = http
local_port = %d
custom_domains = normal.example.com
`, fooPort, barPort, otherPort)
f.RunProcesses([]string{serverConf}, []string{clientConf})
// user1
framework.NewRequestExpect(f).Explain("user1").Port(vhostHTTPPort).
RequestModify(func(r *request.Request) { RequestModify(func(r *request.Request) {
r.HTTP().HTTPHost("normal.example.com").HTTPPath("/foo") r.HTTP().HTTPHost("normal.example.com").HTTPAuth("user1", "")
}). }).
ExpectResp([]byte("foo")). ExpectResp([]byte("foo")).
Ensure() Ensure()
// bar path // user2
framework.NewRequestExpect(f).Explain("bar path").Port(vhostHTTPPort). framework.NewRequestExpect(f).Explain("user2").Port(vhostHTTPPort).
RequestModify(func(r *request.Request) { RequestModify(func(r *request.Request) {
r.HTTP().HTTPHost("normal.example.com").HTTPPath("/bar") r.HTTP().HTTPHost("normal.example.com").HTTPAuth("user2", "")
}). }).
ExpectResp([]byte("bar")). ExpectResp([]byte("bar")).
Ensure() Ensure()
// other path // other user
framework.NewRequestExpect(f).Explain("other path").Port(vhostHTTPPort). framework.NewRequestExpect(f).Explain("other user").Port(vhostHTTPPort).
RequestModify(func(r *request.Request) { RequestModify(func(r *request.Request) {
r.HTTP().HTTPHost("normal.example.com").HTTPPath("/other") r.HTTP().HTTPHost("normal.example.com").HTTPAuth("user3", "")
}). }).
ExpectResp([]byte("foo")). ExpectResp([]byte("other")).
Ensure() Ensure()
}) })
@ -110,18 +164,14 @@ var _ = Describe("[Feature: HTTP]", func() {
// set incorrect auth header // set incorrect auth header
framework.NewRequestExpect(f).Port(vhostHTTPPort). framework.NewRequestExpect(f).Port(vhostHTTPPort).
RequestModify(func(r *request.Request) { RequestModify(func(r *request.Request) {
r.HTTP().HTTPHost("normal.example.com").HTTPHeaders(map[string]string{ r.HTTP().HTTPHost("normal.example.com").HTTPAuth("test", "invalid")
"Authorization": utils.BasicAuth("test", "invalid"),
})
}). }).
Ensure(framework.ExpectResponseCode(401)) Ensure(framework.ExpectResponseCode(401))
// set correct auth header // set correct auth header
framework.NewRequestExpect(f).Port(vhostHTTPPort). framework.NewRequestExpect(f).Port(vhostHTTPPort).
RequestModify(func(r *request.Request) { RequestModify(func(r *request.Request) {
r.HTTP().HTTPHost("normal.example.com").HTTPHeaders(map[string]string{ r.HTTP().HTTPHost("normal.example.com").HTTPAuth("test", "test")
"Authorization": utils.BasicAuth("test", "test"),
})
}). }).
Ensure() Ensure()
}) })

View File

@ -60,7 +60,7 @@ func (f *Framework) RunProcesses(serverTemplates []string, clientTemplates []str
ExpectNoError(err) ExpectNoError(err)
time.Sleep(500 * time.Millisecond) time.Sleep(500 * time.Millisecond)
} }
time.Sleep(500 * time.Millisecond) time.Sleep(time.Second)
return currentServerProcesses, currentClientProcesses return currentServerProcesses, currentClientProcesses
} }

View File

@ -13,6 +13,7 @@ import (
"time" "time"
"github.com/fatedier/frp/test/e2e/pkg/rpc" "github.com/fatedier/frp/test/e2e/pkg/rpc"
"github.com/fatedier/frp/test/e2e/pkg/utils"
libdial "github.com/fatedier/golib/net/dial" libdial "github.com/fatedier/golib/net/dial"
) )
@ -20,10 +21,11 @@ type Request struct {
protocol string protocol string
// for all protocol // for all protocol
addr string addr string
port int port int
body []byte body []byte
timeout time.Duration timeout time.Duration
resolver *net.Resolver
// for http or https // for http or https
method string method string
@ -32,6 +34,8 @@ type Request struct {
headers map[string]string headers map[string]string
tlsConfig *tls.Config tlsConfig *tls.Config
authValue string
proxyURL string proxyURL string
} }
@ -40,8 +44,9 @@ func New() *Request {
protocol: "tcp", protocol: "tcp",
addr: "127.0.0.1", addr: "127.0.0.1",
method: "GET", method: "GET",
path: "/", path: "/",
headers: map[string]string{},
} }
} }
@ -108,6 +113,11 @@ func (r *Request) HTTPHeaders(headers map[string]string) *Request {
return r return r
} }
func (r *Request) HTTPAuth(user, password string) *Request {
r.authValue = utils.BasicAuth(user, password)
return r
}
func (r *Request) TLSConfig(tlsConfig *tls.Config) *Request { func (r *Request) TLSConfig(tlsConfig *tls.Config) *Request {
r.tlsConfig = tlsConfig r.tlsConfig = tlsConfig
return r return r
@ -123,6 +133,11 @@ func (r *Request) Body(content []byte) *Request {
return r return r
} }
func (r *Request) Resolver(resolver *net.Resolver) *Request {
r.resolver = resolver
return r
}
func (r *Request) Do() (*Response, error) { func (r *Request) Do() (*Response, error) {
var ( var (
conn net.Conn conn net.Conn
@ -150,11 +165,12 @@ func (r *Request) Do() (*Response, error) {
return nil, err return nil, err
} }
} else { } else {
dialer := &net.Dialer{Resolver: r.resolver}
switch r.protocol { switch r.protocol {
case "tcp": case "tcp":
conn, err = net.Dial("tcp", addr) conn, err = dialer.Dial("tcp", addr)
case "udp": case "udp":
conn, err = net.Dial("udp", addr) conn, err = dialer.Dial("udp", addr)
default: default:
return nil, fmt.Errorf("invalid protocol") return nil, fmt.Errorf("invalid protocol")
} }
@ -198,11 +214,15 @@ func (r *Request) sendHTTPRequest(method, urlstr string, host string, headers ma
for k, v := range headers { for k, v := range headers {
req.Header.Set(k, v) req.Header.Set(k, v)
} }
if r.authValue != "" {
req.Header.Set("Authorization", r.authValue)
}
tr := &http.Transport{ tr := &http.Transport{
DialContext: (&net.Dialer{ DialContext: (&net.Dialer{
Timeout: time.Second, Timeout: time.Second,
KeepAlive: 30 * time.Second, KeepAlive: 30 * time.Second,
DualStack: true, DualStack: true,
Resolver: r.resolver,
}).DialContext, }).DialContext,
MaxIdleConns: 100, MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second, IdleConnTimeout: 90 * time.Second,

View File

@ -12,7 +12,6 @@ import (
"github.com/fatedier/frp/test/e2e/pkg/cert" "github.com/fatedier/frp/test/e2e/pkg/cert"
"github.com/fatedier/frp/test/e2e/pkg/port" "github.com/fatedier/frp/test/e2e/pkg/port"
"github.com/fatedier/frp/test/e2e/pkg/request" "github.com/fatedier/frp/test/e2e/pkg/request"
"github.com/fatedier/frp/test/e2e/pkg/utils"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
) )
@ -181,9 +180,7 @@ var _ = Describe("[Feature: Client-Plugins]", func() {
// from http proxy with auth // from http proxy with auth
framework.NewRequestExpect(f).Request( framework.NewRequestExpect(f).Request(
framework.NewHTTPRequest().HTTPHost("other.example.com").HTTPPath("/test_static_file").Port(vhostPort).HTTPHeaders(map[string]string{ framework.NewHTTPRequest().HTTPHost("other.example.com").HTTPPath("/test_static_file").Port(vhostPort).HTTPAuth("abc", "123"),
"Authorization": utils.BasicAuth("abc", "123"),
}),
).ExpectResp([]byte("foo")).Ensure() ).ExpectResp([]byte("foo")).Ensure()
}) })