frp/pkg/util/net/websocket.go

69 lines
1.2 KiB
Go
Raw Normal View History

2018-08-03 19:41:54 +08:00
package net
import (
2018-08-10 11:43:08 +08:00
"errors"
2018-08-03 19:41:54 +08:00
"net"
"net/http"
"time"
2018-08-03 19:41:54 +08:00
"golang.org/x/net/websocket"
)
2022-08-29 01:02:53 +08:00
var ErrWebsocketListenerClosed = errors.New("websocket listener closed")
2018-08-10 11:43:08 +08:00
const (
2018-08-10 14:44:14 +08:00
FrpWebsocketPath = "/~!frp"
2018-08-10 11:43:08 +08:00
)
2018-08-03 19:41:54 +08:00
type WebsocketListener struct {
2019-10-12 20:13:12 +08:00
ln net.Listener
acceptCh chan net.Conn
2018-08-10 11:43:08 +08:00
2022-08-29 01:02:53 +08:00
server *http.Server
2018-08-03 19:41:54 +08:00
}
2019-01-31 16:54:46 +08:00
// NewWebsocketListener to handle websocket connections
2018-08-10 11:43:08 +08:00
// ln: tcp listener for websocket connections
func NewWebsocketListener(ln net.Listener) (wl *WebsocketListener) {
wl = &WebsocketListener{
2019-10-12 20:13:12 +08:00
acceptCh: make(chan net.Conn),
2018-08-03 19:41:54 +08:00
}
2018-08-10 11:43:08 +08:00
muxer := http.NewServeMux()
muxer.Handle(FrpWebsocketPath, websocket.Handler(func(c *websocket.Conn) {
notifyCh := make(chan struct{})
conn := WrapCloseNotifyConn(c, func() {
close(notifyCh)
})
2019-10-12 20:13:12 +08:00
wl.acceptCh <- conn
2018-08-10 11:43:08 +08:00
<-notifyCh
2018-08-03 19:41:54 +08:00
}))
2018-08-10 11:43:08 +08:00
wl.server = &http.Server{
Addr: ln.Addr().String(),
Handler: muxer,
ReadHeaderTimeout: 60 * time.Second,
2018-08-03 19:41:54 +08:00
}
2018-08-10 11:43:08 +08:00
2022-08-29 01:02:53 +08:00
go func() {
_ = wl.server.Serve(ln)
}()
2018-08-03 19:41:54 +08:00
return
}
2019-10-12 20:13:12 +08:00
func (p *WebsocketListener) Accept() (net.Conn, error) {
c, ok := <-p.acceptCh
2018-08-10 11:43:08 +08:00
if !ok {
return nil, ErrWebsocketListenerClosed
}
2018-08-03 19:41:54 +08:00
return c, nil
}
func (p *WebsocketListener) Close() error {
2018-08-10 11:43:08 +08:00
return p.server.Close()
2018-08-03 19:41:54 +08:00
}
2019-10-12 20:13:12 +08:00
func (p *WebsocketListener) Addr() net.Addr {
return p.ln.Addr()
}