add connection read timeout

This commit is contained in:
fatedier 2017-05-10 00:46:42 +08:00
parent 71f7caa1ee
commit a84dd05351
2 changed files with 19 additions and 1 deletions

View File

@ -30,6 +30,10 @@ import (
"github.com/fatedier/frp/utils/version" "github.com/fatedier/frp/utils/version"
) )
const (
connReadTimeout time.Duration = 10 * time.Second
)
type Control struct { type Control struct {
// frpc service // frpc service
svr *Service svr *Service
@ -144,7 +148,7 @@ func (ctl *Control) NewWorkConn() {
// dispatch this work connection to related proxy // dispatch this work connection to related proxy
if pxy, ok := ctl.proxies[startMsg.ProxyName]; ok { if pxy, ok := ctl.proxies[startMsg.ProxyName]; ok {
workConn.Info("start a new work connection") workConn.Info("start a new work connection, localAddr: %s remoteAddr: %s", workConn.LocalAddr().String(), workConn.RemoteAddr().String())
go pxy.InWorkConn(workConn) go pxy.InWorkConn(workConn)
} else { } else {
workConn.Close() workConn.Close()
@ -168,6 +172,12 @@ func (ctl *Control) login() (err error) {
return err return err
} }
defer func() {
if err != nil {
conn.Close()
}
}()
now := time.Now().Unix() now := time.Now().Unix()
ctl.loginMsg.PrivilegeKey = util.GetAuthKey(config.ClientCommonCfg.PrivilegeToken, now) ctl.loginMsg.PrivilegeKey = util.GetAuthKey(config.ClientCommonCfg.PrivilegeToken, now)
ctl.loginMsg.Timestamp = now ctl.loginMsg.Timestamp = now
@ -178,9 +188,11 @@ func (ctl *Control) login() (err error) {
} }
var loginRespMsg msg.LoginResp var loginRespMsg msg.LoginResp
conn.SetReadDeadline(time.Now().Add(connReadTimeout))
if err = msg.ReadMsgInto(conn, &loginRespMsg); err != nil { if err = msg.ReadMsgInto(conn, &loginRespMsg); err != nil {
return err return err
} }
conn.SetReadDeadline(time.Time{})
if loginRespMsg.Error != "" { if loginRespMsg.Error != "" {
err = fmt.Errorf("%s", loginRespMsg.Error) err = fmt.Errorf("%s", loginRespMsg.Error)

View File

@ -28,6 +28,10 @@ import (
"github.com/fatedier/frp/utils/vhost" "github.com/fatedier/frp/utils/vhost"
) )
const (
connReadTimeout time.Duration = 10 * time.Second
)
var ServerService *Service var ServerService *Service
// Server service. // Server service.
@ -121,11 +125,13 @@ func (svr *Service) Run() {
// Start a new goroutine for dealing connections. // Start a new goroutine for dealing connections.
go func(frpConn net.Conn) { go func(frpConn net.Conn) {
var rawMsg msg.Message var rawMsg msg.Message
frpConn.SetReadDeadline(time.Now().Add(connReadTimeout))
if rawMsg, err = msg.ReadMsg(frpConn); err != nil { if rawMsg, err = msg.ReadMsg(frpConn); err != nil {
log.Warn("Failed to read message: %v", err) log.Warn("Failed to read message: %v", err)
frpConn.Close() frpConn.Close()
return return
} }
frpConn.SetReadDeadline(time.Time{})
switch m := rawMsg.(type) { switch m := rawMsg.(type) {
case *msg.Login: case *msg.Login: