From 3faae194d0a637187539e4da1e6a993b3494b2ce Mon Sep 17 00:00:00 2001 From: fatedier Date: Thu, 30 Mar 2023 21:49:12 +0800 Subject: [PATCH] feat(nathole): use serverUDPPort in nathole discovery when available (#3382) --- cmd/frpc/sub/nathole.go | 13 ++- pkg/nathole/discovery.go | 222 ++++++++++++++++++++++----------------- pkg/nathole/utils.go | 17 +++ 3 files changed, 153 insertions(+), 99 deletions(-) diff --git a/cmd/frpc/sub/nathole.go b/cmd/frpc/sub/nathole.go index db56d76..9e8a21d 100644 --- a/cmd/frpc/sub/nathole.go +++ b/cmd/frpc/sub/nathole.go @@ -53,8 +53,12 @@ var natholeDiscoveryCmd = &cobra.Command{ os.Exit(1) } + serverAddr := "" + if cfg.ServerUDPPort != 0 { + serverAddr = net.JoinHostPort(cfg.ServerAddr, strconv.Itoa(cfg.ServerUDPPort)) + } addresses, err := nathole.Discover( - net.JoinHostPort(cfg.ServerAddr, strconv.Itoa(cfg.ServerUDPPort)), + serverAddr, []string{cfg.NatHoleSTUNServer}, []byte(cfg.Token), ) @@ -62,6 +66,10 @@ var natholeDiscoveryCmd = &cobra.Command{ fmt.Println("discover error:", err) os.Exit(1) } + if len(addresses) < 2 { + fmt.Printf("discover error: can not get enough addresses, need 2, got: %v\n", addresses) + os.Exit(1) + } natType, behavior, err := nathole.ClassifyNATType(addresses) if err != nil { @@ -79,8 +87,5 @@ func validateForNatHoleDiscovery(cfg config.ClientCommonConf) error { if cfg.NatHoleSTUNServer == "" { return fmt.Errorf("nat_hole_stun_server can not be empty") } - if cfg.ServerUDPPort == 0 { - return fmt.Errorf("server udp port can not be empty") - } return nil } diff --git a/pkg/nathole/discovery.go b/pkg/nathole/discovery.go index 761ed10..4c684b3 100644 --- a/pkg/nathole/discovery.go +++ b/pkg/nathole/discovery.go @@ -26,31 +26,12 @@ import ( var responseTimeout = 3 * time.Second -type Address struct { - IP string - Port int -} - type Message struct { Body []byte Addr string } func Discover(serverAddress string, stunServers []string, key []byte) ([]string, error) { - // parse address to net.Address - stunAddresses := make([]net.Addr, 0, len(stunServers)) - for _, stunServer := range stunServers { - addr, err := net.ResolveUDPAddr("udp4", stunServer) - if err != nil { - return nil, err - } - stunAddresses = append(stunAddresses, addr) - } - serverAddr, err := net.ResolveUDPAddr("udp4", serverAddress) - if err != nil { - return nil, err - } - // create a discoverConn and get response from messageChan discoverConn, err := listen() if err != nil { @@ -61,90 +42,29 @@ func Discover(serverAddress string, stunServers []string, key []byte) ([]string, go discoverConn.readLoop() addresses := make([]string, 0, len(stunServers)+1) - // get external address from frp server - externalAddr, err := discoverFromServer(discoverConn, serverAddr, key) - if err != nil { - return nil, err - } - addresses = append(addresses, externalAddr) - - for _, stunAddr := range stunAddresses { - // get external address from stun server - externalAddr, err = discoverFromStunServer(discoverConn, stunAddr) + if serverAddress != "" { + // get external address from frp server + externalAddr, err := discoverConn.discoverFromServer(serverAddress, key) if err != nil { return nil, err } addresses = append(addresses, externalAddr) } + + for _, addr := range stunServers { + // get external address from stun server + externalAddrs, err := discoverConn.discoverFromStunServer(addr) + if err != nil { + return nil, err + } + addresses = append(addresses, externalAddrs...) + } return addresses, nil } -func discoverFromServer(c *discoverConn, addr net.Addr, key []byte) (string, error) { - m := &msg.NatHoleBinding{ - TransactionID: NewTransactionID(), - } - - buf, err := EncodeMessage(m, key) - if err != nil { - return "", err - } - - if _, err := c.conn.WriteTo(buf, addr); err != nil { - return "", err - } - - var respMsg msg.NatHoleBindingResp - select { - case rawMsg := <-c.messageChan: - if err := DecodeMessageInto(rawMsg.Body, key, &respMsg); err != nil { - return "", err - } - case <-time.After(responseTimeout): - return "", fmt.Errorf("wait response from frp server timeout") - } - - if respMsg.TransactionID == "" { - return "", fmt.Errorf("error format: no transaction id found") - } - if respMsg.Error != "" { - return "", fmt.Errorf("get externalAddr from frp server error: %s", respMsg.Error) - } - return respMsg.Address, nil -} - -func discoverFromStunServer(c *discoverConn, addr net.Addr) (string, error) { - request, err := stun.Build(stun.TransactionID, stun.BindingRequest) - if err != nil { - return "", err - } - - if err = request.NewTransactionID(); err != nil { - return "", err - } - if _, err := c.conn.WriteTo(request.Raw, addr); err != nil { - return "", err - } - - var m stun.Message - select { - case msg := <-c.messageChan: - m.Raw = msg.Body - if err := m.Decode(); err != nil { - return "", err - } - case <-time.After(responseTimeout): - return "", fmt.Errorf("wait response from stun server timeout") - } - - xorAddr := &stun.XORMappedAddress{} - mappedAddr := &stun.MappedAddress{} - if err := xorAddr.GetFrom(&m); err == nil { - return xorAddr.String(), nil - } - if err := mappedAddr.GetFrom(&m); err == nil { - return mappedAddr.String(), nil - } - return "", fmt.Errorf("no address found") +type stunResponse struct { + externalAddr string + otherAddr string } type discoverConn struct { @@ -190,3 +110,115 @@ func (c *discoverConn) readLoop() { } } } + +func (c *discoverConn) doSTUNRequest(addr string) (*stunResponse, error) { + serverAddr, err := net.ResolveUDPAddr("udp4", addr) + if err != nil { + return nil, err + } + request, err := stun.Build(stun.TransactionID, stun.BindingRequest) + if err != nil { + return nil, err + } + + if err = request.NewTransactionID(); err != nil { + return nil, err + } + if _, err := c.conn.WriteTo(request.Raw, serverAddr); err != nil { + return nil, err + } + + var m stun.Message + select { + case msg := <-c.messageChan: + m.Raw = msg.Body + if err := m.Decode(); err != nil { + return nil, err + } + case <-time.After(responseTimeout): + return nil, fmt.Errorf("wait response from stun server timeout") + } + xorAddrGetter := &stun.XORMappedAddress{} + mappedAddrGetter := &stun.MappedAddress{} + changedAddrGetter := ChangedAddress{} + otherAddrGetter := &stun.OtherAddress{} + + resp := &stunResponse{} + if err := mappedAddrGetter.GetFrom(&m); err == nil { + resp.externalAddr = mappedAddrGetter.String() + } + if err := xorAddrGetter.GetFrom(&m); err == nil { + resp.externalAddr = xorAddrGetter.String() + } + if err := changedAddrGetter.GetFrom(&m); err == nil { + resp.otherAddr = changedAddrGetter.String() + } + if err := otherAddrGetter.GetFrom(&m); err == nil { + resp.otherAddr = otherAddrGetter.String() + } + return resp, nil +} + +func (c *discoverConn) discoverFromServer(serverAddress string, key []byte) (string, error) { + addr, err := net.ResolveUDPAddr("udp4", serverAddress) + if err != nil { + return "", err + } + m := &msg.NatHoleBinding{ + TransactionID: NewTransactionID(), + } + + buf, err := EncodeMessage(m, key) + if err != nil { + return "", err + } + + if _, err := c.conn.WriteTo(buf, addr); err != nil { + return "", err + } + + var respMsg msg.NatHoleBindingResp + select { + case rawMsg := <-c.messageChan: + if err := DecodeMessageInto(rawMsg.Body, key, &respMsg); err != nil { + return "", err + } + case <-time.After(responseTimeout): + return "", fmt.Errorf("wait response from frp server timeout") + } + + if respMsg.TransactionID == "" { + return "", fmt.Errorf("error format: no transaction id found") + } + if respMsg.Error != "" { + return "", fmt.Errorf("get externalAddr from frp server error: %s", respMsg.Error) + } + return respMsg.Address, nil +} + +func (c *discoverConn) discoverFromStunServer(addr string) ([]string, error) { + resp, err := c.doSTUNRequest(addr) + if err != nil { + return nil, err + } + if resp.externalAddr == "" { + return nil, fmt.Errorf("no external address found") + } + + externalAddrs := make([]string, 0, 2) + externalAddrs = append(externalAddrs, resp.externalAddr) + + if resp.otherAddr == "" { + return externalAddrs, nil + } + + // find external address from changed address + resp, err = c.doSTUNRequest(resp.otherAddr) + if err != nil { + return nil, err + } + if resp.externalAddr != "" { + externalAddrs = append(externalAddrs, resp.externalAddr) + } + return externalAddrs, nil +} diff --git a/pkg/nathole/utils.go b/pkg/nathole/utils.go index 40fad0e..75eda1a 100644 --- a/pkg/nathole/utils.go +++ b/pkg/nathole/utils.go @@ -16,8 +16,11 @@ package nathole import ( "bytes" + "net" + "strconv" "github.com/fatedier/golib/crypto" + "github.com/pion/stun" "github.com/fatedier/frp/pkg/msg" ) @@ -46,3 +49,17 @@ func DecodeMessageInto(data, key []byte, m msg.Message) error { } return nil } + +type ChangedAddress struct { + IP net.IP + Port int +} + +func (s *ChangedAddress) GetFrom(m *stun.Message) error { + a := (*stun.MappedAddress)(s) + return a.GetFromAs(m, stun.AttrChangedAddress) +} + +func (s *ChangedAddress) String() string { + return net.JoinHostPort(s.IP.String(), strconv.Itoa(s.Port)) +}