// Copyright 2018 fatedier, fatedier@gmail.com // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package mux import ( "fmt" "io" "net" "sort" "sync" "time" "github.com/fatedier/golib/errors" gnet "github.com/fatedier/golib/net" ) const ( // DefaultTimeout is the default length of time to wait for bytes we need. DefaultTimeout = 10 * time.Second ) type Mux struct { ln net.Listener defaultLn *listener // sorted by priority lns []*listener maxNeedBytesNum uint32 mu sync.RWMutex } func NewMux(ln net.Listener) (mux *Mux) { mux = &Mux{ ln: ln, lns: make([]*listener, 0), } return } // priority func (mux *Mux) Listen(priority int, needBytesNum uint32, fn MatchFunc) net.Listener { ln := &listener{ c: make(chan net.Conn), mux: mux, priority: priority, needBytesNum: needBytesNum, matchFn: fn, } mux.mu.Lock() defer mux.mu.Unlock() if needBytesNum > mux.maxNeedBytesNum { mux.maxNeedBytesNum = needBytesNum } newlns := append(mux.copyLns(), ln) sort.Slice(newlns, func(i, j int) bool { if newlns[i].priority == newlns[j].priority { return newlns[i].needBytesNum < newlns[j].needBytesNum } return newlns[i].priority < newlns[j].priority }) mux.lns = newlns return ln } func (mux *Mux) ListenHttp(priority int) net.Listener { return mux.Listen(priority, HttpNeedBytesNum, HttpMatchFunc) } func (mux *Mux) ListenHttps(priority int) net.Listener { return mux.Listen(priority, HttpsNeedBytesNum, HttpsMatchFunc) } func (mux *Mux) DefaultListener() net.Listener { mux.mu.Lock() defer mux.mu.Unlock() if mux.defaultLn == nil { mux.defaultLn = &listener{ c: make(chan net.Conn), mux: mux, } } return mux.defaultLn } func (mux *Mux) release(ln *listener) bool { result := false mux.mu.Lock() defer mux.mu.Unlock() lns := mux.copyLns() for i, l := range lns { if l == ln { lns = append(lns[:i], lns[i+1:]...) result = true break } } mux.lns = lns return result } func (mux *Mux) copyLns() []*listener { lns := make([]*listener, 0, len(mux.lns)) for _, l := range mux.lns { lns = append(lns, l) } return lns } // Serve handles connections from ln and multiplexes then across registered listeners. func (mux *Mux) Serve() error { for { // Wait for the next connection. // If it returns a temporary error then simply retry. // If it returns any other error then exit immediately. conn, err := mux.ln.Accept() if err, ok := err.(interface { Temporary() bool }); ok && err.Temporary() { continue } if err != nil { return err } go mux.handleConn(conn) } } func (mux *Mux) handleConn(conn net.Conn) { mux.mu.RLock() maxNeedBytesNum := mux.maxNeedBytesNum lns := mux.lns defaultLn := mux.defaultLn mux.mu.RUnlock() sharedConn, rd := gnet.NewSharedConnSize(conn, int(maxNeedBytesNum)) data := make([]byte, maxNeedBytesNum) conn.SetReadDeadline(time.Now().Add(DefaultTimeout)) _, err := io.ReadFull(rd, data) if err != nil { conn.Close() return } conn.SetReadDeadline(time.Time{}) for _, ln := range lns { if match := ln.matchFn(data); match { err = errors.PanicToError(func() { ln.c <- sharedConn }) if err != nil { conn.Close() } return } } // No match listeners if defaultLn != nil { err = errors.PanicToError(func() { defaultLn.c <- sharedConn }) if err != nil { conn.Close() } return } // No listeners for this connection, close it. conn.Close() return } type listener struct { mux *Mux priority int needBytesNum uint32 matchFn MatchFunc c chan net.Conn mu sync.RWMutex } // Accept waits for and returns the next connection to the listener. func (ln *listener) Accept() (net.Conn, error) { conn, ok := <-ln.c if !ok { return nil, fmt.Errorf("network connection closed") } return conn, nil } // Close removes this listener from the parent mux and closes the channel. func (ln *listener) Close() error { if ok := ln.mux.release(ln); ok { // Close done to signal to any RLock holders to release their lock. close(ln.c) } return nil } func (ln *listener) Addr() net.Addr { if ln.mux == nil { return nil } ln.mux.mu.RLock() defer ln.mux.mu.RUnlock() if ln.mux.ln == nil { return nil } return ln.mux.ln.Addr() }