diff --git a/client/visitor.go b/client/visitor.go index 76f28d9..e7a22d8 100644 --- a/client/visitor.go +++ b/client/visitor.go @@ -259,7 +259,11 @@ func (sv *XtcpVisitor) handleConn(userConn frpNet.Conn) { sv.Trace("send all detect msg done") // Listen for visitorConn's address and wait for client connection. - lConn, _ := net.ListenUDP("udp", laddr) + lConn, err := net.ListenUDP("udp", laddr) + if err != nil { + sv.Error("listen on visitorConn's local adress error: %v", err) + return + } lConn.SetReadDeadline(time.Now().Add(5 * time.Second)) sidBuf := pool.GetBuf(1024) n, _, err = lConn.ReadFromUDP(sidBuf) diff --git a/models/config/proxy.go b/models/config/proxy.go index 75eb138..21e8664 100644 --- a/models/config/proxy.go +++ b/models/config/proxy.go @@ -635,7 +635,7 @@ func (cfg *StcpProxyConf) LoadFromFile(name string, section ini.Section) (err er if tmpStr == "server" || tmpStr == "visitor" { cfg.Role = tmpStr } else { - cfg.Role = "server" + return fmt.Errorf("Parse conf error: incorrect role [%s]", tmpStr) } cfg.Sk = section["sk"] @@ -724,7 +724,7 @@ func (cfg *XtcpProxyConf) LoadFromFile(name string, section ini.Section) (err er if tmpStr == "server" || tmpStr == "visitor" { cfg.Role = tmpStr } else { - cfg.Role = "server" + return fmt.Errorf("Parse conf error: incorrect role [%s]", tmpStr) } cfg.Sk = section["sk"] diff --git a/models/msg/msg.go b/models/msg/msg.go index 22bf5f6..aac0ce7 100644 --- a/models/msg/msg.go +++ b/models/msg/msg.go @@ -181,5 +181,5 @@ type NatHoleResp struct { } type NatHoleSid struct { - Sid string `json;"sid"` + Sid string `json:"sid"` } diff --git a/models/plugin/http_proxy.go b/models/plugin/http_proxy.go index aaee5a1..f5fed6c 100644 --- a/models/plugin/http_proxy.go +++ b/models/plugin/http_proxy.go @@ -111,7 +111,7 @@ func (hp *HttpProxy) Handle(conn io.ReadWriteCloser) { if realConn, ok := conn.(frpNet.Conn); ok { wrapConn = realConn } else { - wrapConn = frpNet.WrapReadWriteCloserToConn(conn) + wrapConn = frpNet.WrapReadWriteCloserToConn(conn, realConn) } sc, rd := frpNet.NewShareConn(wrapConn) diff --git a/models/plugin/socks5.go b/models/plugin/socks5.go index d3b82e1..b0f1bb2 100644 --- a/models/plugin/socks5.go +++ b/models/plugin/socks5.go @@ -50,7 +50,7 @@ func (sp *Socks5Plugin) Handle(conn io.ReadWriteCloser) { if realConn, ok := conn.(frpNet.Conn); ok { wrapConn = realConn } else { - wrapConn = frpNet.WrapReadWriteCloserToConn(conn) + wrapConn = frpNet.WrapReadWriteCloserToConn(conn, realConn) } sp.Server.ServeConn(wrapConn) diff --git a/server/manager.go b/server/manager.go index c023d18..ebc0928 100644 --- a/server/manager.go +++ b/server/manager.go @@ -146,7 +146,7 @@ func (vm *VisitorManager) NewConn(name string, conn frpNet.Conn, timestamp int64 if useCompression { rwc = frpIo.WithCompression(rwc) } - err = l.PutConn(frpNet.WrapReadWriteCloserToConn(rwc)) + err = l.PutConn(frpNet.WrapReadWriteCloserToConn(rwc, conn)) } else { err = fmt.Errorf("custom listener for [%s] doesn't exist", name) return diff --git a/server/proxy.go b/server/proxy.go index d76358c..8ce1e2b 100644 --- a/server/proxy.go +++ b/server/proxy.go @@ -189,13 +189,16 @@ func (pxy *TcpProxy) Close() { type HttpProxy struct { BaseProxy cfg *config.HttpProxyConf + + closeFuncs []func() } func (pxy *HttpProxy) Run() (err error) { - routeConfig := &vhost.VhostRouteConfig{ - RewriteHost: pxy.cfg.HostHeaderRewrite, - Username: pxy.cfg.HttpUser, - Password: pxy.cfg.HttpPwd, + routeConfig := vhost.VhostRouteConfig{ + RewriteHost: pxy.cfg.HostHeaderRewrite, + Username: pxy.cfg.HttpUser, + Password: pxy.cfg.HttpPwd, + CreateConnFn: pxy.GetRealConn, } locations := pxy.cfg.Locations @@ -206,13 +209,16 @@ func (pxy *HttpProxy) Run() (err error) { routeConfig.Domain = domain for _, location := range locations { routeConfig.Location = location - l, err := pxy.ctl.svr.VhostHttpMuxer.Listen(routeConfig) + err := pxy.ctl.svr.httpReverseProxy.Register(routeConfig) if err != nil { return err } - l.AddLogPrefix(pxy.name) + tmpDomain := routeConfig.Domain + tmpLocation := routeConfig.Location + pxy.closeFuncs = append(pxy.closeFuncs, func() { + pxy.ctl.svr.httpReverseProxy.UnRegister(tmpDomain, tmpLocation) + }) pxy.Info("http proxy listen for host [%s] location [%s]", routeConfig.Domain, routeConfig.Location) - pxy.listeners = append(pxy.listeners, l) } } @@ -220,17 +226,18 @@ func (pxy *HttpProxy) Run() (err error) { routeConfig.Domain = pxy.cfg.SubDomain + "." + config.ServerCommonCfg.SubDomainHost for _, location := range locations { routeConfig.Location = location - l, err := pxy.ctl.svr.VhostHttpMuxer.Listen(routeConfig) + err := pxy.ctl.svr.httpReverseProxy.Register(routeConfig) if err != nil { return err } - l.AddLogPrefix(pxy.name) + tmpDomain := routeConfig.Domain + tmpLocation := routeConfig.Location + pxy.closeFuncs = append(pxy.closeFuncs, func() { + pxy.ctl.svr.httpReverseProxy.UnRegister(tmpDomain, tmpLocation) + }) pxy.Info("http proxy listen for host [%s] location [%s]", routeConfig.Domain, routeConfig.Location) - pxy.listeners = append(pxy.listeners, l) } } - - pxy.startListenHandler(pxy, HandleUserTcpConnection) return } @@ -238,8 +245,33 @@ func (pxy *HttpProxy) GetConf() config.ProxyConf { return pxy.cfg } +func (pxy *HttpProxy) GetRealConn() (workConn frpNet.Conn, err error) { + tmpConn, errRet := pxy.GetWorkConnFromPool() + if errRet != nil { + err = errRet + return + } + + var rwc io.ReadWriteCloser = tmpConn + if pxy.cfg.UseEncryption { + rwc, err = frpIo.WithEncryption(rwc, []byte(config.ServerCommonCfg.PrivilegeToken)) + if err != nil { + pxy.Error("create encryption stream error: %v", err) + return + } + } + if pxy.cfg.UseCompression { + rwc = frpIo.WithCompression(rwc) + } + workConn = frpNet.WrapReadWriteCloserToConn(rwc, tmpConn) + return +} + func (pxy *HttpProxy) Close() { pxy.BaseProxy.Close() + for _, closeFn := range pxy.closeFuncs { + closeFn() + } } type HttpsProxy struct { diff --git a/server/service.go b/server/service.go index 7799d3e..5997f3d 100644 --- a/server/service.go +++ b/server/service.go @@ -16,6 +16,8 @@ package server import ( "fmt" + "net" + "net/http" "time" "github.com/fatedier/frp/assets" @@ -44,12 +46,11 @@ type Service struct { // Accept connections using kcp. kcpListener frpNet.Listener - // For http proxies, route requests to different clients by hostname and other infomation. - VhostHttpMuxer *vhost.HttpMuxer - // For https proxies, route requests to different clients by hostname and other infomation. VhostHttpsMuxer *vhost.HttpsMuxer + httpReverseProxy *vhost.HttpReverseProxy + // Manage all controllers. ctlManager *ControlManager @@ -93,22 +94,26 @@ func NewService() (svr *Service, err error) { err = fmt.Errorf("Listen on kcp address udp [%s:%d] error: %v", cfg.BindAddr, cfg.KcpBindPort, err) return } - log.Info("frps kcp listen on udp %s:%d", cfg.BindAddr, cfg.BindPort) + log.Info("frps kcp listen on udp %s:%d", cfg.BindAddr, cfg.KcpBindPort) } // Create http vhost muxer. if cfg.VhostHttpPort > 0 { - var l frpNet.Listener - l, err = frpNet.ListenTcp(cfg.ProxyBindAddr, cfg.VhostHttpPort) + rp := vhost.NewHttpReverseProxy() + svr.httpReverseProxy = rp + + address := fmt.Sprintf("%s:%d", cfg.ProxyBindAddr, cfg.VhostHttpPort) + server := &http.Server{ + Addr: address, + Handler: rp, + } + var l net.Listener + l, err = net.Listen("tcp", address) if err != nil { err = fmt.Errorf("Create vhost http listener error, %v", err) return } - svr.VhostHttpMuxer, err = vhost.NewHttpMuxer(l, 30*time.Second) - if err != nil { - err = fmt.Errorf("Create vhost httpMuxer error, %v", err) - return - } + go server.Serve(l) log.Info("http service listen on %s:%d", cfg.ProxyBindAddr, cfg.VhostHttpPort) } diff --git a/utils/net/conn.go b/utils/net/conn.go index 392fb98..c1f6f46 100644 --- a/utils/net/conn.go +++ b/utils/net/conn.go @@ -49,32 +49,50 @@ func WrapConn(c net.Conn) Conn { type WrapReadWriteCloserConn struct { io.ReadWriteCloser log.Logger + + underConn net.Conn } -func WrapReadWriteCloserToConn(rwc io.ReadWriteCloser) Conn { +func WrapReadWriteCloserToConn(rwc io.ReadWriteCloser, underConn net.Conn) Conn { return &WrapReadWriteCloserConn{ ReadWriteCloser: rwc, Logger: log.NewPrefixLogger(""), + underConn: underConn, } } func (conn *WrapReadWriteCloserConn) LocalAddr() net.Addr { + if conn.underConn != nil { + return conn.underConn.LocalAddr() + } return (*net.TCPAddr)(nil) } func (conn *WrapReadWriteCloserConn) RemoteAddr() net.Addr { + if conn.underConn != nil { + return conn.underConn.RemoteAddr() + } return (*net.TCPAddr)(nil) } func (conn *WrapReadWriteCloserConn) SetDeadline(t time.Time) error { + if conn.underConn != nil { + return conn.underConn.SetDeadline(t) + } return &net.OpError{Op: "set", Net: "wrap", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} } func (conn *WrapReadWriteCloserConn) SetReadDeadline(t time.Time) error { + if conn.underConn != nil { + return conn.underConn.SetReadDeadline(t) + } return &net.OpError{Op: "set", Net: "wrap", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} } func (conn *WrapReadWriteCloserConn) SetWriteDeadline(t time.Time) error { + if conn.underConn != nil { + return conn.underConn.SetWriteDeadline(t) + } return &net.OpError{Op: "set", Net: "wrap", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} } diff --git a/utils/version/version.go b/utils/version/version.go index 2edcf07..95e5fc4 100644 --- a/utils/version/version.go +++ b/utils/version/version.go @@ -19,7 +19,7 @@ import ( "strings" ) -var version string = "0.14.0" +var version string = "0.14.1" func Full() string { return version diff --git a/utils/vhost/newhttp.go b/utils/vhost/newhttp.go new file mode 100644 index 0000000..6e20a78 --- /dev/null +++ b/utils/vhost/newhttp.go @@ -0,0 +1,186 @@ +// Copyright 2017 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vhost + +import ( + "bytes" + "context" + "errors" + "log" + "net" + "net/http" + "strings" + "sync" + "time" + + frpLog "github.com/fatedier/frp/utils/log" + "github.com/fatedier/frp/utils/pool" +) + +var ( + responseHeaderTimeout = time.Duration(30) * time.Second + + ErrRouterConfigConflict = errors.New("router config conflict") + ErrNoDomain = errors.New("no such domain") +) + +func getHostFromAddr(addr string) (host string) { + strs := strings.Split(addr, ":") + if len(strs) > 1 { + host = strs[0] + } else { + host = addr + } + return +} + +type HttpReverseProxy struct { + proxy *ReverseProxy + tr *http.Transport + + vhostRouter *VhostRouters + + cfgMu sync.RWMutex +} + +func NewHttpReverseProxy() *HttpReverseProxy { + rp := &HttpReverseProxy{ + vhostRouter: NewVhostRouters(), + } + proxy := &ReverseProxy{ + Director: func(req *http.Request) { + req.URL.Scheme = "http" + url := req.Context().Value("url").(string) + host := getHostFromAddr(req.Context().Value("host").(string)) + host = rp.GetRealHost(host, url) + if host != "" { + req.Host = host + } + req.URL.Host = req.Host + }, + Transport: &http.Transport{ + ResponseHeaderTimeout: responseHeaderTimeout, + DisableKeepAlives: true, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + url := ctx.Value("url").(string) + host := getHostFromAddr(ctx.Value("host").(string)) + return rp.CreateConnection(host, url) + }, + }, + BufferPool: newWrapPool(), + ErrorLog: log.New(newWrapLogger(), "", 0), + } + rp.proxy = proxy + return rp +} + +func (rp *HttpReverseProxy) Register(routeCfg VhostRouteConfig) error { + rp.cfgMu.Lock() + defer rp.cfgMu.Unlock() + _, ok := rp.vhostRouter.Exist(routeCfg.Domain, routeCfg.Location) + if ok { + return ErrRouterConfigConflict + } else { + rp.vhostRouter.Add(routeCfg.Domain, routeCfg.Location, &routeCfg) + } + return nil +} + +func (rp *HttpReverseProxy) UnRegister(domain string, location string) { + rp.cfgMu.Lock() + defer rp.cfgMu.Unlock() + rp.vhostRouter.Del(domain, location) +} + +func (rp *HttpReverseProxy) GetRealHost(domain string, location string) (host string) { + vr, ok := rp.getVhost(domain, location) + if ok { + host = vr.payload.(*VhostRouteConfig).RewriteHost + } + return +} + +func (rp *HttpReverseProxy) CreateConnection(domain string, location string) (net.Conn, error) { + vr, ok := rp.getVhost(domain, location) + if ok { + fn := vr.payload.(*VhostRouteConfig).CreateConnFn + if fn != nil { + return fn() + } + } + return nil, ErrNoDomain +} + +func (rp *HttpReverseProxy) CheckAuth(domain, location, user, passwd string) bool { + vr, ok := rp.getVhost(domain, location) + if ok { + checkUser := vr.payload.(*VhostRouteConfig).Username + checkPasswd := vr.payload.(*VhostRouteConfig).Password + if (checkUser != "" || checkPasswd != "") && (checkUser != user || checkPasswd != passwd) { + return false + } + } + return true +} + +func (rp *HttpReverseProxy) getVhost(domain string, location string) (vr *VhostRouter, ok bool) { + rp.cfgMu.RLock() + defer rp.cfgMu.RUnlock() + + // first we check the full hostname + // if not exist, then check the wildcard_domain such as *.example.com + vr, ok = rp.vhostRouter.Get(domain, location) + if ok { + return + } + + domainSplit := strings.Split(domain, ".") + if len(domainSplit) < 3 { + return vr, false + } + domainSplit[0] = "*" + domain = strings.Join(domainSplit, ".") + vr, ok = rp.vhostRouter.Get(domain, location) + return +} + +func (rp *HttpReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + domain := getHostFromAddr(req.Host) + location := req.URL.Path + user, passwd, _ := req.BasicAuth() + if !rp.CheckAuth(domain, location, user, passwd) { + rw.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) + http.Error(rw, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + return + } + rp.proxy.ServeHTTP(rw, req) +} + +type wrapPool struct{} + +func newWrapPool() *wrapPool { return &wrapPool{} } + +func (p *wrapPool) Get() []byte { return pool.GetBuf(32 * 1024) } + +func (p *wrapPool) Put(buf []byte) { pool.PutBuf(buf) } + +type wrapLogger struct{} + +func newWrapLogger() *wrapLogger { return &wrapLogger{} } + +func (l *wrapLogger) Write(p []byte) (n int, err error) { + frpLog.Warn("%s", string(bytes.TrimRight(p, "\n"))) + return len(p), nil +} diff --git a/utils/vhost/reverseproxy.go b/utils/vhost/reverseproxy.go new file mode 100644 index 0000000..610f999 --- /dev/null +++ b/utils/vhost/reverseproxy.go @@ -0,0 +1,370 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// HTTP reverse proxy handler + +package vhost + +import ( + "context" + "io" + "log" + "net" + "net/http" + "net/url" + "strings" + "sync" + "time" +) + +// onExitFlushLoop is a callback set by tests to detect the state of the +// flushLoop() goroutine. +var onExitFlushLoop func() + +// ReverseProxy is an HTTP Handler that takes an incoming request and +// sends it to another server, proxying the response back to the +// client. +type ReverseProxy struct { + // Director must be a function which modifies + // the request into a new request to be sent + // using Transport. Its response is then copied + // back to the original client unmodified. + // Director must not access the provided Request + // after returning. + Director func(*http.Request) + + // The transport used to perform proxy requests. + // If nil, http.DefaultTransport is used. + Transport http.RoundTripper + + // FlushInterval specifies the flush interval + // to flush to the client while copying the + // response body. + // If zero, no periodic flushing is done. + FlushInterval time.Duration + + // ErrorLog specifies an optional logger for errors + // that occur when attempting to proxy the request. + // If nil, logging goes to os.Stderr via the log package's + // standard logger. + ErrorLog *log.Logger + + // BufferPool optionally specifies a buffer pool to + // get byte slices for use by io.CopyBuffer when + // copying HTTP response bodies. + BufferPool BufferPool + + // ModifyResponse is an optional function that + // modifies the Response from the backend. + // If it returns an error, the proxy returns a StatusBadGateway error. + ModifyResponse func(*http.Response) error +} + +// A BufferPool is an interface for getting and returning temporary +// byte slices for use by io.CopyBuffer. +type BufferPool interface { + Get() []byte + Put([]byte) +} + +func singleJoiningSlash(a, b string) string { + aslash := strings.HasSuffix(a, "/") + bslash := strings.HasPrefix(b, "/") + switch { + case aslash && bslash: + return a + b[1:] + case !aslash && !bslash: + return a + "/" + b + } + return a + b +} + +// NewSingleHostReverseProxy returns a new ReverseProxy that routes +// URLs to the scheme, host, and base path provided in target. If the +// target's path is "/base" and the incoming request was for "/dir", +// the target request will be for /base/dir. +// NewSingleHostReverseProxy does not rewrite the Host header. +// To rewrite Host headers, use ReverseProxy directly with a custom +// Director policy. +func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy { + targetQuery := target.RawQuery + director := func(req *http.Request) { + req.URL.Scheme = target.Scheme + req.URL.Host = target.Host + req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path) + if targetQuery == "" || req.URL.RawQuery == "" { + req.URL.RawQuery = targetQuery + req.URL.RawQuery + } else { + req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery + } + if _, ok := req.Header["User-Agent"]; !ok { + // explicitly disable User-Agent so it's not set to default value + req.Header.Set("User-Agent", "") + } + } + return &ReverseProxy{Director: director} +} + +func copyHeader(dst, src http.Header) { + for k, vv := range src { + for _, v := range vv { + dst.Add(k, v) + } + } +} + +func cloneHeader(h http.Header) http.Header { + h2 := make(http.Header, len(h)) + for k, vv := range h { + vv2 := make([]string, len(vv)) + copy(vv2, vv) + h2[k] = vv2 + } + return h2 +} + +// Hop-by-hop headers. These are removed when sent to the backend. +// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html +var hopHeaders = []string{ + "Connection", + "Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google + "Keep-Alive", + "Proxy-Authenticate", + "Proxy-Authorization", + "Te", // canonicalized version of "TE" + "Trailer", // not Trailers per URL above; http://www.rfc-editor.org/errata_search.php?eid=4522 + "Transfer-Encoding", + "Upgrade", +} + +func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + transport := p.Transport + if transport == nil { + transport = http.DefaultTransport + } + + ctx := req.Context() + if cn, ok := rw.(http.CloseNotifier); ok { + var cancel context.CancelFunc + ctx, cancel = context.WithCancel(ctx) + defer cancel() + notifyChan := cn.CloseNotify() + go func() { + select { + case <-notifyChan: + cancel() + case <-ctx.Done(): + } + }() + } + + outreq := req.WithContext(ctx) // includes shallow copies of maps, but okay + if req.ContentLength == 0 { + outreq.Body = nil // Issue 16036: nil Body for http.Transport retries + } + + outreq.Header = cloneHeader(req.Header) + + // Modify for frp + outreq = outreq.WithContext(context.WithValue(outreq.Context(), "url", req.URL.Path)) + outreq = outreq.WithContext(context.WithValue(outreq.Context(), "host", req.Host)) + + p.Director(outreq) + outreq.Close = false + + // Remove hop-by-hop headers listed in the "Connection" header. + // See RFC 2616, section 14.10. + if c := outreq.Header.Get("Connection"); c != "" { + for _, f := range strings.Split(c, ",") { + if f = strings.TrimSpace(f); f != "" { + outreq.Header.Del(f) + } + } + } + + // Remove hop-by-hop headers to the backend. Especially + // important is "Connection" because we want a persistent + // connection, regardless of what the client sent to us. + for _, h := range hopHeaders { + if outreq.Header.Get(h) != "" { + outreq.Header.Del(h) + } + } + + if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { + // If we aren't the first proxy retain prior + // X-Forwarded-For information as a comma+space + // separated list and fold multiple headers into one. + if prior, ok := outreq.Header["X-Forwarded-For"]; ok { + clientIP = strings.Join(prior, ", ") + ", " + clientIP + } + outreq.Header.Set("X-Forwarded-For", clientIP) + } + + res, err := transport.RoundTrip(outreq) + if err != nil { + p.logf("http: proxy error: %v", err) + rw.WriteHeader(http.StatusNotFound) + rw.Write([]byte(NotFound)) + return + } + + // Remove hop-by-hop headers listed in the + // "Connection" header of the response. + if c := res.Header.Get("Connection"); c != "" { + for _, f := range strings.Split(c, ",") { + if f = strings.TrimSpace(f); f != "" { + res.Header.Del(f) + } + } + } + + for _, h := range hopHeaders { + res.Header.Del(h) + } + + if p.ModifyResponse != nil { + if err := p.ModifyResponse(res); err != nil { + p.logf("http: proxy error: %v", err) + rw.WriteHeader(http.StatusBadGateway) + return + } + } + + copyHeader(rw.Header(), res.Header) + + // The "Trailer" header isn't included in the Transport's response, + // at least for *http.Transport. Build it up from Trailer. + announcedTrailers := len(res.Trailer) + if announcedTrailers > 0 { + trailerKeys := make([]string, 0, len(res.Trailer)) + for k := range res.Trailer { + trailerKeys = append(trailerKeys, k) + } + rw.Header().Add("Trailer", strings.Join(trailerKeys, ", ")) + } + + rw.WriteHeader(res.StatusCode) + if len(res.Trailer) > 0 { + // Force chunking if we saw a response trailer. + // This prevents net/http from calculating the length for short + // bodies and adding a Content-Length. + if fl, ok := rw.(http.Flusher); ok { + fl.Flush() + } + } + p.copyResponse(rw, res.Body) + res.Body.Close() // close now, instead of defer, to populate res.Trailer + + if len(res.Trailer) == announcedTrailers { + copyHeader(rw.Header(), res.Trailer) + return + } + + for k, vv := range res.Trailer { + k = http.TrailerPrefix + k + for _, v := range vv { + rw.Header().Add(k, v) + } + } +} + +func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) { + if p.FlushInterval != 0 { + if wf, ok := dst.(writeFlusher); ok { + mlw := &maxLatencyWriter{ + dst: wf, + latency: p.FlushInterval, + done: make(chan bool), + } + go mlw.flushLoop() + defer mlw.stop() + dst = mlw + } + } + + var buf []byte + if p.BufferPool != nil { + buf = p.BufferPool.Get() + } + p.copyBuffer(dst, src, buf) + if p.BufferPool != nil { + p.BufferPool.Put(buf) + } +} + +func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) { + if len(buf) == 0 { + buf = make([]byte, 32*1024) + } + var written int64 + for { + nr, rerr := src.Read(buf) + if rerr != nil && rerr != io.EOF && rerr != context.Canceled { + p.logf("httputil: ReverseProxy read error during body copy: %v", rerr) + } + if nr > 0 { + nw, werr := dst.Write(buf[:nr]) + if nw > 0 { + written += int64(nw) + } + if werr != nil { + return written, werr + } + if nr != nw { + return written, io.ErrShortWrite + } + } + if rerr != nil { + return written, rerr + } + } +} + +func (p *ReverseProxy) logf(format string, args ...interface{}) { + if p.ErrorLog != nil { + p.ErrorLog.Printf(format, args...) + } else { + log.Printf(format, args...) + } +} + +type writeFlusher interface { + io.Writer + http.Flusher +} + +type maxLatencyWriter struct { + dst writeFlusher + latency time.Duration + + mu sync.Mutex // protects Write + Flush + done chan bool +} + +func (m *maxLatencyWriter) Write(p []byte) (int, error) { + m.mu.Lock() + defer m.mu.Unlock() + return m.dst.Write(p) +} + +func (m *maxLatencyWriter) flushLoop() { + t := time.NewTicker(m.latency) + defer t.Stop() + for { + select { + case <-m.done: + if onExitFlushLoop != nil { + onExitFlushLoop() + } + return + case <-t.C: + m.mu.Lock() + m.dst.Flush() + m.mu.Unlock() + } + } +} + +func (m *maxLatencyWriter) stop() { m.done <- true } diff --git a/utils/vhost/router.go b/utils/vhost/router.go index 975119e..37a34fb 100644 --- a/utils/vhost/router.go +++ b/utils/vhost/router.go @@ -14,7 +14,8 @@ type VhostRouters struct { type VhostRouter struct { domain string location string - listener *Listener + + payload interface{} } func NewVhostRouters() *VhostRouters { @@ -23,7 +24,7 @@ func NewVhostRouters() *VhostRouters { } } -func (r *VhostRouters) Add(domain, location string, l *Listener) { +func (r *VhostRouters) Add(domain, location string, payload interface{}) { r.mutex.Lock() defer r.mutex.Unlock() @@ -35,7 +36,7 @@ func (r *VhostRouters) Add(domain, location string, l *Listener) { vr := &VhostRouter{ domain: domain, location: location, - listener: l, + payload: payload, } vrs = append(vrs, vr) diff --git a/utils/vhost/vhost.go b/utils/vhost/vhost.go index 7e11d72..4920d32 100644 --- a/utils/vhost/vhost.go +++ b/utils/vhost/vhost.go @@ -50,12 +50,16 @@ func NewVhostMuxer(listener frpNet.Listener, vhostFunc muxFunc, authFunc httpAut return mux, nil } +type CreateConnFunc func() (frpNet.Conn, error) + type VhostRouteConfig struct { Domain string Location string RewriteHost string Username string Password string + + CreateConnFn CreateConnFunc } // listen for a new domain name, if rewriteHost is not empty and rewriteFunc is not nil @@ -91,7 +95,7 @@ func (v *VhostMuxer) getListener(name, path string) (l *Listener, exist bool) { // if not exist, then check the wildcard_domain such as *.example.com vr, found := v.registryRouter.Get(name, path) if found { - return vr.listener, true + return vr.payload.(*Listener), true } domainSplit := strings.Split(name, ".") @@ -106,7 +110,7 @@ func (v *VhostMuxer) getListener(name, path string) (l *Listener, exist bool) { return } - return vr.listener, true + return vr.payload.(*Listener), true } func (v *VhostMuxer) run() {