fix: frps handle multi conn may happen data race (#1768)

This commit is contained in:
Tank 2020-04-19 16:16:24 +08:00 committed by GitHub
parent 5a61fd84ad
commit 7728e35c52
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 69 additions and 61 deletions

View File

@ -297,32 +297,14 @@ func (svr *Service) Run() {
svr.HandleListener(svr.listener) svr.HandleListener(svr.listener)
} }
func (svr *Service) HandleListener(l net.Listener) { func (svr *Service) handleConnection(ctx context.Context, conn net.Conn) {
// Listen for incoming connections from client. xl := xlog.FromContextSafe(ctx)
for {
c, err := l.Accept()
if err != nil {
log.Warn("Listener for incoming connections from client closed")
return
}
// inject xlog object into net.Conn context
xl := xlog.New()
c = frpNet.NewContextConn(c, xlog.NewContext(context.Background(), xl))
log.Trace("start check TLS connection...") var (
originConn := c rawMsg msg.Message
c, err = frpNet.CheckAndEnableTLSServerConnWithTimeout(c, svr.tlsConfig, svr.cfg.TlsOnly, connReadTimeout) err error
if err != nil { )
log.Warn("CheckAndEnableTLSServerConnWithTimeout error: %v", err)
originConn.Close()
continue
}
log.Trace("success check TLS connection")
// Start a new goroutine for dealing connections.
go func(frpConn net.Conn) {
dealFn := func(conn net.Conn) {
var rawMsg msg.Message
conn.SetReadDeadline(time.Now().Add(connReadTimeout)) conn.SetReadDeadline(time.Now().Add(connReadTimeout))
if rawMsg, err = msg.ReadMsg(conn); err != nil { if rawMsg, err = msg.ReadMsg(conn); err != nil {
log.Trace("Failed to read message: %v", err) log.Trace("Failed to read message: %v", err)
@ -375,8 +357,34 @@ func (svr *Service) HandleListener(l net.Listener) {
log.Warn("Error message type for the new connection [%s]", conn.RemoteAddr().String()) log.Warn("Error message type for the new connection [%s]", conn.RemoteAddr().String())
conn.Close() conn.Close()
} }
} }
func (svr *Service) HandleListener(l net.Listener) {
// Listen for incoming connections from client.
for {
c, err := l.Accept()
if err != nil {
log.Warn("Listener for incoming connections from client closed")
return
}
// inject xlog object into net.Conn context
xl := xlog.New()
ctx := context.Background()
c = frpNet.NewContextConn(c, xlog.NewContext(ctx, xl))
log.Trace("start check TLS connection...")
originConn := c
c, err = frpNet.CheckAndEnableTLSServerConnWithTimeout(c, svr.tlsConfig, svr.cfg.TlsOnly, connReadTimeout)
if err != nil {
log.Warn("CheckAndEnableTLSServerConnWithTimeout error: %v", err)
originConn.Close()
continue
}
log.Trace("success check TLS connection")
// Start a new goroutine for dealing connections.
go func(ctx context.Context, frpConn net.Conn) {
if svr.cfg.TcpMux { if svr.cfg.TcpMux {
fmuxCfg := fmux.DefaultConfig() fmuxCfg := fmux.DefaultConfig()
fmuxCfg.KeepAliveInterval = 20 * time.Second fmuxCfg.KeepAliveInterval = 20 * time.Second
@ -395,12 +403,12 @@ func (svr *Service) HandleListener(l net.Listener) {
session.Close() session.Close()
return return
} }
go dealFn(stream) go svr.handleConnection(ctx, stream)
} }
} else { } else {
dealFn(frpConn) svr.handleConnection(ctx, frpConn)
} }
}(c) }(ctx, c)
} }
} }