2016-01-27 21:24:36 +08:00
|
|
|
package conn
|
|
|
|
|
|
|
|
import (
|
2016-02-03 18:46:24 +08:00
|
|
|
"bufio"
|
2016-01-27 21:24:36 +08:00
|
|
|
"fmt"
|
2016-02-03 18:46:24 +08:00
|
|
|
"io"
|
2016-01-27 21:24:36 +08:00
|
|
|
"net"
|
|
|
|
"sync"
|
|
|
|
|
2016-02-18 16:56:55 +08:00
|
|
|
"github.com/fatedier/frp/utils/log"
|
2016-01-27 21:24:36 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
type Listener struct {
|
2016-02-19 17:01:47 +08:00
|
|
|
addr net.Addr
|
|
|
|
l *net.TCPListener
|
|
|
|
conns chan *Conn
|
|
|
|
closeFlag bool
|
2016-01-27 21:24:36 +08:00
|
|
|
}
|
|
|
|
|
2016-02-19 17:01:47 +08:00
|
|
|
func Listen(bindAddr string, bindPort int64) (l *Listener, err error) {
|
|
|
|
tcpAddr, err := net.ResolveTCPAddr("tcp4", fmt.Sprintf("%s:%d", bindAddr, bindPort))
|
|
|
|
listener, err := net.ListenTCP("tcp", tcpAddr)
|
|
|
|
if err != nil {
|
|
|
|
return l, err
|
|
|
|
}
|
|
|
|
|
|
|
|
l = &Listener{
|
|
|
|
addr: listener.Addr(),
|
|
|
|
l: listener,
|
|
|
|
conns: make(chan *Conn),
|
|
|
|
closeFlag: false,
|
|
|
|
}
|
|
|
|
|
|
|
|
go func() {
|
|
|
|
for {
|
|
|
|
conn, err := l.l.AcceptTCP()
|
|
|
|
if err != nil {
|
|
|
|
if l.closeFlag {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
|
|
|
c := &Conn{
|
|
|
|
TcpConn: conn,
|
|
|
|
closeFlag: false,
|
|
|
|
}
|
|
|
|
c.Reader = bufio.NewReader(c.TcpConn)
|
|
|
|
l.conns <- c
|
|
|
|
}
|
|
|
|
}()
|
|
|
|
return l, err
|
|
|
|
}
|
|
|
|
|
|
|
|
// wait util get one new connection or close
|
|
|
|
// if listener is closed, return nil
|
2016-01-27 21:24:36 +08:00
|
|
|
func (l *Listener) GetConn() (conn *Conn) {
|
2016-02-19 17:01:47 +08:00
|
|
|
var ok bool
|
|
|
|
conn, ok = <-l.conns
|
|
|
|
if !ok {
|
|
|
|
return nil
|
|
|
|
}
|
2016-01-27 21:24:36 +08:00
|
|
|
return conn
|
|
|
|
}
|
|
|
|
|
2016-02-19 17:01:47 +08:00
|
|
|
func (l *Listener) Close() {
|
|
|
|
if l.l != nil && l.closeFlag == false {
|
|
|
|
l.closeFlag = true
|
|
|
|
l.l.Close()
|
|
|
|
close(l.conns)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// wrap for TCPConn
|
2016-01-27 21:24:36 +08:00
|
|
|
type Conn struct {
|
2016-02-19 17:01:47 +08:00
|
|
|
TcpConn *net.TCPConn
|
|
|
|
Reader *bufio.Reader
|
|
|
|
closeFlag bool
|
2016-01-27 21:24:36 +08:00
|
|
|
}
|
|
|
|
|
2016-02-19 17:01:47 +08:00
|
|
|
func ConnectServer(host string, port int64) (c *Conn, err error) {
|
|
|
|
c = &Conn{}
|
2016-01-27 21:24:36 +08:00
|
|
|
servertAddr, err := net.ResolveTCPAddr("tcp4", fmt.Sprintf("%s:%d", host, port))
|
|
|
|
if err != nil {
|
2016-02-19 17:01:47 +08:00
|
|
|
return
|
2016-01-27 21:24:36 +08:00
|
|
|
}
|
|
|
|
conn, err := net.DialTCP("tcp", nil, servertAddr)
|
|
|
|
if err != nil {
|
2016-02-19 17:01:47 +08:00
|
|
|
return
|
2016-01-27 21:24:36 +08:00
|
|
|
}
|
|
|
|
c.TcpConn = conn
|
|
|
|
c.Reader = bufio.NewReader(c.TcpConn)
|
2016-02-19 17:01:47 +08:00
|
|
|
c.closeFlag = false
|
|
|
|
return c, nil
|
2016-01-27 21:24:36 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
func (c *Conn) GetRemoteAddr() (addr string) {
|
|
|
|
return c.TcpConn.RemoteAddr().String()
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c *Conn) GetLocalAddr() (addr string) {
|
|
|
|
return c.TcpConn.LocalAddr().String()
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c *Conn) ReadLine() (buff string, err error) {
|
|
|
|
buff, err = c.Reader.ReadString('\n')
|
2016-02-19 17:01:47 +08:00
|
|
|
if err == io.EOF {
|
|
|
|
c.closeFlag = true
|
|
|
|
}
|
2016-01-27 21:24:36 +08:00
|
|
|
return buff, err
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c *Conn) Write(content string) (err error) {
|
|
|
|
_, err = c.TcpConn.Write([]byte(content))
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c *Conn) Close() {
|
2016-02-18 11:42:31 +08:00
|
|
|
if c.TcpConn != nil {
|
2016-02-19 17:01:47 +08:00
|
|
|
c.closeFlag = true
|
2016-02-05 14:18:26 +08:00
|
|
|
c.TcpConn.Close()
|
|
|
|
}
|
2016-01-27 21:24:36 +08:00
|
|
|
}
|
|
|
|
|
2016-02-19 17:01:47 +08:00
|
|
|
func (c *Conn) IsClosed() bool {
|
|
|
|
return c.closeFlag
|
2016-01-27 21:24:36 +08:00
|
|
|
}
|
|
|
|
|
2016-02-19 17:01:47 +08:00
|
|
|
// will block until connection close
|
2016-01-27 21:24:36 +08:00
|
|
|
func Join(c1 *Conn, c2 *Conn) {
|
|
|
|
var wait sync.WaitGroup
|
|
|
|
pipe := func(to *Conn, from *Conn) {
|
|
|
|
defer to.Close()
|
|
|
|
defer from.Close()
|
|
|
|
defer wait.Done()
|
|
|
|
|
|
|
|
var err error
|
|
|
|
_, err = io.Copy(to.TcpConn, from.TcpConn)
|
|
|
|
if err != nil {
|
|
|
|
log.Warn("join conns error, %v", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
wait.Add(2)
|
|
|
|
go pipe(c1, c2)
|
|
|
|
go pipe(c2, c1)
|
|
|
|
wait.Wait()
|
|
|
|
return
|
|
|
|
}
|