diff --git a/server/service.go b/server/service.go index 8de44ae..7add79e 100644 --- a/server/service.go +++ b/server/service.go @@ -297,6 +297,68 @@ func (svr *Service) Run() { svr.HandleListener(svr.listener) } +func (svr *Service) handleConnection(ctx context.Context, conn net.Conn) { + xl := xlog.FromContextSafe(ctx) + + var ( + rawMsg msg.Message + err error + ) + + conn.SetReadDeadline(time.Now().Add(connReadTimeout)) + if rawMsg, err = msg.ReadMsg(conn); err != nil { + log.Trace("Failed to read message: %v", err) + conn.Close() + return + } + conn.SetReadDeadline(time.Time{}) + + switch m := rawMsg.(type) { + case *msg.Login: + // server plugin hook + content := &plugin.LoginContent{ + Login: *m, + } + retContent, err := svr.pluginManager.Login(content) + if err == nil { + m = &retContent.Login + err = svr.RegisterControl(conn, m) + } + + // If login failed, send error message there. + // Otherwise send success message in control's work goroutine. + if err != nil { + xl.Warn("register control error: %v", err) + msg.WriteMsg(conn, &msg.LoginResp{ + Version: version.Full(), + Error: util.GenerateResponseErrorString("register control error", err, svr.cfg.DetailedErrorsToClient), + }) + conn.Close() + } + case *msg.NewWorkConn: + if err := svr.RegisterWorkConn(conn, m); err != nil { + conn.Close() + } + case *msg.NewVisitorConn: + if err = svr.RegisterVisitorConn(conn, m); err != nil { + xl.Warn("register visitor conn error: %v", err) + msg.WriteMsg(conn, &msg.NewVisitorConnResp{ + ProxyName: m.ProxyName, + Error: util.GenerateResponseErrorString("register visitor conn error", err, svr.cfg.DetailedErrorsToClient), + }) + conn.Close() + } else { + msg.WriteMsg(conn, &msg.NewVisitorConnResp{ + ProxyName: m.ProxyName, + Error: "", + }) + } + default: + log.Warn("Error message type for the new connection [%s]", conn.RemoteAddr().String()) + conn.Close() + } +} + func (svr *Service) HandleListener(l net.Listener) { // Listen for incoming connections from client. for { @@ -307,7 +369,9 @@ func (svr *Service) HandleListener(l net.Listener) { } // inject xlog object into net.Conn context xl := xlog.New() - c = frpNet.NewContextConn(c, xlog.NewContext(context.Background(), xl)) + ctx := context.Background() + + c = frpNet.NewContextConn(c, xlog.NewContext(ctx, xl)) log.Trace("start check TLS connection...") originConn := c @@ -320,63 +384,7 @@ func (svr *Service) HandleListener(l net.Listener) { log.Trace("success check TLS connection") // Start a new goroutine for dealing connections. - go func(frpConn net.Conn) { - dealFn := func(conn net.Conn) { - var rawMsg msg.Message - conn.SetReadDeadline(time.Now().Add(connReadTimeout)) - if rawMsg, err = msg.ReadMsg(conn); err != nil { - log.Trace("Failed to read message: %v", err) - conn.Close() - return - } - conn.SetReadDeadline(time.Time{}) - - switch m := rawMsg.(type) { - case *msg.Login: - // server plugin hook - content := &plugin.LoginContent{ - Login: *m, - } - retContent, err := svr.pluginManager.Login(content) - if err == nil { - m = &retContent.Login - err = svr.RegisterControl(conn, m) - } - - // If login failed, send error message there. - // Otherwise send success message in control's work goroutine. - if err != nil { - xl.Warn("register control error: %v", err) - msg.WriteMsg(conn, &msg.LoginResp{ - Version: version.Full(), - Error: util.GenerateResponseErrorString("register control error", err, svr.cfg.DetailedErrorsToClient), - }) - conn.Close() - } - case *msg.NewWorkConn: - if err := svr.RegisterWorkConn(conn, m); err != nil { - conn.Close() - } - case *msg.NewVisitorConn: - if err = svr.RegisterVisitorConn(conn, m); err != nil { - xl.Warn("register visitor conn error: %v", err) - msg.WriteMsg(conn, &msg.NewVisitorConnResp{ - ProxyName: m.ProxyName, - Error: util.GenerateResponseErrorString("register visitor conn error", err, svr.cfg.DetailedErrorsToClient), - }) - conn.Close() - } else { - msg.WriteMsg(conn, &msg.NewVisitorConnResp{ - ProxyName: m.ProxyName, - Error: "", - }) - } - default: - log.Warn("Error message type for the new connection [%s]", conn.RemoteAddr().String()) - conn.Close() - } - } - + go func(ctx context.Context, frpConn net.Conn) { if svr.cfg.TcpMux { fmuxCfg := fmux.DefaultConfig() fmuxCfg.KeepAliveInterval = 20 * time.Second @@ -395,12 +403,12 @@ func (svr *Service) HandleListener(l net.Listener) { session.Close() return } - go dealFn(stream) + go svr.handleConnection(ctx, stream) } } else { - dealFn(frpConn) + svr.handleConnection(ctx, frpConn) } - }(c) + }(ctx, c) } }