// Copyright 2017 fatedier, fatedier@gmail.com // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package udp import ( "encoding/base64" "net" "sync" "time" "github.com/fatedier/frp/models/msg" "github.com/fatedier/golib/errors" "github.com/fatedier/golib/pool" ) func NewUdpPacket(buf []byte, laddr, raddr *net.UDPAddr) *msg.UdpPacket { return &msg.UdpPacket{ Content: base64.StdEncoding.EncodeToString(buf), LocalAddr: laddr, RemoteAddr: raddr, } } func GetContent(m *msg.UdpPacket) (buf []byte, err error) { buf, err = base64.StdEncoding.DecodeString(m.Content) return } func ForwardUserConn(udpConn *net.UDPConn, readCh <-chan *msg.UdpPacket, sendCh chan<- *msg.UdpPacket) { // read go func() { for udpMsg := range readCh { buf, err := GetContent(udpMsg) if err != nil { continue } udpConn.WriteToUDP(buf, udpMsg.RemoteAddr) } }() // write buf := pool.GetBuf(1500) defer pool.PutBuf(buf) for { n, remoteAddr, err := udpConn.ReadFromUDP(buf) if err != nil { udpConn.Close() return } // buf[:n] will be encoded to string, so the bytes can be reused udpMsg := NewUdpPacket(buf[:n], nil, remoteAddr) select { case sendCh <- udpMsg: default: } } } func Forwarder(dstAddr *net.UDPAddr, readCh <-chan *msg.UdpPacket, sendCh chan<- msg.Message) { var ( mu sync.RWMutex ) udpConnMap := make(map[string]*net.UDPConn) // read from dstAddr and write to sendCh writerFn := func(raddr *net.UDPAddr, udpConn *net.UDPConn) { addr := raddr.String() defer func() { mu.Lock() delete(udpConnMap, addr) mu.Unlock() udpConn.Close() }() buf := pool.GetBuf(1500) for { udpConn.SetReadDeadline(time.Now().Add(30 * time.Second)) n, _, err := udpConn.ReadFromUDP(buf) if err != nil { return } udpMsg := NewUdpPacket(buf[:n], nil, raddr) if err = errors.PanicToError(func() { select { case sendCh <- udpMsg: default: } }); err != nil { return } } } // read from readCh go func() { for udpMsg := range readCh { buf, err := GetContent(udpMsg) if err != nil { continue } mu.Lock() udpConn, ok := udpConnMap[udpMsg.RemoteAddr.String()] if !ok { udpConn, err = net.DialUDP("udp", nil, dstAddr) if err != nil { mu.Unlock() continue } udpConnMap[udpMsg.RemoteAddr.String()] = udpConn } mu.Unlock() _, err = udpConn.Write(buf) if err != nil { udpConn.Close() } if !ok { go writerFn(udpMsg.RemoteAddr, udpConn) } } }() }