fix: frps handle multi conn may happen data race (#1768)

This commit is contained in:
Tank 2020-04-19 16:16:24 +08:00 committed by GitHub
parent 5a61fd84ad
commit 7728e35c52
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -297,6 +297,68 @@ func (svr *Service) Run() {
svr.HandleListener(svr.listener)
}
func (svr *Service) handleConnection(ctx context.Context, conn net.Conn) {
xl := xlog.FromContextSafe(ctx)
var (
rawMsg msg.Message
err error
)
conn.SetReadDeadline(time.Now().Add(connReadTimeout))
if rawMsg, err = msg.ReadMsg(conn); err != nil {
log.Trace("Failed to read message: %v", err)
conn.Close()
return
}
conn.SetReadDeadline(time.Time{})
switch m := rawMsg.(type) {
case *msg.Login:
// server plugin hook
content := &plugin.LoginContent{
Login: *m,
}
retContent, err := svr.pluginManager.Login(content)
if err == nil {
m = &retContent.Login
err = svr.RegisterControl(conn, m)
}
// If login failed, send error message there.
// Otherwise send success message in control's work goroutine.
if err != nil {
xl.Warn("register control error: %v", err)
msg.WriteMsg(conn, &msg.LoginResp{
Version: version.Full(),
Error: util.GenerateResponseErrorString("register control error", err, svr.cfg.DetailedErrorsToClient),
})
conn.Close()
}
case *msg.NewWorkConn:
if err := svr.RegisterWorkConn(conn, m); err != nil {
conn.Close()
}
case *msg.NewVisitorConn:
if err = svr.RegisterVisitorConn(conn, m); err != nil {
xl.Warn("register visitor conn error: %v", err)
msg.WriteMsg(conn, &msg.NewVisitorConnResp{
ProxyName: m.ProxyName,
Error: util.GenerateResponseErrorString("register visitor conn error", err, svr.cfg.DetailedErrorsToClient),
})
conn.Close()
} else {
msg.WriteMsg(conn, &msg.NewVisitorConnResp{
ProxyName: m.ProxyName,
Error: "",
})
}
default:
log.Warn("Error message type for the new connection [%s]", conn.RemoteAddr().String())
conn.Close()
}
}
func (svr *Service) HandleListener(l net.Listener) {
// Listen for incoming connections from client.
for {
@ -307,7 +369,9 @@ func (svr *Service) HandleListener(l net.Listener) {
}
// inject xlog object into net.Conn context
xl := xlog.New()
c = frpNet.NewContextConn(c, xlog.NewContext(context.Background(), xl))
ctx := context.Background()
c = frpNet.NewContextConn(c, xlog.NewContext(ctx, xl))
log.Trace("start check TLS connection...")
originConn := c
@ -320,63 +384,7 @@ func (svr *Service) HandleListener(l net.Listener) {
log.Trace("success check TLS connection")
// Start a new goroutine for dealing connections.
go func(frpConn net.Conn) {
dealFn := func(conn net.Conn) {
var rawMsg msg.Message
conn.SetReadDeadline(time.Now().Add(connReadTimeout))
if rawMsg, err = msg.ReadMsg(conn); err != nil {
log.Trace("Failed to read message: %v", err)
conn.Close()
return
}
conn.SetReadDeadline(time.Time{})
switch m := rawMsg.(type) {
case *msg.Login:
// server plugin hook
content := &plugin.LoginContent{
Login: *m,
}
retContent, err := svr.pluginManager.Login(content)
if err == nil {
m = &retContent.Login
err = svr.RegisterControl(conn, m)
}
// If login failed, send error message there.
// Otherwise send success message in control's work goroutine.
if err != nil {
xl.Warn("register control error: %v", err)
msg.WriteMsg(conn, &msg.LoginResp{
Version: version.Full(),
Error: util.GenerateResponseErrorString("register control error", err, svr.cfg.DetailedErrorsToClient),
})
conn.Close()
}
case *msg.NewWorkConn:
if err := svr.RegisterWorkConn(conn, m); err != nil {
conn.Close()
}
case *msg.NewVisitorConn:
if err = svr.RegisterVisitorConn(conn, m); err != nil {
xl.Warn("register visitor conn error: %v", err)
msg.WriteMsg(conn, &msg.NewVisitorConnResp{
ProxyName: m.ProxyName,
Error: util.GenerateResponseErrorString("register visitor conn error", err, svr.cfg.DetailedErrorsToClient),
})
conn.Close()
} else {
msg.WriteMsg(conn, &msg.NewVisitorConnResp{
ProxyName: m.ProxyName,
Error: "",
})
}
default:
log.Warn("Error message type for the new connection [%s]", conn.RemoteAddr().String())
conn.Close()
}
}
go func(ctx context.Context, frpConn net.Conn) {
if svr.cfg.TcpMux {
fmuxCfg := fmux.DefaultConfig()
fmuxCfg.KeepAliveInterval = 20 * time.Second
@ -395,12 +403,12 @@ func (svr *Service) HandleListener(l net.Listener) {
session.Close()
return
}
go dealFn(stream)
go svr.handleConnection(ctx, stream)
}
} else {
dealFn(frpConn)
svr.handleConnection(ctx, frpConn)
}
}(c)
}(ctx, c)
}
}