utils/vhost: update for vhost_https

This commit is contained in:
fatedier 2016-06-24 15:43:58 +08:00
parent 0a50c3bd82
commit b14441d5cd
6 changed files with 108 additions and 121 deletions

View File

@ -21,12 +21,14 @@ bind_addr = 0.0.0.0
listen_port = 6000 listen_port = 6000
[web01] [web01]
type = https # if type equals http, vhost_http_port must be set
type = http
auth_token = 123 auth_token = 123
# if proxy type equals http, custom_domains must be set separated by commas # if proxy type equals http, custom_domains must be set separated by commas
custom_domains = web01.yourdomain.com,web01.yourdomain2.com custom_domains = web01.yourdomain.com,web01.yourdomain2.com
[web02] [web02]
type = http # if type equals https, vhost_https_port must be set
type = https
auth_token = 123 auth_token = 123
custom_domains = web02.yourdomain.com custom_domains = web02.yourdomain.com

View File

@ -225,11 +225,16 @@ func doLogin(req *msg.ControlReq, c *conn.Conn) (ret int64, info string) {
} }
// check if vhost_port is set // check if vhost_port is set
if s.Type == "http" && server.VhostMuxer == nil { if s.Type == "http" && server.VhostHttpMuxer == nil {
info = fmt.Sprintf("ProxyName [%s], type [http] not support when vhost_http_port is not set", req.ProxyName) info = fmt.Sprintf("ProxyName [%s], type [http] not support when vhost_http_port is not set", req.ProxyName)
log.Warn(info) log.Warn(info)
return return
} }
if s.Type == "https" && server.VhostHttpsMuxer == nil {
info = fmt.Sprintf("ProxyName [%s], type [https] not support when vhost_https_port is not set", req.ProxyName)
log.Warn(info)
return
}
// set infomations from frpc // set infomations from frpc
s.UseEncryption = req.UseEncryption s.UseEncryption = req.UseEncryption

View File

@ -31,8 +31,8 @@ var (
ConfigFile string = "./frps.ini" ConfigFile string = "./frps.ini"
BindAddr string = "0.0.0.0" BindAddr string = "0.0.0.0"
BindPort int64 = 7000 BindPort int64 = 7000
VhostHttpPort int64 = 0 // if VhostHttpPort equals 0, don't listen a public port for http VhostHttpPort int64 = 0 // if VhostHttpPort equals 0, don't listen a public port for http protocol
VhostHttpsPort int64 = 0 // if VhostHttpsPort equals 0, don't listen a public port for http VhostHttpsPort int64 = 0 // if VhostHttpsPort equals 0, don't listen a public port for https protocol
DashboardPort int64 = 0 // if DashboardPort equals 0, dashboard is not available DashboardPort int64 = 0 // if DashboardPort equals 0, dashboard is not available
LogFile string = "console" LogFile string = "console"
LogWay string = "console" // console or file LogWay string = "console" // console or file
@ -102,7 +102,6 @@ func loadCommonConf(confFile string) error {
} else { } else {
VhostHttpsPort = 0 VhostHttpsPort = 0
} }
vhost.VhostHttpsPort = VhostHttpsPort
tmpStr, ok = conf.Get("common", "dashboard_port") tmpStr, ok = conf.Get("common", "dashboard_port")
if ok { if ok {
@ -183,34 +182,25 @@ func loadProxyConf(confFile string) (proxyServers map[string]*ProxyServer, err e
// for http // for http
domainStr, ok := section["custom_domains"] domainStr, ok := section["custom_domains"]
if ok { if ok {
var suffix string
if VhostHttpPort != 80 {
suffix = fmt.Sprintf(":%d", VhostHttpPort)
}
proxyServer.CustomDomains = strings.Split(domainStr, ",") proxyServer.CustomDomains = strings.Split(domainStr, ",")
if len(proxyServer.CustomDomains) == 0 { if len(proxyServer.CustomDomains) == 0 {
return proxyServers, fmt.Errorf("Parse conf error: proxy [%s] custom_domains must be set when type equals http", proxyServer.Name) return proxyServers, fmt.Errorf("Parse conf error: proxy [%s] custom_domains must be set when type equals http", proxyServer.Name)
} }
for i, domain := range proxyServer.CustomDomains { for i, domain := range proxyServer.CustomDomains {
proxyServer.CustomDomains[i] = strings.ToLower(strings.TrimSpace(domain)) + suffix proxyServer.CustomDomains[i] = strings.ToLower(strings.TrimSpace(domain))
} }
} }
} else if proxyServer.Type == "https" { } else if proxyServer.Type == "https" {
// for https // for https
domainStr, ok := section["custom_domains"] domainStr, ok := section["custom_domains"]
if ok { if ok {
var suffix string
if VhostHttpsPort != 443 {
suffix = fmt.Sprintf(":%d", VhostHttpsPort)
}
proxyServer.CustomDomains = strings.Split(domainStr, ",") proxyServer.CustomDomains = strings.Split(domainStr, ",")
if len(proxyServer.CustomDomains) == 0 { if len(proxyServer.CustomDomains) == 0 {
return proxyServers, fmt.Errorf("Parse conf error: proxy [%s] custom_domains must be set when type equals http", proxyServer.Name) return proxyServers, fmt.Errorf("Parse conf error: proxy [%s] custom_domains must be set when type equals https", proxyServer.Name)
} }
for i, domain := range proxyServer.CustomDomains { for i, domain := range proxyServer.CustomDomains {
proxyServer.CustomDomains[i] = strings.ToLower(strings.TrimSpace(domain)) + suffix proxyServer.CustomDomains[i] = strings.ToLower(strings.TrimSpace(domain))
} }
log.Info("proxyServer: %+v", proxyServer.CustomDomains)
} }
} }
proxyServers[proxyServer.Name] = proxyServer proxyServers[proxyServer.Name] = proxyServer

View File

@ -0,0 +1,47 @@
// Copyright 2016 fatedier, fatedier@gmail.com
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package vhost
import (
"bufio"
"net"
"net/http"
"strings"
"time"
"frp/utils/conn"
)
type HttpMuxer struct {
*VhostMuxer
}
func GetHttpHostname(c *conn.Conn) (_ net.Conn, routerName string, err error) {
sc, rd := newShareConn(c.TcpConn)
request, err := http.ReadRequest(bufio.NewReader(rd))
if err != nil {
return sc, "", err
}
tmpArr := strings.Split(request.Host, ":")
routerName = tmpArr[0]
request.Body.Close()
return sc, routerName, nil
}
func NewHttpMuxer(listener *conn.Listener, timeout time.Duration) (*HttpMuxer, error) {
mux, err := NewVhostMuxer(listener, GetHttpHostname, timeout)
return &HttpMuxer{mux}, err
}

View File

@ -15,25 +15,13 @@
package vhost package vhost
import ( import (
_ "bufio"
_ "bytes"
_ "crypto/tls"
"errors"
"fmt" "fmt"
"frp/utils/conn"
"frp/utils/log"
"io" "io"
_ "io/ioutil"
"net" "net"
_ "net/http"
"strings" "strings"
_ "sync"
"time" "time"
)
var ( "frp/utils/conn"
maxHandshake int64 = 65536 // maximum handshake we support (protocol max is 16 MB)
VhostHttpsPort int64 = 443
) )
const ( const (
@ -58,160 +46,140 @@ type HttpsMuxer struct {
*VhostMuxer *VhostMuxer
} }
/* func NewHttpsMuxer(listener *conn.Listener, timeout time.Duration) (*HttpsMuxer, error) {
RFC document: http://tools.ietf.org/html/rfc5246 mux, err := NewVhostMuxer(listener, GetHttpsHostname, timeout)
*/ return &HttpsMuxer{mux}, err
func errMsgToLog(format string, a ...interface{}) error {
errMsg := fmt.Sprintf(format, a...)
log.Warn(errMsg)
return errors.New(errMsg)
} }
func readHandshake(rd io.Reader) (string, error) { func readHandshake(rd io.Reader) (host string, err error) {
data := make([]byte, 1024) data := make([]byte, 1024)
length, err := rd.Read(data) length, err := rd.Read(data)
if err != nil { if err != nil {
return "", errMsgToLog("read err:%v", err) return
} else { } else {
if length < 47 { if length < 47 {
return "", errMsgToLog("readHandshake: proto length[%d] is too short", length) err = fmt.Errorf("readHandshake: proto length[%d] is too short", length)
return
} }
} }
data = data[:length] data = data[:length]
//log.Warn("data: %+v", data)
if uint8(data[5]) != typeClientHello { if uint8(data[5]) != typeClientHello {
return "", errMsgToLog("readHandshake: type[%d] is not clientHello", uint16(data[5])) err = fmt.Errorf("readHandshake: type[%d] is not clientHello", uint16(data[5]))
return
} }
//version and random
//tlsVersion := uint16(data[9])<<8 | uint16(data[10])
//random := data[11:43]
// session // session
sessionIdLen := int(data[43]) sessionIdLen := int(data[43])
if sessionIdLen > 32 || len(data) < 44+sessionIdLen { if sessionIdLen > 32 || len(data) < 44+sessionIdLen {
return "", errMsgToLog("readHandshake: sessionIdLen[%d] is long", sessionIdLen) err = fmt.Errorf("readHandshake: sessionIdLen[%d] is long", sessionIdLen)
return
} }
data = data[44+sessionIdLen:] data = data[44+sessionIdLen:]
if len(data) < 2 { if len(data) < 2 {
return "", errMsgToLog("readHandshake: dataLen[%d] after session is short", len(data)) err = fmt.Errorf("readHandshake: dataLen[%d] after session is short", len(data))
return
} }
// cipher suite numbers // cipher suite numbers
cipherSuiteLen := int(data[0])<<8 | int(data[1]) cipherSuiteLen := int(data[0])<<8 | int(data[1])
if cipherSuiteLen%2 == 1 || len(data) < 2+cipherSuiteLen { if cipherSuiteLen%2 == 1 || len(data) < 2+cipherSuiteLen {
//return "", errMsgToLog("readHandshake: cipherSuiteLen[%d] is long", sessionIdLen) err = fmt.Errorf("readHandshake: dataLen[%d] after cipher suite is short", len(data))
return "", errMsgToLog("readHandshake: dataLen[%d] after cipher suite is short", len(data)) return
} }
data = data[2+cipherSuiteLen:] data = data[2+cipherSuiteLen:]
if len(data) < 1 { if len(data) < 1 {
return "", errMsgToLog("readHandshake: cipherSuiteLen[%d] is long", cipherSuiteLen) err = fmt.Errorf("readHandshake: cipherSuiteLen[%d] is long", cipherSuiteLen)
return
} }
// compression method // compression method
compressionMethodsLen := int(data[0]) compressionMethodsLen := int(data[0])
if len(data) < 1+compressionMethodsLen { if len(data) < 1+compressionMethodsLen {
return "", errMsgToLog("readHandshake: compressionMethodsLen[%d] is long", compressionMethodsLen) err = fmt.Errorf("readHandshake: compressionMethodsLen[%d] is long", compressionMethodsLen)
//return false return
} }
data = data[1+compressionMethodsLen:] data = data[1+compressionMethodsLen:]
if len(data) == 0 { if len(data) == 0 {
// ClientHello is optionally followed by extension data // ClientHello is optionally followed by extension data
//return true err = fmt.Errorf("readHandshake: there is no extension data to get servername")
return "", errMsgToLog("readHandshake: there is no extension data to get servername") return
} }
if len(data) < 2 { if len(data) < 2 {
return "", errMsgToLog("readHandshake: extension dataLen[%d] is too short") err = fmt.Errorf("readHandshake: extension dataLen[%d] is too short")
return
} }
extensionsLength := int(data[0])<<8 | int(data[1]) extensionsLength := int(data[0])<<8 | int(data[1])
data = data[2:] data = data[2:]
if extensionsLength != len(data) { if extensionsLength != len(data) {
return "", errMsgToLog("readHandshake: extensionsLen[%d] is not equal to dataLen[%d]", extensionsLength, len(data)) err = fmt.Errorf("readHandshake: extensionsLen[%d] is not equal to dataLen[%d]", extensionsLength, len(data))
return
} }
for len(data) != 0 { for len(data) != 0 {
if len(data) < 4 { if len(data) < 4 {
return "", errMsgToLog("readHandshake: extensionsDataLen[%d] is too short", len(data)) err = fmt.Errorf("readHandshake: extensionsDataLen[%d] is too short", len(data))
return
} }
extension := uint16(data[0])<<8 | uint16(data[1]) extension := uint16(data[0])<<8 | uint16(data[1])
length := int(data[2])<<8 | int(data[3]) length := int(data[2])<<8 | int(data[3])
data = data[4:] data = data[4:]
if len(data) < length { if len(data) < length {
return "", errMsgToLog("readHandshake: extensionLen[%d] is long", length) err = fmt.Errorf("readHandshake: extensionLen[%d] is long", length)
//return false return
} }
switch extension { switch extension {
case extensionRenegotiationInfo: case extensionRenegotiationInfo:
if length != 1 || data[0] != 0 { if length != 1 || data[0] != 0 {
return "", errMsgToLog("readHandshake: extension reNegotiationInfoLen[%d] is short", length) err = fmt.Errorf("readHandshake: extension reNegotiationInfoLen[%d] is short", length)
return
} }
case extensionNextProtoNeg: case extensionNextProtoNeg:
case extensionStatusRequest: case extensionStatusRequest:
case extensionServerName: case extensionServerName:
d := data[:length] d := data[:length]
if len(d) < 2 { if len(d) < 2 {
return "", errMsgToLog("readHandshake: remiaining dataLen[%d] is short", len(d)) err = fmt.Errorf("readHandshake: remiaining dataLen[%d] is short", len(d))
return
} }
namesLen := int(d[0])<<8 | int(d[1]) namesLen := int(d[0])<<8 | int(d[1])
d = d[2:] d = d[2:]
if len(d) != namesLen { if len(d) != namesLen {
return "", errMsgToLog("readHandshake: nameListLen[%d] is not equal to dataLen[%d]", namesLen, len(d)) err = fmt.Errorf("readHandshake: nameListLen[%d] is not equal to dataLen[%d]", namesLen, len(d))
return
} }
for len(d) > 0 { for len(d) > 0 {
if len(d) < 3 { if len(d) < 3 {
return "", errMsgToLog("readHandshake: extension serverNameLen[%d] is short", len(d)) err = fmt.Errorf("readHandshake: extension serverNameLen[%d] is short", len(d))
return
} }
nameType := d[0] nameType := d[0]
nameLen := int(d[1])<<8 | int(d[2]) nameLen := int(d[1])<<8 | int(d[2])
d = d[3:] d = d[3:]
if len(d) < nameLen { if len(d) < nameLen {
return "", errMsgToLog("readHandshake: nameLen[%d] is not equal to dataLen[%d]", nameLen, len(d)) err = fmt.Errorf("readHandshake: nameLen[%d] is not equal to dataLen[%d]", nameLen, len(d))
return
} }
if nameType == 0 { if nameType == 0 {
suffix := ""
if VhostHttpsPort != 443 {
suffix = fmt.Sprintf(":%d", VhostHttpsPort)
}
serverName := string(d[:nameLen]) serverName := string(d[:nameLen])
domain := strings.ToLower(strings.TrimSpace(serverName)) + suffix host = strings.TrimSpace(serverName)
return domain, nil return host, nil
break
} }
d = d[nameLen:] d = d[nameLen:]
} }
} }
data = data[length:] data = data[length:]
} }
//return "test.codermao.com:8082", nil err = fmt.Errorf("Unknow error")
return "", errMsgToLog("Unknow error") return
} }
func GetHttpsHostname(c *conn.Conn) (sc net.Conn, routerName string, err error) { func GetHttpsHostname(c *conn.Conn) (sc net.Conn, routerName string, err error) {
log.Info("GetHttpsHostname")
sc, rd := newShareConn(c.TcpConn) sc, rd := newShareConn(c.TcpConn)
host, err := readHandshake(rd) host, err := readHandshake(rd)
if err != nil { if err != nil {
return sc, "", err return sc, "", err
} }
/*
if _, ok := c.TcpConn.(*tls.Conn); ok {
log.Warn("convert to tlsConn success")
} else {
log.Warn("convert to tlsConn error")
}*/
//tcpConn.
log.Debug("GetHttpsHostname: %s", host)
return sc, host, nil return sc, host, nil
} }
func NewHttpsMuxer(listener *conn.Listener, timeout time.Duration) (*HttpsMuxer, error) {
mux, err := NewVhostMuxer(listener, GetHttpsHostname, timeout)
return &HttpsMuxer{mux}, err
}

View File

@ -15,12 +15,10 @@
package vhost package vhost
import ( import (
"bufio"
"bytes" "bytes"
"fmt" "fmt"
"io" "io"
"net" "net"
"net/http"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -99,7 +97,6 @@ func (v *VhostMuxer) handle(c *conn.Conn) {
} }
name = strings.ToLower(name) name = strings.ToLower(name)
l, ok := v.getListener(name) l, ok := v.getListener(name)
if !ok { if !ok {
return return
@ -113,28 +110,6 @@ func (v *VhostMuxer) handle(c *conn.Conn) {
l.accept <- c l.accept <- c
} }
type HttpMuxer struct {
*VhostMuxer
}
func GetHttpHostname(c *conn.Conn) (_ net.Conn, routerName string, err error) {
sc, rd := newShareConn(c.TcpConn)
request, err := http.ReadRequest(bufio.NewReader(rd))
if err != nil {
return sc, "", err
}
routerName = request.Host
request.Body.Close()
return sc, routerName, nil
}
func NewHttpMuxer(listener *conn.Listener, timeout time.Duration) (*HttpMuxer, error) {
mux, err := NewVhostMuxer(listener, GetHttpHostname, timeout)
return &HttpMuxer{mux}, err
}
type Listener struct { type Listener struct {
name string name string
mux *VhostMuxer // for closing VhostMuxer mux *VhostMuxer // for closing VhostMuxer