Fix conflicts in fatedier/connection_pool with dev

Conflicts:
	src/frp/cmd/frpc/control.go
	src/frp/cmd/frps/control.go
	src/frp/models/config/config.go
	src/frp/models/server/server.go
This commit is contained in:
fatedier 2016-07-29 23:38:34 +08:00
commit fd3c97a0e9
12 changed files with 229 additions and 55 deletions

View File

@ -55,3 +55,4 @@ local_ip = 127.0.0.1
local_port = 80
use_gzip = true
custom_domains = web03.yourdomain.com
host_header_rewrite = example.com

View File

@ -138,14 +138,14 @@ func loginToServer(cli *client.ProxyClient) (c *conn.Conn, err error) {
nowTime := time.Now().Unix()
req := &msg.ControlReq{
Type: consts.NewCtlConn,
ProxyName: cli.Name,
UseEncryption: cli.UseEncryption,
UseGzip: cli.UseGzip,
PoolCount: cli.PoolCount,
PrivilegeMode: cli.PrivilegeMode,
ProxyType: cli.Type,
Timestamp: nowTime,
Type: consts.NewCtlConn,
ProxyName: cli.Name,
UseEncryption: cli.UseEncryption,
UseGzip: cli.UseGzip,
PrivilegeMode: cli.PrivilegeMode,
ProxyType: cli.Type,
HostHeaderRewrite: cli.HostHeaderRewrite,
Timestamp: nowTime,
}
if cli.PrivilegeMode {
privilegeKey := pcrypto.GetAuthKey(cli.Name + client.PrivilegeToken + fmt.Sprintf("%d", nowTime))

View File

@ -276,6 +276,7 @@ func doLogin(req *msg.ControlReq, c *conn.Conn) (ret int64, info string) {
// set infomations from frpc
s.UseEncryption = req.UseEncryption
s.UseGzip = req.UseGzip
s.HostHeaderRewrite = req.HostHeaderRewrite
if req.PoolCount > server.MaxPoolCount {
s.PoolCount = server.MaxPoolCount
} else if req.PoolCount < 0 {

View File

@ -140,6 +140,14 @@ func LoadConf(confFile string) (err error) {
proxyClient.UseGzip = true
}
if proxyClient.Type == "http" {
// host_header_rewrite
tmpStr, ok = section["host_header_rewrite"]
if ok {
proxyClient.HostHeaderRewrite = tmpStr
}
}
// privilege_mode
proxyClient.PrivilegeMode = false
tmpStr, ok = section["privilege_mode"]
@ -178,6 +186,7 @@ func LoadConf(confFile string) (err error) {
return fmt.Errorf("Parse conf error: proxy [%s] remote_port not found", proxyClient.Name)
}
} else if proxyClient.Type == "http" {
// custom_domains
domainStr, ok := section["custom_domains"]
if ok {
proxyClient.CustomDomains = strings.Split(domainStr, ",")
@ -191,6 +200,7 @@ func LoadConf(confFile string) (err error) {
return fmt.Errorf("Parse conf error: proxy [%s] custom_domains must be set when type equals http", proxyClient.Name)
}
} else if proxyClient.Type == "https" {
// custom_domains
domainStr, ok := section["custom_domains"]
if ok {
proxyClient.CustomDomains = strings.Split(domainStr, ",")

View File

@ -15,12 +15,13 @@
package config
type BaseConf struct {
Name string
AuthToken string
Type string
UseEncryption bool
UseGzip bool
PrivilegeMode bool
PrivilegeToken string
PoolCount int64
Name string
AuthToken string
Type string
UseEncryption bool
UseGzip bool
PrivilegeMode bool
PrivilegeToken string
PoolCount int64
HostHeaderRewrite string
}

View File

@ -29,12 +29,13 @@ type ControlReq struct {
PoolCount int64 `json:"pool_count"`
// configures used if privilege_mode is enabled
PrivilegeMode bool `json:"privilege_mode"`
PrivilegeKey string `json:"privilege_key"`
ProxyType string `json:"proxy_type"`
RemotePort int64 `json:"remote_port"`
CustomDomains []string `json:"custom_domains, omitempty"`
Timestamp int64 `json:"timestamp"`
PrivilegeMode bool `json:"privilege_mode"`
PrivilegeKey string `json:"privilege_key"`
ProxyType string `json:"proxy_type"`
RemotePort int64 `json:"remote_port"`
CustomDomains []string `json:"custom_domains, omitempty"`
HostHeaderRewrite string `json:"host_header_rewrite"`
Timestamp int64 `json:"timestamp"`
}
type ControlRes struct {

View File

@ -15,12 +15,10 @@
package msg
import (
"bufio"
"bytes"
"encoding/binary"
"fmt"
"io"
"net"
"sync"
"frp/models/config"
@ -61,7 +59,7 @@ func JoinMore(c1 *conn.Conn, c2 *conn.Conn, conf config.BaseConf, needRecord boo
defer wait.Done()
// we don't care about errors here
pipeEncrypt(from.TcpConn, to.TcpConn, conf, needRecord)
pipeEncrypt(from, to, conf, needRecord)
}
decryptPipe := func(to *conn.Conn, from *conn.Conn) {
@ -70,7 +68,7 @@ func JoinMore(c1 *conn.Conn, c2 *conn.Conn, conf config.BaseConf, needRecord boo
defer wait.Done()
// we don't care about errors here
pipeDecrypt(to.TcpConn, from.TcpConn, conf, needRecord)
pipeDecrypt(to, from, conf, needRecord)
}
wait.Add(2)
@ -106,7 +104,7 @@ func unpkgMsg(data []byte) (int, []byte, []byte) {
}
// decrypt msg from reader, then write into writer
func pipeDecrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool) (err error) {
func pipeDecrypt(r *conn.Conn, w *conn.Conn, conf config.BaseConf, needRecord bool) (err error) {
laes := new(pcrypto.Pcrypto)
key := conf.AuthToken
if conf.PrivilegeMode {
@ -119,7 +117,7 @@ func pipeDecrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool)
buf := make([]byte, 5*1024+4)
var left, res []byte
var cnt int
var cnt int = -1
// record
var flowBytes int64 = 0
@ -129,13 +127,12 @@ func pipeDecrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool)
}()
}
nreader := bufio.NewReader(r)
for {
// there may be more than 1 package in variable
// and we read more bytes if unpkgMsg returns an error
var newBuf []byte
if cnt < 0 {
n, err := nreader.Read(buf)
n, err := r.Read(buf)
if err != nil {
return err
}
@ -165,7 +162,7 @@ func pipeDecrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool)
}
}
_, err = w.Write(res)
_, err = w.WriteBytes(res)
if err != nil {
return err
}
@ -182,7 +179,7 @@ func pipeDecrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool)
}
// recvive msg from reader, then encrypt msg into writer
func pipeEncrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool) (err error) {
func pipeEncrypt(r *conn.Conn, w *conn.Conn, conf config.BaseConf, needRecord bool) (err error) {
laes := new(pcrypto.Pcrypto)
key := conf.AuthToken
if conf.PrivilegeMode {
@ -201,10 +198,9 @@ func pipeEncrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool)
}()
}
nreader := bufio.NewReader(r)
buf := make([]byte, 5*1024)
for {
n, err := nreader.Read(buf)
n, err := r.Read(buf)
if err != nil {
return err
}
@ -235,7 +231,7 @@ func pipeEncrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool)
}
res = pkgMsg(res)
_, err = w.Write(res)
_, err = w.WriteBytes(res)
if err != nil {
return err
}

View File

@ -65,6 +65,7 @@ func NewProxyServerFromCtlMsg(req *msg.ControlReq) (p *ProxyServer) {
p.BindAddr = BindAddr
p.ListenPort = req.RemotePort
p.CustomDomains = req.CustomDomains
p.HostHeaderRewrite = req.HostHeaderRewrite
return
}
@ -81,7 +82,7 @@ func (p *ProxyServer) Init() {
func (p *ProxyServer) Compare(p2 *ProxyServer) bool {
if p.Name != p2.Name || p.AuthToken != p2.AuthToken || p.Type != p2.Type ||
p.BindAddr != p2.BindAddr || p.ListenPort != p2.ListenPort {
p.BindAddr != p2.BindAddr || p.ListenPort != p2.ListenPort || p.HostHeaderRewrite != p2.HostHeaderRewrite {
return false
}
if len(p.CustomDomains) != len(p2.CustomDomains) {
@ -115,7 +116,7 @@ func (p *ProxyServer) Start(c *conn.Conn) (err error) {
p.listeners = append(p.listeners, l)
} else if p.Type == "http" {
for _, domain := range p.CustomDomains {
l, err := VhostHttpMuxer.Listen(domain)
l, err := VhostHttpMuxer.Listen(domain, p.HostHeaderRewrite)
if err != nil {
return err
}
@ -123,7 +124,7 @@ func (p *ProxyServer) Start(c *conn.Conn) (err error) {
}
} else if p.Type == "https" {
for _, domain := range p.CustomDomains {
l, err := VhostHttpsMuxer.Listen(domain)
l, err := VhostHttpsMuxer.Listen(domain, p.HostHeaderRewrite)
if err != nil {
return err
}
@ -160,14 +161,12 @@ func (p *ProxyServer) Start(c *conn.Conn) (err error) {
return
}
// start another goroutine for join two connections between frpc and user
go func() {
go func(userConn *conn.Conn) {
workConn, err := p.getWorkConn()
if err != nil {
return
}
userConn := c
// message will be transferred to another without modifying
// l means local, r means remote
log.Debug("Join two connections, (l[%s] r[%s]) (l[%s] r[%s])", workConn.GetLocalAddr(), workConn.GetRemoteAddr(),
@ -176,7 +175,8 @@ func (p *ProxyServer) Start(c *conn.Conn) (err error) {
metric.OpenConnection(p.Name)
needRecord := true
go msg.JoinMore(userConn, workConn, p.BaseConf, needRecord)
}()
metric.OpenConnection(p.Name)
}(c)
}
}(listener)
}

View File

@ -117,6 +117,16 @@ func ConnectServer(host string, port int64) (c *Conn, err error) {
return c, nil
}
// if the tcpConn is different with c.TcpConn
// you should call c.Close() first
func (c *Conn) SetTcpConn(tcpConn net.Conn) {
c.mutex.Lock()
defer c.mutex.Unlock()
c.TcpConn = tcpConn
c.closeFlag = false
c.Reader = bufio.NewReader(c.TcpConn)
}
func (c *Conn) GetRemoteAddr() (addr string) {
return c.TcpConn.RemoteAddr().String()
}
@ -125,6 +135,11 @@ func (c *Conn) GetLocalAddr() (addr string) {
return c.TcpConn.LocalAddr().String()
}
func (c *Conn) Read(p []byte) (n int, err error) {
n, err = c.Reader.Read(p)
return
}
func (c *Conn) ReadLine() (buff string, err error) {
buff, err = c.Reader.ReadString('\n')
if err != nil {
@ -138,10 +153,14 @@ func (c *Conn) ReadLine() (buff string, err error) {
return buff, err
}
func (c *Conn) WriteBytes(content []byte) (n int, err error) {
n, err = c.TcpConn.Write(content)
return
}
func (c *Conn) Write(content string) (err error) {
_, err = c.TcpConn.Write([]byte(content))
return err
}
func (c *Conn) SetDeadline(t time.Time) error {

View File

@ -16,8 +16,12 @@ package vhost
import (
"bufio"
"bytes"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
"time"
@ -42,6 +46,123 @@ func GetHttpHostname(c *conn.Conn) (_ net.Conn, routerName string, err error) {
}
func NewHttpMuxer(listener *conn.Listener, timeout time.Duration) (*HttpMuxer, error) {
mux, err := NewVhostMuxer(listener, GetHttpHostname, timeout)
mux, err := NewVhostMuxer(listener, GetHttpHostname, HttpHostNameRewrite, timeout)
return &HttpMuxer{mux}, err
}
func HttpHostNameRewrite(c *conn.Conn, rewriteHost string) (_ net.Conn, err error) {
sc, rd := newShareConn(c.TcpConn)
var buff []byte
if buff, err = hostNameRewrite(rd, rewriteHost); err != nil {
return sc, err
}
err = sc.WriteBuff(buff)
return sc, err
}
func hostNameRewrite(request io.Reader, rewriteHost string) (_ []byte, err error) {
buffer := make([]byte, 1024)
request.Read(buffer)
retBuffer, err := parseRequest(buffer, rewriteHost)
return retBuffer, err
}
func parseRequest(org []byte, rewriteHost string) (ret []byte, err error) {
tp := bytes.NewBuffer(org)
// First line: GET /index.html HTTP/1.0
var b []byte
if b, err = tp.ReadBytes('\n'); err != nil {
return nil, err
}
req := new(http.Request)
// we invoked ReadRequest in GetHttpHostname before, so we ignore error
req.Method, req.RequestURI, req.Proto, _ = parseRequestLine(string(b))
rawurl := req.RequestURI
// CONNECT www.google.com:443 HTTP/1.1
justAuthority := req.Method == "CONNECT" && !strings.HasPrefix(rawurl, "/")
if justAuthority {
rawurl = "http://" + rawurl
}
req.URL, _ = url.ParseRequestURI(rawurl)
if justAuthority {
// Strip the bogus "http://" back off.
req.URL.Scheme = ""
}
// RFC2616: first case
// GET /index.html HTTP/1.1
// Host: www.google.com
if req.URL.Host == "" {
changedBuf, err := changeHostName(tp, rewriteHost)
buf := new(bytes.Buffer)
buf.Write(b)
buf.Write(changedBuf)
return buf.Bytes(), err
}
// RFC2616: second case
// GET http://www.google.com/index.html HTTP/1.1
// Host: doesntmatter
// In this case, any Host line is ignored.
hostPort := strings.Split(req.URL.Host, ":")
if len(hostPort) == 1 {
req.URL.Host = rewriteHost
} else if len(hostPort) == 2 {
req.URL.Host = fmt.Sprintf("%s:%s", rewriteHost, hostPort[1])
}
firstLine := req.Method + " " + req.URL.String() + " " + req.Proto
buf := new(bytes.Buffer)
buf.WriteString(firstLine)
tp.WriteTo(buf)
return buf.Bytes(), err
}
// parseRequestLine parses "GET /foo HTTP/1.1" into its three parts.
func parseRequestLine(line string) (method, requestURI, proto string, ok bool) {
s1 := strings.Index(line, " ")
s2 := strings.Index(line[s1+1:], " ")
if s1 < 0 || s2 < 0 {
return
}
s2 += s1 + 1
return line[:s1], line[s1+1 : s2], line[s2+1:], true
}
func changeHostName(buff *bytes.Buffer, rewriteHost string) (_ []byte, err error) {
retBuf := new(bytes.Buffer)
peek := buff.Bytes()
for len(peek) > 0 {
i := bytes.IndexByte(peek, '\n')
if i < 3 {
// Not present (-1) or found within the next few bytes,
// implying we're at the end ("\r\n\r\n" or "\n\n")
return nil, err
}
kv := peek[:i]
j := bytes.IndexByte(kv, ':')
if j < 0 {
return nil, fmt.Errorf("malformed MIME header line: " + string(kv))
}
if strings.Contains(strings.ToLower(string(kv[:j])), "host") {
var hostHeader string
portPos := bytes.IndexByte(kv[j+1:], ':')
if portPos == -1 {
hostHeader = fmt.Sprintf("Host: %s\n", rewriteHost)
} else {
hostHeader = fmt.Sprintf("Host: %s:%s\n", rewriteHost, kv[portPos+1:])
}
retBuf.WriteString(hostHeader)
peek = peek[i+1:]
break
} else {
retBuf.Write(peek[:i])
retBuf.WriteByte('\n')
}
peek = peek[i+1:]
}
retBuf.Write(peek)
return retBuf.Bytes(), err
}

View File

@ -47,7 +47,7 @@ type HttpsMuxer struct {
}
func NewHttpsMuxer(listener *conn.Listener, timeout time.Duration) (*HttpsMuxer, error) {
mux, err := NewVhostMuxer(listener, GetHttpsHostname, timeout)
mux, err := NewVhostMuxer(listener, GetHttpsHostname, nil, timeout)
return &HttpsMuxer{mux}, err
}

View File

@ -27,37 +27,42 @@ import (
)
type muxFunc func(*conn.Conn) (net.Conn, string, error)
type hostRewriteFunc func(*conn.Conn, string) (net.Conn, error)
type VhostMuxer struct {
listener *conn.Listener
timeout time.Duration
vhostFunc muxFunc
rewriteFunc hostRewriteFunc
registryMap map[string]*Listener
mutex sync.RWMutex
}
func NewVhostMuxer(listener *conn.Listener, vhostFunc muxFunc, timeout time.Duration) (mux *VhostMuxer, err error) {
func NewVhostMuxer(listener *conn.Listener, vhostFunc muxFunc, rewriteFunc hostRewriteFunc, timeout time.Duration) (mux *VhostMuxer, err error) {
mux = &VhostMuxer{
listener: listener,
timeout: timeout,
vhostFunc: vhostFunc,
rewriteFunc: rewriteFunc,
registryMap: make(map[string]*Listener),
}
go mux.run()
return mux, nil
}
func (v *VhostMuxer) Listen(name string) (l *Listener, err error) {
// listen for a new domain name, if rewriteHost is not empty and rewriteFunc is not nil, then rewrite the host header to rewriteHost
func (v *VhostMuxer) Listen(name string, rewriteHost string) (l *Listener, err error) {
v.mutex.Lock()
defer v.mutex.Unlock()
if _, exist := v.registryMap[name]; exist {
return nil, fmt.Errorf("name %s is already bound", name)
return nil, fmt.Errorf("domain name %s is already bound", name)
}
l = &Listener{
name: name,
mux: v,
accept: make(chan *conn.Conn),
name: name,
rewriteHost: rewriteHost,
mux: v,
accept: make(chan *conn.Conn),
}
v.registryMap[name] = l
return l, nil
@ -105,15 +110,16 @@ func (v *VhostMuxer) handle(c *conn.Conn) {
if err = sConn.SetDeadline(time.Time{}); err != nil {
return
}
c.TcpConn = sConn
c.SetTcpConn(sConn)
l.accept <- c
}
type Listener struct {
name string
mux *VhostMuxer // for closing VhostMuxer
accept chan *conn.Conn
name string
rewriteHost string
mux *VhostMuxer // for closing VhostMuxer
accept chan *conn.Conn
}
func (l *Listener) Accept() (*conn.Conn, error) {
@ -121,6 +127,17 @@ func (l *Listener) Accept() (*conn.Conn, error) {
if !ok {
return nil, fmt.Errorf("Listener closed")
}
// if rewriteFunc is exist and rewriteHost is set
// rewrite http requests with a modified host header
if l.mux.rewriteFunc != nil && l.rewriteHost != "" {
fmt.Printf("host rewrite: %s\n", l.rewriteHost)
sConn, err := l.mux.rewriteFunc(conn, l.rewriteHost)
if err != nil {
return nil, fmt.Errorf("http host header rewrite failed")
}
conn.SetTcpConn(sConn)
}
return conn, nil
}
@ -140,6 +157,7 @@ type sharedConn struct {
buff *bytes.Buffer
}
// the bytes you read in io.Reader, will be reserved in sharedConn
func newShareConn(conn net.Conn) (*sharedConn, io.Reader) {
sc := &sharedConn{
Conn: conn,
@ -166,3 +184,9 @@ func (sc *sharedConn) Read(p []byte) (n int, err error) {
sc.Unlock()
return
}
func (sc *sharedConn) WriteBuff(buffer []byte) (err error) {
sc.buff.Reset()
_, err = sc.buff.Write(buffer)
return err
}