diff --git a/models/config/proxy.go b/models/config/proxy.go index 53a2a45..31b26e7 100644 --- a/models/config/proxy.go +++ b/models/config/proxy.go @@ -100,8 +100,10 @@ type BaseProxyConf struct { ProxyName string `json:"proxy_name"` ProxyType string `json:"proxy_type"` - UseEncryption bool `json:"use_encryption"` - UseCompression bool `json:"use_compression"` + UseEncryption bool `json:"use_encryption"` + UseCompression bool `json:"use_compression"` + Group string `json:"group"` + GroupKey string `json:"group_key"` } func (cfg *BaseProxyConf) GetBaseInfo() *BaseProxyConf { @@ -112,7 +114,9 @@ func (cfg *BaseProxyConf) compare(cmp *BaseProxyConf) bool { if cfg.ProxyName != cmp.ProxyName || cfg.ProxyType != cmp.ProxyType || cfg.UseEncryption != cmp.UseEncryption || - cfg.UseCompression != cmp.UseCompression { + cfg.UseCompression != cmp.UseCompression || + cfg.Group != cmp.Group || + cfg.GroupKey != cmp.GroupKey { return false } return true @@ -123,6 +127,8 @@ func (cfg *BaseProxyConf) UnmarshalFromMsg(pMsg *msg.NewProxy) { cfg.ProxyType = pMsg.ProxyType cfg.UseEncryption = pMsg.UseEncryption cfg.UseCompression = pMsg.UseCompression + cfg.Group = pMsg.Group + cfg.GroupKey = pMsg.GroupKey } func (cfg *BaseProxyConf) UnmarshalFromIni(prefix string, name string, section ini.Section) error { @@ -142,6 +148,9 @@ func (cfg *BaseProxyConf) UnmarshalFromIni(prefix string, name string, section i if ok && tmpStr == "true" { cfg.UseCompression = true } + + cfg.Group = section["group"] + cfg.GroupKey = section["group_key"] return nil } @@ -150,6 +159,8 @@ func (cfg *BaseProxyConf) MarshalToMsg(pMsg *msg.NewProxy) { pMsg.ProxyType = cfg.ProxyType pMsg.UseEncryption = cfg.UseEncryption pMsg.UseCompression = cfg.UseCompression + pMsg.Group = cfg.Group + pMsg.GroupKey = cfg.GroupKey } // Bind info diff --git a/models/msg/msg.go b/models/msg/msg.go index 9669c6b..e06fa37 100644 --- a/models/msg/msg.go +++ b/models/msg/msg.go @@ -86,6 +86,8 @@ type NewProxy struct { ProxyType string `json:"proxy_type"` UseEncryption bool `json:"use_encryption"` UseCompression bool `json:"use_compression"` + Group string `json:"group"` + GroupKey string `json:"group_key"` // tcp and udp only RemotePort int `json:"remote_port"` diff --git a/server/group.go b/server/group.go new file mode 100644 index 0000000..24b292c --- /dev/null +++ b/server/group.go @@ -0,0 +1,205 @@ +// Copyright 2018 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "errors" + "fmt" + "net" + "sync" + + gerr "github.com/fatedier/golib/errors" +) + +var ( + ErrGroupAuthFailed = errors.New("group auth failed") + ErrGroupParamsInvalid = errors.New("group params invalid") + ErrListenerClosed = errors.New("group listener closed") +) + +type TcpGroupListener struct { + groupName string + group *TcpGroup + + addr net.Addr + closeCh chan struct{} +} + +func newTcpGroupListener(name string, group *TcpGroup, addr net.Addr) *TcpGroupListener { + return &TcpGroupListener{ + groupName: name, + group: group, + addr: addr, + closeCh: make(chan struct{}), + } +} + +func (ln *TcpGroupListener) Accept() (c net.Conn, err error) { + var ok bool + select { + case <-ln.closeCh: + return nil, ErrListenerClosed + case c, ok = <-ln.group.Accept(): + if !ok { + return nil, ErrListenerClosed + } + return c, nil + } +} + +func (ln *TcpGroupListener) Addr() net.Addr { + return ln.addr +} + +func (ln *TcpGroupListener) Close() (err error) { + close(ln.closeCh) + ln.group.CloseListener(ln) + return +} + +type TcpGroup struct { + group string + groupKey string + addr string + port int + realPort int + + acceptCh chan net.Conn + index uint64 + tcpLn net.Listener + lns []*TcpGroupListener + ctl *TcpGroupCtl + mu sync.Mutex +} + +func NewTcpGroup(ctl *TcpGroupCtl) *TcpGroup { + return &TcpGroup{ + lns: make([]*TcpGroupListener, 0), + ctl: ctl, + acceptCh: make(chan net.Conn), + } +} + +func (tg *TcpGroup) Listen(proxyName string, group string, groupKey string, addr string, port int) (ln *TcpGroupListener, realPort int, err error) { + tg.mu.Lock() + defer tg.mu.Unlock() + if len(tg.lns) == 0 { + realPort, err = tg.ctl.portManager.Acquire(proxyName, port) + if err != nil { + return + } + tcpLn, errRet := net.Listen("tcp", fmt.Sprintf("%s:%d", addr, port)) + if errRet != nil { + err = errRet + return + } + ln = newTcpGroupListener(group, tg, tcpLn.Addr()) + + tg.group = group + tg.groupKey = groupKey + tg.addr = addr + tg.port = port + tg.realPort = realPort + tg.tcpLn = tcpLn + tg.lns = append(tg.lns, ln) + if tg.acceptCh == nil { + tg.acceptCh = make(chan net.Conn) + } + go tg.worker() + } else { + if tg.group != group || tg.addr != addr || tg.port != port { + err = ErrGroupParamsInvalid + return + } + if tg.groupKey != groupKey { + err = ErrGroupAuthFailed + return + } + ln = newTcpGroupListener(group, tg, tg.lns[0].Addr()) + realPort = tg.realPort + tg.lns = append(tg.lns, ln) + } + return +} + +func (tg *TcpGroup) worker() { + for { + c, err := tg.tcpLn.Accept() + if err != nil { + return + } + err = gerr.PanicToError(func() { + tg.acceptCh <- c + }) + if err != nil { + return + } + } +} + +func (tg *TcpGroup) Accept() <-chan net.Conn { + return tg.acceptCh +} + +func (tg *TcpGroup) CloseListener(ln *TcpGroupListener) { + tg.mu.Lock() + defer tg.mu.Unlock() + for i, tmpLn := range tg.lns { + if tmpLn == ln { + tg.lns = append(tg.lns[:i], tg.lns[i+1:]...) + break + } + } + if len(tg.lns) == 0 { + close(tg.acceptCh) + tg.tcpLn.Close() + tg.ctl.portManager.Release(tg.realPort) + tg.ctl.RemoveGroup(tg.group) + } +} + +type TcpGroupCtl struct { + groups map[string]*TcpGroup + + portManager *PortManager + mu sync.Mutex +} + +func NewTcpGroupCtl(portManager *PortManager) *TcpGroupCtl { + return &TcpGroupCtl{ + groups: make(map[string]*TcpGroup), + portManager: portManager, + } +} + +func (tgc *TcpGroupCtl) Listen(proxyNanme string, group string, groupKey string, + addr string, port int) (l net.Listener, realPort int, err error) { + + tgc.mu.Lock() + defer tgc.mu.Unlock() + if tcpGroup, ok := tgc.groups[group]; ok { + return tcpGroup.Listen(proxyNanme, group, groupKey, addr, port) + } else { + tcpGroup = NewTcpGroup(tgc) + tgc.groups[group] = tcpGroup + return tcpGroup.Listen(proxyNanme, group, groupKey, addr, port) + } +} + +func (tgc *TcpGroupCtl) RemoveGroup(group string) { + tgc.mu.Lock() + defer tgc.mu.Unlock() + delete(tgc.groups, group) +} diff --git a/server/proxy.go b/server/proxy.go index 4bf6f1f..a443626 100644 --- a/server/proxy.go +++ b/server/proxy.go @@ -181,27 +181,44 @@ type TcpProxy struct { } func (pxy *TcpProxy) Run() (remoteAddr string, err error) { - pxy.realPort, err = pxy.ctl.svr.tcpPortManager.Acquire(pxy.name, pxy.cfg.RemotePort) - if err != nil { - return - } - defer func() { - if err != nil { - pxy.ctl.svr.tcpPortManager.Release(pxy.realPort) + if pxy.cfg.Group != "" { + l, realPort, errRet := pxy.ctl.svr.tcpGroupCtl.Listen(pxy.name, pxy.cfg.Group, pxy.cfg.GroupKey, g.GlbServerCfg.ProxyBindAddr, pxy.cfg.RemotePort) + if errRet != nil { + err = errRet + return } - }() - - remoteAddr = fmt.Sprintf(":%d", pxy.realPort) - pxy.cfg.RemotePort = pxy.realPort - listener, errRet := frpNet.ListenTcp(g.GlbServerCfg.ProxyBindAddr, pxy.realPort) - if errRet != nil { - err = errRet - return + defer func() { + if err != nil { + l.Close() + } + }() + pxy.realPort = realPort + listener := frpNet.WrapLogListener(l) + listener.AddLogPrefix(pxy.name) + pxy.listeners = append(pxy.listeners, listener) + pxy.Info("tcp proxy listen port [%d] in group [%s]", pxy.cfg.RemotePort, pxy.cfg.Group) + } else { + pxy.realPort, err = pxy.ctl.svr.tcpPortManager.Acquire(pxy.name, pxy.cfg.RemotePort) + if err != nil { + return + } + defer func() { + if err != nil { + pxy.ctl.svr.tcpPortManager.Release(pxy.realPort) + } + }() + listener, errRet := frpNet.ListenTcp(g.GlbServerCfg.ProxyBindAddr, pxy.realPort) + if errRet != nil { + err = errRet + return + } + listener.AddLogPrefix(pxy.name) + pxy.listeners = append(pxy.listeners, listener) + pxy.Info("tcp proxy listen port [%d]", pxy.cfg.RemotePort) } - listener.AddLogPrefix(pxy.name) - pxy.listeners = append(pxy.listeners, listener) - pxy.Info("tcp proxy listen port [%d]", pxy.cfg.RemotePort) + pxy.cfg.RemotePort = pxy.realPort + remoteAddr = fmt.Sprintf(":%d", pxy.realPort) pxy.startListenHandler(pxy, HandleUserTcpConnection) return } @@ -212,7 +229,9 @@ func (pxy *TcpProxy) GetConf() config.ProxyConf { func (pxy *TcpProxy) Close() { pxy.BaseProxy.Close() - pxy.ctl.svr.tcpPortManager.Release(pxy.realPort) + if pxy.cfg.Group == "" { + pxy.ctl.svr.tcpPortManager.Release(pxy.realPort) + } } type HttpProxy struct { diff --git a/server/service.go b/server/service.go index 3cc1c5a..8038fdd 100644 --- a/server/service.go +++ b/server/service.go @@ -40,38 +40,41 @@ const ( var ServerService *Service -// Server service. +// Server service type Service struct { - // Dispatch connections to different handlers listen on same port. + // Dispatch connections to different handlers listen on same port muxer *mux.Mux - // Accept connections from client. + // Accept connections from client listener frpNet.Listener - // Accept connections using kcp. + // Accept connections using kcp kcpListener frpNet.Listener - // For https proxies, route requests to different clients by hostname and other infomation. + // For https proxies, route requests to different clients by hostname and other infomation VhostHttpsMuxer *vhost.HttpsMuxer httpReverseProxy *vhost.HttpReverseProxy - // Manage all controllers. + // Manage all controllers ctlManager *ControlManager - // Manage all proxies. + // Manage all proxies pxyManager *ProxyManager - // Manage all visitor listeners. + // Manage all visitor listeners visitorManager *VisitorManager - // Manage all tcp ports. + // Manage all tcp ports tcpPortManager *PortManager - // Manage all udp ports. + // Manage all udp ports udpPortManager *PortManager - // Controller for nat hole connections. + // Tcp Group Controller + tcpGroupCtl *TcpGroupCtl + + // Controller for nat hole connections natHoleController *NatHoleController } @@ -84,6 +87,7 @@ func NewService() (svr *Service, err error) { tcpPortManager: NewPortManager("tcp", cfg.ProxyBindAddr, cfg.AllowPorts), udpPortManager: NewPortManager("udp", cfg.ProxyBindAddr, cfg.AllowPorts), } + svr.tcpGroupCtl = NewTcpGroupCtl(svr.tcpPortManager) // Init assets. err = assets.Load(cfg.AssetsDir) diff --git a/utils/vhost/router.go b/utils/vhost/router.go index 37a34fb..ea5c347 100644 --- a/utils/vhost/router.go +++ b/utils/vhost/router.go @@ -52,16 +52,13 @@ func (r *VhostRouters) Del(domain, location string) { if !found { return } - - for i, vr := range vrs { - if vr.location == location { - if len(vrs) > i+1 { - r.RouterByDomain[domain] = append(vrs[:i], vrs[i+1:]...) - } else { - r.RouterByDomain[domain] = vrs[:i] - } + newVrs := make([]*VhostRouter, 0) + for _, vr := range vrs { + if vr.location != location { + newVrs = append(newVrs, vr) } } + r.RouterByDomain[domain] = newVrs } func (r *VhostRouters) Get(host, path string) (vr *VhostRouter, exist bool) {