diff --git a/models/config/proxy.go b/models/config/proxy.go index 4718a8c..84b9a8e 100644 --- a/models/config/proxy.go +++ b/models/config/proxy.go @@ -23,6 +23,7 @@ import ( "github.com/fatedier/frp/models/consts" "github.com/fatedier/frp/models/msg" + "github.com/fatedier/frp/utils/util" ini "github.com/vaughan0/go-ini" ) @@ -173,7 +174,8 @@ func (cfg *BindInfoConf) UnMarshalToMsg(pMsg *msg.NewProxy) { func (cfg *BindInfoConf) check() (err error) { if len(ServerCommonCfg.PrivilegeAllowPorts) != 0 { - if _, ok := ServerCommonCfg.PrivilegeAllowPorts[cfg.RemotePort]; !ok { + // TODO: once linstenPort used, should remove the port from privilege ports + if ok := util.ContainsPort(ServerCommonCfg.PrivilegeAllowPorts, cfg.RemotePort); !ok { return fmt.Errorf("remote port [%d] isn't allowed", cfg.RemotePort) } } diff --git a/models/config/server_common.go b/models/config/server_common.go index 070de8a..24fe0f5 100644 --- a/models/config/server_common.go +++ b/models/config/server_common.go @@ -19,6 +19,8 @@ import ( "strconv" "strings" + util "github.com/fatedier/frp/utils/util" + "github.com/prometheus/common/log" ini "github.com/vaughan0/go-ini" ) @@ -52,7 +54,7 @@ type ServerCommonConf struct { TcpMux bool // if PrivilegeAllowPorts is not nil, tcp proxies which remote port exist in this map can be connected - PrivilegeAllowPorts map[int64]struct{} + PrivilegeAllowPorts [][2]int64 MaxPoolCount int64 HeartBeatTimeout int64 UserConnTimeout int64 @@ -198,47 +200,12 @@ func LoadServerCommonConf(conf ini.File) (cfg *ServerCommonConf, err error) { return } - cfg.PrivilegeAllowPorts = make(map[int64]struct{}) - tmpStr, ok = conf.Get("common", "privilege_allow_ports") + allowPortsStr, ok := conf.Get("common", "privilege_allow_ports") + // TODO: check if conflicts exist in port ranges if ok { - // e.g. 1000-2000,2001,2002,3000-4000 - portRanges := strings.Split(tmpStr, ",") - for _, portRangeStr := range portRanges { - // 1000-2000 or 2001 - portArray := strings.Split(portRangeStr, "-") - // length: only 1 or 2 is correct - rangeType := len(portArray) - if rangeType == 1 { - // single port - singlePort, errRet := strconv.ParseInt(portArray[0], 10, 64) - if errRet != nil { - err = fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect, %v", errRet) - return - } - ServerCommonCfg.PrivilegeAllowPorts[singlePort] = struct{}{} - } else if rangeType == 2 { - // range ports - min, errRet := strconv.ParseInt(portArray[0], 10, 64) - if errRet != nil { - err = fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect, %v", errRet) - return - } - max, errRet := strconv.ParseInt(portArray[1], 10, 64) - if errRet != nil { - err = fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect, %v", errRet) - return - } - if max < min { - err = fmt.Errorf("Parse conf error: privilege_allow_ports range incorrect") - return - } - for i := min; i <= max; i++ { - cfg.PrivilegeAllowPorts[i] = struct{}{} - } - } else { - err = fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect") - return - } + cfg.PrivilegeAllowPorts, err = util.GetPortRanges(allowPortsStr) + if err != nil { + return } } } diff --git a/utils/util/util.go b/utils/util/util.go index 3e8bca5..5c44e51 100644 --- a/utils/util/util.go +++ b/utils/util/util.go @@ -19,6 +19,8 @@ import ( "crypto/rand" "encoding/hex" "fmt" + "strconv" + "strings" ) // RandId return a rand string used in frp. @@ -45,3 +47,66 @@ func GetAuthKey(token string, timestamp int64) (key string) { data := md5Ctx.Sum(nil) return hex.EncodeToString(data) } + +// for example: rangeStr is "1000-2000,2001,2002,3000-4000", return an array as port ranges. +func GetPortRanges(rangeStr string) (portRanges [][2]int64, err error) { + // for example: 1000-2000,2001,2002,3000-4000 + rangeArray := strings.Split(rangeStr, ",") + for _, portRangeStr := range rangeArray { + // 1000-2000 or 2001 + portArray := strings.Split(portRangeStr, "-") + // length: only 1 or 2 is correct + rangeType := len(portArray) + if rangeType == 1 { + singlePort, err := strconv.ParseInt(portArray[0], 10, 64) + if err != nil { + return [][2]int64{}, fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect, %v", err) + } + portRanges = append(portRanges, [2]int64{singlePort, singlePort}) + } else if rangeType == 2 { + min, err := strconv.ParseInt(portArray[0], 10, 64) + if err != nil { + return [][2]int64{}, fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect, %v", err) + } + max, err := strconv.ParseInt(portArray[1], 10, 64) + if err != nil { + return [][2]int64{}, fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect, %v", err) + } + if max < min { + return [][2]int64{}, fmt.Errorf("Parse conf error: privilege_allow_ports range incorrect") + } + portRanges = append(portRanges, [2]int64{min, max}) + } else { + return [][2]int64{}, fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect") + } + } + return portRanges, nil +} + +func ContainsPort(portRanges [][2]int64, port int64) bool { + for _, pr := range portRanges { + if port >= pr[0] && port <= pr[1] { + return true + } + } + return false +} + +func PortRangesCut(portRanges [][2]int64, port int64) [][2]int64 { + var tmpRanges [][2]int64 + for _, pr := range portRanges { + if port >= pr[0] && port <= pr[1] { + leftRange := [2]int64{pr[0], port - 1} + rightRange := [2]int64{port + 1, pr[1]} + if leftRange[0] <= leftRange[1] { + tmpRanges = append(tmpRanges, leftRange) + } + if rightRange[0] <= rightRange[1] { + tmpRanges = append(tmpRanges, rightRange) + } + } else { + tmpRanges = append(tmpRanges, pr) + } + } + return tmpRanges +}