// Copyright 2018 fatedier, fatedier@gmail.com // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package visitor import ( "context" "fmt" "net" "reflect" "sync" "time" "github.com/samber/lo" v1 "github.com/fatedier/frp/pkg/config/v1" "github.com/fatedier/frp/pkg/transport" "github.com/fatedier/frp/pkg/util/xlog" ) type Manager struct { clientCfg *v1.ClientCommonConfig cfgs map[string]v1.VisitorConfigurer visitors map[string]Visitor helper Helper checkInterval time.Duration keepVisitorsRunningOnce sync.Once mu sync.RWMutex ctx context.Context stopCh chan struct{} } func NewManager( ctx context.Context, runID string, clientCfg *v1.ClientCommonConfig, connectServer func() (net.Conn, error), msgTransporter transport.MessageTransporter, ) *Manager { m := &Manager{ clientCfg: clientCfg, cfgs: make(map[string]v1.VisitorConfigurer), visitors: make(map[string]Visitor), checkInterval: 10 * time.Second, ctx: ctx, stopCh: make(chan struct{}), } m.helper = &visitorHelperImpl{ connectServerFn: connectServer, msgTransporter: msgTransporter, transferConnFn: m.TransferConn, runID: runID, } return m } // keepVisitorsRunning checks all visitors' status periodically, if some visitor is not running, start it. // It will only start after Reload is called and a new visitor is added. func (vm *Manager) keepVisitorsRunning() { xl := xlog.FromContextSafe(vm.ctx) ticker := time.NewTicker(vm.checkInterval) defer ticker.Stop() for { select { case <-vm.stopCh: xl.Tracef("gracefully shutdown visitor manager") return case <-ticker.C: vm.mu.Lock() for _, cfg := range vm.cfgs { name := cfg.GetBaseConfig().Name if _, exist := vm.visitors[name]; !exist { xl.Infof("try to start visitor [%s]", name) _ = vm.startVisitor(cfg) } } vm.mu.Unlock() } } } func (vm *Manager) Close() { vm.mu.Lock() defer vm.mu.Unlock() for _, v := range vm.visitors { v.Close() } select { case <-vm.stopCh: default: close(vm.stopCh) } } // Hold lock before calling this function. func (vm *Manager) startVisitor(cfg v1.VisitorConfigurer) (err error) { xl := xlog.FromContextSafe(vm.ctx) name := cfg.GetBaseConfig().Name visitor := NewVisitor(vm.ctx, cfg, vm.clientCfg, vm.helper) err = visitor.Run() if err != nil { xl.Warnf("start error: %v", err) } else { vm.visitors[name] = visitor xl.Infof("start visitor success") } return } func (vm *Manager) UpdateAll(cfgs []v1.VisitorConfigurer) { if len(cfgs) > 0 { // Only start keepVisitorsRunning goroutine once and only when there is at least one visitor. vm.keepVisitorsRunningOnce.Do(func() { go vm.keepVisitorsRunning() }) } xl := xlog.FromContextSafe(vm.ctx) cfgsMap := lo.KeyBy(cfgs, func(c v1.VisitorConfigurer) string { return c.GetBaseConfig().Name }) vm.mu.Lock() defer vm.mu.Unlock() delNames := make([]string, 0) for name, oldCfg := range vm.cfgs { del := false cfg, ok := cfgsMap[name] if !ok || !reflect.DeepEqual(oldCfg, cfg) { del = true } if del { delNames = append(delNames, name) delete(vm.cfgs, name) if visitor, ok := vm.visitors[name]; ok { visitor.Close() } delete(vm.visitors, name) } } if len(delNames) > 0 { xl.Infof("visitor removed: %v", delNames) } addNames := make([]string, 0) for _, cfg := range cfgs { name := cfg.GetBaseConfig().Name if _, ok := vm.cfgs[name]; !ok { vm.cfgs[name] = cfg addNames = append(addNames, name) _ = vm.startVisitor(cfg) } } if len(addNames) > 0 { xl.Infof("visitor added: %v", addNames) } } // TransferConn transfers a connection to a visitor. func (vm *Manager) TransferConn(name string, conn net.Conn) error { vm.mu.RLock() defer vm.mu.RUnlock() v, ok := vm.visitors[name] if !ok { return fmt.Errorf("visitor [%s] not found", name) } return v.AcceptConn(conn) } type visitorHelperImpl struct { connectServerFn func() (net.Conn, error) msgTransporter transport.MessageTransporter transferConnFn func(name string, conn net.Conn) error runID string } func (v *visitorHelperImpl) ConnectServer() (net.Conn, error) { return v.connectServerFn() } func (v *visitorHelperImpl) TransferConn(name string, conn net.Conn) error { return v.transferConnFn(name, conn) } func (v *visitorHelperImpl) MsgTransporter() transport.MessageTransporter { return v.msgTransporter } func (v *visitorHelperImpl) RunID() string { return v.runID }