add real ip test

This commit is contained in:
fatedier 2021-06-21 19:27:26 +08:00
parent fe4e9b55f3
commit a51e221db3
6 changed files with 176 additions and 22 deletions

View File

@ -1,6 +1,6 @@
version: 2
jobs:
go1.16:
go-version-latest:
docker:
- image: circleci/golang:1.16-node
working_directory: /go/src/github.com/fatedier/frp
@ -8,7 +8,7 @@ jobs:
- checkout
- run: make
- run: make alltest
go1.15:
go-version-last:
docker:
- image: circleci/golang:1.15-node
working_directory: /go/src/github.com/fatedier/frp
@ -21,5 +21,5 @@ workflows:
version: 2
build_and_test:
jobs:
- go1.16
- go1.15
- go-version-latest
- go-version-last

View File

@ -1,20 +1,152 @@
package features
import (
"bufio"
"fmt"
"net"
"net/http"
"github.com/fatedier/frp/pkg/util/log"
"github.com/fatedier/frp/test/e2e/framework"
"github.com/fatedier/frp/test/e2e/framework/consts"
"github.com/fatedier/frp/test/e2e/mock/server/httpserver"
"github.com/fatedier/frp/test/e2e/mock/server/streamserver"
"github.com/fatedier/frp/test/e2e/pkg/request"
"github.com/fatedier/frp/test/e2e/pkg/rpc"
. "github.com/onsi/ginkgo"
pp "github.com/pires/go-proxyproto"
)
var _ = Describe("[Feature: Real IP]", func() {
f := framework.NewDefaultFramework()
It("HTTP X-Forwarded-For", func() {
// TODO
_ = f
vhostHTTPPort := f.AllocPort()
serverConf := consts.DefaultServerConfig + fmt.Sprintf(`
vhost_http_port = %d
`, vhostHTTPPort)
localPort := f.AllocPort()
localServer := httpserver.New(
httpserver.WithBindPort(localPort),
httpserver.WithHandler(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Write([]byte(req.Header.Get("X-Forwarded-For")))
})),
)
f.RunServer("", localServer)
clientConf := consts.DefaultClientConfig
clientConf += fmt.Sprintf(`
[test]
type = http
local_port = %d
custom_domains = normal.example.com
`, localPort)
f.RunProcesses([]string{serverConf}, []string{clientConf})
framework.NewRequestExpect(f).Port(vhostHTTPPort).
RequestModify(func(r *request.Request) {
r.HTTP().HTTPHost("normal.example.com")
}).
ExpectResp([]byte("127.0.0.1")).
Ensure()
})
It("Proxy Protocol", func() {
// TODO
Describe("Proxy Protocol", func() {
It("TCP", func() {
serverConf := consts.DefaultServerConfig
clientConf := consts.DefaultClientConfig
localPort := f.AllocPort()
localServer := streamserver.New(streamserver.TCP, streamserver.WithBindPort(localPort),
streamserver.WithCustomHandler(func(c net.Conn) {
defer c.Close()
rd := bufio.NewReader(c)
ppHeader, err := pp.Read(rd)
if err != nil {
log.Error("read proxy protocol error: %v", err)
return
}
for {
if _, err := rpc.ReadBytes(rd); err != nil {
return
}
buf := []byte(ppHeader.SourceAddr.String())
rpc.WriteBytes(c, buf)
}
}))
f.RunServer("", localServer)
remotePort := f.AllocPort()
clientConf += fmt.Sprintf(`
[tcp]
type = tcp
local_port = %d
remote_port = %d
proxy_protocol_version = v2
`, localPort, remotePort)
f.RunProcesses([]string{serverConf}, []string{clientConf})
framework.NewRequestExpect(f).Port(remotePort).Ensure(func(resp *request.Response) bool {
log.Trace("ProxyProtocol get SourceAddr: %s", string(resp.Content))
addr, err := net.ResolveTCPAddr("tcp", string(resp.Content))
if err != nil {
return false
}
if addr.IP.String() != "127.0.0.1" {
return false
}
return true
})
})
It("HTTP", func() {
vhostHTTPPort := f.AllocPort()
serverConf := consts.DefaultServerConfig + fmt.Sprintf(`
vhost_http_port = %d
`, vhostHTTPPort)
clientConf := consts.DefaultClientConfig
localPort := f.AllocPort()
var srcAddrRecord string
localServer := streamserver.New(streamserver.TCP, streamserver.WithBindPort(localPort),
streamserver.WithCustomHandler(func(c net.Conn) {
defer c.Close()
rd := bufio.NewReader(c)
ppHeader, err := pp.Read(rd)
if err != nil {
log.Error("read proxy protocol error: %v", err)
return
}
srcAddrRecord = ppHeader.SourceAddr.String()
}))
f.RunServer("", localServer)
clientConf += fmt.Sprintf(`
[test]
type = http
local_port = %d
custom_domains = normal.example.com
proxy_protocol_version = v2
`, localPort)
f.RunProcesses([]string{serverConf}, []string{clientConf})
framework.NewRequestExpect(f).Port(vhostHTTPPort).RequestModify(func(r *request.Request) {
r.HTTP().HTTPHost("normal.example.com")
}).Ensure(framework.ExpectResponseCode(404))
log.Trace("ProxyProtocol get SourceAddr: %s", srcAddrRecord)
addr, err := net.ResolveTCPAddr("tcp", srcAddrRecord)
framework.ExpectNoError(err, srcAddrRecord)
framework.ExpectEqualValues("127.0.0.1", addr.IP.String())
})
})
})

View File

@ -17,7 +17,11 @@ func SpecifiedHTTPBodyHandler(body []byte) http.HandlerFunc {
func ExpectResponseCode(code int) EnsureFunc {
return func(resp *request.Response) bool {
return resp.Code == code
if resp.Code == code {
return true
}
flog.Warn("Expect code %d, but got %d", code, resp.Code)
return false
}
}

View File

@ -1,7 +1,9 @@
package streamserver
import (
"bufio"
"fmt"
"io"
"net"
libnet "github.com/fatedier/frp/pkg/util/net"
@ -22,6 +24,8 @@ type Server struct {
bindPort int
respContent []byte
handler func(net.Conn)
l net.Listener
}
@ -32,6 +36,7 @@ func New(netType Type, options ...Option) *Server {
netType: netType,
bindAddr: "127.0.0.1",
}
s.handler = s.handle
for _, option := range options {
s = option(s)
@ -60,6 +65,13 @@ func WithRespContent(content []byte) Option {
}
}
func WithCustomHandler(handler func(net.Conn)) Option {
return func(s *Server) *Server {
s.handler = handler
return s
}
}
func (s *Server) Run() error {
if err := s.initListener(); err != nil {
return err
@ -71,7 +83,7 @@ func (s *Server) Run() error {
if err != nil {
return
}
go s.handle(c)
go s.handler(c)
}
}()
return nil
@ -101,8 +113,12 @@ func (s *Server) initListener() (err error) {
func (s *Server) handle(c net.Conn) {
defer c.Close()
var reader io.Reader = c
if s.netType == UDP {
reader = bufio.NewReader(c)
}
for {
buf, err := rpc.ReadBytes(c)
buf, err := rpc.ReadBytes(reader)
if err != nil {
return
}

View File

@ -1,6 +1,7 @@
package request
import (
"bufio"
"bytes"
"fmt"
"io"
@ -120,7 +121,7 @@ func (r *Request) Do() (*Response, error) {
addr := net.JoinHostPort(r.addr, strconv.Itoa(r.port))
// for protocol http
if r.protocol == "http" {
return sendHTTPRequest(r.method, fmt.Sprintf("http://%s%s", addr, r.path),
return r.sendHTTPRequest(r.method, fmt.Sprintf("http://%s%s", addr, r.path),
r.host, r.headers, r.proxyURL, r.body)
}
@ -151,7 +152,7 @@ func (r *Request) Do() (*Response, error) {
if r.timeout > 0 {
conn.SetDeadline(time.Now().Add(r.timeout))
}
buf, err := sendRequestByConn(conn, r.body)
buf, err := r.sendRequestByConn(conn, r.body)
if err != nil {
return nil, err
}
@ -164,7 +165,7 @@ type Response struct {
Content []byte
}
func sendHTTPRequest(method, urlstr string, host string, headers map[string]string, proxy string, body []byte) (*Response, error) {
func (r *Request) sendHTTPRequest(method, urlstr string, host string, headers map[string]string, proxy string, body []byte) (*Response, error) {
var inBody io.Reader
if len(body) != 0 {
inBody = bytes.NewReader(body)
@ -210,13 +211,18 @@ func sendHTTPRequest(method, urlstr string, host string, headers map[string]stri
return ret, nil
}
func sendRequestByConn(c net.Conn, content []byte) ([]byte, error) {
func (r *Request) sendRequestByConn(c net.Conn, content []byte) ([]byte, error) {
_, err := rpc.WriteBytes(c, content)
if err != nil {
return nil, fmt.Errorf("write error: %v", err)
}
buf, err := rpc.ReadBytes(c)
var reader io.Reader = c
if r.protocol == "udp" {
reader = bufio.NewReader(c)
}
buf, err := rpc.ReadBytes(reader)
if err != nil {
return nil, fmt.Errorf("read error: %v", err)
}

View File

@ -1,7 +1,6 @@
package rpc
import (
"bufio"
"bytes"
"encoding/binary"
"errors"
@ -16,15 +15,12 @@ func WriteBytes(w io.Writer, buf []byte) (int, error) {
}
func ReadBytes(r io.Reader) ([]byte, error) {
// To compatible with UDP connection, use bufio reader here to avoid lost conent.
rd := bufio.NewReader(r)
var length int64
if err := binary.Read(rd, binary.BigEndian, &length); err != nil {
if err := binary.Read(r, binary.BigEndian, &length); err != nil {
return nil, err
}
buffer := make([]byte, length)
n, err := io.ReadFull(rd, buffer)
n, err := io.ReadFull(r, buffer)
if err != nil {
return nil, err
}