support http load balancing

This commit is contained in:
fatedier 2019-07-31 00:41:58 +08:00
parent 5796c27ed5
commit b3ed863021
11 changed files with 329 additions and 110 deletions

View File

@ -29,6 +29,9 @@ type ResourceController struct {
// Tcp Group Controller // Tcp Group Controller
TcpGroupCtl *group.TcpGroupCtl TcpGroupCtl *group.TcpGroupCtl
// HTTP Group Controller
HTTPGroupCtl *group.HTTPGroupController
// Manage all tcp ports // Manage all tcp ports
TcpPortManager *ports.PortManager TcpPortManager *ports.PortManager

View File

@ -23,4 +23,5 @@ var (
ErrGroupParamsInvalid = errors.New("group params invalid") ErrGroupParamsInvalid = errors.New("group params invalid")
ErrListenerClosed = errors.New("group listener closed") ErrListenerClosed = errors.New("group listener closed")
ErrGroupDifferentPort = errors.New("group should have same remote port") ErrGroupDifferentPort = errors.New("group should have same remote port")
ErrProxyRepeated = errors.New("group proxy repeated")
) )

157
server/group/http.go Normal file
View File

@ -0,0 +1,157 @@
package group
import (
"fmt"
"sync"
"sync/atomic"
frpNet "github.com/fatedier/frp/utils/net"
"github.com/fatedier/frp/utils/vhost"
)
type HTTPGroupController struct {
groups map[string]*HTTPGroup
vhostRouter *vhost.VhostRouters
mu sync.Mutex
}
func NewHTTPGroupController(vhostRouter *vhost.VhostRouters) *HTTPGroupController {
return &HTTPGroupController{
groups: make(map[string]*HTTPGroup),
vhostRouter: vhostRouter,
}
}
func (ctl *HTTPGroupController) Register(proxyName, group, groupKey string,
routeConfig vhost.VhostRouteConfig) (err error) {
indexKey := httpGroupIndex(group, routeConfig.Domain, routeConfig.Location)
ctl.mu.Lock()
g, ok := ctl.groups[indexKey]
if !ok {
g = NewHTTPGroup(ctl)
ctl.groups[indexKey] = g
}
ctl.mu.Unlock()
return g.Register(proxyName, group, groupKey, routeConfig)
}
func (ctl *HTTPGroupController) UnRegister(proxyName, group, domain, location string) {
indexKey := httpGroupIndex(group, domain, location)
ctl.mu.Lock()
defer ctl.mu.Unlock()
g, ok := ctl.groups[indexKey]
if !ok {
return
}
isEmpty := g.UnRegister(proxyName)
if isEmpty {
delete(ctl.groups, indexKey)
}
}
type HTTPGroup struct {
group string
groupKey string
domain string
location string
createFuncs map[string]vhost.CreateConnFunc
pxyNames []string
index uint64
ctl *HTTPGroupController
mu sync.RWMutex
}
func NewHTTPGroup(ctl *HTTPGroupController) *HTTPGroup {
return &HTTPGroup{
createFuncs: make(map[string]vhost.CreateConnFunc),
pxyNames: make([]string, 0),
ctl: ctl,
}
}
func (g *HTTPGroup) Register(proxyName, group, groupKey string,
routeConfig vhost.VhostRouteConfig) (err error) {
g.mu.Lock()
defer g.mu.Unlock()
if len(g.createFuncs) == 0 {
// the first proxy in this group
tmp := routeConfig // copy object
tmp.CreateConnFn = g.createConn
err = g.ctl.vhostRouter.Add(routeConfig.Domain, routeConfig.Location, &tmp)
if err != nil {
return
}
g.group = group
g.groupKey = groupKey
g.domain = routeConfig.Domain
g.location = routeConfig.Location
} else {
if g.group != group || g.domain != routeConfig.Domain || g.location != routeConfig.Location {
err = ErrGroupParamsInvalid
return
}
if g.groupKey != groupKey {
err = ErrGroupAuthFailed
return
}
}
if _, ok := g.createFuncs[proxyName]; ok {
err = ErrProxyRepeated
return
}
g.createFuncs[proxyName] = routeConfig.CreateConnFn
g.pxyNames = append(g.pxyNames, proxyName)
return nil
}
func (g *HTTPGroup) UnRegister(proxyName string) (isEmpty bool) {
g.mu.Lock()
defer g.mu.Unlock()
delete(g.createFuncs, proxyName)
for i, name := range g.pxyNames {
if name == proxyName {
g.pxyNames = append(g.pxyNames[:i], g.pxyNames[i+1:]...)
break
}
}
if len(g.createFuncs) == 0 {
isEmpty = true
g.ctl.vhostRouter.Del(g.domain, g.location)
}
return
}
func (g *HTTPGroup) createConn(remoteAddr string) (frpNet.Conn, error) {
var f vhost.CreateConnFunc
newIndex := atomic.AddUint64(&g.index, 1)
g.mu.RLock()
group := g.group
domain := g.domain
location := g.location
if len(g.pxyNames) > 0 {
name := g.pxyNames[int(newIndex)%len(g.pxyNames)]
f, _ = g.createFuncs[name]
}
g.mu.RUnlock()
if f == nil {
return nil, fmt.Errorf("no CreateConnFunc for http group [%s], domain [%s], location [%s]", group, domain, location)
}
return f(remoteAddr)
}
func httpGroupIndex(group, domain, location string) string {
return fmt.Sprintf("%s_%s_%s", group, domain, location)
}

View File

@ -24,46 +24,47 @@ import (
gerr "github.com/fatedier/golib/errors" gerr "github.com/fatedier/golib/errors"
) )
type TcpGroupListener struct { // TcpGroupCtl manage all TcpGroups
groupName string type TcpGroupCtl struct {
group *TcpGroup groups map[string]*TcpGroup
addr net.Addr // portManager is used to manage port
closeCh chan struct{} portManager *ports.PortManager
mu sync.Mutex
} }
func newTcpGroupListener(name string, group *TcpGroup, addr net.Addr) *TcpGroupListener { // NewTcpGroupCtl return a new TcpGroupCtl
return &TcpGroupListener{ func NewTcpGroupCtl(portManager *ports.PortManager) *TcpGroupCtl {
groupName: name, return &TcpGroupCtl{
group: group, groups: make(map[string]*TcpGroup),
addr: addr, portManager: portManager,
closeCh: make(chan struct{}),
} }
} }
func (ln *TcpGroupListener) Accept() (c net.Conn, err error) { // Listen is the wrapper for TcpGroup's Listen
var ok bool // If there are no group, we will create one here
select { func (tgc *TcpGroupCtl) Listen(proxyName string, group string, groupKey string,
case <-ln.closeCh: addr string, port int) (l net.Listener, realPort int, err error) {
return nil, ErrListenerClosed
case c, ok = <-ln.group.Accept(): tgc.mu.Lock()
tcpGroup, ok := tgc.groups[group]
if !ok { if !ok {
return nil, ErrListenerClosed tcpGroup = NewTcpGroup(tgc)
} tgc.groups[group] = tcpGroup
return c, nil
} }
tgc.mu.Unlock()
return tcpGroup.Listen(proxyName, group, groupKey, addr, port)
} }
func (ln *TcpGroupListener) Addr() net.Addr { // RemoveGroup remove TcpGroup from controller
return ln.addr func (tgc *TcpGroupCtl) RemoveGroup(group string) {
} tgc.mu.Lock()
defer tgc.mu.Unlock()
func (ln *TcpGroupListener) Close() (err error) { delete(tgc.groups, group)
close(ln.closeCh)
ln.group.CloseListener(ln)
return
} }
// TcpGroup route connections to different proxies
type TcpGroup struct { type TcpGroup struct {
group string group string
groupKey string groupKey string
@ -79,6 +80,7 @@ type TcpGroup struct {
mu sync.Mutex mu sync.Mutex
} }
// NewTcpGroup return a new TcpGroup
func NewTcpGroup(ctl *TcpGroupCtl) *TcpGroup { func NewTcpGroup(ctl *TcpGroupCtl) *TcpGroup {
return &TcpGroup{ return &TcpGroup{
lns: make([]*TcpGroupListener, 0), lns: make([]*TcpGroupListener, 0),
@ -87,10 +89,14 @@ func NewTcpGroup(ctl *TcpGroupCtl) *TcpGroup {
} }
} }
// Listen will return a new TcpGroupListener
// if TcpGroup already has a listener, just add a new TcpGroupListener to the queues
// otherwise, listen on the real address
func (tg *TcpGroup) Listen(proxyName string, group string, groupKey string, addr string, port int) (ln *TcpGroupListener, realPort int, err error) { func (tg *TcpGroup) Listen(proxyName string, group string, groupKey string, addr string, port int) (ln *TcpGroupListener, realPort int, err error) {
tg.mu.Lock() tg.mu.Lock()
defer tg.mu.Unlock() defer tg.mu.Unlock()
if len(tg.lns) == 0 { if len(tg.lns) == 0 {
// the first listener, listen on the real address
realPort, err = tg.ctl.portManager.Acquire(proxyName, port) realPort, err = tg.ctl.portManager.Acquire(proxyName, port)
if err != nil { if err != nil {
return return
@ -114,6 +120,7 @@ func (tg *TcpGroup) Listen(proxyName string, group string, groupKey string, addr
} }
go tg.worker() go tg.worker()
} else { } else {
// address and port in the same group must be equal
if tg.group != group || tg.addr != addr { if tg.group != group || tg.addr != addr {
err = ErrGroupParamsInvalid err = ErrGroupParamsInvalid
return return
@ -133,6 +140,7 @@ func (tg *TcpGroup) Listen(proxyName string, group string, groupKey string, addr
return return
} }
// worker is called when the real tcp listener has been created
func (tg *TcpGroup) worker() { func (tg *TcpGroup) worker() {
for { for {
c, err := tg.tcpLn.Accept() c, err := tg.tcpLn.Accept()
@ -152,6 +160,7 @@ func (tg *TcpGroup) Accept() <-chan net.Conn {
return tg.acceptCh return tg.acceptCh
} }
// CloseListener remove the TcpGroupListener from the TcpGroup
func (tg *TcpGroup) CloseListener(ln *TcpGroupListener) { func (tg *TcpGroup) CloseListener(ln *TcpGroupListener) {
tg.mu.Lock() tg.mu.Lock()
defer tg.mu.Unlock() defer tg.mu.Unlock()
@ -169,36 +178,47 @@ func (tg *TcpGroup) CloseListener(ln *TcpGroupListener) {
} }
} }
type TcpGroupCtl struct { // TcpGroupListener
groups map[string]*TcpGroup type TcpGroupListener struct {
groupName string
group *TcpGroup
portManager *ports.PortManager addr net.Addr
mu sync.Mutex closeCh chan struct{}
} }
func NewTcpGroupCtl(portManager *ports.PortManager) *TcpGroupCtl { func newTcpGroupListener(name string, group *TcpGroup, addr net.Addr) *TcpGroupListener {
return &TcpGroupCtl{ return &TcpGroupListener{
groups: make(map[string]*TcpGroup), groupName: name,
portManager: portManager, group: group,
addr: addr,
closeCh: make(chan struct{}),
} }
} }
func (tgc *TcpGroupCtl) Listen(proxyNanme string, group string, groupKey string, // Accept will accept connections from TcpGroup
addr string, port int) (l net.Listener, realPort int, err error) { func (ln *TcpGroupListener) Accept() (c net.Conn, err error) {
var ok bool
tgc.mu.Lock() select {
defer tgc.mu.Unlock() case <-ln.closeCh:
if tcpGroup, ok := tgc.groups[group]; ok { return nil, ErrListenerClosed
return tcpGroup.Listen(proxyNanme, group, groupKey, addr, port) case c, ok = <-ln.group.Accept():
} else { if !ok {
tcpGroup = NewTcpGroup(tgc) return nil, ErrListenerClosed
tgc.groups[group] = tcpGroup }
return tcpGroup.Listen(proxyNanme, group, groupKey, addr, port) return c, nil
} }
} }
func (tgc *TcpGroupCtl) RemoveGroup(group string) { func (ln *TcpGroupListener) Addr() net.Addr {
tgc.mu.Lock() return ln.addr
defer tgc.mu.Unlock() }
delete(tgc.groups, group)
// Close close the listener
func (ln *TcpGroupListener) Close() (err error) {
close(ln.closeCh)
// remove self from TcpGroup
ln.group.CloseListener(ln)
return
} }

View File

@ -50,6 +50,12 @@ func (pxy *HttpProxy) Run() (remoteAddr string, err error) {
locations = []string{""} locations = []string{""}
} }
defer func() {
if err != nil {
pxy.Close()
}
}()
addrs := make([]string, 0) addrs := make([]string, 0)
for _, domain := range pxy.cfg.CustomDomains { for _, domain := range pxy.cfg.CustomDomains {
if domain == "" { if domain == "" {
@ -59,17 +65,31 @@ func (pxy *HttpProxy) Run() (remoteAddr string, err error) {
routeConfig.Domain = domain routeConfig.Domain = domain
for _, location := range locations { for _, location := range locations {
routeConfig.Location = location routeConfig.Location = location
tmpDomain := routeConfig.Domain
tmpLocation := routeConfig.Location
// handle group
if pxy.cfg.Group != "" {
err = pxy.rc.HTTPGroupCtl.Register(pxy.name, pxy.cfg.Group, pxy.cfg.GroupKey, routeConfig)
if err != nil {
return
}
pxy.closeFuncs = append(pxy.closeFuncs, func() {
pxy.rc.HTTPGroupCtl.UnRegister(pxy.name, pxy.cfg.Group, tmpDomain, tmpLocation)
})
} else {
// no group
err = pxy.rc.HttpReverseProxy.Register(routeConfig) err = pxy.rc.HttpReverseProxy.Register(routeConfig)
if err != nil { if err != nil {
return return
} }
tmpDomain := routeConfig.Domain
tmpLocation := routeConfig.Location
addrs = append(addrs, util.CanonicalAddr(tmpDomain, int(g.GlbServerCfg.VhostHttpPort)))
pxy.closeFuncs = append(pxy.closeFuncs, func() { pxy.closeFuncs = append(pxy.closeFuncs, func() {
pxy.rc.HttpReverseProxy.UnRegister(tmpDomain, tmpLocation) pxy.rc.HttpReverseProxy.UnRegister(tmpDomain, tmpLocation)
}) })
pxy.Info("http proxy listen for host [%s] location [%s]", routeConfig.Domain, routeConfig.Location) }
addrs = append(addrs, util.CanonicalAddr(routeConfig.Domain, int(g.GlbServerCfg.VhostHttpPort)))
pxy.Info("http proxy listen for host [%s] location [%s] group [%s]", routeConfig.Domain, routeConfig.Location, pxy.cfg.Group)
} }
} }
@ -77,17 +97,31 @@ func (pxy *HttpProxy) Run() (remoteAddr string, err error) {
routeConfig.Domain = pxy.cfg.SubDomain + "." + g.GlbServerCfg.SubDomainHost routeConfig.Domain = pxy.cfg.SubDomain + "." + g.GlbServerCfg.SubDomainHost
for _, location := range locations { for _, location := range locations {
routeConfig.Location = location routeConfig.Location = location
tmpDomain := routeConfig.Domain
tmpLocation := routeConfig.Location
// handle group
if pxy.cfg.Group != "" {
err = pxy.rc.HTTPGroupCtl.Register(pxy.name, pxy.cfg.Group, pxy.cfg.GroupKey, routeConfig)
if err != nil {
return
}
pxy.closeFuncs = append(pxy.closeFuncs, func() {
pxy.rc.HTTPGroupCtl.UnRegister(pxy.name, pxy.cfg.Group, tmpDomain, tmpLocation)
})
} else {
err = pxy.rc.HttpReverseProxy.Register(routeConfig) err = pxy.rc.HttpReverseProxy.Register(routeConfig)
if err != nil { if err != nil {
return return
} }
tmpDomain := routeConfig.Domain
tmpLocation := routeConfig.Location
addrs = append(addrs, util.CanonicalAddr(tmpDomain, g.GlbServerCfg.VhostHttpPort))
pxy.closeFuncs = append(pxy.closeFuncs, func() { pxy.closeFuncs = append(pxy.closeFuncs, func() {
pxy.rc.HttpReverseProxy.UnRegister(tmpDomain, tmpLocation) pxy.rc.HttpReverseProxy.UnRegister(tmpDomain, tmpLocation)
}) })
pxy.Info("http proxy listen for host [%s] location [%s]", routeConfig.Domain, routeConfig.Location) }
addrs = append(addrs, util.CanonicalAddr(tmpDomain, g.GlbServerCfg.VhostHttpPort))
pxy.Info("http proxy listen for host [%s] location [%s] group [%s]", routeConfig.Domain, routeConfig.Location, pxy.cfg.Group)
} }
} }
remoteAddr = strings.Join(addrs, ",") remoteAddr = strings.Join(addrs, ",")

View File

@ -31,6 +31,11 @@ type HttpsProxy struct {
func (pxy *HttpsProxy) Run() (remoteAddr string, err error) { func (pxy *HttpsProxy) Run() (remoteAddr string, err error) {
routeConfig := &vhost.VhostRouteConfig{} routeConfig := &vhost.VhostRouteConfig{}
defer func() {
if err != nil {
pxy.Close()
}
}()
addrs := make([]string, 0) addrs := make([]string, 0)
for _, domain := range pxy.cfg.CustomDomains { for _, domain := range pxy.cfg.CustomDomains {
if domain == "" { if domain == "" {

View File

@ -72,6 +72,8 @@ func (pxy *BaseProxy) Close() {
} }
} }
// GetWorkConnFromPool try to get a new work connections from pool
// for quickly response, we immediately send the StartWorkConn message to frpc after take out one from pool
func (pxy *BaseProxy) GetWorkConnFromPool(src, dst net.Addr) (workConn frpNet.Conn, err error) { func (pxy *BaseProxy) GetWorkConnFromPool(src, dst net.Addr) (workConn frpNet.Conn, err error) {
// try all connections from the pool // try all connections from the pool
for i := 0; i < pxy.poolCount+1; i++ { for i := 0; i < pxy.poolCount+1; i++ {

View File

@ -76,6 +76,9 @@ type Service struct {
// Manage all proxies // Manage all proxies
pxyManager *proxy.ProxyManager pxyManager *proxy.ProxyManager
// HTTP vhost router
httpVhostRouter *vhost.VhostRouters
// All resource managers and controllers // All resource managers and controllers
rc *controller.ResourceController rc *controller.ResourceController
@ -95,12 +98,16 @@ func NewService() (svr *Service, err error) {
TcpPortManager: ports.NewPortManager("tcp", cfg.ProxyBindAddr, cfg.AllowPorts), TcpPortManager: ports.NewPortManager("tcp", cfg.ProxyBindAddr, cfg.AllowPorts),
UdpPortManager: ports.NewPortManager("udp", cfg.ProxyBindAddr, cfg.AllowPorts), UdpPortManager: ports.NewPortManager("udp", cfg.ProxyBindAddr, cfg.AllowPorts),
}, },
httpVhostRouter: vhost.NewVhostRouters(),
tlsConfig: generateTLSConfig(), tlsConfig: generateTLSConfig(),
} }
// Init group controller // Init group controller
svr.rc.TcpGroupCtl = group.NewTcpGroupCtl(svr.rc.TcpPortManager) svr.rc.TcpGroupCtl = group.NewTcpGroupCtl(svr.rc.TcpPortManager)
// Init HTTP group controller
svr.rc.HTTPGroupCtl = group.NewHTTPGroupController(svr.httpVhostRouter)
// Init assets // Init assets
err = assets.Load(cfg.AssetsDir) err = assets.Load(cfg.AssetsDir)
if err != nil { if err != nil {
@ -159,7 +166,7 @@ func NewService() (svr *Service, err error) {
if cfg.VhostHttpPort > 0 { if cfg.VhostHttpPort > 0 {
rp := vhost.NewHttpReverseProxy(vhost.HttpReverseProxyOptions{ rp := vhost.NewHttpReverseProxy(vhost.HttpReverseProxyOptions{
ResponseHeaderTimeoutS: cfg.VhostHttpTimeout, ResponseHeaderTimeoutS: cfg.VhostHttpTimeout,
}) }, svr.httpVhostRouter)
svr.rc.HttpReverseProxy = rp svr.rc.HttpReverseProxy = rp
address := fmt.Sprintf("%s:%d", cfg.ProxyBindAddr, cfg.VhostHttpPort) address := fmt.Sprintf("%s:%d", cfg.ProxyBindAddr, cfg.VhostHttpPort)

View File

@ -23,7 +23,6 @@ import (
"net" "net"
"net/http" "net/http"
"strings" "strings"
"sync"
"time" "time"
frpLog "github.com/fatedier/frp/utils/log" frpLog "github.com/fatedier/frp/utils/log"
@ -32,7 +31,6 @@ import (
) )
var ( var (
ErrRouterConfigConflict = errors.New("router config conflict")
ErrNoDomain = errors.New("no such domain") ErrNoDomain = errors.New("no such domain")
) )
@ -52,20 +50,18 @@ type HttpReverseProxyOptions struct {
type HttpReverseProxy struct { type HttpReverseProxy struct {
proxy *ReverseProxy proxy *ReverseProxy
vhostRouter *VhostRouters vhostRouter *VhostRouters
responseHeaderTimeout time.Duration responseHeaderTimeout time.Duration
cfgMu sync.RWMutex
} }
func NewHttpReverseProxy(option HttpReverseProxyOptions) *HttpReverseProxy { func NewHttpReverseProxy(option HttpReverseProxyOptions, vhostRouter *VhostRouters) *HttpReverseProxy {
if option.ResponseHeaderTimeoutS <= 0 { if option.ResponseHeaderTimeoutS <= 0 {
option.ResponseHeaderTimeoutS = 60 option.ResponseHeaderTimeoutS = 60
} }
rp := &HttpReverseProxy{ rp := &HttpReverseProxy{
responseHeaderTimeout: time.Duration(option.ResponseHeaderTimeoutS) * time.Second, responseHeaderTimeout: time.Duration(option.ResponseHeaderTimeoutS) * time.Second,
vhostRouter: NewVhostRouters(), vhostRouter: vhostRouter,
} }
proxy := &ReverseProxy{ proxy := &ReverseProxy{
Director: func(req *http.Request) { Director: func(req *http.Request) {
@ -106,21 +102,18 @@ func NewHttpReverseProxy(option HttpReverseProxyOptions) *HttpReverseProxy {
return rp return rp
} }
// Register register the route config to reverse proxy
// reverse proxy will use CreateConnFn from routeCfg to create a connection to the remote service
func (rp *HttpReverseProxy) Register(routeCfg VhostRouteConfig) error { func (rp *HttpReverseProxy) Register(routeCfg VhostRouteConfig) error {
rp.cfgMu.Lock() err := rp.vhostRouter.Add(routeCfg.Domain, routeCfg.Location, &routeCfg)
defer rp.cfgMu.Unlock() if err != nil {
_, ok := rp.vhostRouter.Exist(routeCfg.Domain, routeCfg.Location) return err
if ok {
return ErrRouterConfigConflict
} else {
rp.vhostRouter.Add(routeCfg.Domain, routeCfg.Location, &routeCfg)
} }
return nil return nil
} }
// UnRegister unregister route config by domain and location
func (rp *HttpReverseProxy) UnRegister(domain string, location string) { func (rp *HttpReverseProxy) UnRegister(domain string, location string) {
rp.cfgMu.Lock()
defer rp.cfgMu.Unlock()
rp.vhostRouter.Del(domain, location) rp.vhostRouter.Del(domain, location)
} }
@ -140,6 +133,7 @@ func (rp *HttpReverseProxy) GetHeaders(domain string, location string) (headers
return return
} }
// CreateConnection create a new connection by route config
func (rp *HttpReverseProxy) CreateConnection(domain string, location string, remoteAddr string) (net.Conn, error) { func (rp *HttpReverseProxy) CreateConnection(domain string, location string, remoteAddr string) (net.Conn, error) {
vr, ok := rp.getVhost(domain, location) vr, ok := rp.getVhost(domain, location)
if ok { if ok {
@ -163,10 +157,8 @@ func (rp *HttpReverseProxy) CheckAuth(domain, location, user, passwd string) boo
return true return true
} }
// getVhost get vhost router by domain and location
func (rp *HttpReverseProxy) getVhost(domain string, location string) (vr *VhostRouter, ok bool) { func (rp *HttpReverseProxy) getVhost(domain string, location string) (vr *VhostRouter, ok bool) {
rp.cfgMu.RLock()
defer rp.cfgMu.RUnlock()
// first we check the full hostname // first we check the full hostname
// if not exist, then check the wildcard_domain such as *.example.com // if not exist, then check the wildcard_domain such as *.example.com
vr, ok = rp.vhostRouter.Get(domain, location) vr, ok = rp.vhostRouter.Get(domain, location)

View File

@ -1,11 +1,16 @@
package vhost package vhost
import ( import (
"errors"
"sort" "sort"
"strings" "strings"
"sync" "sync"
) )
var (
ErrRouterConfigConflict = errors.New("router config conflict")
)
type VhostRouters struct { type VhostRouters struct {
RouterByDomain map[string][]*VhostRouter RouterByDomain map[string][]*VhostRouter
mutex sync.RWMutex mutex sync.RWMutex
@ -24,10 +29,14 @@ func NewVhostRouters() *VhostRouters {
} }
} }
func (r *VhostRouters) Add(domain, location string, payload interface{}) { func (r *VhostRouters) Add(domain, location string, payload interface{}) error {
r.mutex.Lock() r.mutex.Lock()
defer r.mutex.Unlock() defer r.mutex.Unlock()
if _, exist := r.exist(domain, location); exist {
return ErrRouterConfigConflict
}
vrs, found := r.RouterByDomain[domain] vrs, found := r.RouterByDomain[domain]
if !found { if !found {
vrs = make([]*VhostRouter, 0, 1) vrs = make([]*VhostRouter, 0, 1)
@ -42,6 +51,7 @@ func (r *VhostRouters) Add(domain, location string, payload interface{}) {
sort.Sort(sort.Reverse(ByLocation(vrs))) sort.Sort(sort.Reverse(ByLocation(vrs)))
r.RouterByDomain[domain] = vrs r.RouterByDomain[domain] = vrs
return nil
} }
func (r *VhostRouters) Del(domain, location string) { func (r *VhostRouters) Del(domain, location string) {
@ -80,10 +90,7 @@ func (r *VhostRouters) Get(host, path string) (vr *VhostRouter, exist bool) {
return return
} }
func (r *VhostRouters) Exist(host, path string) (vr *VhostRouter, exist bool) { func (r *VhostRouters) exist(host, path string) (vr *VhostRouter, exist bool) {
r.mutex.RLock()
defer r.mutex.RUnlock()
vrs, found := r.RouterByDomain[host] vrs, found := r.RouterByDomain[host]
if !found { if !found {
return return

View File

@ -15,7 +15,6 @@ package vhost
import ( import (
"fmt" "fmt"
"strings" "strings"
"sync"
"time" "time"
"github.com/fatedier/frp/utils/log" "github.com/fatedier/frp/utils/log"
@ -35,7 +34,6 @@ type VhostMuxer struct {
authFunc httpAuthFunc authFunc httpAuthFunc
rewriteFunc hostRewriteFunc rewriteFunc hostRewriteFunc
registryRouter *VhostRouters registryRouter *VhostRouters
mutex sync.RWMutex
} }
func NewVhostMuxer(listener frpNet.Listener, vhostFunc muxFunc, authFunc httpAuthFunc, rewriteFunc hostRewriteFunc, timeout time.Duration) (mux *VhostMuxer, err error) { func NewVhostMuxer(listener frpNet.Listener, vhostFunc muxFunc, authFunc httpAuthFunc, rewriteFunc hostRewriteFunc, timeout time.Duration) (mux *VhostMuxer, err error) {
@ -53,6 +51,7 @@ func NewVhostMuxer(listener frpNet.Listener, vhostFunc muxFunc, authFunc httpAut
type CreateConnFunc func(remoteAddr string) (frpNet.Conn, error) type CreateConnFunc func(remoteAddr string) (frpNet.Conn, error)
// VhostRouteConfig is the params used to match HTTP requests
type VhostRouteConfig struct { type VhostRouteConfig struct {
Domain string Domain string
Location string Location string
@ -67,14 +66,6 @@ type VhostRouteConfig struct {
// listen for a new domain name, if rewriteHost is not empty and rewriteFunc is not nil // listen for a new domain name, if rewriteHost is not empty and rewriteFunc is not nil
// then rewrite the host header to rewriteHost // then rewrite the host header to rewriteHost
func (v *VhostMuxer) Listen(cfg *VhostRouteConfig) (l *Listener, err error) { func (v *VhostMuxer) Listen(cfg *VhostRouteConfig) (l *Listener, err error) {
v.mutex.Lock()
defer v.mutex.Unlock()
_, ok := v.registryRouter.Exist(cfg.Domain, cfg.Location)
if ok {
return nil, fmt.Errorf("hostname [%s] location [%s] is already registered", cfg.Domain, cfg.Location)
}
l = &Listener{ l = &Listener{
name: cfg.Domain, name: cfg.Domain,
location: cfg.Location, location: cfg.Location,
@ -85,14 +76,14 @@ func (v *VhostMuxer) Listen(cfg *VhostRouteConfig) (l *Listener, err error) {
accept: make(chan frpNet.Conn), accept: make(chan frpNet.Conn),
Logger: log.NewPrefixLogger(""), Logger: log.NewPrefixLogger(""),
} }
v.registryRouter.Add(cfg.Domain, cfg.Location, l) err = v.registryRouter.Add(cfg.Domain, cfg.Location, l)
if err != nil {
return
}
return l, nil return l, nil
} }
func (v *VhostMuxer) getListener(name, path string) (l *Listener, exist bool) { func (v *VhostMuxer) getListener(name, path string) (l *Listener, exist bool) {
v.mutex.RLock()
defer v.mutex.RUnlock()
// first we check the full hostname // first we check the full hostname
// if not exist, then check the wildcard_domain such as *.example.com // if not exist, then check the wildcard_domain such as *.example.com
vr, found := v.registryRouter.Get(name, path) vr, found := v.registryRouter.Get(name, path)