diff --git a/server/proxy.go b/server/proxy.go index bd6234d..8ce1e2b 100644 --- a/server/proxy.go +++ b/server/proxy.go @@ -194,10 +194,11 @@ type HttpProxy struct { } func (pxy *HttpProxy) Run() (err error) { - routeConfig := &vhost.VhostRouteConfig{ - RewriteHost: pxy.cfg.HostHeaderRewrite, - Username: pxy.cfg.HttpUser, - Password: pxy.cfg.HttpPwd, + routeConfig := vhost.VhostRouteConfig{ + RewriteHost: pxy.cfg.HostHeaderRewrite, + Username: pxy.cfg.HttpUser, + Password: pxy.cfg.HttpPwd, + CreateConnFn: pxy.GetRealConn, } locations := pxy.cfg.Locations @@ -208,7 +209,7 @@ func (pxy *HttpProxy) Run() (err error) { routeConfig.Domain = domain for _, location := range locations { routeConfig.Location = location - err := pxy.ctl.svr.httpReverseProxy.Register(routeConfig.Domain, routeConfig.Location, routeConfig.RewriteHost, pxy.GetRealConn) + err := pxy.ctl.svr.httpReverseProxy.Register(routeConfig) if err != nil { return err } @@ -225,7 +226,7 @@ func (pxy *HttpProxy) Run() (err error) { routeConfig.Domain = pxy.cfg.SubDomain + "." + config.ServerCommonCfg.SubDomainHost for _, location := range locations { routeConfig.Location = location - err := pxy.ctl.svr.httpReverseProxy.Register(routeConfig.Domain, routeConfig.Location, routeConfig.RewriteHost, pxy.GetRealConn) + err := pxy.ctl.svr.httpReverseProxy.Register(routeConfig) if err != nil { return err } diff --git a/utils/vhost/newhttp.go b/utils/vhost/newhttp.go index 55ba419..6e20a78 100644 --- a/utils/vhost/newhttp.go +++ b/utils/vhost/newhttp.go @@ -26,7 +26,6 @@ import ( "time" frpLog "github.com/fatedier/frp/utils/log" - frpNet "github.com/fatedier/frp/utils/net" "github.com/fatedier/frp/utils/pool" ) @@ -47,13 +46,6 @@ func getHostFromAddr(addr string) (host string) { return } -type CreateConnFunc func() (frpNet.Conn, error) - -type ProxyOption struct { - RewriteHost string - DialFunc CreateConnFunc -} - type HttpReverseProxy struct { proxy *ReverseProxy tr *http.Transport @@ -94,18 +86,14 @@ func NewHttpReverseProxy() *HttpReverseProxy { return rp } -func (rp *HttpReverseProxy) Register(domain string, location string, rewriteHost string, fn CreateConnFunc) error { +func (rp *HttpReverseProxy) Register(routeCfg VhostRouteConfig) error { rp.cfgMu.Lock() defer rp.cfgMu.Unlock() - _, ok := rp.vhostRouter.Exist(domain, location) + _, ok := rp.vhostRouter.Exist(routeCfg.Domain, routeCfg.Location) if ok { return ErrRouterConfigConflict } else { - payload := &ProxyOption{ - RewriteHost: rewriteHost, - DialFunc: fn, - } - rp.vhostRouter.Add(domain, location, payload) + rp.vhostRouter.Add(routeCfg.Domain, routeCfg.Location, &routeCfg) } return nil } @@ -119,7 +107,7 @@ func (rp *HttpReverseProxy) UnRegister(domain string, location string) { func (rp *HttpReverseProxy) GetRealHost(domain string, location string) (host string) { vr, ok := rp.getVhost(domain, location) if ok { - host = vr.payload.(*ProxyOption).RewriteHost + host = vr.payload.(*VhostRouteConfig).RewriteHost } return } @@ -127,7 +115,7 @@ func (rp *HttpReverseProxy) GetRealHost(domain string, location string) (host st func (rp *HttpReverseProxy) CreateConnection(domain string, location string) (net.Conn, error) { vr, ok := rp.getVhost(domain, location) if ok { - fn := vr.payload.(*ProxyOption).DialFunc + fn := vr.payload.(*VhostRouteConfig).CreateConnFn if fn != nil { return fn() } @@ -135,6 +123,18 @@ func (rp *HttpReverseProxy) CreateConnection(domain string, location string) (ne return nil, ErrNoDomain } +func (rp *HttpReverseProxy) CheckAuth(domain, location, user, passwd string) bool { + vr, ok := rp.getVhost(domain, location) + if ok { + checkUser := vr.payload.(*VhostRouteConfig).Username + checkPasswd := vr.payload.(*VhostRouteConfig).Password + if (checkUser != "" || checkPasswd != "") && (checkUser != user || checkPasswd != passwd) { + return false + } + } + return true +} + func (rp *HttpReverseProxy) getVhost(domain string, location string) (vr *VhostRouter, ok bool) { rp.cfgMu.RLock() defer rp.cfgMu.RUnlock() @@ -157,6 +157,14 @@ func (rp *HttpReverseProxy) getVhost(domain string, location string) (vr *VhostR } func (rp *HttpReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + domain := getHostFromAddr(req.Host) + location := req.URL.Path + user, passwd, _ := req.BasicAuth() + if !rp.CheckAuth(domain, location, user, passwd) { + rw.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) + http.Error(rw, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + return + } rp.proxy.ServeHTTP(rw, req) } diff --git a/utils/vhost/vhost.go b/utils/vhost/vhost.go index 391878b..4920d32 100644 --- a/utils/vhost/vhost.go +++ b/utils/vhost/vhost.go @@ -50,12 +50,16 @@ func NewVhostMuxer(listener frpNet.Listener, vhostFunc muxFunc, authFunc httpAut return mux, nil } +type CreateConnFunc func() (frpNet.Conn, error) + type VhostRouteConfig struct { Domain string Location string RewriteHost string Username string Password string + + CreateConnFn CreateConnFunc } // listen for a new domain name, if rewriteHost is not empty and rewriteFunc is not nil