diff --git a/.gitignore b/.gitignore index fab4548..e237cc4 100644 --- a/.gitignore +++ b/.gitignore @@ -26,3 +26,6 @@ _testmain.go # Self bin/ +# Cache +*.swp +*.swo diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 0000000..68fda64 --- /dev/null +++ b/.travis.yml @@ -0,0 +1,12 @@ +sudo: false +language: go + +go: + - 1.4.2 + - 1.5.2 + +install: + - make + +script: + - make test diff --git a/Makefile b/Makefile index 06a5054..0151542 100644 --- a/Makefile +++ b/Makefile @@ -2,14 +2,20 @@ export PATH := $(GOPATH)/bin:$(PATH) all: build -build: godep frps frpc +build: godep fmt frps frpc godep: @go get github.com/tools/godep godep restore +fmt: + @godep go fmt ./... + frps: godep go build -o bin/frps ./cmd/frps frpc: godep go build -o bin/frpc ./cmd/frpc + +test: + @godep go test ./... diff --git a/README.md b/README.md index f6f1b7f..766797f 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,5 @@ # frp + +[![Build Status](https://travis-ci.org/fatedier/frp.svg)](https://travis-ci.org/fatedier/frp) + A fast reverse proxy. diff --git a/cmd/frpc/control.go b/cmd/frpc/control.go index e917a93..c328555 100644 --- a/cmd/frpc/control.go +++ b/cmd/frpc/control.go @@ -1,30 +1,73 @@ package main import ( + "encoding/json" + "fmt" "io" "sync" - "encoding/json" + "time" - "frp/pkg/models" - "frp/pkg/utils/conn" - "frp/pkg/utils/log" + "github.com/fatedier/frp/models/client" + "github.com/fatedier/frp/models/consts" + "github.com/fatedier/frp/models/msg" + "github.com/fatedier/frp/utils/conn" + "github.com/fatedier/frp/utils/log" ) -func ControlProcess(cli *models.ProxyClient, wait *sync.WaitGroup) { +var isHeartBeatContinue bool = true + +func ControlProcess(cli *client.ProxyClient, wait *sync.WaitGroup) { defer wait.Done() - c := &conn.Conn{} - err := c.ConnectServer(ServerAddr, ServerPort) + c, err := loginToServer(cli) if err != nil { - log.Error("ProxyName [%s], connect to server [%s:%d] error, %v", cli.Name, ServerAddr, ServerPort, err) + log.Error("ProxyName [%s], connect to server failed!", cli.Name) return } defer c.Close() - req := &models.ClientCtlReq{ - Type: models.ControlConn, - ProxyName: cli.Name, - Passwd: cli.Passwd, + for { + // ignore response content now + _, err := c.ReadLine() + if err == io.EOF { + isHeartBeatContinue = false + log.Debug("ProxyName [%s], server close this control conn", cli.Name) + var sleepTime time.Duration = 1 + for { + log.Debug("ProxyName [%s], try to reconnect to server[%s:%d]...", cli.Name, client.ServerAddr, client.ServerPort) + tmpConn, err := loginToServer(cli) + if err == nil { + c.Close() + c = tmpConn + break + } + + if sleepTime < 60 { + sleepTime = sleepTime * 2 + } + time.Sleep(sleepTime * time.Second) + } + continue + } else if err != nil { + log.Warn("ProxyName [%s], read from server error, %v", cli.Name, err) + continue + } + + cli.StartTunnel(client.ServerAddr, client.ServerPort) + } +} + +func loginToServer(cli *client.ProxyClient) (c *conn.Conn, err error) { + c, err = conn.ConnectServer(client.ServerAddr, client.ServerPort) + if err != nil { + log.Error("ProxyName [%s], connect to server [%s:%d] error, %v", cli.Name, client.ServerAddr, client.ServerPort, err) + return + } + + req := &msg.ClientCtlReq{ + Type: consts.CtlConn, + ProxyName: cli.Name, + Passwd: cli.Passwd, } buf, _ := json.Marshal(req) err = c.Write(string(buf) + "\n") @@ -40,7 +83,7 @@ func ControlProcess(cli *models.ProxyClient, wait *sync.WaitGroup) { } log.Debug("ProxyName [%s], read [%s]", cli.Name, res) - clientCtlRes := &models.ClientCtlRes{} + clientCtlRes := &msg.ClientCtlRes{} if err = json.Unmarshal([]byte(res), &clientCtlRes); err != nil { log.Error("ProxyName [%s], format server response error, %v", cli.Name, err) return @@ -48,20 +91,28 @@ func ControlProcess(cli *models.ProxyClient, wait *sync.WaitGroup) { if clientCtlRes.Code != 0 { log.Error("ProxyName [%s], start proxy error, %s", cli.Name, clientCtlRes.Msg) - return + return c, fmt.Errorf("%s", clientCtlRes.Msg) } - for { - // ignore response content now - _, err := c.ReadLine() - if err == io.EOF { - log.Debug("ProxyName [%s], server close this control conn", cli.Name) - break - } else if err != nil { - log.Warn("ProxyName [%s], read from server error, %v", cli.Name, err) - continue - } + go startHeartBeat(c) + log.Debug("ProxyName [%s], connect to server[%s:%d] success!", cli.Name, client.ServerAddr, client.ServerPort) - cli.StartTunnel(ServerAddr, ServerPort) - } + return +} + +func startHeartBeat(c *conn.Conn) { + log.Debug("Start to send heartbeat") + for { + time.Sleep(time.Duration(client.HeartBeatInterval) * time.Second) + if !c.IsClosed() { + err := c.Write("\n") + if err != nil { + log.Error("Send hearbeat to server failed! Err:%s", err.Error()) + continue + } + } else { + break + } + } + log.Info("heartbeat exit") } diff --git a/cmd/frpc/main.go b/cmd/frpc/main.go index 7f07282..c17f3e7 100644 --- a/cmd/frpc/main.go +++ b/cmd/frpc/main.go @@ -3,23 +3,24 @@ package main import ( "os" "sync" - - "frp/pkg/utils/log" + + "github.com/fatedier/frp/models/client" + "github.com/fatedier/frp/utils/log" ) func main() { - err := LoadConf("./frpc.ini") + err := client.LoadConf("./frpc.ini") if err != nil { os.Exit(-1) } - log.InitLog(LogWay, LogFile, LogLevel) + log.InitLog(client.LogWay, client.LogFile, client.LogLevel) // wait until all control goroutine exit var wait sync.WaitGroup - wait.Add(len(ProxyClients)) + wait.Add(len(client.ProxyClients)) - for _, client := range ProxyClients { + for _, client := range client.ProxyClients { go ControlProcess(client, &wait) } diff --git a/cmd/frps/control.go b/cmd/frps/control.go index 62d141e..02ff86e 100644 --- a/cmd/frps/control.go +++ b/cmd/frps/control.go @@ -1,12 +1,16 @@ package main import ( - "fmt" "encoding/json" + "fmt" + "io" + "time" - "frp/pkg/utils/log" - "frp/pkg/utils/conn" - "frp/pkg/models" + "github.com/fatedier/frp/models/consts" + "github.com/fatedier/frp/models/msg" + "github.com/fatedier/frp/models/server" + "github.com/fatedier/frp/utils/conn" + "github.com/fatedier/frp/utils/log" ) func ProcessControlConn(l *conn.Listener) { @@ -17,7 +21,7 @@ func ProcessControlConn(l *conn.Listener) { } } -// control connection from every client and server +// connection from every client and server func controlWorker(c *conn.Conn) { // the first message is from client to server // if error, close connection @@ -28,107 +32,146 @@ func controlWorker(c *conn.Conn) { } log.Debug("get: %s", res) - clientCtlReq := &models.ClientCtlReq{} - clientCtlRes := &models.ClientCtlRes{} + clientCtlReq := &msg.ClientCtlReq{} + clientCtlRes := &msg.ClientCtlRes{} if err := json.Unmarshal([]byte(res), &clientCtlReq); err != nil { log.Warn("Parse err: %v : %s", err, res) return } // check - succ, msg, needRes := checkProxy(clientCtlReq, c) + succ, info, needRes := checkProxy(clientCtlReq, c) if !succ { clientCtlRes.Code = 1 - clientCtlRes.Msg = msg + clientCtlRes.Msg = info } - + if needRes { + // control conn + defer c.Close() + buf, _ := json.Marshal(clientCtlRes) err = c.Write(string(buf) + "\n") if err != nil { log.Warn("Write error, %v", err) + time.Sleep(1 * time.Second) + return } } else { - // work conn, just return + // work conn, just return return } - defer c.Close() // others is from server to client - server, ok := ProxyServers[clientCtlReq.ProxyName] + s, ok := server.ProxyServers[clientCtlReq.ProxyName] if !ok { log.Warn("ProxyName [%s] is not exist", clientCtlReq.ProxyName) return } - serverCtlReq := &models.ClientCtlReq{} - serverCtlReq.Type = models.WorkConn + // read control msg from client + go readControlMsgFromClient(s, c) + + serverCtlReq := &msg.ClientCtlReq{} + serverCtlReq.Type = consts.WorkConn for { - server.WaitUserConn() + closeFlag := s.WaitUserConn() + if closeFlag { + log.Debug("ProxyName [%s], goroutine for dealing user conn is closed", s.Name) + break + } buf, _ := json.Marshal(serverCtlReq) err = c.Write(string(buf) + "\n") if err != nil { - log.Warn("ProxyName [%s], write to client error, proxy exit", server.Name) - server.Close() + log.Warn("ProxyName [%s], write to client error, proxy exit", s.Name) + s.Close() return } - log.Debug("ProxyName [%s], write to client to add work conn success", server.Name) + log.Debug("ProxyName [%s], write to client to add work conn success", s.Name) } + log.Info("ProxyName [%s], I'm dead!", s.Name) return } -func checkProxy(req *models.ClientCtlReq, c *conn.Conn) (succ bool, msg string, needRes bool) { +func checkProxy(req *msg.ClientCtlReq, c *conn.Conn) (succ bool, info string, needRes bool) { succ = false needRes = true // check if proxy name exist - server, ok := ProxyServers[req.ProxyName] + s, ok := server.ProxyServers[req.ProxyName] if !ok { - msg = fmt.Sprintf("ProxyName [%s] is not exist", req.ProxyName) - log.Warn(msg) + info = fmt.Sprintf("ProxyName [%s] is not exist", req.ProxyName) + log.Warn(info) return } // check password - if req.Passwd != server.Passwd { - msg = fmt.Sprintf("ProxyName [%s], password is not correct", req.ProxyName) - log.Warn(msg) + if req.Passwd != s.Passwd { + info = fmt.Sprintf("ProxyName [%s], password is not correct", req.ProxyName) + log.Warn(info) return } - + // control conn - if req.Type == models.ControlConn { - if server.Status != models.Idle { - msg = fmt.Sprintf("ProxyName [%s], already in use", req.ProxyName) - log.Warn(msg) + if req.Type == consts.CtlConn { + if s.Status != consts.Idle { + info = fmt.Sprintf("ProxyName [%s], already in use", req.ProxyName) + log.Warn(info) return } // start proxy and listen for user conn, no block - err := server.Start() + err := s.Start() if err != nil { - msg = fmt.Sprintf("ProxyName [%s], start proxy error: %v", req.ProxyName, err.Error()) - log.Warn(msg) + info = fmt.Sprintf("ProxyName [%s], start proxy error: %v", req.ProxyName, err.Error()) + log.Warn(info) return } log.Info("ProxyName [%s], start proxy success", req.ProxyName) - } else if req.Type == models.WorkConn { - // work conn + } else if req.Type == consts.WorkConn { + // work conn needRes = false - if server.Status != models.Working { + if s.Status != consts.Working { log.Warn("ProxyName [%s], is not working when it gets one new work conn", req.ProxyName) return } - server.CliConnChan <- c + s.CliConnChan <- c } else { - msg = fmt.Sprintf("ProxyName [%s], type [%d] unsupport", req.ProxyName) - log.Warn(msg) + info = fmt.Sprintf("ProxyName [%s], type [%d] unsupport", req.ProxyName, req.Type) + log.Warn(info) return } succ = true return } + +func readControlMsgFromClient(s *server.ProxyServer, c *conn.Conn) { + isContinueRead := true + f := func() { + isContinueRead = false + c.Close() + s.Close() + } + timer := time.AfterFunc(time.Duration(server.HeartBeatTimeout)*time.Second, f) + defer timer.Stop() + + for isContinueRead { + _, err := c.ReadLine() + if err != nil { + if err == io.EOF { + log.Warn("ProxyName [%s], client is dead!", s.Name) + c.Close() + s.Close() + break + } + log.Error("ProxyName [%s], read error: %v", s.Name, err) + continue + } + + timer.Reset(time.Duration(server.HeartBeatTimeout) * time.Second) + } +} diff --git a/cmd/frps/main.go b/cmd/frps/main.go index 1288622..e21f927 100644 --- a/cmd/frps/main.go +++ b/cmd/frps/main.go @@ -3,19 +3,20 @@ package main import ( "os" - "frp/pkg/utils/log" - "frp/pkg/utils/conn" + "github.com/fatedier/frp/models/server" + "github.com/fatedier/frp/utils/conn" + "github.com/fatedier/frp/utils/log" ) func main() { - err := LoadConf("./frps.ini") + err := server.LoadConf("./frps.ini") if err != nil { os.Exit(-1) } - log.InitLog(LogWay, LogFile, LogLevel) + log.InitLog(server.LogWay, server.LogFile, server.LogLevel) - l, err := conn.Listen(BindAddr, BindPort) + l, err := conn.Listen(server.BindAddr, server.BindPort) if err != nil { log.Error("Create listener error, %v", err) os.Exit(-1) diff --git a/conf/frpc.ini b/conf/frpc.ini index d2ba710..f6df4b6 100644 --- a/conf/frpc.ini +++ b/conf/frpc.ini @@ -4,9 +4,9 @@ server_addr = 127.0.0.1 bind_port = 7000 log_file = ./frpc.log # debug, info, warn, error -log_level = info +log_level = debug # file, console -log_way = file +log_way = console # test1即为name [test1] diff --git a/conf/frps.ini b/conf/frps.ini index f6a6995..0c44cb1 100644 --- a/conf/frps.ini +++ b/conf/frps.ini @@ -4,9 +4,9 @@ bind_addr = 0.0.0.0 bind_port = 7000 log_file = ./frps.log # debug, info, warn, error -log_level = info +log_level = debug # file, console -log_way = file +log_way = console # test1即为name [test1] diff --git a/pkg/models/client.go b/models/client/client.go similarity index 67% rename from pkg/models/client.go rename to models/client/client.go index 1f01d50..81a0448 100644 --- a/pkg/models/client.go +++ b/models/client/client.go @@ -1,21 +1,22 @@ -package models +package client import ( "encoding/json" - "frp/pkg/utils/conn" - "frp/pkg/utils/log" + "github.com/fatedier/frp/models/consts" + "github.com/fatedier/frp/models/msg" + "github.com/fatedier/frp/utils/conn" + "github.com/fatedier/frp/utils/log" ) type ProxyClient struct { - Name string - Passwd string - LocalPort int64 + Name string + Passwd string + LocalPort int64 } func (p *ProxyClient) GetLocalConn() (c *conn.Conn, err error) { - c = &conn.Conn{} - err = c.ConnectServer("127.0.0.1", p.LocalPort) + c, err = conn.ConnectServer("127.0.0.1", p.LocalPort) if err != nil { log.Error("ProxyName [%s], connect to local port error, %v", p.Name, err) } @@ -23,23 +24,22 @@ func (p *ProxyClient) GetLocalConn() (c *conn.Conn, err error) { } func (p *ProxyClient) GetRemoteConn(addr string, port int64) (c *conn.Conn, err error) { - c = &conn.Conn{} - defer func(){ + defer func() { if err != nil { c.Close() } }() - err = c.ConnectServer(addr, port) + c, err = conn.ConnectServer(addr, port) if err != nil { log.Error("ProxyName [%s], connect to server [%s:%d] error, %v", p.Name, addr, port, err) return } - req := &ClientCtlReq{ - Type: WorkConn, - ProxyName: p.Name, - Passwd: p.Passwd, + req := &msg.ClientCtlReq{ + Type: consts.WorkConn, + ProxyName: p.Name, + Passwd: p.Passwd, } buf, _ := json.Marshal(req) @@ -63,8 +63,9 @@ func (p *ProxyClient) StartTunnel(serverAddr string, serverPort int64) (err erro return } + // l means local, r means remote log.Debug("Join two conns, (l[%s] r[%s]) (l[%s] r[%s])", localConn.GetLocalAddr(), localConn.GetRemoteAddr(), - remoteConn.GetLocalAddr(), remoteConn.GetRemoteAddr()) + remoteConn.GetLocalAddr(), remoteConn.GetRemoteAddr()) go conn.Join(localConn, remoteConn) return nil } diff --git a/cmd/frpc/config.go b/models/client/config.go similarity index 81% rename from cmd/frpc/config.go rename to models/client/config.go index a26fe6a..063b216 100644 --- a/cmd/frpc/config.go +++ b/models/client/config.go @@ -1,25 +1,23 @@ -package main +package client import ( "fmt" "strconv" - "frp/pkg/models" - ini "github.com/vaughan0/go-ini" ) // common config var ( - ServerAddr string = "0.0.0.0" - ServerPort int64 = 7000 - LogFile string = "./frpc.log" - LogLevel string = "warn" - LogWay string = "file" + ServerAddr string = "0.0.0.0" + ServerPort int64 = 7000 + LogFile string = "./frpc.log" + LogLevel string = "warn" + LogWay string = "file" + HeartBeatInterval int64 = 5 ) -var ProxyClients map[string]*models.ProxyClient = make(map[string]*models.ProxyClient) - +var ProxyClients map[string]*ProxyClient = make(map[string]*ProxyClient) func LoadConf(confFile string) (err error) { var tmpStr string @@ -59,7 +57,7 @@ func LoadConf(confFile string) (err error) { // servers for name, section := range conf { if name != "common" { - proxyClient := &models.ProxyClient{} + proxyClient := &ProxyClient{} proxyClient.Name = name proxyClient.Passwd, ok = section["passwd"] diff --git a/models/consts/consts.go b/models/consts/consts.go new file mode 100644 index 0000000..51dfe20 --- /dev/null +++ b/models/consts/consts.go @@ -0,0 +1,13 @@ +package consts + +// server status +const ( + Idle = iota + Working +) + +// connection type +const ( + CtlConn = iota + WorkConn +) diff --git a/models/msg/msg.go b/models/msg/msg.go new file mode 100644 index 0000000..6555296 --- /dev/null +++ b/models/msg/msg.go @@ -0,0 +1,20 @@ +package msg + +type GeneralRes struct { + Code int64 `json:"code"` + Msg string `json:"msg"` +} + +type ClientCtlReq struct { + Type int64 `json:"type"` + ProxyName string `json:"proxy_name"` + Passwd string `json:"passwd"` +} + +type ClientCtlRes struct { + GeneralRes +} + +type ServerCtlReq struct { + Type int64 `json:"type"` +} diff --git a/cmd/frps/config.go b/models/server/config.go similarity index 82% rename from cmd/frps/config.go rename to models/server/config.go index feb07d5..f9e974e 100644 --- a/cmd/frps/config.go +++ b/models/server/config.go @@ -1,25 +1,23 @@ -package main +package server import ( "fmt" "strconv" - "frp/pkg/models" - ini "github.com/vaughan0/go-ini" ) // common config var ( - BindAddr string = "0.0.0.0" - BindPort int64 = 9527 - LogFile string = "./frps.log" - LogLevel string = "warn" - LogWay string = "file" + BindAddr string = "0.0.0.0" + BindPort int64 = 9527 + LogFile string = "./frps.log" + LogLevel string = "warn" + LogWay string = "file" + HeartBeatTimeout int64 = 30 ) -var ProxyServers map[string]*models.ProxyServer = make(map[string]*models.ProxyServer) - +var ProxyServers map[string]*ProxyServer = make(map[string]*ProxyServer) func LoadConf(confFile string) (err error) { var tmpStr string @@ -59,7 +57,7 @@ func LoadConf(confFile string) (err error) { // servers for name, section := range conf { if name != "common" { - proxyServer := &models.ProxyServer{} + proxyServer := &ProxyServer{} proxyServer.Name = name proxyServer.Passwd, ok = section["passwd"] diff --git a/models/server/server.go b/models/server/server.go new file mode 100644 index 0000000..889b2d7 --- /dev/null +++ b/models/server/server.go @@ -0,0 +1,130 @@ +package server + +import ( + "container/list" + "sync" + + "github.com/fatedier/frp/models/consts" + "github.com/fatedier/frp/utils/conn" + "github.com/fatedier/frp/utils/log" +) + +type ProxyServer struct { + Name string + Passwd string + BindAddr string + ListenPort int64 + Status int64 + CliConnChan chan *conn.Conn // get client conns from control goroutine + + listener *conn.Listener // accept new connection from remote users + ctlMsgChan chan int64 // every time accept a new user conn, put "1" to the channel + userConnList *list.List // store user conns + mutex sync.Mutex +} + +func (p *ProxyServer) Init() { + p.Status = consts.Idle + p.CliConnChan = make(chan *conn.Conn) + p.ctlMsgChan = make(chan int64) + p.userConnList = list.New() +} + +func (p *ProxyServer) Lock() { + p.mutex.Lock() +} + +func (p *ProxyServer) Unlock() { + p.mutex.Unlock() +} + +// start listening for user conns +func (p *ProxyServer) Start() (err error) { + p.Init() + p.listener, err = conn.Listen(p.BindAddr, p.ListenPort) + if err != nil { + return err + } + + p.Status = consts.Working + + // start a goroutine for listener + go func() { + for { + // block + // if listener is closed, get nil + c := p.listener.GetConn() + if c == nil { + log.Info("ProxyName [%s], listener is closed", p.Name) + return + } + log.Debug("ProxyName [%s], get one new user conn [%s]", p.Name, c.GetRemoteAddr()) + + // insert into list + p.Lock() + if p.Status != consts.Working { + log.Debug("ProxyName [%s] is not working, new user conn close", p.Name) + c.Close() + p.Unlock() + return + } + p.userConnList.PushBack(c) + p.Unlock() + + // put msg to control conn + p.ctlMsgChan <- 1 + } + }() + + // start another goroutine for join two conns from client and user + go func() { + for { + cliConn, ok := <-p.CliConnChan + if !ok { + return + } + + p.Lock() + element := p.userConnList.Front() + + var userConn *conn.Conn + if element != nil { + userConn = element.Value.(*conn.Conn) + p.userConnList.Remove(element) + } else { + cliConn.Close() + p.Unlock() + continue + } + p.Unlock() + + // msg will transfer to another without modifying + // l means local, r means remote + log.Debug("Join two conns, (l[%s] r[%s]) (l[%s] r[%s])", cliConn.GetLocalAddr(), cliConn.GetRemoteAddr(), + userConn.GetLocalAddr(), userConn.GetRemoteAddr()) + go conn.Join(cliConn, userConn) + } + }() + + return nil +} + +func (p *ProxyServer) Close() { + p.Lock() + p.Status = consts.Idle + p.listener.Close() + close(p.ctlMsgChan) + close(p.CliConnChan) + p.userConnList = list.New() + p.Unlock() +} + +func (p *ProxyServer) WaitUserConn() (closeFlag bool) { + closeFlag = false + + _, ok := <-p.ctlMsgChan + if !ok { + closeFlag = true + } + return +} diff --git a/pkg/models/msg.go b/pkg/models/msg.go deleted file mode 100644 index 0062556..0000000 --- a/pkg/models/msg.go +++ /dev/null @@ -1,27 +0,0 @@ -package models - -type GeneralRes struct { - Code int64 `json:"code"` - Msg string `json:"msg"` -} - -// type -const ( - ControlConn = iota - WorkConn -) - -type ClientCtlReq struct { - Type int64 `json:"type"` - ProxyName string `json:"proxy_name"` - Passwd string `json:"passwd"` -} - -type ClientCtlRes struct { - GeneralRes -} - - -type ServerCtlReq struct { - Type int64 `json:"type"` -} diff --git a/pkg/models/server.go b/pkg/models/server.go deleted file mode 100644 index bd6baa8..0000000 --- a/pkg/models/server.go +++ /dev/null @@ -1,116 +0,0 @@ -package models - -import ( - "sync" - "container/list" - - "frp/pkg/utils/conn" - "frp/pkg/utils/log" -) - -const ( - Idle = iota - Working -) - -type ProxyServer struct { - Name string - Passwd string - BindAddr string - ListenPort int64 - - Status int64 - Listener *conn.Listener // accept new connection from remote users - CtlMsgChan chan int64 // every time accept a new user conn, put "1" to the channel - CliConnChan chan *conn.Conn // get client conns from control goroutine - UserConnList *list.List // store user conns - Mutex sync.Mutex -} - -func (p *ProxyServer) Init() { - p.Status = Idle - p.CtlMsgChan = make(chan int64) - p.CliConnChan = make(chan *conn.Conn) - p.UserConnList = list.New() -} - -func (p *ProxyServer) Lock() { - p.Mutex.Lock() -} - -func (p *ProxyServer) Unlock() { - p.Mutex.Unlock() -} - -// start listening for user conns -func (p *ProxyServer) Start() (err error) { - p.Listener, err = conn.Listen(p.BindAddr, p.ListenPort) - if err != nil { - return err - } - - p.Status = Working - - // start a goroutine for listener - go func() { - for { - // block - c := p.Listener.GetConn() - log.Debug("ProxyName [%s], get one new user conn [%s]", p.Name, c.GetRemoteAddr()) - - // put to list - p.Lock() - if p.Status != Working { - log.Debug("ProxyName [%s] is not working, new user conn close", p.Name) - c.Close() - p.Unlock() - return - } - p.UserConnList.PushBack(c) - p.Unlock() - - // put msg to control conn - p.CtlMsgChan <- 1 - } - }() - - // start another goroutine for join two conns from client and user - go func() { - for { - cliConn := <-p.CliConnChan - p.Lock() - element := p.UserConnList.Front() - - var userConn *conn.Conn - if element != nil { - userConn = element.Value.(*conn.Conn) - p.UserConnList.Remove(element) - } else { - cliConn.Close() - continue - } - p.Unlock() - - // msg will transfer to another without modifying - log.Debug("Join two conns, (l[%s] r[%s]) (l[%s] r[%s])", cliConn.GetLocalAddr(), cliConn.GetRemoteAddr(), - userConn.GetLocalAddr(), userConn.GetRemoteAddr()) - go conn.Join(cliConn, userConn) - } - }() - - return nil -} - -func (p *ProxyServer) Close() { - p.Lock() - p.Status = Idle - p.CtlMsgChan = make(chan int64) - p.CliConnChan = make(chan *conn.Conn) - p.UserConnList = list.New() - p.Unlock() -} - -func (p *ProxyServer) WaitUserConn() (res int64) { - res = <-p.CtlMsgChan - return -} diff --git a/utils/broadcast/broadcast.go b/utils/broadcast/broadcast.go new file mode 100644 index 0000000..4d45012 --- /dev/null +++ b/utils/broadcast/broadcast.go @@ -0,0 +1,73 @@ +package broadcast + +type Broadcast struct { + listeners []chan interface{} + reg chan (chan interface{}) + unreg chan (chan interface{}) + in chan interface{} + stop chan int64 + stopStatus bool +} + +func NewBroadcast() *Broadcast { + b := &Broadcast{ + listeners: make([]chan interface{}, 0), + reg: make(chan (chan interface{})), + unreg: make(chan (chan interface{})), + in: make(chan interface{}), + stop: make(chan int64), + stopStatus: false, + } + + go func() { + for { + select { + case l := <-b.unreg: + // remove L from b.listeners + // this operation is slow: O(n) but not used frequently + // unlike iterating over listeners + oldListeners := b.listeners + b.listeners = make([]chan interface{}, 0, len(oldListeners)) + for _, oldL := range oldListeners { + if l != oldL { + b.listeners = append(b.listeners, oldL) + } + } + + case l := <-b.reg: + b.listeners = append(b.listeners, l) + + case item := <-b.in: + for _, l := range b.listeners { + l <- item + } + + case _ = <-b.stop: + b.stopStatus = true + break + } + } + }() + + return b +} + +func (b *Broadcast) In() chan interface{} { + return b.in +} + +func (b *Broadcast) Reg() chan interface{} { + listener := make(chan interface{}) + b.reg <- listener + return listener +} + +func (b *Broadcast) UnReg(listener chan interface{}) { + b.unreg <- listener +} + +func (b *Broadcast) Close() { + if b.stopStatus == false { + b.stop <- 1 + } +} diff --git a/utils/broadcast/broadcast_test.go b/utils/broadcast/broadcast_test.go new file mode 100644 index 0000000..3354adc --- /dev/null +++ b/utils/broadcast/broadcast_test.go @@ -0,0 +1,63 @@ +package broadcast + +import ( + "sync" + "testing" + "time" +) + +var ( + totalNum int = 5 + succNum int = 0 + mutex sync.Mutex +) + +func TestBroadcast(t *testing.T) { + b := NewBroadcast() + if b == nil { + t.Errorf("New Broadcast error, nil return") + } + defer b.Close() + + var wait sync.WaitGroup + wait.Add(totalNum) + for i := 0; i < totalNum; i++ { + go worker(b, &wait) + } + + time.Sleep(1e6 * 20) + msg := "test" + b.In() <- msg + + wait.Wait() + if succNum != totalNum { + t.Errorf("TotalNum %d, FailNum(timeout) %d", totalNum, totalNum-succNum) + } +} + +func worker(b *Broadcast, wait *sync.WaitGroup) { + defer wait.Done() + msgChan := b.Reg() + + // exit if nothing got in 2 seconds + timeout := make(chan bool, 1) + go func() { + time.Sleep(time.Duration(2) * time.Second) + timeout <- true + }() + + select { + case item := <-msgChan: + msg := item.(string) + if msg == "test" { + mutex.Lock() + succNum++ + mutex.Unlock() + } else { + break + } + + case <-timeout: + break + } +} diff --git a/pkg/utils/conn/conn.go b/utils/conn/conn.go similarity index 60% rename from pkg/utils/conn/conn.go rename to utils/conn/conn.go index 60929ac..4cf6762 100644 --- a/pkg/utils/conn/conn.go +++ b/utils/conn/conn.go @@ -1,43 +1,97 @@ package conn import ( - "fmt" - "net" "bufio" - "sync" + "fmt" "io" + "net" + "sync" - "frp/pkg/utils/log" + "github.com/fatedier/frp/utils/log" ) type Listener struct { - Addr net.Addr - Conns chan *Conn + addr net.Addr + l *net.TCPListener + conns chan *Conn + closeFlag bool } -// wait util get one +func Listen(bindAddr string, bindPort int64) (l *Listener, err error) { + tcpAddr, err := net.ResolveTCPAddr("tcp4", fmt.Sprintf("%s:%d", bindAddr, bindPort)) + listener, err := net.ListenTCP("tcp", tcpAddr) + if err != nil { + return l, err + } + + l = &Listener{ + addr: listener.Addr(), + l: listener, + conns: make(chan *Conn), + closeFlag: false, + } + + go func() { + for { + conn, err := l.l.AcceptTCP() + if err != nil { + if l.closeFlag { + return + } + continue + } + + c := &Conn{ + TcpConn: conn, + closeFlag: false, + } + c.Reader = bufio.NewReader(c.TcpConn) + l.conns <- c + } + }() + return l, err +} + +// wait util get one new connection or close +// if listener is closed, return nil func (l *Listener) GetConn() (conn *Conn) { - conn = <-l.Conns + var ok bool + conn, ok = <-l.conns + if !ok { + return nil + } return conn } -type Conn struct { - TcpConn *net.TCPConn - Reader *bufio.Reader +func (l *Listener) Close() { + if l.l != nil && l.closeFlag == false { + l.closeFlag = true + l.l.Close() + close(l.conns) + } } -func (c *Conn) ConnectServer(host string, port int64) (err error) { +// wrap for TCPConn +type Conn struct { + TcpConn *net.TCPConn + Reader *bufio.Reader + closeFlag bool +} + +func ConnectServer(host string, port int64) (c *Conn, err error) { + c = &Conn{} servertAddr, err := net.ResolveTCPAddr("tcp4", fmt.Sprintf("%s:%d", host, port)) if err != nil { - return err + return } conn, err := net.DialTCP("tcp", nil, servertAddr) if err != nil { - return err + return } c.TcpConn = conn c.Reader = bufio.NewReader(c.TcpConn) - return nil + c.closeFlag = false + return c, nil } func (c *Conn) GetRemoteAddr() (addr string) { @@ -50,6 +104,9 @@ func (c *Conn) GetLocalAddr() (addr string) { func (c *Conn) ReadLine() (buff string, err error) { buff, err = c.Reader.ReadString('\n') + if err == io.EOF { + c.closeFlag = true + } return buff, err } @@ -59,40 +116,17 @@ func (c *Conn) Write(content string) (err error) { } func (c *Conn) Close() { - c.TcpConn.Close() + if c.TcpConn != nil { + c.closeFlag = true + c.TcpConn.Close() + } } -func Listen(bindAddr string, bindPort int64) (l *Listener, err error) { - tcpAddr, err := net.ResolveTCPAddr("tcp4", fmt.Sprintf("%s:%d", bindAddr, bindPort)) - listener, err := net.ListenTCP("tcp", tcpAddr) - if err != nil { - return l, err - } - - l = &Listener{ - Addr: listener.Addr(), - Conns: make(chan *Conn), - } - - go func() { - for { - conn, err := listener.AcceptTCP() - if err != nil { - log.Error("Accept new tcp connection error, %v", err) - continue - } - - c := &Conn{ - TcpConn: conn, - } - c.Reader = bufio.NewReader(c.TcpConn) - l.Conns <- c - } - }() - return l, err +func (c *Conn) IsClosed() bool { + return c.closeFlag } -// will block until conn close +// will block until connection close func Join(c1 *Conn, c2 *Conn) { var wait sync.WaitGroup pipe := func(to *Conn, from *Conn) { diff --git a/pkg/utils/log/log.go b/utils/log/log.go similarity index 94% rename from pkg/utils/log/log.go rename to utils/log/log.go index 1a55c3c..f6587cd 100644 --- a/pkg/utils/log/log.go +++ b/utils/log/log.go @@ -22,7 +22,7 @@ func SetLogFile(logWay string, logFile string) { if logWay == "console" { Log.SetLogger("console", "") } else { - Log.SetLogger("file", `{"filename": "` + logFile + `"}`) + Log.SetLogger("file", `{"filename": "`+logFile+`"}`) } } diff --git a/utils/pcrypto/pcrypto.go b/utils/pcrypto/pcrypto.go new file mode 100644 index 0000000..729db90 --- /dev/null +++ b/utils/pcrypto/pcrypto.go @@ -0,0 +1,87 @@ +package pcrypto + +import ( + "bytes" + "compress/gzip" + "crypto/aes" + "crypto/cipher" + "encoding/base64" + "encoding/hex" + "errors" + "fmt" + "io/ioutil" +) + +type Pcrypto struct { + pkey []byte + paes cipher.Block +} + +func (pc *Pcrypto) Init(key []byte) error { + var err error + pc.pkey = PKCS7Padding(key, aes.BlockSize) + pc.paes, err = aes.NewCipher(pc.pkey) + + return err +} + +func (pc *Pcrypto) Encrypto(src []byte) ([]byte, error) { + // aes + src = PKCS7Padding(src, aes.BlockSize) + blockMode := cipher.NewCBCEncrypter(pc.paes, pc.pkey) + crypted := make([]byte, len(src)) + blockMode.CryptBlocks(crypted, src) + + // gzip + var zbuf bytes.Buffer + zwr := gzip.NewWriter(&zbuf) + defer zwr.Close() + zwr.Write(crypted) + zwr.Flush() + + // base64 + return []byte(base64.StdEncoding.EncodeToString(zbuf.Bytes())), nil +} + +func (pc *Pcrypto) Decrypto(str []byte) ([]byte, error) { + // base64 + data, err := base64.StdEncoding.DecodeString(string(str)) + if err != nil { + return nil, err + } + + // gunzip + zbuf := bytes.NewBuffer(data) + zrd, _ := gzip.NewReader(zbuf) + defer zrd.Close() + data, _ = ioutil.ReadAll(zrd) + + // aes + decryptText, err := hex.DecodeString(fmt.Sprintf("%x", data)) + if err != nil { + return nil, err + } + + if len(decryptText)%aes.BlockSize != 0 { + return nil, errors.New("crypto/cipher: ciphertext is not a multiple of the block size") + } + + blockMode := cipher.NewCBCDecrypter(pc.paes, pc.pkey) + + blockMode.CryptBlocks(decryptText, decryptText) + decryptText = PKCS7UnPadding(decryptText) + + return decryptText, nil +} + +func PKCS7Padding(ciphertext []byte, blockSize int) []byte { + padding := blockSize - len(ciphertext)%blockSize + padtext := bytes.Repeat([]byte{byte(padding)}, padding) + return append(ciphertext, padtext...) +} + +func PKCS7UnPadding(origData []byte) []byte { + length := len(origData) + unpadding := int(origData[length-1]) + return origData[:(length - unpadding)] +} diff --git a/utils/pcrypto/pcrypto_test.go b/utils/pcrypto/pcrypto_test.go new file mode 100644 index 0000000..73377e3 --- /dev/null +++ b/utils/pcrypto/pcrypto_test.go @@ -0,0 +1,47 @@ +package pcrypto + +import ( + "crypto/aes" + "fmt" + "testing" +) + +func TestEncrypto(t *testing.T) { + pp := new(Pcrypto) + pp.Init([]byte("Hana")) + res, err := pp.Encrypto([]byte("Just One Test!")) + if err != nil { + t.Error(err) + } + + fmt.Printf("[%x]\n", res) +} + +func TestDecrypto(t *testing.T) { + pp := new(Pcrypto) + pp.Init([]byte("Hana")) + res, err := pp.Encrypto([]byte("Just One Test!")) + if err != nil { + t.Error(err) + } + + res, err = pp.Decrypto(res) + if err != nil { + t.Error(err) + } + + fmt.Printf("[%s]\n", string(res)) +} + +func TestPKCS7Padding(t *testing.T) { + ltt := []byte("Test_PKCS7Padding") + ltt = PKCS7Padding(ltt, aes.BlockSize) + fmt.Printf("[%x]\n", (ltt)) +} + +func TestPKCS7UnPadding(t *testing.T) { + ltt := []byte("Test_PKCS7Padding") + ltt = PKCS7Padding(ltt, aes.BlockSize) + ltt = PKCS7UnPadding(ltt) + fmt.Printf("[%x]\n", ltt) +}