diff --git a/src/frp/models/msg/process.go b/src/frp/models/msg/process.go index 4c7783b..43e34ab 100644 --- a/src/frp/models/msg/process.go +++ b/src/frp/models/msg/process.go @@ -15,12 +15,10 @@ package msg import ( - "bufio" "bytes" "encoding/binary" "fmt" "io" - "net" "sync" "frp/models/config" @@ -61,7 +59,7 @@ func JoinMore(c1 *conn.Conn, c2 *conn.Conn, conf config.BaseConf, needRecord boo defer wait.Done() // we don't care about errors here - pipeEncrypt(from.TcpConn, to.TcpConn, conf, needRecord) + pipeEncrypt(from, to, conf, needRecord) } decryptPipe := func(to *conn.Conn, from *conn.Conn) { @@ -70,7 +68,7 @@ func JoinMore(c1 *conn.Conn, c2 *conn.Conn, conf config.BaseConf, needRecord boo defer wait.Done() // we don't care about errors here - pipeDecrypt(to.TcpConn, from.TcpConn, conf, needRecord) + pipeDecrypt(to, from, conf, needRecord) } wait.Add(2) @@ -106,7 +104,7 @@ func unpkgMsg(data []byte) (int, []byte, []byte) { } // decrypt msg from reader, then write into writer -func pipeDecrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool) (err error) { +func pipeDecrypt(r *conn.Conn, w *conn.Conn, conf config.BaseConf, needRecord bool) (err error) { laes := new(pcrypto.Pcrypto) key := conf.AuthToken if conf.PrivilegeMode { @@ -119,7 +117,7 @@ func pipeDecrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool) buf := make([]byte, 5*1024+4) var left, res []byte - var cnt int + var cnt int = -1 // record var flowBytes int64 = 0 @@ -129,13 +127,12 @@ func pipeDecrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool) }() } - nreader := bufio.NewReader(r) for { // there may be more than 1 package in variable // and we read more bytes if unpkgMsg returns an error var newBuf []byte if cnt < 0 { - n, err := nreader.Read(buf) + n, err := r.Read(buf) if err != nil { return err } @@ -165,7 +162,7 @@ func pipeDecrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool) } } - _, err = w.Write(res) + _, err = w.WriteBytes(res) if err != nil { return err } @@ -182,7 +179,7 @@ func pipeDecrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool) } // recvive msg from reader, then encrypt msg into writer -func pipeEncrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool) (err error) { +func pipeEncrypt(r *conn.Conn, w *conn.Conn, conf config.BaseConf, needRecord bool) (err error) { laes := new(pcrypto.Pcrypto) key := conf.AuthToken if conf.PrivilegeMode { @@ -201,10 +198,9 @@ func pipeEncrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool) }() } - nreader := bufio.NewReader(r) buf := make([]byte, 5*1024) for { - n, err := nreader.Read(buf) + n, err := r.Read(buf) if err != nil { return err } @@ -235,7 +231,7 @@ func pipeEncrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool) } res = pkgMsg(res) - _, err = w.Write(res) + _, err = w.WriteBytes(res) if err != nil { return err } diff --git a/src/frp/models/server/server.go b/src/frp/models/server/server.go index 139c989..e69a279 100644 --- a/src/frp/models/server/server.go +++ b/src/frp/models/server/server.go @@ -154,13 +154,12 @@ func (p *ProxyServer) Start(c *conn.Conn) (err error) { } // start another goroutine for join two conns from frpc and user - go func() { + go func(userConn *conn.Conn) { workConn, err := p.getWorkConn() if err != nil { return } - userConn := c // msg will transfer to another without modifying // l means local, r means remote log.Debug("Join two connections, (l[%s] r[%s]) (l[%s] r[%s])", workConn.GetLocalAddr(), workConn.GetRemoteAddr(), @@ -169,7 +168,7 @@ func (p *ProxyServer) Start(c *conn.Conn) (err error) { needRecord := true go msg.JoinMore(userConn, workConn, p.BaseConf, needRecord) metric.OpenConnection(p.Name) - }() + }(c) } }(listener) } diff --git a/src/frp/utils/conn/conn.go b/src/frp/utils/conn/conn.go index ed330f6..a398110 100644 --- a/src/frp/utils/conn/conn.go +++ b/src/frp/utils/conn/conn.go @@ -125,6 +125,11 @@ func (c *Conn) GetLocalAddr() (addr string) { return c.TcpConn.LocalAddr().String() } +func (c *Conn) Read(p []byte) (n int, err error) { + n, err = c.Reader.Read(p) + return +} + func (c *Conn) ReadLine() (buff string, err error) { buff, err = c.Reader.ReadString('\n') if err != nil { @@ -138,10 +143,14 @@ func (c *Conn) ReadLine() (buff string, err error) { return buff, err } +func (c *Conn) WriteBytes(content []byte) (n int, err error) { + n, err = c.TcpConn.Write(content) + return +} + func (c *Conn) Write(content string) (err error) { _, err = c.TcpConn.Write([]byte(content)) return err - } func (c *Conn) SetDeadline(t time.Time) error {