diff --git a/models/config/server_common.go b/models/config/server_common.go index df6b7a1..58a7252 100644 --- a/models/config/server_common.go +++ b/models/config/server_common.go @@ -128,6 +128,9 @@ type ServerCommonConf struct { // may proxy to. If this value is 0, no limit will be applied. By default, // this value is 0. MaxPortsPerClient int64 `json:"max_ports_per_client"` + // TlsOnly specifies whether to only accept TLS-encrypted connections. By + // default, the value is false. + TlsOnly bool `json:"tls_only"` // HeartBeatTimeout specifies the maximum time to wait for a heartbeat // before terminating the connection. It is not recommended to change this // value. By default, this value is 90. @@ -167,6 +170,7 @@ func GetDefaultServerConf() ServerCommonConf { AllowPorts: make(map[int]struct{}), MaxPoolCount: 5, MaxPortsPerClient: 0, + TlsOnly: false, HeartBeatTimeout: 90, UserConnTimeout: 10, Custom404Page: "", @@ -378,6 +382,12 @@ func UnmarshalServerConfFromIni(content string) (cfg ServerCommonConf, err error cfg.HeartBeatTimeout = v } } + + if tmpStr, ok = conf.Get("common", "tls_only"); ok && tmpStr == "true" { + cfg.TlsOnly = true + } else { + cfg.TlsOnly = false + } return } diff --git a/server/service.go b/server/service.go index 122555a..314d59b 100644 --- a/server/service.go +++ b/server/service.go @@ -284,7 +284,7 @@ func (svr *Service) HandleListener(l net.Listener) { log.Trace("start check TLS connection...") originConn := c - c, err = frpNet.CheckAndEnableTLSServerConnWithTimeout(c, svr.tlsConfig, connReadTimeout) + c, err = frpNet.CheckAndEnableTLSServerConnWithTimeout(c, svr.tlsConfig, svr.cfg.TlsOnly, connReadTimeout) if err != nil { log.Warn("CheckAndEnableTLSServerConnWithTimeout error: %v", err) originConn.Close() diff --git a/utils/net/tls.go b/utils/net/tls.go index b9fca31..d327122 100644 --- a/utils/net/tls.go +++ b/utils/net/tls.go @@ -16,6 +16,7 @@ package net import ( "crypto/tls" + "fmt" "net" "time" @@ -32,7 +33,7 @@ func WrapTLSClientConn(c net.Conn, tlsConfig *tls.Config) (out net.Conn) { return } -func CheckAndEnableTLSServerConnWithTimeout(c net.Conn, tlsConfig *tls.Config, timeout time.Duration) (out net.Conn, err error) { +func CheckAndEnableTLSServerConnWithTimeout(c net.Conn, tlsConfig *tls.Config, tlsOnly bool, timeout time.Duration) (out net.Conn, err error) { sc, r := gnet.NewSharedConnSize(c, 2) buf := make([]byte, 1) var n int @@ -46,6 +47,10 @@ func CheckAndEnableTLSServerConnWithTimeout(c net.Conn, tlsConfig *tls.Config, t if n == 1 && int(buf[0]) == FRP_TLS_HEAD_BYTE { out = tls.Server(c, tlsConfig) } else { + if tlsOnly { + err = fmt.Errorf("non-TLS connection received on a TlsOnly server") + return + } out = sc } return