Strict configuration parsing (#3773)

* Test configuration loading more precisely

* Add strict configuration parsing
This commit is contained in:
Aarni Koskela 2023-11-16 09:42:49 +02:00 committed by GitHub
parent 184223cb2f
commit e8deb65c4b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 119 additions and 47 deletions

View File

@ -57,7 +57,7 @@ func (svr *Service) apiReload(w http.ResponseWriter, _ *http.Request) {
} }
}() }()
cliCfg, pxyCfgs, visitorCfgs, _, err := config.LoadClientConfig(svr.cfgFile) cliCfg, pxyCfgs, visitorCfgs, _, err := config.LoadClientConfig(svr.cfgFile, svr.strictConfig)
if err != nil { if err != nil {
res.Code = 400 res.Code = 400
res.Msg = err.Error() res.Msg = err.Error()

View File

@ -70,6 +70,9 @@ type Service struct {
// string if no configuration file was used. // string if no configuration file was used.
cfgFile string cfgFile string
// Whether strict configuration parsing had been requested.
strictConfig bool
// service context // service context
ctx context.Context ctx context.Context
// call cancel to stop service // call cancel to stop service
@ -82,11 +85,13 @@ func NewService(
pxyCfgs []v1.ProxyConfigurer, pxyCfgs []v1.ProxyConfigurer,
visitorCfgs []v1.VisitorConfigurer, visitorCfgs []v1.VisitorConfigurer,
cfgFile string, cfgFile string,
strictConfig bool,
) *Service { ) *Service {
return &Service{ return &Service{
authSetter: auth.NewAuthSetter(cfg.Auth), authSetter: auth.NewAuthSetter(cfg.Auth),
cfg: cfg, cfg: cfg,
cfgFile: cfgFile, cfgFile: cfgFile,
strictConfig: strictConfig,
pxyCfgs: pxyCfgs, pxyCfgs: pxyCfgs,
visitorCfgs: visitorCfgs, visitorCfgs: visitorCfgs,
ctx: context.Background(), ctx: context.Background(),

View File

@ -52,7 +52,7 @@ func NewAdminCommand(name, short string, handler func(*v1.ClientCommonConfig) er
Use: name, Use: name,
Short: short, Short: short,
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
cfg, _, _, _, err := config.LoadClientConfig(cfgFile) cfg, _, _, _, err := config.LoadClientConfig(cfgFile, strictConfig)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
os.Exit(1) os.Exit(1)

View File

@ -48,7 +48,7 @@ var natholeDiscoveryCmd = &cobra.Command{
Short: "Discover nathole information from stun server", Short: "Discover nathole information from stun server",
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
// ignore error here, because we can use command line pameters // ignore error here, because we can use command line pameters
cfg, _, _, _, err := config.LoadClientConfig(cfgFile) cfg, _, _, _, err := config.LoadClientConfig(cfgFile, strictConfig)
if err != nil { if err != nil {
cfg = &v1.ClientCommonConfig{} cfg = &v1.ClientCommonConfig{}
} }

View File

@ -84,7 +84,7 @@ func NewProxyCommand(name string, c v1.ProxyConfigurer, clientCfg *v1.ClientComm
fmt.Println(err) fmt.Println(err)
os.Exit(1) os.Exit(1)
} }
err := startService(clientCfg, []v1.ProxyConfigurer{c}, nil, "") err := startService(clientCfg, []v1.ProxyConfigurer{c}, nil, "", strictConfig)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
os.Exit(1) os.Exit(1)
@ -110,7 +110,7 @@ func NewVisitorCommand(name string, c v1.VisitorConfigurer, clientCfg *v1.Client
fmt.Println(err) fmt.Println(err)
os.Exit(1) os.Exit(1)
} }
err := startService(clientCfg, nil, []v1.VisitorConfigurer{c}, "") err := startService(clientCfg, nil, []v1.VisitorConfigurer{c}, "", strictConfig)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
os.Exit(1) os.Exit(1)

View File

@ -39,12 +39,14 @@ var (
cfgFile string cfgFile string
cfgDir string cfgDir string
showVersion bool showVersion bool
strictConfig bool
) )
func init() { func init() {
rootCmd.PersistentFlags().StringVarP(&cfgFile, "config", "c", "./frpc.ini", "config file of frpc") rootCmd.PersistentFlags().StringVarP(&cfgFile, "config", "c", "./frpc.ini", "config file of frpc")
rootCmd.PersistentFlags().StringVarP(&cfgDir, "config_dir", "", "", "config directory, run one frpc service for each file in config directory") rootCmd.PersistentFlags().StringVarP(&cfgDir, "config_dir", "", "", "config directory, run one frpc service for each file in config directory")
rootCmd.PersistentFlags().BoolVarP(&showVersion, "version", "v", false, "version of frpc") rootCmd.PersistentFlags().BoolVarP(&showVersion, "version", "v", false, "version of frpc")
rootCmd.PersistentFlags().BoolVarP(&strictConfig, "strict_config", "", false, "strict config parsing mode")
} }
var rootCmd = &cobra.Command{ var rootCmd = &cobra.Command{
@ -108,7 +110,7 @@ func handleTermSignal(svr *client.Service) {
} }
func runClient(cfgFilePath string) error { func runClient(cfgFilePath string) error {
cfg, pxyCfgs, visitorCfgs, isLegacyFormat, err := config.LoadClientConfig(cfgFilePath) cfg, pxyCfgs, visitorCfgs, isLegacyFormat, err := config.LoadClientConfig(cfgFilePath, strictConfig)
if err != nil { if err != nil {
return err return err
} }
@ -120,11 +122,14 @@ func runClient(cfgFilePath string) error {
warning, err := validation.ValidateAllClientConfig(cfg, pxyCfgs, visitorCfgs) warning, err := validation.ValidateAllClientConfig(cfg, pxyCfgs, visitorCfgs)
if warning != nil { if warning != nil {
fmt.Printf("WARNING: %v\n", warning) fmt.Printf("WARNING: %v\n", warning)
if strictConfig {
return fmt.Errorf("warning: %v", warning)
}
} }
if err != nil { if err != nil {
return err return err
} }
return startService(cfg, pxyCfgs, visitorCfgs, cfgFilePath) return startService(cfg, pxyCfgs, visitorCfgs, cfgFilePath, strictConfig)
} }
func startService( func startService(
@ -132,6 +137,7 @@ func startService(
pxyCfgs []v1.ProxyConfigurer, pxyCfgs []v1.ProxyConfigurer,
visitorCfgs []v1.VisitorConfigurer, visitorCfgs []v1.VisitorConfigurer,
cfgFile string, cfgFile string,
strictConfig bool,
) error { ) error {
log.InitLog(cfg.Log.To, cfg.Log.Level, cfg.Log.MaxDays, cfg.Log.DisablePrintColor) log.InitLog(cfg.Log.To, cfg.Log.Level, cfg.Log.MaxDays, cfg.Log.DisablePrintColor)
@ -139,7 +145,7 @@ func startService(
log.Info("start frpc service for config file [%s]", cfgFile) log.Info("start frpc service for config file [%s]", cfgFile)
defer log.Info("frpc service for config file [%s] stopped", cfgFile) defer log.Info("frpc service for config file [%s] stopped", cfgFile)
} }
svr := client.NewService(cfg, pxyCfgs, visitorCfgs, cfgFile) svr := client.NewService(cfg, pxyCfgs, visitorCfgs, cfgFile, strictConfig)
shouldGracefulClose := cfg.Transport.Protocol == "kcp" || cfg.Transport.Protocol == "quic" shouldGracefulClose := cfg.Transport.Protocol == "kcp" || cfg.Transport.Protocol == "quic"
// Capture the exit signal if we use kcp or quic. // Capture the exit signal if we use kcp or quic.

View File

@ -37,7 +37,7 @@ var verifyCmd = &cobra.Command{
return nil return nil
} }
cliCfg, pxyCfgs, visitorCfgs, _, err := config.LoadClientConfig(cfgFile) cliCfg, pxyCfgs, visitorCfgs, _, err := config.LoadClientConfig(cfgFile, strictConfig)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
os.Exit(1) os.Exit(1)

View File

@ -32,6 +32,7 @@ import (
var ( var (
cfgFile string cfgFile string
showVersion bool showVersion bool
strictConfig bool
serverCfg v1.ServerConfig serverCfg v1.ServerConfig
) )
@ -39,6 +40,7 @@ var (
func init() { func init() {
rootCmd.PersistentFlags().StringVarP(&cfgFile, "config", "c", "", "config file of frps") rootCmd.PersistentFlags().StringVarP(&cfgFile, "config", "c", "", "config file of frps")
rootCmd.PersistentFlags().BoolVarP(&showVersion, "version", "v", false, "version of frps") rootCmd.PersistentFlags().BoolVarP(&showVersion, "version", "v", false, "version of frps")
rootCmd.PersistentFlags().BoolVarP(&strictConfig, "strict_config", "", false, "strict config parsing mode")
RegisterServerConfigFlags(rootCmd, &serverCfg) RegisterServerConfigFlags(rootCmd, &serverCfg)
} }
@ -58,7 +60,7 @@ var rootCmd = &cobra.Command{
err error err error
) )
if cfgFile != "" { if cfgFile != "" {
svrCfg, isLegacyFormat, err = config.LoadServerConfig(cfgFile) svrCfg, isLegacyFormat, err = config.LoadServerConfig(cfgFile, strictConfig)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
os.Exit(1) os.Exit(1)

View File

@ -36,7 +36,7 @@ var verifyCmd = &cobra.Command{
fmt.Println("frps: the configuration file is not specified") fmt.Println("frps: the configuration file is not specified")
return nil return nil
} }
svrCfg, _, err := config.LoadServerConfig(cfgFile) svrCfg, _, err := config.LoadServerConfig(cfgFile, strictConfig)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
os.Exit(1) os.Exit(1)

View File

@ -27,7 +27,7 @@ import (
"github.com/samber/lo" "github.com/samber/lo"
"gopkg.in/ini.v1" "gopkg.in/ini.v1"
"k8s.io/apimachinery/pkg/util/sets" "k8s.io/apimachinery/pkg/util/sets"
"k8s.io/apimachinery/pkg/util/yaml" yaml "k8s.io/apimachinery/pkg/util/yaml"
"github.com/fatedier/frp/pkg/config/legacy" "github.com/fatedier/frp/pkg/config/legacy"
v1 "github.com/fatedier/frp/pkg/config/v1" v1 "github.com/fatedier/frp/pkg/config/v1"
@ -100,27 +100,40 @@ func LoadFileContentWithTemplate(path string, values *Values) ([]byte, error) {
return RenderWithTemplate(b, values) return RenderWithTemplate(b, values)
} }
func LoadConfigureFromFile(path string, c any) error { func LoadConfigureFromFile(path string, c any, strict bool) error {
content, err := LoadFileContentWithTemplate(path, GetValues()) content, err := LoadFileContentWithTemplate(path, GetValues())
if err != nil { if err != nil {
return err return err
} }
return LoadConfigure(content, c) return LoadConfigure(content, c, strict)
} }
// LoadConfigure loads configuration from bytes and unmarshal into c. // LoadConfigure loads configuration from bytes and unmarshal into c.
// Now it supports json, yaml and toml format. // Now it supports json, yaml and toml format.
func LoadConfigure(b []byte, c any) error { func LoadConfigure(b []byte, c any, strict bool) error {
var tomlObj interface{} var tomlObj interface{}
// Try to unmarshal as TOML first; swallow errors from that (assume it's not valid TOML).
// TODO: caller should probably be able to specify the format, so we don't need to swallow errors.
if err := toml.Unmarshal(b, &tomlObj); err == nil { if err := toml.Unmarshal(b, &tomlObj); err == nil {
b, err = json.Marshal(&tomlObj) b, err = json.Marshal(&tomlObj)
if err != nil { if err != nil {
return err return err
} }
} }
decoder := yaml.NewYAMLOrJSONDecoder(bytes.NewBuffer(b), 4096) // If the buffer smells like JSON (first non-whitespace character is '{'), unmarshal as JSON directly.
if yaml.IsJSONBuffer(b) {
decoder := json.NewDecoder(bytes.NewBuffer(b))
if strict {
decoder.DisallowUnknownFields()
}
return decoder.Decode(c) return decoder.Decode(c)
} }
// It wasn't JSON. Unmarshal as YAML.
if strict {
return yaml.UnmarshalStrict(b, c)
}
return yaml.Unmarshal(b, c)
}
func NewProxyConfigurerFromMsg(m *msg.NewProxy, serverCfg *v1.ServerConfig) (v1.ProxyConfigurer, error) { func NewProxyConfigurerFromMsg(m *msg.NewProxy, serverCfg *v1.ServerConfig) (v1.ProxyConfigurer, error) {
m.ProxyType = util.EmptyOr(m.ProxyType, string(v1.ProxyTypeTCP)) m.ProxyType = util.EmptyOr(m.ProxyType, string(v1.ProxyTypeTCP))
@ -139,7 +152,7 @@ func NewProxyConfigurerFromMsg(m *msg.NewProxy, serverCfg *v1.ServerConfig) (v1.
return configurer, nil return configurer, nil
} }
func LoadServerConfig(path string) (*v1.ServerConfig, bool, error) { func LoadServerConfig(path string, strict bool) (*v1.ServerConfig, bool, error) {
var ( var (
svrCfg *v1.ServerConfig svrCfg *v1.ServerConfig
isLegacyFormat bool isLegacyFormat bool
@ -158,7 +171,7 @@ func LoadServerConfig(path string) (*v1.ServerConfig, bool, error) {
isLegacyFormat = true isLegacyFormat = true
} else { } else {
svrCfg = &v1.ServerConfig{} svrCfg = &v1.ServerConfig{}
if err := LoadConfigureFromFile(path, svrCfg); err != nil { if err := LoadConfigureFromFile(path, svrCfg, strict); err != nil {
return nil, false, err return nil, false, err
} }
} }
@ -168,7 +181,7 @@ func LoadServerConfig(path string) (*v1.ServerConfig, bool, error) {
return svrCfg, isLegacyFormat, nil return svrCfg, isLegacyFormat, nil
} }
func LoadClientConfig(path string) ( func LoadClientConfig(path string, strict bool) (
*v1.ClientCommonConfig, *v1.ClientCommonConfig,
[]v1.ProxyConfigurer, []v1.ProxyConfigurer,
[]v1.VisitorConfigurer, []v1.VisitorConfigurer,
@ -196,7 +209,7 @@ func LoadClientConfig(path string) (
isLegacyFormat = true isLegacyFormat = true
} else { } else {
allCfg := v1.ClientConfig{} allCfg := v1.ClientConfig{}
if err := LoadConfigureFromFile(path, &allCfg); err != nil { if err := LoadConfigureFromFile(path, &allCfg, strict); err != nil {
return nil, nil, nil, false, err return nil, nil, nil, false, err
} }
cliCfg = &allCfg.ClientCommonConfig cliCfg = &allCfg.ClientCommonConfig
@ -211,7 +224,7 @@ func LoadClientConfig(path string) (
// Load additional config from includes. // Load additional config from includes.
// legacy ini format alredy handle this in ParseClientConfig. // legacy ini format alredy handle this in ParseClientConfig.
if len(cliCfg.IncludeConfigFiles) > 0 && !isLegacyFormat { if len(cliCfg.IncludeConfigFiles) > 0 && !isLegacyFormat {
extPxyCfgs, extVisitorCfgs, err := LoadAdditionalClientConfigs(cliCfg.IncludeConfigFiles, isLegacyFormat) extPxyCfgs, extVisitorCfgs, err := LoadAdditionalClientConfigs(cliCfg.IncludeConfigFiles, isLegacyFormat, strict)
if err != nil { if err != nil {
return nil, nil, nil, isLegacyFormat, err return nil, nil, nil, isLegacyFormat, err
} }
@ -242,7 +255,7 @@ func LoadClientConfig(path string) (
return cliCfg, pxyCfgs, visitorCfgs, isLegacyFormat, nil return cliCfg, pxyCfgs, visitorCfgs, isLegacyFormat, nil
} }
func LoadAdditionalClientConfigs(paths []string, isLegacyFormat bool) ([]v1.ProxyConfigurer, []v1.VisitorConfigurer, error) { func LoadAdditionalClientConfigs(paths []string, isLegacyFormat bool, strict bool) ([]v1.ProxyConfigurer, []v1.VisitorConfigurer, error) {
pxyCfgs := make([]v1.ProxyConfigurer, 0) pxyCfgs := make([]v1.ProxyConfigurer, 0)
visitorCfgs := make([]v1.VisitorConfigurer, 0) visitorCfgs := make([]v1.VisitorConfigurer, 0)
for _, path := range paths { for _, path := range paths {
@ -265,7 +278,7 @@ func LoadAdditionalClientConfigs(paths []string, isLegacyFormat bool) ([]v1.Prox
if matched, _ := filepath.Match(filepath.Join(absDir, filepath.Base(path)), absFile); matched { if matched, _ := filepath.Match(filepath.Join(absDir, filepath.Base(path)), absFile); matched {
// support yaml/json/toml // support yaml/json/toml
cfg := v1.ClientConfig{} cfg := v1.ClientConfig{}
if err := LoadConfigureFromFile(absFile, &cfg); err != nil { if err := LoadConfigureFromFile(absFile, &cfg, strict); err != nil {
return nil, nil, fmt.Errorf("load additional config from %s error: %v", absFile, err) return nil, nil, fmt.Errorf("load additional config from %s error: %v", absFile, err)
} }
for _, c := range cfg.Proxies { for _, c := range cfg.Proxies {

View File

@ -15,6 +15,7 @@
package config package config
import ( import (
"strings"
"testing" "testing"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -22,9 +23,7 @@ import (
v1 "github.com/fatedier/frp/pkg/config/v1" v1 "github.com/fatedier/frp/pkg/config/v1"
) )
func TestLoadConfigure(t *testing.T) { const tomlServerContent = `
require := require.New(t)
content := `
bindAddr = "127.0.0.1" bindAddr = "127.0.0.1"
kcpBindPort = 7000 kcpBindPort = 7000
quicBindPort = 7001 quicBindPort = 7001
@ -33,8 +32,34 @@ custom404Page = "/abc.html"
transport.tcpKeepalive = 10 transport.tcpKeepalive = 10
` `
const yamlServerContent = `
bindAddr: 127.0.0.1
kcpBindPort: 7000
quicBindPort: 7001
tcpmuxHTTPConnectPort: 7005
custom404Page: /abc.html
transport:
tcpKeepalive: 10
`
const jsonServerContent = `
{
"bindAddr": "127.0.0.1",
"kcpBindPort": 7000,
"quicBindPort": 7001,
"tcpmuxHTTPConnectPort": 7005,
"custom404Page": "/abc.html",
"transport": {
"tcpKeepalive": 10
}
}
`
func TestLoadServerConfig(t *testing.T) {
for _, content := range []string{tomlServerContent, yamlServerContent, jsonServerContent} {
svrCfg := v1.ServerConfig{} svrCfg := v1.ServerConfig{}
err := LoadConfigure([]byte(content), &svrCfg) err := LoadConfigure([]byte(content), &svrCfg, true)
require := require.New(t)
require.NoError(err) require.NoError(err)
require.EqualValues("127.0.0.1", svrCfg.BindAddr) require.EqualValues("127.0.0.1", svrCfg.BindAddr)
require.EqualValues(7000, svrCfg.KCPBindPort) require.EqualValues(7000, svrCfg.KCPBindPort)
@ -43,3 +68,24 @@ transport.tcpKeepalive = 10
require.EqualValues("/abc.html", svrCfg.Custom404Page) require.EqualValues("/abc.html", svrCfg.Custom404Page)
require.EqualValues(10, svrCfg.Transport.TCPKeepAlive) require.EqualValues(10, svrCfg.Transport.TCPKeepAlive)
} }
}
// Test that loading in strict mode fails when the config is invalid.
func TestLoadServerConfigErrorMode(t *testing.T) {
for strict := range []bool{false, true} {
for _, content := range []string{tomlServerContent, yamlServerContent, jsonServerContent} {
// Break the content with an innocent typo
brokenContent := strings.Replace(content, "bindAddr", "bindAdur", 1)
svrCfg := v1.ServerConfig{}
err := LoadConfigure([]byte(brokenContent), &svrCfg, strict == 1)
require := require.New(t)
if strict == 1 {
require.ErrorContains(err, "bindAdur")
} else {
require.NoError(err)
// BindAddr didn't get parsed because of the typo.
require.EqualValues("", svrCfg.BindAddr)
}
}
}
}