From 86b2e686a5db67873d4d618925b496c081e0b735 Mon Sep 17 00:00:00 2001 From: fatedier Date: Tue, 3 Aug 2021 22:58:03 +0800 Subject: [PATCH] vhost: use new readClientHello function (#2504) --- pkg/util/vhost/https.go | 189 +++++----------------- pkg/util/vhost/https_test.go | 38 +++++ test/e2e/mock/server/httpserver/server.go | 8 +- 3 files changed, 80 insertions(+), 155 deletions(-) create mode 100644 pkg/util/vhost/https_test.go diff --git a/pkg/util/vhost/https.go b/pkg/util/vhost/https.go index a2b980b..dd20739 100644 --- a/pkg/util/vhost/https.go +++ b/pkg/util/vhost/https.go @@ -15,32 +15,12 @@ package vhost import ( - "fmt" + "crypto/tls" "io" "net" - "strings" "time" gnet "github.com/fatedier/golib/net" - "github.com/fatedier/golib/pool" -) - -const ( - typeClientHello uint8 = 1 // Type client hello -) - -// TLS extension numbers -const ( - extensionServerName uint16 = 0 - extensionStatusRequest uint16 = 5 - extensionSupportedCurves uint16 = 10 - extensionSupportedPoints uint16 = 11 - extensionSignatureAlgorithms uint16 = 13 - extensionALPN uint16 = 16 - extensionSCT uint16 = 18 - extensionSessionTicket uint16 = 35 - extensionNextProtoNeg uint16 = 13172 // not IANA assigned - extensionRenegotiationInfo uint16 = 0xff01 ) type HTTPSMuxer struct { @@ -52,142 +32,49 @@ func NewHTTPSMuxer(listener net.Listener, timeout time.Duration) (*HTTPSMuxer, e return &HTTPSMuxer{mux}, err } -func readHandshake(rd io.Reader) (host string, err error) { - data := pool.GetBuf(1024) - origin := data - defer pool.PutBuf(origin) - - _, err = io.ReadFull(rd, data[:47]) - if err != nil { - return - } - - length, err := rd.Read(data[47:]) - if err != nil { - return - } - length += 47 - data = data[:length] - if uint8(data[5]) != typeClientHello { - err = fmt.Errorf("readHandshake: type[%d] is not clientHello", uint16(data[5])) - return - } - - // session - sessionIDLen := int(data[43]) - if sessionIDLen > 32 || len(data) < 44+sessionIDLen { - err = fmt.Errorf("readHandshake: sessionIdLen[%d] is long", sessionIDLen) - return - } - data = data[44+sessionIDLen:] - if len(data) < 2 { - err = fmt.Errorf("readHandshake: dataLen[%d] after session is short", len(data)) - return - } - - // cipher suite numbers - cipherSuiteLen := int(data[0])<<8 | int(data[1]) - if cipherSuiteLen%2 == 1 || len(data) < 2+cipherSuiteLen { - err = fmt.Errorf("readHandshake: dataLen[%d] after cipher suite is short", len(data)) - return - } - data = data[2+cipherSuiteLen:] - if len(data) < 1 { - err = fmt.Errorf("readHandshake: cipherSuiteLen[%d] is long", cipherSuiteLen) - return - } - - // compression method - compressionMethodsLen := int(data[0]) - if len(data) < 1+compressionMethodsLen { - err = fmt.Errorf("readHandshake: compressionMethodsLen[%d] is long", compressionMethodsLen) - return - } - - data = data[1+compressionMethodsLen:] - if len(data) == 0 { - // ClientHello is optionally followed by extension data - err = fmt.Errorf("readHandshake: there is no extension data to get servername") - return - } - if len(data) < 2 { - err = fmt.Errorf("readHandshake: extension dataLen[%d] is too short", len(data)) - return - } - - extensionsLength := int(data[0])<<8 | int(data[1]) - data = data[2:] - if extensionsLength != len(data) { - err = fmt.Errorf("readHandshake: extensionsLen[%d] is not equal to dataLen[%d]", extensionsLength, len(data)) - return - } - for len(data) != 0 { - if len(data) < 4 { - err = fmt.Errorf("readHandshake: extensionsDataLen[%d] is too short", len(data)) - return - } - extension := uint16(data[0])<<8 | uint16(data[1]) - length := int(data[2])<<8 | int(data[3]) - data = data[4:] - if len(data) < length { - err = fmt.Errorf("readHandshake: extensionLen[%d] is long", length) - return - } - - switch extension { - case extensionRenegotiationInfo: - if length != 1 || data[0] != 0 { - err = fmt.Errorf("readHandshake: extension reNegotiationInfoLen[%d] is short", length) - return - } - case extensionNextProtoNeg: - case extensionStatusRequest: - case extensionServerName: - d := data[:length] - if len(d) < 2 { - err = fmt.Errorf("readHandshake: remiaining dataLen[%d] is short", len(d)) - return - } - namesLen := int(d[0])<<8 | int(d[1]) - d = d[2:] - if len(d) != namesLen { - err = fmt.Errorf("readHandshake: nameListLen[%d] is not equal to dataLen[%d]", namesLen, len(d)) - return - } - for len(d) > 0 { - if len(d) < 3 { - err = fmt.Errorf("readHandshake: extension serverNameLen[%d] is short", len(d)) - return - } - nameType := d[0] - nameLen := int(d[1])<<8 | int(d[2]) - d = d[3:] - if len(d) < nameLen { - err = fmt.Errorf("readHandshake: nameLen[%d] is not equal to dataLen[%d]", nameLen, len(d)) - return - } - if nameType == 0 { - serverName := string(d[:nameLen]) - host = strings.TrimSpace(serverName) - return host, nil - } - d = d[nameLen:] - } - } - data = data[length:] - } - err = fmt.Errorf("Unknown error") - return -} - func GetHTTPSHostname(c net.Conn) (_ net.Conn, _ map[string]string, err error) { reqInfoMap := make(map[string]string, 0) sc, rd := gnet.NewSharedConn(c) - host, err := readHandshake(rd) + + clientHello, err := readClientHello(rd) if err != nil { return nil, reqInfoMap, err } - reqInfoMap["Host"] = host + + reqInfoMap["Host"] = clientHello.ServerName reqInfoMap["Scheme"] = "https" return sc, reqInfoMap, nil } + +func readClientHello(reader io.Reader) (*tls.ClientHelloInfo, error) { + var hello *tls.ClientHelloInfo + + // Note that Handshake always fails because the readOnlyConn is not a real connection. + // As long as the Client Hello is successfully read, the failure should only happen after GetConfigForClient is called, + // so we only care about the error if hello was never set. + err := tls.Server(readOnlyConn{reader: reader}, &tls.Config{ + GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) { + hello = &tls.ClientHelloInfo{} + *hello = *argHello + return nil, nil + }, + }).Handshake() + + if hello == nil { + return nil, err + } + return hello, nil +} + +type readOnlyConn struct { + reader io.Reader +} + +func (conn readOnlyConn) Read(p []byte) (int, error) { return conn.reader.Read(p) } +func (conn readOnlyConn) Write(p []byte) (int, error) { return 0, io.ErrClosedPipe } +func (conn readOnlyConn) Close() error { return nil } +func (conn readOnlyConn) LocalAddr() net.Addr { return nil } +func (conn readOnlyConn) RemoteAddr() net.Addr { return nil } +func (conn readOnlyConn) SetDeadline(t time.Time) error { return nil } +func (conn readOnlyConn) SetReadDeadline(t time.Time) error { return nil } +func (conn readOnlyConn) SetWriteDeadline(t time.Time) error { return nil } diff --git a/pkg/util/vhost/https_test.go b/pkg/util/vhost/https_test.go new file mode 100644 index 0000000..47fb9da --- /dev/null +++ b/pkg/util/vhost/https_test.go @@ -0,0 +1,38 @@ +package vhost + +import ( + "crypto/tls" + "net" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestGetHTTPSHostname(t *testing.T) { + require := require.New(t) + + l, err := net.Listen("tcp", ":") + require.NoError(err) + defer l.Close() + + var conn net.Conn + go func() { + conn, _ = l.Accept() + require.NotNil(conn) + }() + + go func() { + time.Sleep(100 * time.Millisecond) + tls.Dial("tcp", l.Addr().String(), &tls.Config{ + InsecureSkipVerify: true, + ServerName: "example.com", + }) + }() + + time.Sleep(200 * time.Millisecond) + _, infos, err := GetHTTPSHostname(conn) + require.NoError(err) + require.Equal("example.com", infos["Host"]) + require.Equal("https", infos["Scheme"]) +} diff --git a/test/e2e/mock/server/httpserver/server.go b/test/e2e/mock/server/httpserver/server.go index 90b3d39..f35c119 100644 --- a/test/e2e/mock/server/httpserver/server.go +++ b/test/e2e/mock/server/httpserver/server.go @@ -11,7 +11,7 @@ import ( type Server struct { bindAddr string bindPort int - hanlder http.Handler + handler http.Handler l net.Listener tlsConfig *tls.Config @@ -54,14 +54,14 @@ func WithTlsConfig(tlsConfig *tls.Config) Option { func WithHandler(h http.Handler) Option { return func(s *Server) *Server { - s.hanlder = h + s.handler = h return s } } func WithResponse(resp []byte) Option { return func(s *Server) *Server { - s.hanlder = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + s.handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write(resp) }) return s @@ -76,7 +76,7 @@ func (s *Server) Run() error { addr := net.JoinHostPort(s.bindAddr, strconv.Itoa(s.bindPort)) hs := &http.Server{ Addr: addr, - Handler: s.hanlder, + Handler: s.handler, TLSConfig: s.tlsConfig, }