config: add some validations (#3610)

This commit is contained in:
fatedier 2023-09-13 18:59:51 +08:00 committed by GitHub
parent 7cd02f5bd8
commit 74255f711e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 144 additions and 20 deletions

View File

@ -117,6 +117,7 @@ var rootCmd = &cobra.Command{
// Do not show command usage here. // Do not show command usage here.
err := runClient(cfgFile) err := runClient(cfgFile)
if err != nil { if err != nil {
fmt.Println(err)
os.Exit(1) os.Exit(1)
} }
return nil return nil
@ -199,7 +200,6 @@ func parseClientCommonCfgFromCmd() (*v1.ClientCommonConfig, error) {
func runClient(cfgFilePath string) error { func runClient(cfgFilePath string) error {
cfg, pxyCfgs, visitorCfgs, isLegacyFormat, err := config.LoadClientConfig(cfgFilePath) cfg, pxyCfgs, visitorCfgs, isLegacyFormat, err := config.LoadClientConfig(cfgFilePath)
if err != nil { if err != nil {
fmt.Println(err)
return err return err
} }
if isLegacyFormat { if isLegacyFormat {

View File

@ -30,32 +30,32 @@ type ServerConfig struct {
BindAddr string `json:"bindAddr,omitempty"` BindAddr string `json:"bindAddr,omitempty"`
// BindPort specifies the port that the server listens on. By default, this // BindPort specifies the port that the server listens on. By default, this
// value is 7000. // value is 7000.
BindPort int `json:"bindPort,omitempty" validate:"gte=0,lte=65535"` BindPort int `json:"bindPort,omitempty"`
// KCPBindPort specifies the KCP port that the server listens on. If this // KCPBindPort specifies the KCP port that the server listens on. If this
// value is 0, the server will not listen for KCP connections. // value is 0, the server will not listen for KCP connections.
KCPBindPort int `json:"kcpBindPort,omitempty" validate:"gte=0,lte=65535"` KCPBindPort int `json:"kcpBindPort,omitempty"`
// QUICBindPort specifies the QUIC port that the server listens on. // QUICBindPort specifies the QUIC port that the server listens on.
// Set this value to 0 will disable this feature. // Set this value to 0 will disable this feature.
QUICBindPort int `json:"quicBindPort,omitempty" validate:"gte=0,lte=65535"` QUICBindPort int `json:"quicBindPort,omitempty"`
// ProxyBindAddr specifies the address that the proxy binds to. This value // ProxyBindAddr specifies the address that the proxy binds to. This value
// may be the same as BindAddr. // may be the same as BindAddr.
ProxyBindAddr string `json:"proxyBindAddr,omitempty"` ProxyBindAddr string `json:"proxyBindAddr,omitempty"`
// VhostHTTPPort specifies the port that the server listens for HTTP Vhost // VhostHTTPPort specifies the port that the server listens for HTTP Vhost
// requests. If this value is 0, the server will not listen for HTTP // requests. If this value is 0, the server will not listen for HTTP
// requests. // requests.
VhostHTTPPort int `json:"vhostHTTPPort,omitempty" validate:"gte=0,lte=65535"` VhostHTTPPort int `json:"vhostHTTPPort,omitempty"`
// VhostHTTPTimeout specifies the response header timeout for the Vhost // VhostHTTPTimeout specifies the response header timeout for the Vhost
// HTTP server, in seconds. By default, this value is 60. // HTTP server, in seconds. By default, this value is 60.
VhostHTTPTimeout int64 `json:"vhostHTTPTimeout,omitempty"` VhostHTTPTimeout int64 `json:"vhostHTTPTimeout,omitempty"`
// VhostHTTPSPort specifies the port that the server listens for HTTPS // VhostHTTPSPort specifies the port that the server listens for HTTPS
// Vhost requests. If this value is 0, the server will not listen for HTTPS // Vhost requests. If this value is 0, the server will not listen for HTTPS
// requests. // requests.
VhostHTTPSPort int `json:"vhostHTTPSPort,omitempty" validate:"gte=0,lte=65535"` VhostHTTPSPort int `json:"vhostHTTPSPort,omitempty"`
// TCPMuxHTTPConnectPort specifies the port that the server listens for TCP // TCPMuxHTTPConnectPort specifies the port that the server listens for TCP
// HTTP CONNECT requests. If the value is 0, the server will not multiplex TCP // HTTP CONNECT requests. If the value is 0, the server will not multiplex TCP
// requests on one single port. If it's not - it will listen on this value for // requests on one single port. If it's not - it will listen on this value for
// HTTP CONNECT requests. // HTTP CONNECT requests.
TCPMuxHTTPConnectPort int `json:"tcpmuxHTTPConnectPort,omitempty" validate:"gte=0,lte=65535"` TCPMuxHTTPConnectPort int `json:"tcpmuxHTTPConnectPort,omitempty"`
// If TCPMuxPassthrough is true, frps won't do any update on traffic. // If TCPMuxPassthrough is true, frps won't do any update on traffic.
TCPMuxPassthrough bool `json:"tcpmuxPassthrough,omitempty"` TCPMuxPassthrough bool `json:"tcpmuxPassthrough,omitempty"`
// SubDomainHost specifies the domain that will be attached to sub-domains // SubDomainHost specifies the domain that will be attached to sub-domains

View File

@ -29,6 +29,21 @@ func ValidateClientCommonConfig(c *v1.ClientCommonConfig) (Warning, error) {
warnings Warning warnings Warning
errs error errs error
) )
if !lo.Contains(supportedAuthMethods, c.Auth.Method) {
errs = AppendError(errs, fmt.Errorf("invalid auth method, optional values are %v", supportedAuthMethods))
}
if !lo.Every(supportedAuthAdditionalScopes, c.Auth.AdditionalScopes) {
errs = AppendError(errs, fmt.Errorf("invalid auth additional scopes, optional values are %v", supportedAuthAdditionalScopes))
}
if err := validateLogConfig(&c.Log); err != nil {
errs = AppendError(errs, err)
}
if err := validateWebServerConfig(&c.WebServer); err != nil {
errs = AppendError(errs, err)
}
if c.Transport.HeartbeatTimeout > 0 && c.Transport.HeartbeatInterval > 0 { if c.Transport.HeartbeatTimeout > 0 && c.Transport.HeartbeatInterval > 0 {
if c.Transport.HeartbeatTimeout < c.Transport.HeartbeatInterval { if c.Transport.HeartbeatTimeout < c.Transport.HeartbeatInterval {
errs = AppendError(errs, fmt.Errorf("invalid transport.heartbeatTimeout, heartbeat timeout should not less than heartbeat interval")) errs = AppendError(errs, fmt.Errorf("invalid transport.heartbeatTimeout, heartbeat timeout should not less than heartbeat interval"))
@ -48,8 +63,8 @@ func ValidateClientCommonConfig(c *v1.ClientCommonConfig) (Warning, error) {
warnings = AppendError(warnings, checkTLSConfig("transport.tls.trustedCaFile", c.Transport.TLS.TrustedCaFile)) warnings = AppendError(warnings, checkTLSConfig("transport.tls.trustedCaFile", c.Transport.TLS.TrustedCaFile))
} }
if !lo.Contains([]string{"tcp", "kcp", "quic", "websocket", "wss"}, c.Transport.Protocol) { if !lo.Contains(supportedTransportProtocols, c.Transport.Protocol) {
errs = AppendError(errs, fmt.Errorf("invalid transport.protocol, only support tcp, kcp, quic, websocket, wss")) errs = AppendError(errs, fmt.Errorf("invalid transport.protocol, optional values are %v", supportedTransportProtocols))
} }
for _, f := range c.IncludeConfigFiles { for _, f := range c.IncludeConfigFiles {

View File

@ -17,6 +17,8 @@ package validation
import ( import (
"fmt" "fmt"
"github.com/samber/lo"
v1 "github.com/fatedier/frp/pkg/config/v1" v1 "github.com/fatedier/frp/pkg/config/v1"
) )
@ -29,13 +31,21 @@ func validateWebServerConfig(c *v1.WebServerConfig) error {
return fmt.Errorf("tls.keyFile must be specified when tls is enabled") return fmt.Errorf("tls.keyFile must be specified when tls is enabled")
} }
} }
return nil
return ValidatePort(c.Port, "webServer.port")
} }
// ValidatePort checks that the network port is in range // ValidatePort checks that the network port is in range
func ValidatePort(port int) error { func ValidatePort(port int, fieldPath string) error {
if 0 <= port && port <= 65535 { if 0 <= port && port <= 65535 {
return nil return nil
} }
return fmt.Errorf("port number %d must be in the range 0..65535", port) return fmt.Errorf("%s: port number %d must be in the range 0..65535", fieldPath, port)
}
func validateLogConfig(c *v1.LogConfig) error {
if !lo.Contains(supportedLogLevels, c.Level) {
return fmt.Errorf("invalid log level, optional values are %v", supportedLogLevels)
}
return nil
} }

View File

@ -39,7 +39,7 @@ func validateProxyBaseConfigForClient(c *v1.ProxyBaseConfig) error {
} }
if c.Plugin.Type == "" { if c.Plugin.Type == "" {
if err := ValidatePort(c.LocalPort); err != nil { if err := ValidatePort(c.LocalPort, "localPort"); err != nil {
return fmt.Errorf("localPort: %v", err) return fmt.Errorf("localPort: %v", err)
} }
} }

View File

@ -15,6 +15,10 @@
package validation package validation
import ( import (
"fmt"
"github.com/samber/lo"
v1 "github.com/fatedier/frp/pkg/config/v1" v1 "github.com/fatedier/frp/pkg/config/v1"
) )
@ -23,15 +27,32 @@ func ValidateServerConfig(c *v1.ServerConfig) (Warning, error) {
warnings Warning warnings Warning
errs error errs error
) )
if !lo.Contains(supportedAuthMethods, c.Auth.Method) {
errs = AppendError(errs, fmt.Errorf("invalid auth method, optional values are %v", supportedAuthMethods))
}
if !lo.Every(supportedAuthAdditionalScopes, c.Auth.AdditionalScopes) {
errs = AppendError(errs, fmt.Errorf("invalid auth additional scopes, optional values are %v", supportedAuthAdditionalScopes))
}
if err := validateLogConfig(&c.Log); err != nil {
errs = AppendError(errs, err)
}
if err := validateWebServerConfig(&c.WebServer); err != nil { if err := validateWebServerConfig(&c.WebServer); err != nil {
errs = AppendError(errs, err) errs = AppendError(errs, err)
} }
errs = AppendError(errs, ValidatePort(c.BindPort)) errs = AppendError(errs, ValidatePort(c.BindPort, "bindPort"))
errs = AppendError(errs, ValidatePort(c.KCPBindPort)) errs = AppendError(errs, ValidatePort(c.KCPBindPort, "kcpBindPort"))
errs = AppendError(errs, ValidatePort(c.QUICBindPort)) errs = AppendError(errs, ValidatePort(c.QUICBindPort, "quicBindPort"))
errs = AppendError(errs, ValidatePort(c.VhostHTTPPort)) errs = AppendError(errs, ValidatePort(c.VhostHTTPPort, "vhostHTTPPort"))
errs = AppendError(errs, ValidatePort(c.VhostHTTPSPort)) errs = AppendError(errs, ValidatePort(c.VhostHTTPSPort, "vhostHTTPSPort"))
errs = AppendError(errs, ValidatePort(c.TCPMuxHTTPConnectPort)) errs = AppendError(errs, ValidatePort(c.TCPMuxHTTPConnectPort, "tcpMuxHTTPConnectPort"))
for _, p := range c.HTTPPlugins {
if !lo.Every(supportedHTTPPluginOps, p.Ops) {
errs = AppendError(errs, fmt.Errorf("invalid http plugin ops, optional values are %v", supportedHTTPPluginOps))
}
}
return warnings, errs return warnings, errs
} }

View File

@ -16,6 +16,46 @@ package validation
import ( import (
"errors" "errors"
v1 "github.com/fatedier/frp/pkg/config/v1"
splugin "github.com/fatedier/frp/pkg/plugin/server"
)
var (
supportedTransportProtocols = []string{
"tcp",
"kcp",
"quic",
"websocket",
"wss",
}
supportedAuthMethods = []string{
"token",
"oidc",
}
supportedAuthAdditionalScopes = []v1.AuthScope{
"HeartBeats",
"NewWorkConns",
}
supportedLogLevels = []string{
"trace",
"debug",
"info",
"warn",
"error",
}
supportedHTTPPluginOps = []string{
splugin.OpLogin,
splugin.OpNewProxy,
splugin.OpCloseProxy,
splugin.OpPing,
splugin.OpNewWorkConn,
splugin.OpNewUserConn,
}
) )
type Warning error type Warning error

View File

@ -99,7 +99,7 @@ func (c *TypedVisitorConfig) UnmarshalJSON(b []byte) error {
c.Type = typeStruct.Type c.Type = typeStruct.Type
configurer := NewVisitorConfigurerByType(typeStruct.Type) configurer := NewVisitorConfigurerByType(typeStruct.Type)
if configurer == nil { if configurer == nil {
return fmt.Errorf("unknown visitor type: %s" + typeStruct.Type) return fmt.Errorf("unknown visitor type: %s", typeStruct.Type)
} }
if err := json.Unmarshal(b, configurer); err != nil { if err := json.Unmarshal(b, configurer); err != nil {
return err return err

View File

@ -81,4 +81,42 @@ var _ = ginkgo.Describe("[Feature: Config]", func() {
framework.NewRequestExpect(f).Port(remotePort2).Ensure() framework.NewRequestExpect(f).Port(remotePort2).Ensure()
}) })
}) })
ginkgo.Describe("Support Formats", func() {
ginkgo.It("YAML", func() {
serverConf := fmt.Sprintf(`
bindPort: {{ .%s }}
log:
level: trace
`, port.GenName("Server"))
remotePort := f.AllocPort()
clientConf := fmt.Sprintf(`
serverPort: {{ .%s }}
log:
level: trace
proxies:
- name: tcp
type: tcp
localPort: {{ .%s }}
remotePort: %d
`, port.GenName("Server"), framework.TCPEchoServerPort, remotePort)
f.RunProcesses([]string{serverConf}, []string{clientConf})
framework.NewRequestExpect(f).Port(remotePort).Ensure()
})
ginkgo.It("JSON", func() {
serverConf := fmt.Sprintf(`{"bindPort": {{ .%s }}, "log": {"level": "trace"}}`, port.GenName("Server"))
remotePort := f.AllocPort()
clientConf := fmt.Sprintf(`{"serverPort": {{ .%s }}, "log": {"level": "trace"},
"proxies": [{"name": "tcp", "type": "tcp", "localPort": {{ .%s }}, "remotePort": %d}]}`,
port.GenName("Server"), framework.TCPEchoServerPort, remotePort)
f.RunProcesses([]string{serverConf}, []string{clientConf})
framework.NewRequestExpect(f).Port(remotePort).Ensure()
})
})
}) })