// Copyright 2023 The frp Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package nathole import ( "context" "crypto/md5" "encoding/hex" "fmt" "net" "strconv" "sync" "time" "github.com/fatedier/golib/errors" "github.com/samber/lo" "golang.org/x/sync/errgroup" "github.com/fatedier/frp/pkg/msg" "github.com/fatedier/frp/pkg/transport" "github.com/fatedier/frp/pkg/util/log" "github.com/fatedier/frp/pkg/util/util" ) // NatHoleTimeout seconds. var NatHoleTimeout int64 = 10 func NewTransactionID() string { id, _ := util.RandID() return fmt.Sprintf("%d%s", time.Now().Unix(), id) } type ClientCfg struct { name string sk string allowUsers []string sidCh chan string } type Session struct { sid string analysisKey string recommandMode int recommandIndex int visitorMsg *msg.NatHoleVisitor visitorTransporter transport.MessageTransporter vResp *msg.NatHoleResp vNatFeature *NatFeature vBehavior RecommandBehavior clientMsg *msg.NatHoleClient clientTransporter transport.MessageTransporter cResp *msg.NatHoleResp cNatFeature *NatFeature cBehavior RecommandBehavior notifyCh chan struct{} } func (s *Session) genAnalysisKey() { hash := md5.New() vIPs := lo.Uniq(parseIPs(s.visitorMsg.MappedAddrs)) if len(vIPs) > 0 { hash.Write([]byte(vIPs[0])) } hash.Write([]byte(s.vNatFeature.NatType)) hash.Write([]byte(s.vNatFeature.Behavior)) hash.Write([]byte(strconv.FormatBool(s.vNatFeature.RegularPortsChange))) cIPs := lo.Uniq(parseIPs(s.clientMsg.MappedAddrs)) if len(cIPs) > 0 { hash.Write([]byte(cIPs[0])) } hash.Write([]byte(s.cNatFeature.NatType)) hash.Write([]byte(s.cNatFeature.Behavior)) hash.Write([]byte(strconv.FormatBool(s.cNatFeature.RegularPortsChange))) s.analysisKey = hex.EncodeToString(hash.Sum(nil)) } type Controller struct { clientCfgs map[string]*ClientCfg sessions map[string]*Session analyzer *Analyzer mu sync.RWMutex } func NewController(analysisDataReserveDuration time.Duration) (*Controller, error) { return &Controller{ clientCfgs: make(map[string]*ClientCfg), sessions: make(map[string]*Session), analyzer: NewAnalyzer(analysisDataReserveDuration), }, nil } func (c *Controller) CleanWorker(ctx context.Context) { ticker := time.NewTicker(time.Hour) defer ticker.Stop() for { select { case <-ticker.C: start := time.Now() count, total := c.analyzer.Clean() log.Trace("clean %d/%d nathole analysis data, cost %v", count, total, time.Since(start)) case <-ctx.Done(): return } } } func (c *Controller) ListenClient(name string, sk string, allowUsers []string) chan string { cfg := &ClientCfg{ name: name, sk: sk, allowUsers: allowUsers, sidCh: make(chan string), } c.mu.Lock() defer c.mu.Unlock() // TODO(fatedier): return error if name already exists c.clientCfgs[name] = cfg return cfg.sidCh } func (c *Controller) CloseClient(name string) { c.mu.Lock() defer c.mu.Unlock() delete(c.clientCfgs, name) } func (c *Controller) GenSid() string { t := time.Now().Unix() id, _ := util.RandID() return fmt.Sprintf("%d%s", t, id) } func (c *Controller) HandleVisitor(m *msg.NatHoleVisitor, transporter transport.MessageTransporter, visitorUser string) { if m.PreCheck { cfg, ok := c.clientCfgs[m.ProxyName] if !ok { _ = transporter.Send(c.GenNatHoleResponse(m.TransactionID, nil, fmt.Sprintf("xtcp server for [%s] doesn't exist", m.ProxyName))) return } if !lo.Contains(cfg.allowUsers, visitorUser) && !lo.Contains(cfg.allowUsers, "*") { _ = transporter.Send(c.GenNatHoleResponse(m.TransactionID, nil, fmt.Sprintf("xtcp visitor user [%s] not allowed for [%s]", visitorUser, m.ProxyName))) return } _ = transporter.Send(c.GenNatHoleResponse(m.TransactionID, nil, "")) return } sid := c.GenSid() session := &Session{ sid: sid, visitorMsg: m, visitorTransporter: transporter, notifyCh: make(chan struct{}, 1), } var ( clientCfg *ClientCfg ok bool ) err := func() error { c.mu.Lock() defer c.mu.Unlock() clientCfg, ok = c.clientCfgs[m.ProxyName] if !ok { return fmt.Errorf("xtcp server for [%s] doesn't exist", m.ProxyName) } if !util.ConstantTimeEqString(m.SignKey, util.GetAuthKey(clientCfg.sk, m.Timestamp)) { return fmt.Errorf("xtcp connection of [%s] auth failed", m.ProxyName) } c.sessions[sid] = session return nil }() if err != nil { log.Warn("handle visitorMsg error: %v", err) _ = transporter.Send(c.GenNatHoleResponse(m.TransactionID, nil, err.Error())) return } log.Trace("handle visitor message, sid [%s]", sid) defer func() { c.mu.Lock() defer c.mu.Unlock() delete(c.sessions, sid) }() if err := errors.PanicToError(func() { clientCfg.sidCh <- sid }); err != nil { return } // wait for NatHoleClient message select { case <-session.notifyCh: case <-time.After(time.Duration(NatHoleTimeout) * time.Second): log.Debug("wait for NatHoleClient message timeout, sid [%s]", sid) return } // Make hole-punching decisions based on the NAT information of the client and visitor. vResp, cResp, err := c.analysis(session) if err != nil { log.Debug("sid [%s] analysis error: %v", err) vResp = c.GenNatHoleResponse(session.visitorMsg.TransactionID, nil, err.Error()) cResp = c.GenNatHoleResponse(session.clientMsg.TransactionID, nil, err.Error()) } session.cResp = cResp session.vResp = vResp // send response to visitor and client var g errgroup.Group g.Go(func() error { // if it's sender, wait for a while to make sure the client has send the detect messages if vResp.DetectBehavior.Role == "sender" { time.Sleep(1 * time.Second) } _ = session.visitorTransporter.Send(vResp) return nil }) g.Go(func() error { // if it's sender, wait for a while to make sure the client has send the detect messages if cResp.DetectBehavior.Role == "sender" { time.Sleep(1 * time.Second) } _ = session.clientTransporter.Send(cResp) return nil }) _ = g.Wait() time.Sleep(time.Duration(cResp.DetectBehavior.ReadTimeoutMs+30000) * time.Millisecond) } func (c *Controller) HandleClient(m *msg.NatHoleClient, transporter transport.MessageTransporter) { c.mu.RLock() session, ok := c.sessions[m.Sid] c.mu.RUnlock() if !ok { return } log.Trace("handle client message, sid [%s]", session.sid) session.clientMsg = m session.clientTransporter = transporter select { case session.notifyCh <- struct{}{}: default: } } func (c *Controller) HandleReport(m *msg.NatHoleReport) { c.mu.RLock() session, ok := c.sessions[m.Sid] c.mu.RUnlock() if !ok { log.Trace("sid [%s] report make hole success: %v, but session not found", m.Sid, m.Success) return } if m.Success { c.analyzer.ReportSuccess(session.analysisKey, session.recommandMode, session.recommandIndex) } log.Info("sid [%s] report make hole success: %v, mode %v, index %v", m.Sid, m.Success, session.recommandMode, session.recommandIndex) } func (c *Controller) GenNatHoleResponse(transactionID string, session *Session, errInfo string) *msg.NatHoleResp { var sid string if session != nil { sid = session.sid } return &msg.NatHoleResp{ TransactionID: transactionID, Sid: sid, Error: errInfo, } } // analysis analyzes the NAT type and behavior of the visitor and client, then makes hole-punching decisions. // return the response to the visitor and client. func (c *Controller) analysis(session *Session) (*msg.NatHoleResp, *msg.NatHoleResp, error) { cm := session.clientMsg vm := session.visitorMsg cNatFeature, err := ClassifyNATFeature(cm.MappedAddrs, parseIPs(cm.AssistedAddrs)) if err != nil { return nil, nil, fmt.Errorf("classify client nat feature error: %v", err) } vNatFeature, err := ClassifyNATFeature(vm.MappedAddrs, parseIPs(vm.AssistedAddrs)) if err != nil { return nil, nil, fmt.Errorf("classify visitor nat feature error: %v", err) } session.cNatFeature = cNatFeature session.vNatFeature = vNatFeature session.genAnalysisKey() mode, index, cBehavior, vBehavior := c.analyzer.GetRecommandBehaviors(session.analysisKey, cNatFeature, vNatFeature) session.recommandMode = mode session.recommandIndex = index session.cBehavior = cBehavior session.vBehavior = vBehavior timeoutMs := lo.Max([]int{cBehavior.SendDelayMs, vBehavior.SendDelayMs}) + 5000 if cBehavior.ListenRandomPorts > 0 || vBehavior.ListenRandomPorts > 0 { timeoutMs += 30000 } protocol := vm.Protocol vResp := &msg.NatHoleResp{ TransactionID: vm.TransactionID, Sid: session.sid, Protocol: protocol, CandidateAddrs: lo.Uniq(cm.MappedAddrs), AssistedAddrs: lo.Uniq(cm.AssistedAddrs), DetectBehavior: msg.NatHoleDetectBehavior{ Mode: mode, Role: vBehavior.Role, TTL: vBehavior.TTL, SendDelayMs: vBehavior.SendDelayMs, ReadTimeoutMs: timeoutMs - vBehavior.SendDelayMs, SendRandomPorts: vBehavior.PortsRandomNumber, ListenRandomPorts: vBehavior.ListenRandomPorts, CandidatePorts: getRangePorts(cm.MappedAddrs, cNatFeature.PortsDifference, vBehavior.PortsRangeNumber), }, } cResp := &msg.NatHoleResp{ TransactionID: cm.TransactionID, Sid: session.sid, Protocol: protocol, CandidateAddrs: lo.Uniq(vm.MappedAddrs), AssistedAddrs: lo.Uniq(vm.AssistedAddrs), DetectBehavior: msg.NatHoleDetectBehavior{ Mode: mode, Role: cBehavior.Role, TTL: cBehavior.TTL, SendDelayMs: cBehavior.SendDelayMs, ReadTimeoutMs: timeoutMs - cBehavior.SendDelayMs, SendRandomPorts: cBehavior.PortsRandomNumber, ListenRandomPorts: cBehavior.ListenRandomPorts, CandidatePorts: getRangePorts(vm.MappedAddrs, vNatFeature.PortsDifference, cBehavior.PortsRangeNumber), }, } log.Debug("sid [%s] visitor nat: %+v, candidateAddrs: %v; client nat: %+v, candidateAddrs: %v, protocol: %s", session.sid, *vNatFeature, vm.MappedAddrs, *cNatFeature, cm.MappedAddrs, protocol) log.Debug("sid [%s] visitor detect behavior: %+v", session.sid, vResp.DetectBehavior) log.Debug("sid [%s] client detect behavior: %+v", session.sid, cResp.DetectBehavior) return vResp, cResp, nil } func getRangePorts(addrs []string, difference, maxNumber int) []msg.PortsRange { if maxNumber <= 0 { return nil } addr, err := lo.Last(addrs) if err != nil { return nil } var ports []msg.PortsRange _, portStr, err := net.SplitHostPort(addr) if err != nil { return nil } port, err := strconv.Atoi(portStr) if err != nil { return nil } ports = append(ports, msg.PortsRange{ From: lo.Max([]int{port - difference - 5, port - maxNumber, 1}), To: lo.Min([]int{port + difference + 5, port + maxNumber, 65535}), }) return ports }