frp/utils/net/websocket.go

104 lines
1.9 KiB
Go
Raw Normal View History

2018-08-03 19:41:54 +08:00
package net
import (
2018-08-10 11:43:08 +08:00
"errors"
2018-08-03 19:41:54 +08:00
"fmt"
"net"
"net/http"
"net/url"
"time"
"golang.org/x/net/websocket"
)
2018-08-10 11:43:08 +08:00
var (
ErrWebsocketListenerClosed = errors.New("websocket listener closed")
)
const (
2018-08-10 14:44:14 +08:00
FrpWebsocketPath = "/~!frp"
2018-08-10 11:43:08 +08:00
)
2018-08-03 19:41:54 +08:00
type WebsocketListener struct {
2019-10-12 20:13:12 +08:00
ln net.Listener
acceptCh chan net.Conn
2018-08-10 11:43:08 +08:00
2018-08-03 19:41:54 +08:00
server *http.Server
httpMutex *http.ServeMux
}
2019-01-31 16:54:46 +08:00
// NewWebsocketListener to handle websocket connections
2018-08-10 11:43:08 +08:00
// ln: tcp listener for websocket connections
func NewWebsocketListener(ln net.Listener) (wl *WebsocketListener) {
wl = &WebsocketListener{
2019-10-12 20:13:12 +08:00
acceptCh: make(chan net.Conn),
2018-08-03 19:41:54 +08:00
}
2018-08-10 11:43:08 +08:00
muxer := http.NewServeMux()
muxer.Handle(FrpWebsocketPath, websocket.Handler(func(c *websocket.Conn) {
notifyCh := make(chan struct{})
conn := WrapCloseNotifyConn(c, func() {
close(notifyCh)
})
2019-10-12 20:13:12 +08:00
wl.acceptCh <- conn
2018-08-10 11:43:08 +08:00
<-notifyCh
2018-08-03 19:41:54 +08:00
}))
2018-08-10 11:43:08 +08:00
wl.server = &http.Server{
Addr: ln.Addr().String(),
Handler: muxer,
2018-08-03 19:41:54 +08:00
}
2018-08-10 11:43:08 +08:00
go wl.server.Serve(ln)
2018-08-03 19:41:54 +08:00
return
}
2018-08-10 11:43:08 +08:00
func ListenWebsocket(bindAddr string, bindPort int) (*WebsocketListener, error) {
tcpLn, err := net.Listen("tcp", fmt.Sprintf("%s:%d", bindAddr, bindPort))
2018-08-03 19:41:54 +08:00
if err != nil {
2018-08-10 11:43:08 +08:00
return nil, err
2018-08-03 19:41:54 +08:00
}
2018-08-10 11:43:08 +08:00
l := NewWebsocketListener(tcpLn)
return l, nil
2018-08-03 19:41:54 +08:00
}
2019-10-12 20:13:12 +08:00
func (p *WebsocketListener) Accept() (net.Conn, error) {
c, ok := <-p.acceptCh
2018-08-10 11:43:08 +08:00
if !ok {
return nil, ErrWebsocketListenerClosed
}
2018-08-03 19:41:54 +08:00
return c, nil
}
func (p *WebsocketListener) Close() error {
2018-08-10 11:43:08 +08:00
return p.server.Close()
2018-08-03 19:41:54 +08:00
}
2019-10-12 20:13:12 +08:00
func (p *WebsocketListener) Addr() net.Addr {
return p.ln.Addr()
}
2018-08-10 11:43:08 +08:00
// addr: domain:port
2019-10-12 20:13:12 +08:00
func ConnectWebsocketServer(addr string) (net.Conn, error) {
2018-08-10 11:43:08 +08:00
addr = "ws://" + addr + FrpWebsocketPath
2018-08-03 19:41:54 +08:00
uri, err := url.Parse(addr)
if err != nil {
2018-08-10 11:43:08 +08:00
return nil, err
2018-08-03 19:41:54 +08:00
}
origin := "http://" + uri.Host
cfg, err := websocket.NewConfig(addr, origin)
if err != nil {
2018-08-10 11:43:08 +08:00
return nil, err
2018-08-03 19:41:54 +08:00
}
cfg.Dialer = &net.Dialer{
2018-08-10 11:43:08 +08:00
Timeout: 10 * time.Second,
2018-08-03 19:41:54 +08:00
}
conn, err := websocket.DialConfig(cfg)
if err != nil {
2018-08-10 11:43:08 +08:00
return nil, err
2018-08-03 19:41:54 +08:00
}
2019-10-12 20:13:12 +08:00
return conn, nil
2018-08-03 19:41:54 +08:00
}