From cf9193a42988356168f4dd66bbff7939e5fe5699 Mon Sep 17 00:00:00 2001 From: fatedier Date: Tue, 23 Jan 2018 01:29:52 +0800 Subject: [PATCH] newhttp: support websocket --- utils/vhost/newhttp.go | 5 ++++ utils/vhost/reverseproxy.go | 59 +++++++++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+) diff --git a/utils/vhost/newhttp.go b/utils/vhost/newhttp.go index 6e20a78..12391ca 100644 --- a/utils/vhost/newhttp.go +++ b/utils/vhost/newhttp.go @@ -79,6 +79,11 @@ func NewHttpReverseProxy() *HttpReverseProxy { return rp.CreateConnection(host, url) }, }, + WebSocketDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + url := ctx.Value("url").(string) + host := getHostFromAddr(ctx.Value("host").(string)) + return rp.CreateConnection(host, url) + }, BufferPool: newWrapPool(), ErrorLog: log.New(newWrapLogger(), "", 0), } diff --git a/utils/vhost/reverseproxy.go b/utils/vhost/reverseproxy.go index 610f999..365e0f2 100644 --- a/utils/vhost/reverseproxy.go +++ b/utils/vhost/reverseproxy.go @@ -16,6 +16,8 @@ import ( "strings" "sync" "time" + + frpIo "github.com/fatedier/frp/utils/io" ) // onExitFlushLoop is a callback set by tests to detect the state of the @@ -59,6 +61,8 @@ type ReverseProxy struct { // modifies the Response from the backend. // If it returns an error, the proxy returns a StatusBadGateway error. ModifyResponse func(*http.Response) error + + WebSocketDialContext func(ctx context.Context, network, addr string) (net.Conn, error) } // A BufferPool is an interface for getting and returning temporary @@ -139,6 +143,48 @@ var hopHeaders = []string{ } func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + if IsWebsocketRequest(req) { + p.serveWebSocket(rw, req) + } else { + p.serveHTTP(rw, req) + } +} + +func (p *ReverseProxy) serveWebSocket(rw http.ResponseWriter, req *http.Request) { + if p.WebSocketDialContext == nil { + rw.WriteHeader(500) + return + } + + req = req.WithContext(context.WithValue(req.Context(), "url", req.URL.Path)) + req = req.WithContext(context.WithValue(req.Context(), "host", req.Host)) + + targetConn, err := p.WebSocketDialContext(req.Context(), "tcp", "") + if err != nil { + rw.WriteHeader(501) + return + } + defer targetConn.Close() + + p.Director(req) + + hijacker, ok := rw.(http.Hijacker) + if !ok { + rw.WriteHeader(500) + return + } + conn, _, errHijack := hijacker.Hijack() + if errHijack != nil { + rw.WriteHeader(500) + return + } + defer conn.Close() + + req.Write(targetConn) + frpIo.Join(conn, targetConn) +} + +func (p *ReverseProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) { transport := p.Transport if transport == nil { transport = http.DefaultTransport @@ -368,3 +414,16 @@ func (m *maxLatencyWriter) flushLoop() { } func (m *maxLatencyWriter) stop() { m.done <- true } + +func IsWebsocketRequest(req *http.Request) bool { + containsHeader := func(name, value string) bool { + items := strings.Split(req.Header.Get(name), ",") + for _, item := range items { + if value == strings.ToLower(strings.TrimSpace(item)) { + return true + } + } + return false + } + return containsHeader("Connection", "upgrade") && containsHeader("Upgrade", "websocket") +}