frp/pkg/utils/conn/conn.go
2016-02-05 14:18:26 +08:00

118 lines
2.1 KiB
Go

package conn
import (
"bufio"
"fmt"
"io"
"net"
"sync"
"frp/pkg/utils/log"
)
type Listener struct {
Addr net.Addr
Conns chan *Conn
}
// wait util get one
func (l *Listener) GetConn() (conn *Conn) {
conn = <-l.Conns
return conn
}
type Conn struct {
TcpConn *net.TCPConn
Reader *bufio.Reader
}
func (c *Conn) ConnectServer(host string, port int64) (err error) {
servertAddr, err := net.ResolveTCPAddr("tcp4", fmt.Sprintf("%s:%d", host, port))
if err != nil {
return err
}
conn, err := net.DialTCP("tcp", nil, servertAddr)
if err != nil {
return err
}
c.TcpConn = conn
c.Reader = bufio.NewReader(c.TcpConn)
return nil
}
func (c *Conn) GetRemoteAddr() (addr string) {
return c.TcpConn.RemoteAddr().String()
}
func (c *Conn) GetLocalAddr() (addr string) {
return c.TcpConn.LocalAddr().String()
}
func (c *Conn) ReadLine() (buff string, err error) {
buff, err = c.Reader.ReadString('\n')
return buff, err
}
func (c *Conn) Write(content string) (err error) {
_, err = c.TcpConn.Write([]byte(content))
return err
}
func (c *Conn) Close() {
if c.TcpConn != nil { // ZWF:我觉得应该加一个非空保护
c.TcpConn.Close()
}
}
func Listen(bindAddr string, bindPort int64) (l *Listener, err error) {
tcpAddr, err := net.ResolveTCPAddr("tcp4", fmt.Sprintf("%s:%d", bindAddr, bindPort))
listener, err := net.ListenTCP("tcp", tcpAddr)
if err != nil {
return l, err
}
l = &Listener{
Addr: listener.Addr(),
Conns: make(chan *Conn),
}
go func() {
for {
conn, err := listener.AcceptTCP()
if err != nil {
log.Error("Accept new tcp connection error, %v", err)
continue
}
c := &Conn{
TcpConn: conn,
}
c.Reader = bufio.NewReader(c.TcpConn)
l.Conns <- c
}
}()
return l, err
}
// will block until conn close
func Join(c1 *Conn, c2 *Conn) {
var wait sync.WaitGroup
pipe := func(to *Conn, from *Conn) {
defer to.Close()
defer from.Close()
defer wait.Done()
var err error
_, err = io.Copy(to.TcpConn, from.TcpConn)
if err != nil {
log.Warn("join conns error, %v", err)
}
}
wait.Add(2)
go pipe(c1, c2)
go pipe(c2, c1)
wait.Wait()
return
}