vhost: set DisableKeepAlives = false and fix websocket not work

This commit is contained in:
fatedier 2021-01-18 21:49:44 +08:00
parent c842558ace
commit 46f809d711
3 changed files with 48 additions and 18 deletions

View File

@ -74,4 +74,4 @@ func hasPort(host string) bool {
return true return true
} }
return host[0] == '[' && strings.Contains(host, "]:") return host[0] == '[' && strings.Contains(host, "]:")
} }

View File

@ -17,6 +17,7 @@ package vhost
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/base64"
"errors" "errors"
"fmt" "fmt"
"log" "log"
@ -59,20 +60,25 @@ func NewHTTPReverseProxy(option HTTPReverseProxyOptions, vhostRouter *Routers) *
req.URL.Scheme = "http" req.URL.Scheme = "http"
url := req.Context().Value(RouteInfoURL).(string) url := req.Context().Value(RouteInfoURL).(string)
oldHost := util.GetHostFromAddr(req.Context().Value(RouteInfoHost).(string)) oldHost := util.GetHostFromAddr(req.Context().Value(RouteInfoHost).(string))
host := rp.GetRealHost(oldHost, url) rc := rp.GetRouteConfig(oldHost, url)
if host != "" { if rc != nil {
req.Host = host if rc.RewriteHost != "" {
} req.Host = rc.RewriteHost
req.URL.Host = req.Host }
// Set {domain}.{location} as URL host here to let http transport reuse connections.
req.URL.Host = rc.Domain + "." + base64.StdEncoding.EncodeToString([]byte(rc.Location))
headers := rp.GetHeaders(oldHost, url) for k, v := range rc.Headers {
for k, v := range headers { req.Header.Set(k, v)
req.Header.Set(k, v) }
} else {
req.URL.Host = req.Host
} }
}, },
Transport: &http.Transport{ Transport: &http.Transport{
ResponseHeaderTimeout: rp.responseHeaderTimeout, ResponseHeaderTimeout: rp.responseHeaderTimeout,
DisableKeepAlives: true, IdleConnTimeout: 60 * time.Second,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
url := ctx.Value(RouteInfoURL).(string) url := ctx.Value(RouteInfoURL).(string)
host := util.GetHostFromAddr(ctx.Value(RouteInfoHost).(string)) host := util.GetHostFromAddr(ctx.Value(RouteInfoHost).(string))
@ -107,6 +113,14 @@ func (rp *HTTPReverseProxy) UnRegister(domain string, location string) {
rp.vhostRouter.Del(domain, location) rp.vhostRouter.Del(domain, location)
} }
func (rp *HTTPReverseProxy) GetRouteConfig(domain string, location string) *RouteConfig {
vr, ok := rp.getVhost(domain, location)
if ok {
return vr.payload.(*RouteConfig)
}
return nil
}
func (rp *HTTPReverseProxy) GetRealHost(domain string, location string) (host string) { func (rp *HTTPReverseProxy) GetRealHost(domain string, location string) (host string) {
vr, ok := rp.getVhost(domain, location) vr, ok := rp.getVhost(domain, location)
if ok { if ok {

View File

@ -139,6 +139,7 @@ func TestHealthCheck(t *testing.T) {
} }
httpSvc3 := mock.NewHTTPServer(15005, func(w http.ResponseWriter, r *http.Request) { httpSvc3 := mock.NewHTTPServer(15005, func(w http.ResponseWriter, r *http.Request) {
time.Sleep(time.Second)
w.Write([]byte("http3")) w.Write([]byte("http3"))
}) })
err = httpSvc3.Start() err = httpSvc3.Start()
@ -147,6 +148,7 @@ func TestHealthCheck(t *testing.T) {
} }
httpSvc4 := mock.NewHTTPServer(15006, func(w http.ResponseWriter, r *http.Request) { httpSvc4 := mock.NewHTTPServer(15006, func(w http.ResponseWriter, r *http.Request) {
time.Sleep(time.Second)
w.Write([]byte("http4")) w.Write([]byte("http4"))
}) })
err = httpSvc4.Start() err = httpSvc4.Start()
@ -277,16 +279,30 @@ func TestHealthCheck(t *testing.T) {
// ****** load balancing type http ****** // ****** load balancing type http ******
result = make([]string, 0) result = make([]string, 0)
var wait sync.WaitGroup
var mu sync.Mutex
wait.Add(2)
code, body, _, err = util.SendHTTPMsg("GET", "http://127.0.0.1:14000/xxx", "test.balancing.com", nil, "") go func() {
assert.NoError(err) defer wait.Done()
assert.Equal(200, code) code, body, _, err := util.SendHTTPMsg("GET", "http://127.0.0.1:14000/xxx", "test.balancing.com", nil, "")
result = append(result, body) assert.NoError(err)
assert.Equal(200, code)
mu.Lock()
result = append(result, body)
mu.Unlock()
}()
code, body, _, err = util.SendHTTPMsg("GET", "http://127.0.0.1:14000/xxx", "test.balancing.com", nil, "") go func() {
assert.NoError(err) defer wait.Done()
assert.Equal(200, code) code, body, _, err = util.SendHTTPMsg("GET", "http://127.0.0.1:14000/xxx", "test.balancing.com", nil, "")
result = append(result, body) assert.NoError(err)
assert.Equal(200, code)
mu.Lock()
result = append(result, body)
mu.Unlock()
}()
wait.Wait()
assert.Contains(result, "http3") assert.Contains(result, "http3")
assert.Contains(result, "http4") assert.Contains(result, "http4")