mirror of
https://gitee.com/IrisVega/frp.git
synced 2024-11-01 22:31:29 +08:00
vhost: use new readClientHello function (#2504)
This commit is contained in:
parent
09f39de74e
commit
86b2e686a5
@ -15,32 +15,12 @@
|
|||||||
package vhost
|
package vhost
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"crypto/tls"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
gnet "github.com/fatedier/golib/net"
|
gnet "github.com/fatedier/golib/net"
|
||||||
"github.com/fatedier/golib/pool"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
typeClientHello uint8 = 1 // Type client hello
|
|
||||||
)
|
|
||||||
|
|
||||||
// TLS extension numbers
|
|
||||||
const (
|
|
||||||
extensionServerName uint16 = 0
|
|
||||||
extensionStatusRequest uint16 = 5
|
|
||||||
extensionSupportedCurves uint16 = 10
|
|
||||||
extensionSupportedPoints uint16 = 11
|
|
||||||
extensionSignatureAlgorithms uint16 = 13
|
|
||||||
extensionALPN uint16 = 16
|
|
||||||
extensionSCT uint16 = 18
|
|
||||||
extensionSessionTicket uint16 = 35
|
|
||||||
extensionNextProtoNeg uint16 = 13172 // not IANA assigned
|
|
||||||
extensionRenegotiationInfo uint16 = 0xff01
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type HTTPSMuxer struct {
|
type HTTPSMuxer struct {
|
||||||
@ -52,142 +32,49 @@ func NewHTTPSMuxer(listener net.Listener, timeout time.Duration) (*HTTPSMuxer, e
|
|||||||
return &HTTPSMuxer{mux}, err
|
return &HTTPSMuxer{mux}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func readHandshake(rd io.Reader) (host string, err error) {
|
|
||||||
data := pool.GetBuf(1024)
|
|
||||||
origin := data
|
|
||||||
defer pool.PutBuf(origin)
|
|
||||||
|
|
||||||
_, err = io.ReadFull(rd, data[:47])
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
length, err := rd.Read(data[47:])
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
length += 47
|
|
||||||
data = data[:length]
|
|
||||||
if uint8(data[5]) != typeClientHello {
|
|
||||||
err = fmt.Errorf("readHandshake: type[%d] is not clientHello", uint16(data[5]))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// session
|
|
||||||
sessionIDLen := int(data[43])
|
|
||||||
if sessionIDLen > 32 || len(data) < 44+sessionIDLen {
|
|
||||||
err = fmt.Errorf("readHandshake: sessionIdLen[%d] is long", sessionIDLen)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
data = data[44+sessionIDLen:]
|
|
||||||
if len(data) < 2 {
|
|
||||||
err = fmt.Errorf("readHandshake: dataLen[%d] after session is short", len(data))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// cipher suite numbers
|
|
||||||
cipherSuiteLen := int(data[0])<<8 | int(data[1])
|
|
||||||
if cipherSuiteLen%2 == 1 || len(data) < 2+cipherSuiteLen {
|
|
||||||
err = fmt.Errorf("readHandshake: dataLen[%d] after cipher suite is short", len(data))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
data = data[2+cipherSuiteLen:]
|
|
||||||
if len(data) < 1 {
|
|
||||||
err = fmt.Errorf("readHandshake: cipherSuiteLen[%d] is long", cipherSuiteLen)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// compression method
|
|
||||||
compressionMethodsLen := int(data[0])
|
|
||||||
if len(data) < 1+compressionMethodsLen {
|
|
||||||
err = fmt.Errorf("readHandshake: compressionMethodsLen[%d] is long", compressionMethodsLen)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
data = data[1+compressionMethodsLen:]
|
|
||||||
if len(data) == 0 {
|
|
||||||
// ClientHello is optionally followed by extension data
|
|
||||||
err = fmt.Errorf("readHandshake: there is no extension data to get servername")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if len(data) < 2 {
|
|
||||||
err = fmt.Errorf("readHandshake: extension dataLen[%d] is too short", len(data))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
extensionsLength := int(data[0])<<8 | int(data[1])
|
|
||||||
data = data[2:]
|
|
||||||
if extensionsLength != len(data) {
|
|
||||||
err = fmt.Errorf("readHandshake: extensionsLen[%d] is not equal to dataLen[%d]", extensionsLength, len(data))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
for len(data) != 0 {
|
|
||||||
if len(data) < 4 {
|
|
||||||
err = fmt.Errorf("readHandshake: extensionsDataLen[%d] is too short", len(data))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
extension := uint16(data[0])<<8 | uint16(data[1])
|
|
||||||
length := int(data[2])<<8 | int(data[3])
|
|
||||||
data = data[4:]
|
|
||||||
if len(data) < length {
|
|
||||||
err = fmt.Errorf("readHandshake: extensionLen[%d] is long", length)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
switch extension {
|
|
||||||
case extensionRenegotiationInfo:
|
|
||||||
if length != 1 || data[0] != 0 {
|
|
||||||
err = fmt.Errorf("readHandshake: extension reNegotiationInfoLen[%d] is short", length)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
case extensionNextProtoNeg:
|
|
||||||
case extensionStatusRequest:
|
|
||||||
case extensionServerName:
|
|
||||||
d := data[:length]
|
|
||||||
if len(d) < 2 {
|
|
||||||
err = fmt.Errorf("readHandshake: remiaining dataLen[%d] is short", len(d))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
namesLen := int(d[0])<<8 | int(d[1])
|
|
||||||
d = d[2:]
|
|
||||||
if len(d) != namesLen {
|
|
||||||
err = fmt.Errorf("readHandshake: nameListLen[%d] is not equal to dataLen[%d]", namesLen, len(d))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
for len(d) > 0 {
|
|
||||||
if len(d) < 3 {
|
|
||||||
err = fmt.Errorf("readHandshake: extension serverNameLen[%d] is short", len(d))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
nameType := d[0]
|
|
||||||
nameLen := int(d[1])<<8 | int(d[2])
|
|
||||||
d = d[3:]
|
|
||||||
if len(d) < nameLen {
|
|
||||||
err = fmt.Errorf("readHandshake: nameLen[%d] is not equal to dataLen[%d]", nameLen, len(d))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if nameType == 0 {
|
|
||||||
serverName := string(d[:nameLen])
|
|
||||||
host = strings.TrimSpace(serverName)
|
|
||||||
return host, nil
|
|
||||||
}
|
|
||||||
d = d[nameLen:]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
data = data[length:]
|
|
||||||
}
|
|
||||||
err = fmt.Errorf("Unknown error")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetHTTPSHostname(c net.Conn) (_ net.Conn, _ map[string]string, err error) {
|
func GetHTTPSHostname(c net.Conn) (_ net.Conn, _ map[string]string, err error) {
|
||||||
reqInfoMap := make(map[string]string, 0)
|
reqInfoMap := make(map[string]string, 0)
|
||||||
sc, rd := gnet.NewSharedConn(c)
|
sc, rd := gnet.NewSharedConn(c)
|
||||||
host, err := readHandshake(rd)
|
|
||||||
|
clientHello, err := readClientHello(rd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, reqInfoMap, err
|
return nil, reqInfoMap, err
|
||||||
}
|
}
|
||||||
reqInfoMap["Host"] = host
|
|
||||||
|
reqInfoMap["Host"] = clientHello.ServerName
|
||||||
reqInfoMap["Scheme"] = "https"
|
reqInfoMap["Scheme"] = "https"
|
||||||
return sc, reqInfoMap, nil
|
return sc, reqInfoMap, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func readClientHello(reader io.Reader) (*tls.ClientHelloInfo, error) {
|
||||||
|
var hello *tls.ClientHelloInfo
|
||||||
|
|
||||||
|
// Note that Handshake always fails because the readOnlyConn is not a real connection.
|
||||||
|
// As long as the Client Hello is successfully read, the failure should only happen after GetConfigForClient is called,
|
||||||
|
// so we only care about the error if hello was never set.
|
||||||
|
err := tls.Server(readOnlyConn{reader: reader}, &tls.Config{
|
||||||
|
GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||||
|
hello = &tls.ClientHelloInfo{}
|
||||||
|
*hello = *argHello
|
||||||
|
return nil, nil
|
||||||
|
},
|
||||||
|
}).Handshake()
|
||||||
|
|
||||||
|
if hello == nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return hello, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type readOnlyConn struct {
|
||||||
|
reader io.Reader
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn readOnlyConn) Read(p []byte) (int, error) { return conn.reader.Read(p) }
|
||||||
|
func (conn readOnlyConn) Write(p []byte) (int, error) { return 0, io.ErrClosedPipe }
|
||||||
|
func (conn readOnlyConn) Close() error { return nil }
|
||||||
|
func (conn readOnlyConn) LocalAddr() net.Addr { return nil }
|
||||||
|
func (conn readOnlyConn) RemoteAddr() net.Addr { return nil }
|
||||||
|
func (conn readOnlyConn) SetDeadline(t time.Time) error { return nil }
|
||||||
|
func (conn readOnlyConn) SetReadDeadline(t time.Time) error { return nil }
|
||||||
|
func (conn readOnlyConn) SetWriteDeadline(t time.Time) error { return nil }
|
||||||
|
38
pkg/util/vhost/https_test.go
Normal file
38
pkg/util/vhost/https_test.go
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
package vhost
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetHTTPSHostname(t *testing.T) {
|
||||||
|
require := require.New(t)
|
||||||
|
|
||||||
|
l, err := net.Listen("tcp", ":")
|
||||||
|
require.NoError(err)
|
||||||
|
defer l.Close()
|
||||||
|
|
||||||
|
var conn net.Conn
|
||||||
|
go func() {
|
||||||
|
conn, _ = l.Accept()
|
||||||
|
require.NotNil(conn)
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
tls.Dial("tcp", l.Addr().String(), &tls.Config{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
ServerName: "example.com",
|
||||||
|
})
|
||||||
|
}()
|
||||||
|
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
_, infos, err := GetHTTPSHostname(conn)
|
||||||
|
require.NoError(err)
|
||||||
|
require.Equal("example.com", infos["Host"])
|
||||||
|
require.Equal("https", infos["Scheme"])
|
||||||
|
}
|
@ -11,7 +11,7 @@ import (
|
|||||||
type Server struct {
|
type Server struct {
|
||||||
bindAddr string
|
bindAddr string
|
||||||
bindPort int
|
bindPort int
|
||||||
hanlder http.Handler
|
handler http.Handler
|
||||||
|
|
||||||
l net.Listener
|
l net.Listener
|
||||||
tlsConfig *tls.Config
|
tlsConfig *tls.Config
|
||||||
@ -54,14 +54,14 @@ func WithTlsConfig(tlsConfig *tls.Config) Option {
|
|||||||
|
|
||||||
func WithHandler(h http.Handler) Option {
|
func WithHandler(h http.Handler) Option {
|
||||||
return func(s *Server) *Server {
|
return func(s *Server) *Server {
|
||||||
s.hanlder = h
|
s.handler = h
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithResponse(resp []byte) Option {
|
func WithResponse(resp []byte) Option {
|
||||||
return func(s *Server) *Server {
|
return func(s *Server) *Server {
|
||||||
s.hanlder = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
s.handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Write(resp)
|
w.Write(resp)
|
||||||
})
|
})
|
||||||
return s
|
return s
|
||||||
@ -76,7 +76,7 @@ func (s *Server) Run() error {
|
|||||||
addr := net.JoinHostPort(s.bindAddr, strconv.Itoa(s.bindPort))
|
addr := net.JoinHostPort(s.bindAddr, strconv.Itoa(s.bindPort))
|
||||||
hs := &http.Server{
|
hs := &http.Server{
|
||||||
Addr: addr,
|
Addr: addr,
|
||||||
Handler: s.hanlder,
|
Handler: s.handler,
|
||||||
TLSConfig: s.tlsConfig,
|
TLSConfig: s.tlsConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user