add host_header_rewrite in frpc.ini to rewrite your requests with a modified Host header

This commit is contained in:
fatedier 2016-07-26 00:18:19 +08:00
parent d2e1cfa5bc
commit 452e02adab
11 changed files with 93 additions and 79 deletions

View File

@ -52,3 +52,4 @@ local_ip = 127.0.0.1
local_port = 80 local_port = 80
use_gzip = true use_gzip = true
custom_domains = web03.yourdomain.com custom_domains = web03.yourdomain.com
host_header_rewrite = example.com

View File

@ -138,15 +138,14 @@ func loginToServer(cli *client.ProxyClient) (c *conn.Conn, err error) {
nowTime := time.Now().Unix() nowTime := time.Now().Unix()
req := &msg.ControlReq{ req := &msg.ControlReq{
Type: consts.NewCtlConn, Type: consts.NewCtlConn,
ProxyName: cli.Name, ProxyName: cli.Name,
UseEncryption: cli.UseEncryption, UseEncryption: cli.UseEncryption,
UseGzip: cli.UseGzip, UseGzip: cli.UseGzip,
PrivilegeMode: cli.PrivilegeMode, PrivilegeMode: cli.PrivilegeMode,
ProxyType: cli.Type, ProxyType: cli.Type,
LocalIp: cli.LocalIp, HostHeaderRewrite: cli.HostHeaderRewrite,
LocalPort: cli.LocalPort, Timestamp: nowTime,
Timestamp: nowTime,
} }
if cli.PrivilegeMode { if cli.PrivilegeMode {
privilegeKey := pcrypto.GetAuthKey(cli.Name + client.PrivilegeToken + fmt.Sprintf("%d", nowTime)) privilegeKey := pcrypto.GetAuthKey(cli.Name + client.PrivilegeToken + fmt.Sprintf("%d", nowTime))

View File

@ -276,8 +276,7 @@ func doLogin(req *msg.ControlReq, c *conn.Conn) (ret int64, info string) {
// set infomations from frpc // set infomations from frpc
s.UseEncryption = req.UseEncryption s.UseEncryption = req.UseEncryption
s.UseGzip = req.UseGzip s.UseGzip = req.UseGzip
s.ClientIp = req.LocalIp s.HostHeaderRewrite = req.HostHeaderRewrite
s.ClientPort = req.LocalPort
// start proxy and listen for user connections, no block // start proxy and listen for user connections, no block
err := s.Start(c) err := s.Start(c)

View File

@ -140,6 +140,14 @@ func LoadConf(confFile string) (err error) {
proxyClient.UseGzip = true proxyClient.UseGzip = true
} }
if proxyClient.Type == "http" {
// host_header_rewrite
tmpStr, ok = section["host_header_rewrite"]
if ok {
proxyClient.HostHeaderRewrite = tmpStr
}
}
// privilege_mode // privilege_mode
proxyClient.PrivilegeMode = false proxyClient.PrivilegeMode = false
tmpStr, ok = section["privilege_mode"] tmpStr, ok = section["privilege_mode"]
@ -167,6 +175,7 @@ func LoadConf(confFile string) (err error) {
return fmt.Errorf("Parse conf error: proxy [%s] remote_port not found", proxyClient.Name) return fmt.Errorf("Parse conf error: proxy [%s] remote_port not found", proxyClient.Name)
} }
} else if proxyClient.Type == "http" { } else if proxyClient.Type == "http" {
// custom_domains
domainStr, ok := section["custom_domains"] domainStr, ok := section["custom_domains"]
if ok { if ok {
proxyClient.CustomDomains = strings.Split(domainStr, ",") proxyClient.CustomDomains = strings.Split(domainStr, ",")
@ -180,6 +189,7 @@ func LoadConf(confFile string) (err error) {
return fmt.Errorf("Parse conf error: proxy [%s] custom_domains must be set when type equals http", proxyClient.Name) return fmt.Errorf("Parse conf error: proxy [%s] custom_domains must be set when type equals http", proxyClient.Name)
} }
} else if proxyClient.Type == "https" { } else if proxyClient.Type == "https" {
// custom_domains
domainStr, ok := section["custom_domains"] domainStr, ok := section["custom_domains"]
if ok { if ok {
proxyClient.CustomDomains = strings.Split(domainStr, ",") proxyClient.CustomDomains = strings.Split(domainStr, ",")

View File

@ -15,14 +15,12 @@
package config package config
type BaseConf struct { type BaseConf struct {
Name string Name string
AuthToken string AuthToken string
Type string Type string
UseEncryption bool UseEncryption bool
UseGzip bool UseGzip bool
PrivilegeMode bool PrivilegeMode bool
PrivilegeToken string PrivilegeToken string
ClientIp string HostHeaderRewrite string
ClientPort int64
ServerPort int64
} }

View File

@ -26,16 +26,15 @@ type ControlReq struct {
AuthKey string `json:"auth_key"` AuthKey string `json:"auth_key"`
UseEncryption bool `json:"use_encryption"` UseEncryption bool `json:"use_encryption"`
UseGzip bool `json:"use_gzip"` UseGzip bool `json:"use_gzip"`
LocalIp string `json:"local_ip"`
LocalPort int64 `json:"local_port"`
// configures used if privilege_mode is enabled // configures used if privilege_mode is enabled
PrivilegeMode bool `json:"privilege_mode"` PrivilegeMode bool `json:"privilege_mode"`
PrivilegeKey string `json:"privilege_key"` PrivilegeKey string `json:"privilege_key"`
ProxyType string `json:"proxy_type"` ProxyType string `json:"proxy_type"`
RemotePort int64 `json:"remote_port"` RemotePort int64 `json:"remote_port"`
CustomDomains []string `json:"custom_domains, omitempty"` CustomDomains []string `json:"custom_domains, omitempty"`
Timestamp int64 `json:"timestamp"` HostHeaderRewrite string `json:"host_header_rewrite"`
Timestamp int64 `json:"timestamp"`
} }
type ControlRes struct { type ControlRes struct {

View File

@ -64,7 +64,7 @@ func NewProxyServerFromCtlMsg(req *msg.ControlReq) (p *ProxyServer) {
p.BindAddr = BindAddr p.BindAddr = BindAddr
p.ListenPort = req.RemotePort p.ListenPort = req.RemotePort
p.CustomDomains = req.CustomDomains p.CustomDomains = req.CustomDomains
p.ServerPort = VhostHttpPort p.HostHeaderRewrite = req.HostHeaderRewrite
return return
} }
@ -80,7 +80,7 @@ func (p *ProxyServer) Init() {
func (p *ProxyServer) Compare(p2 *ProxyServer) bool { func (p *ProxyServer) Compare(p2 *ProxyServer) bool {
if p.Name != p2.Name || p.AuthToken != p2.AuthToken || p.Type != p2.Type || if p.Name != p2.Name || p.AuthToken != p2.AuthToken || p.Type != p2.Type ||
p.BindAddr != p2.BindAddr || p.ListenPort != p2.ListenPort { p.BindAddr != p2.BindAddr || p.ListenPort != p2.ListenPort || p.HostHeaderRewrite != p2.HostHeaderRewrite {
return false return false
} }
if len(p.CustomDomains) != len(p2.CustomDomains) { if len(p.CustomDomains) != len(p2.CustomDomains) {
@ -114,7 +114,7 @@ func (p *ProxyServer) Start(c *conn.Conn) (err error) {
p.listeners = append(p.listeners, l) p.listeners = append(p.listeners, l)
} else if p.Type == "http" { } else if p.Type == "http" {
for _, domain := range p.CustomDomains { for _, domain := range p.CustomDomains {
l, err := VhostHttpMuxer.Listen(domain, p.Type, p.ClientIp, p.ClientPort, p.ServerPort) l, err := VhostHttpMuxer.Listen(domain, p.HostHeaderRewrite)
if err != nil { if err != nil {
return err return err
} }
@ -122,7 +122,7 @@ func (p *ProxyServer) Start(c *conn.Conn) (err error) {
} }
} else if p.Type == "https" { } else if p.Type == "https" {
for _, domain := range p.CustomDomains { for _, domain := range p.CustomDomains {
l, err := VhostHttpsMuxer.Listen(domain, p.Type, p.ClientIp, p.ClientPort, p.ServerPort) l, err := VhostHttpsMuxer.Listen(domain, p.HostHeaderRewrite)
if err != nil { if err != nil {
return err return err
} }

View File

@ -117,8 +117,13 @@ func ConnectServer(host string, port int64) (c *Conn, err error) {
return c, nil return c, nil
} }
// if the tcpConn is different with c.TcpConn
// you should call c.Close() first
func (c *Conn) SetTcpConn(tcpConn net.Conn) { func (c *Conn) SetTcpConn(tcpConn net.Conn) {
c.mutex.Lock()
defer c.mutex.Unlock()
c.TcpConn = tcpConn c.TcpConn = tcpConn
c.closeFlag = false
c.Reader = bufio.NewReader(c.TcpConn) c.Reader = bufio.NewReader(c.TcpConn)
} }

View File

@ -26,7 +26,6 @@ import (
"time" "time"
"frp/utils/conn" "frp/utils/conn"
"frp/utils/log"
) )
type HttpMuxer struct { type HttpMuxer struct {
@ -47,31 +46,28 @@ func GetHttpHostname(c *conn.Conn) (_ net.Conn, routerName string, err error) {
} }
func NewHttpMuxer(listener *conn.Listener, timeout time.Duration) (*HttpMuxer, error) { func NewHttpMuxer(listener *conn.Listener, timeout time.Duration) (*HttpMuxer, error) {
mux, err := NewVhostMuxer(listener, GetHttpHostname, timeout) mux, err := NewVhostMuxer(listener, GetHttpHostname, HttpHostNameRewrite, timeout)
return &HttpMuxer{mux}, err return &HttpMuxer{mux}, err
} }
func HostNameRewrite(c *conn.Conn, clientHost string) (_ net.Conn, err error) { func HttpHostNameRewrite(c *conn.Conn, rewriteHost string) (_ net.Conn, err error) {
log.Info("HostNameRewrite, clientHost: %s", clientHost)
sc, rd := newShareConn(c.TcpConn) sc, rd := newShareConn(c.TcpConn)
var buff []byte var buff []byte
if buff, err = hostNameRewrite(rd, clientHost); err != nil { if buff, err = hostNameRewrite(rd, rewriteHost); err != nil {
return sc, err return sc, err
} }
err = sc.WriteBuff(buff) err = sc.WriteBuff(buff)
return sc, err return sc, err
} }
func hostNameRewrite(request io.Reader, clientHost string) (_ []byte, err error) { func hostNameRewrite(request io.Reader, rewriteHost string) (_ []byte, err error) {
buffer := make([]byte, 1024) buffer := make([]byte, 1024)
request.Read(buffer) request.Read(buffer)
log.Debug("before hostNameRewrite:\n %s", string(buffer)) retBuffer, err := parseRequest(buffer, rewriteHost)
retBuffer, err := parseRequest(buffer, clientHost)
log.Debug("after hostNameRewrite:\n %s", string(retBuffer))
return retBuffer, err return retBuffer, err
} }
func parseRequest(org []byte, clientHost string) (ret []byte, err error) { func parseRequest(org []byte, rewriteHost string) (ret []byte, err error) {
tp := bytes.NewBuffer(org) tp := bytes.NewBuffer(org)
// First line: GET /index.html HTTP/1.0 // First line: GET /index.html HTTP/1.0
var b []byte var b []byte
@ -79,10 +75,10 @@ func parseRequest(org []byte, clientHost string) (ret []byte, err error) {
return nil, err return nil, err
} }
req := new(http.Request) req := new(http.Request)
//we invoked ReadRequest in GetHttpHostname before, so we ignore error // we invoked ReadRequest in GetHttpHostname before, so we ignore error
req.Method, req.RequestURI, req.Proto, _ = parseRequestLine(string(b)) req.Method, req.RequestURI, req.Proto, _ = parseRequestLine(string(b))
rawurl := req.RequestURI rawurl := req.RequestURI
//CONNECT www.google.com:443 HTTP/1.1 // CONNECT www.google.com:443 HTTP/1.1
justAuthority := req.Method == "CONNECT" && !strings.HasPrefix(rawurl, "/") justAuthority := req.Method == "CONNECT" && !strings.HasPrefix(rawurl, "/")
if justAuthority { if justAuthority {
rawurl = "http://" + rawurl rawurl = "http://" + rawurl
@ -97,7 +93,7 @@ func parseRequest(org []byte, clientHost string) (ret []byte, err error) {
// GET /index.html HTTP/1.1 // GET /index.html HTTP/1.1
// Host: www.google.com // Host: www.google.com
if req.URL.Host == "" { if req.URL.Host == "" {
changedBuf, err := changeHostName(tp, clientHost) changedBuf, err := changeHostName(tp, rewriteHost)
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
buf.Write(b) buf.Write(b)
buf.Write(changedBuf) buf.Write(changedBuf)
@ -108,7 +104,12 @@ func parseRequest(org []byte, clientHost string) (ret []byte, err error) {
// GET http://www.google.com/index.html HTTP/1.1 // GET http://www.google.com/index.html HTTP/1.1
// Host: doesntmatter // Host: doesntmatter
// In this case, any Host line is ignored. // In this case, any Host line is ignored.
req.URL.Host = clientHost hostPort := strings.Split(req.URL.Host, ":")
if len(hostPort) == 1 {
req.URL.Host = rewriteHost
} else if len(hostPort) == 2 {
req.URL.Host = fmt.Sprintf("%s:%s", rewriteHost, hostPort[1])
}
firstLine := req.Method + " " + req.URL.String() + " " + req.Proto firstLine := req.Method + " " + req.URL.String() + " " + req.Proto
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
buf.WriteString(firstLine) buf.WriteString(firstLine)
@ -128,7 +129,7 @@ func parseRequestLine(line string) (method, requestURI, proto string, ok bool) {
return line[:s1], line[s1+1 : s2], line[s2+1:], true return line[:s1], line[s1+1 : s2], line[s2+1:], true
} }
func changeHostName(buff *bytes.Buffer, clientHost string) (_ []byte, err error) { func changeHostName(buff *bytes.Buffer, rewriteHost string) (_ []byte, err error) {
retBuf := new(bytes.Buffer) retBuf := new(bytes.Buffer)
peek := buff.Bytes() peek := buff.Bytes()
@ -145,7 +146,13 @@ func changeHostName(buff *bytes.Buffer, clientHost string) (_ []byte, err error)
return nil, fmt.Errorf("malformed MIME header line: " + string(kv)) return nil, fmt.Errorf("malformed MIME header line: " + string(kv))
} }
if strings.Contains(strings.ToLower(string(kv[:j])), "host") { if strings.Contains(strings.ToLower(string(kv[:j])), "host") {
hostHeader := fmt.Sprintf("Host: %s\n", clientHost) var hostHeader string
portPos := bytes.IndexByte(kv[j+1:], ':')
if portPos == -1 {
hostHeader = fmt.Sprintf("Host: %s\n", rewriteHost)
} else {
hostHeader = fmt.Sprintf("Host: %s:%s\n", rewriteHost, kv[portPos+1:])
}
retBuf.WriteString(hostHeader) retBuf.WriteString(hostHeader)
peek = peek[i+1:] peek = peek[i+1:]
break break

View File

@ -47,7 +47,7 @@ type HttpsMuxer struct {
} }
func NewHttpsMuxer(listener *conn.Listener, timeout time.Duration) (*HttpsMuxer, error) { func NewHttpsMuxer(listener *conn.Listener, timeout time.Duration) (*HttpsMuxer, error) {
mux, err := NewVhostMuxer(listener, GetHttpsHostname, timeout) mux, err := NewVhostMuxer(listener, GetHttpsHostname, nil, timeout)
return &HttpsMuxer{mux}, err return &HttpsMuxer{mux}, err
} }

View File

@ -27,41 +27,42 @@ import (
) )
type muxFunc func(*conn.Conn) (net.Conn, string, error) type muxFunc func(*conn.Conn) (net.Conn, string, error)
type hostRewriteFunc func(*conn.Conn, string) (net.Conn, error)
type VhostMuxer struct { type VhostMuxer struct {
listener *conn.Listener listener *conn.Listener
timeout time.Duration timeout time.Duration
vhostFunc muxFunc vhostFunc muxFunc
rewriteFunc hostRewriteFunc
registryMap map[string]*Listener registryMap map[string]*Listener
mutex sync.RWMutex mutex sync.RWMutex
} }
func NewVhostMuxer(listener *conn.Listener, vhostFunc muxFunc, timeout time.Duration) (mux *VhostMuxer, err error) { func NewVhostMuxer(listener *conn.Listener, vhostFunc muxFunc, rewriteFunc hostRewriteFunc, timeout time.Duration) (mux *VhostMuxer, err error) {
mux = &VhostMuxer{ mux = &VhostMuxer{
listener: listener, listener: listener,
timeout: timeout, timeout: timeout,
vhostFunc: vhostFunc, vhostFunc: vhostFunc,
rewriteFunc: rewriteFunc,
registryMap: make(map[string]*Listener), registryMap: make(map[string]*Listener),
} }
go mux.run() go mux.run()
return mux, nil return mux, nil
} }
func (v *VhostMuxer) Listen(name, proxytype, clientIp string, clientPort, serverPort int64) (l *Listener, err error) { // listen for a new domain name, if rewriteHost is not empty and rewriteFunc is not nil, then rewrite the host header to rewriteHost
func (v *VhostMuxer) Listen(name string, rewriteHost string) (l *Listener, err error) {
v.mutex.Lock() v.mutex.Lock()
defer v.mutex.Unlock() defer v.mutex.Unlock()
if _, exist := v.registryMap[name]; exist { if _, exist := v.registryMap[name]; exist {
return nil, fmt.Errorf("name %s is already bound", name) return nil, fmt.Errorf("domain name %s is already bound", name)
} }
l = &Listener{ l = &Listener{
name: name, name: name,
mux: v, rewriteHost: rewriteHost,
accept: make(chan *conn.Conn), mux: v,
proxyType: proxytype, accept: make(chan *conn.Conn),
clientIp: clientIp,
clientPort: clientPort,
serverPort: serverPort,
} }
v.registryMap[name] = l v.registryMap[name] = l
return l, nil return l, nil
@ -115,13 +116,10 @@ func (v *VhostMuxer) handle(c *conn.Conn) {
} }
type Listener struct { type Listener struct {
name string name string
mux *VhostMuxer // for closing VhostMuxer rewriteHost string
accept chan *conn.Conn mux *VhostMuxer // for closing VhostMuxer
proxyType string //suppor http host rewrite accept chan *conn.Conn
clientIp string
clientPort int64
serverPort int64
} }
func (l *Listener) Accept() (*conn.Conn, error) { func (l *Listener) Accept() (*conn.Conn, error) {
@ -129,19 +127,16 @@ func (l *Listener) Accept() (*conn.Conn, error) {
if !ok { if !ok {
return nil, fmt.Errorf("Listener closed") return nil, fmt.Errorf("Listener closed")
} }
if net.ParseIP(l.clientIp) == nil && l.proxyType == "http" {
if (l.name != l.clientIp) || (l.serverPort != l.clientPort) { // if rewriteFunc is exist and rewriteHost is set
clientHost := l.clientIp // rewrite http requests with a modified host header
if l.clientPort != 80 { if l.mux.rewriteFunc != nil && l.rewriteHost != "" {
strPort := fmt.Sprintf(":%d", l.clientPort) fmt.Printf("host rewrite: %s\n", l.rewriteHost)
clientHost += strPort sConn, err := l.mux.rewriteFunc(conn, l.rewriteHost)
} if err != nil {
retConn, err := HostNameRewrite(conn, clientHost) return nil, fmt.Errorf("http host header rewrite failed")
if err != nil {
return nil, fmt.Errorf("http host rewrite failed")
}
conn.SetTcpConn(retConn)
} }
conn.SetTcpConn(sConn)
} }
return conn, nil return conn, nil
} }
@ -162,6 +157,7 @@ type sharedConn struct {
buff *bytes.Buffer buff *bytes.Buffer
} }
// the bytes you read in io.Reader, will be reserved in sharedConn
func newShareConn(conn net.Conn) (*sharedConn, io.Reader) { func newShareConn(conn net.Conn) (*sharedConn, io.Reader) {
sc := &sharedConn{ sc := &sharedConn{
Conn: conn, Conn: conn,