// Copyright 2017 fatedier, fatedier@gmail.com // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package net import ( "fmt" "io" "net" "sync" "time" "github.com/fatedier/golib/pool" ) type UDPPacket struct { Buf []byte LocalAddr net.Addr RemoteAddr net.Addr } type FakeUDPConn struct { l *UDPListener localAddr net.Addr remoteAddr net.Addr packets chan []byte closeFlag bool lastActive time.Time mu sync.RWMutex } func NewFakeUDPConn(l *UDPListener, laddr, raddr net.Addr) *FakeUDPConn { fc := &FakeUDPConn{ l: l, localAddr: laddr, remoteAddr: raddr, packets: make(chan []byte, 20), } go func() { for { time.Sleep(5 * time.Second) fc.mu.RLock() if time.Now().Sub(fc.lastActive) > 10*time.Second { fc.mu.RUnlock() fc.Close() break } fc.mu.RUnlock() } }() return fc } func (c *FakeUDPConn) putPacket(content []byte) { defer func() { if err := recover(); err != nil { } }() select { case c.packets <- content: default: } } func (c *FakeUDPConn) Read(b []byte) (n int, err error) { content, ok := <-c.packets if !ok { return 0, io.EOF } c.mu.Lock() c.lastActive = time.Now() c.mu.Unlock() if len(b) < len(content) { n = len(b) } else { n = len(content) } copy(b, content) return n, nil } func (c *FakeUDPConn) Write(b []byte) (n int, err error) { c.mu.RLock() if c.closeFlag { c.mu.RUnlock() return 0, io.ErrClosedPipe } c.mu.RUnlock() packet := &UDPPacket{ Buf: b, LocalAddr: c.localAddr, RemoteAddr: c.remoteAddr, } c.l.writeUDPPacket(packet) c.mu.Lock() c.lastActive = time.Now() c.mu.Unlock() return len(b), nil } func (c *FakeUDPConn) Close() error { c.mu.Lock() defer c.mu.Unlock() if !c.closeFlag { c.closeFlag = true close(c.packets) } return nil } func (c *FakeUDPConn) IsClosed() bool { c.mu.RLock() defer c.mu.RUnlock() return c.closeFlag } func (c *FakeUDPConn) LocalAddr() net.Addr { return c.localAddr } func (c *FakeUDPConn) RemoteAddr() net.Addr { return c.remoteAddr } func (c *FakeUDPConn) SetDeadline(t time.Time) error { return nil } func (c *FakeUDPConn) SetReadDeadline(t time.Time) error { return nil } func (c *FakeUDPConn) SetWriteDeadline(t time.Time) error { return nil } type UDPListener struct { addr net.Addr acceptCh chan net.Conn writeCh chan *UDPPacket readConn net.Conn closeFlag bool fakeConns map[string]*FakeUDPConn } func ListenUDP(bindAddr string, bindPort int) (l *UDPListener, err error) { udpAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", bindAddr, bindPort)) if err != nil { return l, err } readConn, err := net.ListenUDP("udp", udpAddr) l = &UDPListener{ addr: udpAddr, acceptCh: make(chan net.Conn), writeCh: make(chan *UDPPacket, 1000), fakeConns: make(map[string]*FakeUDPConn), } // for reading go func() { for { buf := pool.GetBuf(1450) n, remoteAddr, err := readConn.ReadFromUDP(buf) if err != nil { close(l.acceptCh) close(l.writeCh) return } fakeConn, exist := l.fakeConns[remoteAddr.String()] if !exist || fakeConn.IsClosed() { fakeConn = NewFakeUDPConn(l, l.Addr(), remoteAddr) l.fakeConns[remoteAddr.String()] = fakeConn } fakeConn.putPacket(buf[:n]) l.acceptCh <- fakeConn } }() // for writing go func() { for { packet, ok := <-l.writeCh if !ok { return } if addr, ok := packet.RemoteAddr.(*net.UDPAddr); ok { readConn.WriteToUDP(packet.Buf, addr) } } }() return } func (l *UDPListener) writeUDPPacket(packet *UDPPacket) (err error) { defer func() { if errRet := recover(); errRet != nil { err = fmt.Errorf("udp write closed listener") } }() l.writeCh <- packet return } func (l *UDPListener) WriteMsg(buf []byte, remoteAddr *net.UDPAddr) (err error) { // only set remote addr here packet := &UDPPacket{ Buf: buf, RemoteAddr: remoteAddr, } err = l.writeUDPPacket(packet) return } func (l *UDPListener) Accept() (net.Conn, error) { conn, ok := <-l.acceptCh if !ok { return conn, fmt.Errorf("channel for udp listener closed") } return conn, nil } func (l *UDPListener) Close() error { if !l.closeFlag { l.closeFlag = true l.readConn.Close() } return nil } func (l *UDPListener) Addr() net.Addr { return l.addr }