add host_header_rewrite in frpc.ini to rewrite your requests with a modified Host header

This commit is contained in:
fatedier 2016-07-26 00:18:19 +08:00
parent d2e1cfa5bc
commit 452e02adab
11 changed files with 93 additions and 79 deletions

View File

@ -52,3 +52,4 @@ local_ip = 127.0.0.1
local_port = 80 local_port = 80
use_gzip = true use_gzip = true
custom_domains = web03.yourdomain.com custom_domains = web03.yourdomain.com
host_header_rewrite = example.com

View File

@ -144,8 +144,7 @@ func loginToServer(cli *client.ProxyClient) (c *conn.Conn, err error) {
UseGzip: cli.UseGzip, UseGzip: cli.UseGzip,
PrivilegeMode: cli.PrivilegeMode, PrivilegeMode: cli.PrivilegeMode,
ProxyType: cli.Type, ProxyType: cli.Type,
LocalIp: cli.LocalIp, HostHeaderRewrite: cli.HostHeaderRewrite,
LocalPort: cli.LocalPort,
Timestamp: nowTime, Timestamp: nowTime,
} }
if cli.PrivilegeMode { if cli.PrivilegeMode {

View File

@ -276,8 +276,7 @@ func doLogin(req *msg.ControlReq, c *conn.Conn) (ret int64, info string) {
// set infomations from frpc // set infomations from frpc
s.UseEncryption = req.UseEncryption s.UseEncryption = req.UseEncryption
s.UseGzip = req.UseGzip s.UseGzip = req.UseGzip
s.ClientIp = req.LocalIp s.HostHeaderRewrite = req.HostHeaderRewrite
s.ClientPort = req.LocalPort
// start proxy and listen for user connections, no block // start proxy and listen for user connections, no block
err := s.Start(c) err := s.Start(c)

View File

@ -140,6 +140,14 @@ func LoadConf(confFile string) (err error) {
proxyClient.UseGzip = true proxyClient.UseGzip = true
} }
if proxyClient.Type == "http" {
// host_header_rewrite
tmpStr, ok = section["host_header_rewrite"]
if ok {
proxyClient.HostHeaderRewrite = tmpStr
}
}
// privilege_mode // privilege_mode
proxyClient.PrivilegeMode = false proxyClient.PrivilegeMode = false
tmpStr, ok = section["privilege_mode"] tmpStr, ok = section["privilege_mode"]
@ -167,6 +175,7 @@ func LoadConf(confFile string) (err error) {
return fmt.Errorf("Parse conf error: proxy [%s] remote_port not found", proxyClient.Name) return fmt.Errorf("Parse conf error: proxy [%s] remote_port not found", proxyClient.Name)
} }
} else if proxyClient.Type == "http" { } else if proxyClient.Type == "http" {
// custom_domains
domainStr, ok := section["custom_domains"] domainStr, ok := section["custom_domains"]
if ok { if ok {
proxyClient.CustomDomains = strings.Split(domainStr, ",") proxyClient.CustomDomains = strings.Split(domainStr, ",")
@ -180,6 +189,7 @@ func LoadConf(confFile string) (err error) {
return fmt.Errorf("Parse conf error: proxy [%s] custom_domains must be set when type equals http", proxyClient.Name) return fmt.Errorf("Parse conf error: proxy [%s] custom_domains must be set when type equals http", proxyClient.Name)
} }
} else if proxyClient.Type == "https" { } else if proxyClient.Type == "https" {
// custom_domains
domainStr, ok := section["custom_domains"] domainStr, ok := section["custom_domains"]
if ok { if ok {
proxyClient.CustomDomains = strings.Split(domainStr, ",") proxyClient.CustomDomains = strings.Split(domainStr, ",")

View File

@ -22,7 +22,5 @@ type BaseConf struct {
UseGzip bool UseGzip bool
PrivilegeMode bool PrivilegeMode bool
PrivilegeToken string PrivilegeToken string
ClientIp string HostHeaderRewrite string
ClientPort int64
ServerPort int64
} }

View File

@ -26,8 +26,6 @@ type ControlReq struct {
AuthKey string `json:"auth_key"` AuthKey string `json:"auth_key"`
UseEncryption bool `json:"use_encryption"` UseEncryption bool `json:"use_encryption"`
UseGzip bool `json:"use_gzip"` UseGzip bool `json:"use_gzip"`
LocalIp string `json:"local_ip"`
LocalPort int64 `json:"local_port"`
// configures used if privilege_mode is enabled // configures used if privilege_mode is enabled
PrivilegeMode bool `json:"privilege_mode"` PrivilegeMode bool `json:"privilege_mode"`
@ -35,6 +33,7 @@ type ControlReq struct {
ProxyType string `json:"proxy_type"` ProxyType string `json:"proxy_type"`
RemotePort int64 `json:"remote_port"` RemotePort int64 `json:"remote_port"`
CustomDomains []string `json:"custom_domains, omitempty"` CustomDomains []string `json:"custom_domains, omitempty"`
HostHeaderRewrite string `json:"host_header_rewrite"`
Timestamp int64 `json:"timestamp"` Timestamp int64 `json:"timestamp"`
} }

View File

@ -64,7 +64,7 @@ func NewProxyServerFromCtlMsg(req *msg.ControlReq) (p *ProxyServer) {
p.BindAddr = BindAddr p.BindAddr = BindAddr
p.ListenPort = req.RemotePort p.ListenPort = req.RemotePort
p.CustomDomains = req.CustomDomains p.CustomDomains = req.CustomDomains
p.ServerPort = VhostHttpPort p.HostHeaderRewrite = req.HostHeaderRewrite
return return
} }
@ -80,7 +80,7 @@ func (p *ProxyServer) Init() {
func (p *ProxyServer) Compare(p2 *ProxyServer) bool { func (p *ProxyServer) Compare(p2 *ProxyServer) bool {
if p.Name != p2.Name || p.AuthToken != p2.AuthToken || p.Type != p2.Type || if p.Name != p2.Name || p.AuthToken != p2.AuthToken || p.Type != p2.Type ||
p.BindAddr != p2.BindAddr || p.ListenPort != p2.ListenPort { p.BindAddr != p2.BindAddr || p.ListenPort != p2.ListenPort || p.HostHeaderRewrite != p2.HostHeaderRewrite {
return false return false
} }
if len(p.CustomDomains) != len(p2.CustomDomains) { if len(p.CustomDomains) != len(p2.CustomDomains) {
@ -114,7 +114,7 @@ func (p *ProxyServer) Start(c *conn.Conn) (err error) {
p.listeners = append(p.listeners, l) p.listeners = append(p.listeners, l)
} else if p.Type == "http" { } else if p.Type == "http" {
for _, domain := range p.CustomDomains { for _, domain := range p.CustomDomains {
l, err := VhostHttpMuxer.Listen(domain, p.Type, p.ClientIp, p.ClientPort, p.ServerPort) l, err := VhostHttpMuxer.Listen(domain, p.HostHeaderRewrite)
if err != nil { if err != nil {
return err return err
} }
@ -122,7 +122,7 @@ func (p *ProxyServer) Start(c *conn.Conn) (err error) {
} }
} else if p.Type == "https" { } else if p.Type == "https" {
for _, domain := range p.CustomDomains { for _, domain := range p.CustomDomains {
l, err := VhostHttpsMuxer.Listen(domain, p.Type, p.ClientIp, p.ClientPort, p.ServerPort) l, err := VhostHttpsMuxer.Listen(domain, p.HostHeaderRewrite)
if err != nil { if err != nil {
return err return err
} }

View File

@ -117,8 +117,13 @@ func ConnectServer(host string, port int64) (c *Conn, err error) {
return c, nil return c, nil
} }
// if the tcpConn is different with c.TcpConn
// you should call c.Close() first
func (c *Conn) SetTcpConn(tcpConn net.Conn) { func (c *Conn) SetTcpConn(tcpConn net.Conn) {
c.mutex.Lock()
defer c.mutex.Unlock()
c.TcpConn = tcpConn c.TcpConn = tcpConn
c.closeFlag = false
c.Reader = bufio.NewReader(c.TcpConn) c.Reader = bufio.NewReader(c.TcpConn)
} }

View File

@ -26,7 +26,6 @@ import (
"time" "time"
"frp/utils/conn" "frp/utils/conn"
"frp/utils/log"
) )
type HttpMuxer struct { type HttpMuxer struct {
@ -47,31 +46,28 @@ func GetHttpHostname(c *conn.Conn) (_ net.Conn, routerName string, err error) {
} }
func NewHttpMuxer(listener *conn.Listener, timeout time.Duration) (*HttpMuxer, error) { func NewHttpMuxer(listener *conn.Listener, timeout time.Duration) (*HttpMuxer, error) {
mux, err := NewVhostMuxer(listener, GetHttpHostname, timeout) mux, err := NewVhostMuxer(listener, GetHttpHostname, HttpHostNameRewrite, timeout)
return &HttpMuxer{mux}, err return &HttpMuxer{mux}, err
} }
func HostNameRewrite(c *conn.Conn, clientHost string) (_ net.Conn, err error) { func HttpHostNameRewrite(c *conn.Conn, rewriteHost string) (_ net.Conn, err error) {
log.Info("HostNameRewrite, clientHost: %s", clientHost)
sc, rd := newShareConn(c.TcpConn) sc, rd := newShareConn(c.TcpConn)
var buff []byte var buff []byte
if buff, err = hostNameRewrite(rd, clientHost); err != nil { if buff, err = hostNameRewrite(rd, rewriteHost); err != nil {
return sc, err return sc, err
} }
err = sc.WriteBuff(buff) err = sc.WriteBuff(buff)
return sc, err return sc, err
} }
func hostNameRewrite(request io.Reader, clientHost string) (_ []byte, err error) { func hostNameRewrite(request io.Reader, rewriteHost string) (_ []byte, err error) {
buffer := make([]byte, 1024) buffer := make([]byte, 1024)
request.Read(buffer) request.Read(buffer)
log.Debug("before hostNameRewrite:\n %s", string(buffer)) retBuffer, err := parseRequest(buffer, rewriteHost)
retBuffer, err := parseRequest(buffer, clientHost)
log.Debug("after hostNameRewrite:\n %s", string(retBuffer))
return retBuffer, err return retBuffer, err
} }
func parseRequest(org []byte, clientHost string) (ret []byte, err error) { func parseRequest(org []byte, rewriteHost string) (ret []byte, err error) {
tp := bytes.NewBuffer(org) tp := bytes.NewBuffer(org)
// First line: GET /index.html HTTP/1.0 // First line: GET /index.html HTTP/1.0
var b []byte var b []byte
@ -79,10 +75,10 @@ func parseRequest(org []byte, clientHost string) (ret []byte, err error) {
return nil, err return nil, err
} }
req := new(http.Request) req := new(http.Request)
//we invoked ReadRequest in GetHttpHostname before, so we ignore error // we invoked ReadRequest in GetHttpHostname before, so we ignore error
req.Method, req.RequestURI, req.Proto, _ = parseRequestLine(string(b)) req.Method, req.RequestURI, req.Proto, _ = parseRequestLine(string(b))
rawurl := req.RequestURI rawurl := req.RequestURI
//CONNECT www.google.com:443 HTTP/1.1 // CONNECT www.google.com:443 HTTP/1.1
justAuthority := req.Method == "CONNECT" && !strings.HasPrefix(rawurl, "/") justAuthority := req.Method == "CONNECT" && !strings.HasPrefix(rawurl, "/")
if justAuthority { if justAuthority {
rawurl = "http://" + rawurl rawurl = "http://" + rawurl
@ -97,7 +93,7 @@ func parseRequest(org []byte, clientHost string) (ret []byte, err error) {
// GET /index.html HTTP/1.1 // GET /index.html HTTP/1.1
// Host: www.google.com // Host: www.google.com
if req.URL.Host == "" { if req.URL.Host == "" {
changedBuf, err := changeHostName(tp, clientHost) changedBuf, err := changeHostName(tp, rewriteHost)
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
buf.Write(b) buf.Write(b)
buf.Write(changedBuf) buf.Write(changedBuf)
@ -108,7 +104,12 @@ func parseRequest(org []byte, clientHost string) (ret []byte, err error) {
// GET http://www.google.com/index.html HTTP/1.1 // GET http://www.google.com/index.html HTTP/1.1
// Host: doesntmatter // Host: doesntmatter
// In this case, any Host line is ignored. // In this case, any Host line is ignored.
req.URL.Host = clientHost hostPort := strings.Split(req.URL.Host, ":")
if len(hostPort) == 1 {
req.URL.Host = rewriteHost
} else if len(hostPort) == 2 {
req.URL.Host = fmt.Sprintf("%s:%s", rewriteHost, hostPort[1])
}
firstLine := req.Method + " " + req.URL.String() + " " + req.Proto firstLine := req.Method + " " + req.URL.String() + " " + req.Proto
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
buf.WriteString(firstLine) buf.WriteString(firstLine)
@ -128,7 +129,7 @@ func parseRequestLine(line string) (method, requestURI, proto string, ok bool) {
return line[:s1], line[s1+1 : s2], line[s2+1:], true return line[:s1], line[s1+1 : s2], line[s2+1:], true
} }
func changeHostName(buff *bytes.Buffer, clientHost string) (_ []byte, err error) { func changeHostName(buff *bytes.Buffer, rewriteHost string) (_ []byte, err error) {
retBuf := new(bytes.Buffer) retBuf := new(bytes.Buffer)
peek := buff.Bytes() peek := buff.Bytes()
@ -145,7 +146,13 @@ func changeHostName(buff *bytes.Buffer, clientHost string) (_ []byte, err error)
return nil, fmt.Errorf("malformed MIME header line: " + string(kv)) return nil, fmt.Errorf("malformed MIME header line: " + string(kv))
} }
if strings.Contains(strings.ToLower(string(kv[:j])), "host") { if strings.Contains(strings.ToLower(string(kv[:j])), "host") {
hostHeader := fmt.Sprintf("Host: %s\n", clientHost) var hostHeader string
portPos := bytes.IndexByte(kv[j+1:], ':')
if portPos == -1 {
hostHeader = fmt.Sprintf("Host: %s\n", rewriteHost)
} else {
hostHeader = fmt.Sprintf("Host: %s:%s\n", rewriteHost, kv[portPos+1:])
}
retBuf.WriteString(hostHeader) retBuf.WriteString(hostHeader)
peek = peek[i+1:] peek = peek[i+1:]
break break

View File

@ -47,7 +47,7 @@ type HttpsMuxer struct {
} }
func NewHttpsMuxer(listener *conn.Listener, timeout time.Duration) (*HttpsMuxer, error) { func NewHttpsMuxer(listener *conn.Listener, timeout time.Duration) (*HttpsMuxer, error) {
mux, err := NewVhostMuxer(listener, GetHttpsHostname, timeout) mux, err := NewVhostMuxer(listener, GetHttpsHostname, nil, timeout)
return &HttpsMuxer{mux}, err return &HttpsMuxer{mux}, err
} }

View File

@ -27,41 +27,42 @@ import (
) )
type muxFunc func(*conn.Conn) (net.Conn, string, error) type muxFunc func(*conn.Conn) (net.Conn, string, error)
type hostRewriteFunc func(*conn.Conn, string) (net.Conn, error)
type VhostMuxer struct { type VhostMuxer struct {
listener *conn.Listener listener *conn.Listener
timeout time.Duration timeout time.Duration
vhostFunc muxFunc vhostFunc muxFunc
rewriteFunc hostRewriteFunc
registryMap map[string]*Listener registryMap map[string]*Listener
mutex sync.RWMutex mutex sync.RWMutex
} }
func NewVhostMuxer(listener *conn.Listener, vhostFunc muxFunc, timeout time.Duration) (mux *VhostMuxer, err error) { func NewVhostMuxer(listener *conn.Listener, vhostFunc muxFunc, rewriteFunc hostRewriteFunc, timeout time.Duration) (mux *VhostMuxer, err error) {
mux = &VhostMuxer{ mux = &VhostMuxer{
listener: listener, listener: listener,
timeout: timeout, timeout: timeout,
vhostFunc: vhostFunc, vhostFunc: vhostFunc,
rewriteFunc: rewriteFunc,
registryMap: make(map[string]*Listener), registryMap: make(map[string]*Listener),
} }
go mux.run() go mux.run()
return mux, nil return mux, nil
} }
func (v *VhostMuxer) Listen(name, proxytype, clientIp string, clientPort, serverPort int64) (l *Listener, err error) { // listen for a new domain name, if rewriteHost is not empty and rewriteFunc is not nil, then rewrite the host header to rewriteHost
func (v *VhostMuxer) Listen(name string, rewriteHost string) (l *Listener, err error) {
v.mutex.Lock() v.mutex.Lock()
defer v.mutex.Unlock() defer v.mutex.Unlock()
if _, exist := v.registryMap[name]; exist { if _, exist := v.registryMap[name]; exist {
return nil, fmt.Errorf("name %s is already bound", name) return nil, fmt.Errorf("domain name %s is already bound", name)
} }
l = &Listener{ l = &Listener{
name: name, name: name,
rewriteHost: rewriteHost,
mux: v, mux: v,
accept: make(chan *conn.Conn), accept: make(chan *conn.Conn),
proxyType: proxytype,
clientIp: clientIp,
clientPort: clientPort,
serverPort: serverPort,
} }
v.registryMap[name] = l v.registryMap[name] = l
return l, nil return l, nil
@ -116,12 +117,9 @@ func (v *VhostMuxer) handle(c *conn.Conn) {
type Listener struct { type Listener struct {
name string name string
rewriteHost string
mux *VhostMuxer // for closing VhostMuxer mux *VhostMuxer // for closing VhostMuxer
accept chan *conn.Conn accept chan *conn.Conn
proxyType string //suppor http host rewrite
clientIp string
clientPort int64
serverPort int64
} }
func (l *Listener) Accept() (*conn.Conn, error) { func (l *Listener) Accept() (*conn.Conn, error) {
@ -129,19 +127,16 @@ func (l *Listener) Accept() (*conn.Conn, error) {
if !ok { if !ok {
return nil, fmt.Errorf("Listener closed") return nil, fmt.Errorf("Listener closed")
} }
if net.ParseIP(l.clientIp) == nil && l.proxyType == "http" {
if (l.name != l.clientIp) || (l.serverPort != l.clientPort) { // if rewriteFunc is exist and rewriteHost is set
clientHost := l.clientIp // rewrite http requests with a modified host header
if l.clientPort != 80 { if l.mux.rewriteFunc != nil && l.rewriteHost != "" {
strPort := fmt.Sprintf(":%d", l.clientPort) fmt.Printf("host rewrite: %s\n", l.rewriteHost)
clientHost += strPort sConn, err := l.mux.rewriteFunc(conn, l.rewriteHost)
}
retConn, err := HostNameRewrite(conn, clientHost)
if err != nil { if err != nil {
return nil, fmt.Errorf("http host rewrite failed") return nil, fmt.Errorf("http host header rewrite failed")
}
conn.SetTcpConn(retConn)
} }
conn.SetTcpConn(sConn)
} }
return conn, nil return conn, nil
} }
@ -162,6 +157,7 @@ type sharedConn struct {
buff *bytes.Buffer buff *bytes.Buffer
} }
// the bytes you read in io.Reader, will be reserved in sharedConn
func newShareConn(conn net.Conn) (*sharedConn, io.Reader) { func newShareConn(conn net.Conn) (*sharedConn, io.Reader) {
sc := &sharedConn{ sc := &sharedConn{
Conn: conn, Conn: conn,