// Copyright 2023 The frp Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package nathole import ( "context" "fmt" "math/rand/v2" "net" "slices" "strconv" "strings" "time" "github.com/fatedier/golib/pool" "golang.org/x/net/ipv4" "k8s.io/apimachinery/pkg/util/sets" "github.com/fatedier/frp/pkg/msg" "github.com/fatedier/frp/pkg/transport" "github.com/fatedier/frp/pkg/util/xlog" ) var ( // mode 0: simple detect mode, usually for both EasyNAT or HardNAT & EasyNAT(Public Network) // a. receiver sends detect message with low TTL // b. sender sends normal detect message to receiver // c. receiver receives detect message and sends back a message to sender // // mode 1: For HardNAT & EasyNAT, send detect messages to multiple guessed ports. // Usually applicable to scenarios where port changes are regular. // Most of the steps are the same as mode 0, but EasyNAT is fixed as the receiver and will send detect messages // with low TTL to multiple guessed ports of the sender. // // mode 2: For HardNAT & EasyNAT, ports changes are not regular. // a. HardNAT machine will listen on multiple ports and send detect messages with low TTL to EasyNAT machine // b. EasyNAT machine will send detect messages to random ports of HardNAT machine. // // mode 3: For HardNAT & HardNAT, both changes in the ports are regular. // Most of the steps are the same as mode 1, but the sender also needs to send detect messages to multiple guessed // ports of the receiver. // // mode 4: For HardNAT & HardNAT, one of the changes in the ports is regular. // Regular port changes are usually on the sender side. // a. Receiver listens on multiple ports and sends detect messages with low TTL to the sender's guessed range ports. // b. Sender sends detect messages to random ports of the receiver. SupportedModes = []int{DetectMode0, DetectMode1, DetectMode2, DetectMode3, DetectMode4} SupportedRoles = []string{DetectRoleSender, DetectRoleReceiver} DetectMode0 = 0 DetectMode1 = 1 DetectMode2 = 2 DetectMode3 = 3 DetectMode4 = 4 DetectRoleSender = "sender" DetectRoleReceiver = "receiver" ) type PrepareResult struct { Addrs []string AssistedAddrs []string ListenConn *net.UDPConn NatType string Behavior string } // PreCheck is used to check if the proxy is ready for penetration. // Call this function before calling Prepare to avoid unnecessary preparation work. func PreCheck( ctx context.Context, transporter transport.MessageTransporter, proxyName string, timeout time.Duration, ) error { timeoutCtx, cancel := context.WithTimeout(ctx, timeout) defer cancel() var natHoleRespMsg *msg.NatHoleResp transactionID := NewTransactionID() m, err := transporter.Do(timeoutCtx, &msg.NatHoleVisitor{ TransactionID: transactionID, ProxyName: proxyName, PreCheck: true, }, transactionID, msg.TypeNameNatHoleResp) if err != nil { return fmt.Errorf("get natHoleRespMsg error: %v", err) } mm, ok := m.(*msg.NatHoleResp) if !ok { return fmt.Errorf("get natHoleRespMsg error: invalid message type") } natHoleRespMsg = mm if natHoleRespMsg.Error != "" { return fmt.Errorf("%s", natHoleRespMsg.Error) } return nil } // Prepare is used to do some preparation work before penetration. func Prepare(stunServers []string) (*PrepareResult, error) { // discover for Nat type addrs, localAddr, err := Discover(stunServers, "") if err != nil { return nil, fmt.Errorf("discover error: %v", err) } if len(addrs) < 2 { return nil, fmt.Errorf("discover error: not enough addresses") } localIPs, _ := ListLocalIPsForNatHole(10) natFeature, err := ClassifyNATFeature(addrs, localIPs) if err != nil { return nil, fmt.Errorf("classify nat feature error: %v", err) } laddr, err := net.ResolveUDPAddr("udp4", localAddr.String()) if err != nil { return nil, fmt.Errorf("resolve local udp addr error: %v", err) } listenConn, err := net.ListenUDP("udp4", laddr) if err != nil { return nil, fmt.Errorf("listen local udp addr error: %v", err) } assistedAddrs := make([]string, 0, len(localIPs)) for _, ip := range localIPs { assistedAddrs = append(assistedAddrs, net.JoinHostPort(ip, strconv.Itoa(laddr.Port))) } return &PrepareResult{ Addrs: addrs, AssistedAddrs: assistedAddrs, ListenConn: listenConn, NatType: natFeature.NatType, Behavior: natFeature.Behavior, }, nil } // ExchangeInfo is used to exchange information between client and visitor. // 1. Send input message to server by msgTransporter. // 2. Server will gather information from client and visitor and analyze it. Then send back a NatHoleResp message to them to tell them how to do next. // 3. Receive NatHoleResp message from server. func ExchangeInfo( ctx context.Context, transporter transport.MessageTransporter, laneKey string, m msg.Message, timeout time.Duration, ) (*msg.NatHoleResp, error) { timeoutCtx, cancel := context.WithTimeout(ctx, timeout) defer cancel() var natHoleRespMsg *msg.NatHoleResp m, err := transporter.Do(timeoutCtx, m, laneKey, msg.TypeNameNatHoleResp) if err != nil { return nil, fmt.Errorf("get natHoleRespMsg error: %v", err) } mm, ok := m.(*msg.NatHoleResp) if !ok { return nil, fmt.Errorf("get natHoleRespMsg error: invalid message type") } natHoleRespMsg = mm if natHoleRespMsg.Error != "" { return nil, fmt.Errorf("natHoleRespMsg get error info: %s", natHoleRespMsg.Error) } if len(natHoleRespMsg.CandidateAddrs) == 0 { return nil, fmt.Errorf("natHoleRespMsg get empty candidate addresses") } return natHoleRespMsg, nil } // MakeHole is used to make a NAT hole between client and visitor. func MakeHole(ctx context.Context, listenConn *net.UDPConn, m *msg.NatHoleResp, key []byte) (*net.UDPConn, *net.UDPAddr, error) { xl := xlog.FromContextSafe(ctx) transactionID := NewTransactionID() sendToRangePortsFunc := func(conn *net.UDPConn, addr string) error { return sendSidMessage(ctx, conn, m.Sid, transactionID, addr, key, m.DetectBehavior.TTL) } listenConns := []*net.UDPConn{listenConn} var detectAddrs []string if m.DetectBehavior.Role == DetectRoleSender { // sender if m.DetectBehavior.SendDelayMs > 0 { time.Sleep(time.Duration(m.DetectBehavior.SendDelayMs) * time.Millisecond) } detectAddrs = m.AssistedAddrs detectAddrs = append(detectAddrs, m.CandidateAddrs...) } else { // receiver if len(m.DetectBehavior.CandidatePorts) == 0 { detectAddrs = m.CandidateAddrs } if m.DetectBehavior.ListenRandomPorts > 0 { for i := 0; i < m.DetectBehavior.ListenRandomPorts; i++ { tmpConn, err := net.ListenUDP("udp4", nil) if err != nil { xl.Warnf("listen random udp addr error: %v", err) continue } listenConns = append(listenConns, tmpConn) } } } detectAddrs = slices.Compact(detectAddrs) for _, detectAddr := range detectAddrs { for _, conn := range listenConns { if err := sendSidMessage(ctx, conn, m.Sid, transactionID, detectAddr, key, m.DetectBehavior.TTL); err != nil { xl.Tracef("send sid message from %s to %s error: %v", conn.LocalAddr(), detectAddr, err) } } } if len(m.DetectBehavior.CandidatePorts) > 0 { for _, conn := range listenConns { sendSidMessageToRangePorts(ctx, conn, m.CandidateAddrs, m.DetectBehavior.CandidatePorts, sendToRangePortsFunc) } } if m.DetectBehavior.SendRandomPorts > 0 { ctx, cancel := context.WithCancel(ctx) defer cancel() for i := range listenConns { go sendSidMessageToRandomPorts(ctx, listenConns[i], m.CandidateAddrs, m.DetectBehavior.SendRandomPorts, sendToRangePortsFunc) } } timeout := 5 * time.Second if m.DetectBehavior.ReadTimeoutMs > 0 { timeout = time.Duration(m.DetectBehavior.ReadTimeoutMs) * time.Millisecond } if len(listenConns) == 1 { raddr, err := waitDetectMessage(ctx, listenConns[0], m.Sid, key, timeout, m.DetectBehavior.Role) if err != nil { return nil, nil, fmt.Errorf("wait detect message error: %v", err) } return listenConns[0], raddr, nil } type result struct { lConn *net.UDPConn raddr *net.UDPAddr } resultCh := make(chan result) for _, conn := range listenConns { go func(lConn *net.UDPConn) { addr, err := waitDetectMessage(ctx, lConn, m.Sid, key, timeout, m.DetectBehavior.Role) if err != nil { lConn.Close() return } select { case resultCh <- result{lConn: lConn, raddr: addr}: default: lConn.Close() } }(conn) } select { case result := <-resultCh: return result.lConn, result.raddr, nil case <-time.After(timeout): return nil, nil, fmt.Errorf("wait detect message timeout") case <-ctx.Done(): return nil, nil, fmt.Errorf("wait detect message canceled") } } func waitDetectMessage( ctx context.Context, conn *net.UDPConn, sid string, key []byte, timeout time.Duration, role string, ) (*net.UDPAddr, error) { xl := xlog.FromContextSafe(ctx) for { buf := pool.GetBuf(1024) _ = conn.SetReadDeadline(time.Now().Add(timeout)) n, raddr, err := conn.ReadFromUDP(buf) _ = conn.SetReadDeadline(time.Time{}) if err != nil { return nil, err } xl.Debugf("get udp message local %s, from %s", conn.LocalAddr(), raddr) var m msg.NatHoleSid if err := DecodeMessageInto(buf[:n], key, &m); err != nil { xl.Warnf("decode sid message error: %v", err) continue } pool.PutBuf(buf) if m.Sid != sid { xl.Warnf("get sid message with wrong sid: %s, expect: %s", m.Sid, sid) continue } if !m.Response { // only wait for response messages if we are a sender if role == DetectRoleSender { continue } m.Response = true buf2, err := EncodeMessage(&m, key) if err != nil { xl.Warnf("encode sid message error: %v", err) continue } _, _ = conn.WriteToUDP(buf2, raddr) } return raddr, nil } } func sendSidMessage( ctx context.Context, conn *net.UDPConn, sid string, transactionID string, addr string, key []byte, ttl int, ) error { xl := xlog.FromContextSafe(ctx) ttlStr := "" if ttl > 0 { ttlStr = fmt.Sprintf(" with ttl %d", ttl) } xl.Tracef("send sid message from %s to %s%s", conn.LocalAddr(), addr, ttlStr) raddr, err := net.ResolveUDPAddr("udp4", addr) if err != nil { return err } if transactionID == "" { transactionID = NewTransactionID() } m := &msg.NatHoleSid{ TransactionID: transactionID, Sid: sid, Response: false, Nonce: strings.Repeat("0", rand.IntN(20)), } buf, err := EncodeMessage(m, key) if err != nil { return err } if ttl > 0 { uConn := ipv4.NewConn(conn) original, err := uConn.TTL() if err != nil { xl.Tracef("get ttl error %v", err) return err } xl.Tracef("original ttl %d", original) err = uConn.SetTTL(ttl) if err != nil { xl.Tracef("set ttl error %v", err) } else { defer func() { _ = uConn.SetTTL(original) }() } } if _, err := conn.WriteToUDP(buf, raddr); err != nil { return err } return nil } func sendSidMessageToRangePorts( ctx context.Context, conn *net.UDPConn, addrs []string, ports []msg.PortsRange, sendFunc func(*net.UDPConn, string) error, ) { xl := xlog.FromContextSafe(ctx) for _, ip := range slices.Compact(parseIPs(addrs)) { for _, portsRange := range ports { for i := portsRange.From; i <= portsRange.To; i++ { detectAddr := net.JoinHostPort(ip, strconv.Itoa(i)) if err := sendFunc(conn, detectAddr); err != nil { xl.Tracef("send sid message from %s to %s error: %v", conn.LocalAddr(), detectAddr, err) } time.Sleep(2 * time.Millisecond) } } } } func sendSidMessageToRandomPorts( ctx context.Context, conn *net.UDPConn, addrs []string, count int, sendFunc func(*net.UDPConn, string) error, ) { xl := xlog.FromContextSafe(ctx) used := sets.New[int]() getUnusedPort := func() int { for i := 0; i < 10; i++ { port := rand.IntN(65535-1024) + 1024 if !used.Has(port) { used.Insert(port) return port } } return 0 } for i := 0; i < count; i++ { select { case <-ctx.Done(): return default: } port := getUnusedPort() if port == 0 { continue } for _, ip := range slices.Compact(parseIPs(addrs)) { detectAddr := net.JoinHostPort(ip, strconv.Itoa(port)) if err := sendFunc(conn, detectAddr); err != nil { xl.Tracef("send sid message from %s to %s error: %v", conn.LocalAddr(), detectAddr, err) } time.Sleep(time.Millisecond * 15) } } } func parseIPs(addrs []string) []string { var ips []string for _, addr := range addrs { if ip, _, err := net.SplitHostPort(addr); err == nil { ips = append(ips, ip) } } return ips }