diff --git a/src/cmd/frps/main.go b/src/cmd/frps/main.go index 8e94d8d..3459ce0 100644 --- a/src/cmd/frps/main.go +++ b/src/cmd/frps/main.go @@ -15,6 +15,7 @@ package main import ( + "encoding/base64" "encoding/json" "fmt" "io/ioutil" @@ -38,7 +39,7 @@ var usage string = `frps is the server of frp Usage: frps [-c config_file] [-L log_file] [--log-level=] [--addr=] - frps --reload + frps [-c config_file] --reload frps -h | --help frps -v | --version @@ -68,7 +69,18 @@ func main() { // reload check if args["--reload"] != nil { if args["--reload"].(bool) { - resp, err := http.Get("http://" + server.BindAddr + ":" + fmt.Sprintf("%d", server.DashboardPort) + "/api/reload") + req, err := http.NewRequest("GET", "http://"+server.BindAddr+":"+fmt.Sprintf("%d", server.DashboardPort)+"/api/reload", nil) + if err != nil { + fmt.Printf("frps reload error: %v\n", err) + os.Exit(1) + } + + authStr := "Basic " + base64.StdEncoding.EncodeToString([]byte(server.DashboardUsername+":"+server.DashboardPassword)) + + req.Header.Add("Authorization", authStr) + defaultClient := &http.Client{} + resp, err := defaultClient.Do(req) + if err != nil { fmt.Printf("frps reload error: %v\n", err) os.Exit(1) diff --git a/src/models/server/dashboard.go b/src/models/server/dashboard.go index 4be7a6d..6960108 100644 --- a/src/models/server/dashboard.go +++ b/src/models/server/dashboard.go @@ -15,9 +15,11 @@ package server import ( + "encoding/base64" "fmt" "net" "net/http" + "strings" "time" "github.com/fatedier/frp/src/assets" @@ -32,13 +34,14 @@ func RunDashboardServer(addr string, port int64) (err error) { // url router mux := http.NewServeMux() // api, see dashboard_api.go - mux.HandleFunc("/api/reload", apiReload) + // mux.HandleFunc("/api/reload", apiReload) + mux.HandleFunc("/api/reload", use(apiReload, basicAuth)) mux.HandleFunc("/api/proxies", apiProxies) // view, see dashboard_view.go mux.Handle("/favicon.ico", http.FileServer(assets.FileSystem)) mux.Handle("/static/", http.StripPrefix("/static/", http.FileServer(assets.FileSystem))) - mux.HandleFunc("/", viewDashboard) + mux.HandleFunc("/", use(viewDashboard, basicAuth)) address := fmt.Sprintf("%s:%d", addr, port) server := &http.Server{ @@ -58,3 +61,43 @@ func RunDashboardServer(addr string, port int64) (err error) { go server.Serve(ln) return } + +func use(h http.HandlerFunc, middleware ...func(http.HandlerFunc) http.HandlerFunc) http.HandlerFunc { + for _, m := range middleware { + h = m(h) + } + + return h +} + +func basicAuth(h http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + + w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) + + s := strings.SplitN(r.Header.Get("Authorization"), " ", 2) + if len(s) != 2 { + http.Error(w, "Not authorized", 401) + return + } + + b, err := base64.StdEncoding.DecodeString(s[1]) + if err != nil { + http.Error(w, err.Error(), 401) + return + } + + pair := strings.SplitN(string(b), ":", 2) + if len(pair) != 2 { + http.Error(w, "Not authorized", 401) + return + } + + if pair[0] != DashboardUsername || pair[1] != DashboardPassword { + http.Error(w, "Not authorized", 401) + return + } + + h.ServeHTTP(w, r) + } +}