Merge pull request #50 from fatedier/fatedier/fix_package_loss

frp/models/msg: fix a bug if local service write to socket immediatel…
This commit is contained in:
fatedier 2016-07-20 16:37:45 +08:00 committed by GitHub
commit 4067591a4d
4 changed files with 27 additions and 18 deletions

View File

@ -15,12 +15,10 @@
package msg package msg
import ( import (
"bufio"
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"io" "io"
"net"
"sync" "sync"
"frp/models/config" "frp/models/config"
@ -61,7 +59,7 @@ func JoinMore(c1 *conn.Conn, c2 *conn.Conn, conf config.BaseConf, needRecord boo
defer wait.Done() defer wait.Done()
// we don't care about errors here // we don't care about errors here
pipeEncrypt(from.TcpConn, to.TcpConn, conf, needRecord) pipeEncrypt(from, to, conf, needRecord)
} }
decryptPipe := func(to *conn.Conn, from *conn.Conn) { decryptPipe := func(to *conn.Conn, from *conn.Conn) {
@ -70,7 +68,7 @@ func JoinMore(c1 *conn.Conn, c2 *conn.Conn, conf config.BaseConf, needRecord boo
defer wait.Done() defer wait.Done()
// we don't care about errors here // we don't care about errors here
pipeDecrypt(to.TcpConn, from.TcpConn, conf, needRecord) pipeDecrypt(to, from, conf, needRecord)
} }
wait.Add(2) wait.Add(2)
@ -106,7 +104,7 @@ func unpkgMsg(data []byte) (int, []byte, []byte) {
} }
// decrypt msg from reader, then write into writer // decrypt msg from reader, then write into writer
func pipeDecrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool) (err error) { func pipeDecrypt(r *conn.Conn, w *conn.Conn, conf config.BaseConf, needRecord bool) (err error) {
laes := new(pcrypto.Pcrypto) laes := new(pcrypto.Pcrypto)
key := conf.AuthToken key := conf.AuthToken
if conf.PrivilegeMode { if conf.PrivilegeMode {
@ -119,7 +117,7 @@ func pipeDecrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool)
buf := make([]byte, 5*1024+4) buf := make([]byte, 5*1024+4)
var left, res []byte var left, res []byte
var cnt int var cnt int = -1
// record // record
var flowBytes int64 = 0 var flowBytes int64 = 0
@ -129,13 +127,12 @@ func pipeDecrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool)
}() }()
} }
nreader := bufio.NewReader(r)
for { for {
// there may be more than 1 package in variable // there may be more than 1 package in variable
// and we read more bytes if unpkgMsg returns an error // and we read more bytes if unpkgMsg returns an error
var newBuf []byte var newBuf []byte
if cnt < 0 { if cnt < 0 {
n, err := nreader.Read(buf) n, err := r.Read(buf)
if err != nil { if err != nil {
return err return err
} }
@ -165,7 +162,7 @@ func pipeDecrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool)
} }
} }
_, err = w.Write(res) _, err = w.WriteBytes(res)
if err != nil { if err != nil {
return err return err
} }
@ -182,7 +179,7 @@ func pipeDecrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool)
} }
// recvive msg from reader, then encrypt msg into writer // recvive msg from reader, then encrypt msg into writer
func pipeEncrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool) (err error) { func pipeEncrypt(r *conn.Conn, w *conn.Conn, conf config.BaseConf, needRecord bool) (err error) {
laes := new(pcrypto.Pcrypto) laes := new(pcrypto.Pcrypto)
key := conf.AuthToken key := conf.AuthToken
if conf.PrivilegeMode { if conf.PrivilegeMode {
@ -201,10 +198,9 @@ func pipeEncrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool)
}() }()
} }
nreader := bufio.NewReader(r)
buf := make([]byte, 5*1024) buf := make([]byte, 5*1024)
for { for {
n, err := nreader.Read(buf) n, err := r.Read(buf)
if err != nil { if err != nil {
return err return err
} }
@ -235,7 +231,7 @@ func pipeEncrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool)
} }
res = pkgMsg(res) res = pkgMsg(res)
_, err = w.Write(res) _, err = w.WriteBytes(res)
if err != nil { if err != nil {
return err return err
} }

View File

@ -154,13 +154,12 @@ func (p *ProxyServer) Start(c *conn.Conn) (err error) {
} }
// start another goroutine for join two conns from frpc and user // start another goroutine for join two conns from frpc and user
go func() { go func(userConn *conn.Conn) {
workConn, err := p.getWorkConn() workConn, err := p.getWorkConn()
if err != nil { if err != nil {
return return
} }
userConn := c
// msg will transfer to another without modifying // msg will transfer to another without modifying
// l means local, r means remote // l means local, r means remote
log.Debug("Join two connections, (l[%s] r[%s]) (l[%s] r[%s])", workConn.GetLocalAddr(), workConn.GetRemoteAddr(), log.Debug("Join two connections, (l[%s] r[%s]) (l[%s] r[%s])", workConn.GetLocalAddr(), workConn.GetRemoteAddr(),
@ -169,7 +168,7 @@ func (p *ProxyServer) Start(c *conn.Conn) (err error) {
needRecord := true needRecord := true
go msg.JoinMore(userConn, workConn, p.BaseConf, needRecord) go msg.JoinMore(userConn, workConn, p.BaseConf, needRecord)
metric.OpenConnection(p.Name) metric.OpenConnection(p.Name)
}() }(c)
} }
}(listener) }(listener)
} }

View File

@ -117,6 +117,11 @@ func ConnectServer(host string, port int64) (c *Conn, err error) {
return c, nil return c, nil
} }
func (c *Conn) SetTcpConn(tcpConn net.Conn) {
c.TcpConn = tcpConn
c.Reader = bufio.NewReader(c.TcpConn)
}
func (c *Conn) GetRemoteAddr() (addr string) { func (c *Conn) GetRemoteAddr() (addr string) {
return c.TcpConn.RemoteAddr().String() return c.TcpConn.RemoteAddr().String()
} }
@ -125,6 +130,11 @@ func (c *Conn) GetLocalAddr() (addr string) {
return c.TcpConn.LocalAddr().String() return c.TcpConn.LocalAddr().String()
} }
func (c *Conn) Read(p []byte) (n int, err error) {
n, err = c.Reader.Read(p)
return
}
func (c *Conn) ReadLine() (buff string, err error) { func (c *Conn) ReadLine() (buff string, err error) {
buff, err = c.Reader.ReadString('\n') buff, err = c.Reader.ReadString('\n')
if err != nil { if err != nil {
@ -138,10 +148,14 @@ func (c *Conn) ReadLine() (buff string, err error) {
return buff, err return buff, err
} }
func (c *Conn) WriteBytes(content []byte) (n int, err error) {
n, err = c.TcpConn.Write(content)
return
}
func (c *Conn) Write(content string) (err error) { func (c *Conn) Write(content string) (err error) {
_, err = c.TcpConn.Write([]byte(content)) _, err = c.TcpConn.Write([]byte(content))
return err return err
} }
func (c *Conn) SetDeadline(t time.Time) error { func (c *Conn) SetDeadline(t time.Time) error {

View File

@ -105,7 +105,7 @@ func (v *VhostMuxer) handle(c *conn.Conn) {
if err = sConn.SetDeadline(time.Time{}); err != nil { if err = sConn.SetDeadline(time.Time{}); err != nil {
return return
} }
c.TcpConn = sConn c.SetTcpConn(sConn)
l.accept <- c l.accept <- c
} }