diff --git a/client/control.go b/client/control.go index 9342238..ba37a2d 100644 --- a/client/control.go +++ b/client/control.go @@ -21,8 +21,6 @@ import ( "sync" "time" - "github.com/xtaci/smux" - "github.com/fatedier/frp/g" "github.com/fatedier/frp/models/config" "github.com/fatedier/frp/models/msg" @@ -32,6 +30,8 @@ import ( "github.com/fatedier/frp/utils/shutdown" "github.com/fatedier/frp/utils/util" "github.com/fatedier/frp/utils/version" + + fmux "github.com/hashicorp/yamux" ) const ( @@ -51,7 +51,7 @@ type Control struct { conn frpNet.Conn // tcp stream multiplexing, if enabled - session *smux.Session + session *fmux.Session // put a message in this channel to send it over control connection to server sendCh chan (msg.Message) @@ -198,7 +198,7 @@ func (ctl *Control) login() (err error) { }() if g.GlbClientCfg.TcpMux { - session, errRet := smux.Client(conn, nil) + session, errRet := fmux.Client(conn, nil) if errRet != nil { return errRet } diff --git a/glide.lock b/glide.lock index 84313f0..48015ce 100644 --- a/glide.lock +++ b/glide.lock @@ -1,5 +1,5 @@ -hash: 367ad1f2515b51db9d04d5620fd88843fb6faabf303fe3103b896ef7a3f5a126 -updated: 2018-04-23T02:33:52.913905+08:00 +hash: e2a62cbc49d9da8ff95682f5c0b7731a7047afdd139acddb691c51ea98f726e1 +updated: 2018-04-25T02:41:38.15698+08:00 imports: - name: github.com/armon/go-socks5 version: e75332964ef517daa070d7c38a9466a0d687e0a5 @@ -17,6 +17,8 @@ imports: version: 5979233c5d6225d4a8e438cdd0b411888449ddab - name: github.com/gorilla/websocket version: ea4d1f681babbce9545c9c5f3d5194a789c89f5b +- name: github.com/hashicorp/yamux + version: 2658be15c5f05e76244154714161f17e3e77de2e - name: github.com/inconshreveable/mousetrap version: 76626ae9c91c4f2a10f34cad8ce83ea42c93bb75 - name: github.com/julienschmidt/httprouter @@ -38,7 +40,7 @@ imports: - name: github.com/rodaine/table version: 212a2ad1c462ed4d5b5511ea2b480a573281dbbd - name: github.com/spf13/cobra - version: 615425954c3b0d9485a7027d4d451fdcdfdee84e + version: a1f051bc3eba734da4772d60e2d677f47cf93ef4 - name: github.com/spf13/pflag version: 583c0c0531f06d5278b7d917446061adc344b5cd - name: github.com/stretchr/testify @@ -57,8 +59,6 @@ imports: - sm4 - name: github.com/vaughan0/go-ini version: a98ad7ee00ec53921f08832bc06ecf7fd600e6a1 -- name: github.com/xtaci/smux - version: 2de5471dfcbc029f5fe1392b83fe784127c4943e - name: golang.org/x/crypto version: e1a4589e7d3ea14a3352255d04b6f1a418845e5e subpackages: diff --git a/glide.yaml b/glide.yaml index 2cc1535..2765c56 100644 --- a/glide.yaml +++ b/glide.yaml @@ -46,8 +46,6 @@ import: - sm4 - package: github.com/vaughan0/go-ini version: a98ad7ee00ec53921f08832bc06ecf7fd600e6a1 -- package: github.com/xtaci/smux - version: 2de5471dfcbc029f5fe1392b83fe784127c4943e - package: golang.org/x/crypto version: e1a4589e7d3ea14a3352255d04b6f1a418845e5e subpackages: @@ -71,3 +69,6 @@ import: version: v1.0.0 - package: github.com/gorilla/websocket version: v1.2.0 +- package: github.com/hashicorp/yamux +- package: github.com/spf13/cobra + version: v0.0.2 diff --git a/server/service.go b/server/service.go index 48f49e0..dc6f4ff 100644 --- a/server/service.go +++ b/server/service.go @@ -29,7 +29,7 @@ import ( "github.com/fatedier/frp/utils/version" "github.com/fatedier/frp/utils/vhost" - "github.com/xtaci/smux" + fmux "github.com/hashicorp/yamux" ) const ( @@ -234,7 +234,7 @@ func (svr *Service) HandleListener(l frpNet.Listener) { } if g.GlbServerCfg.TcpMux { - session, err := smux.Server(frpConn, nil) + session, err := fmux.Server(frpConn, nil) if err != nil { log.Warn("Failed to create mux connection: %v", err) frpConn.Close() diff --git a/vendor/github.com/xtaci/smux/.gitignore b/vendor/github.com/hashicorp/yamux/.gitignore similarity index 97% rename from vendor/github.com/xtaci/smux/.gitignore rename to vendor/github.com/hashicorp/yamux/.gitignore index daf913b..8365624 100644 --- a/vendor/github.com/xtaci/smux/.gitignore +++ b/vendor/github.com/hashicorp/yamux/.gitignore @@ -21,4 +21,3 @@ _testmain.go *.exe *.test -*.prof diff --git a/vendor/github.com/hashicorp/yamux/LICENSE b/vendor/github.com/hashicorp/yamux/LICENSE new file mode 100644 index 0000000..f0e5c79 --- /dev/null +++ b/vendor/github.com/hashicorp/yamux/LICENSE @@ -0,0 +1,362 @@ +Mozilla Public License, version 2.0 + +1. Definitions + +1.1. "Contributor" + + means each individual or legal entity that creates, contributes to the + creation of, or owns Covered Software. + +1.2. "Contributor Version" + + means the combination of the Contributions of others (if any) used by a + Contributor and that particular Contributor's Contribution. + +1.3. "Contribution" + + means Covered Software of a particular Contributor. + +1.4. "Covered Software" + + means Source Code Form to which the initial Contributor has attached the + notice in Exhibit A, the Executable Form of such Source Code Form, and + Modifications of such Source Code Form, in each case including portions + thereof. + +1.5. "Incompatible With Secondary Licenses" + means + + a. that the initial Contributor has attached the notice described in + Exhibit B to the Covered Software; or + + b. that the Covered Software was made available under the terms of + version 1.1 or earlier of the License, but not also under the terms of + a Secondary License. + +1.6. "Executable Form" + + means any form of the work other than Source Code Form. + +1.7. "Larger Work" + + means a work that combines Covered Software with other material, in a + separate file or files, that is not Covered Software. + +1.8. "License" + + means this document. + +1.9. "Licensable" + + means having the right to grant, to the maximum extent possible, whether + at the time of the initial grant or subsequently, any and all of the + rights conveyed by this License. + +1.10. "Modifications" + + means any of the following: + + a. any file in Source Code Form that results from an addition to, + deletion from, or modification of the contents of Covered Software; or + + b. any new file in Source Code Form that contains any Covered Software. + +1.11. "Patent Claims" of a Contributor + + means any patent claim(s), including without limitation, method, + process, and apparatus claims, in any patent Licensable by such + Contributor that would be infringed, but for the grant of the License, + by the making, using, selling, offering for sale, having made, import, + or transfer of either its Contributions or its Contributor Version. + +1.12. "Secondary License" + + means either the GNU General Public License, Version 2.0, the GNU Lesser + General Public License, Version 2.1, the GNU Affero General Public + License, Version 3.0, or any later versions of those licenses. + +1.13. "Source Code Form" + + means the form of the work preferred for making modifications. + +1.14. "You" (or "Your") + + means an individual or a legal entity exercising rights under this + License. For legal entities, "You" includes any entity that controls, is + controlled by, or is under common control with You. For purposes of this + definition, "control" means (a) the power, direct or indirect, to cause + the direction or management of such entity, whether by contract or + otherwise, or (b) ownership of more than fifty percent (50%) of the + outstanding shares or beneficial ownership of such entity. + + +2. License Grants and Conditions + +2.1. Grants + + Each Contributor hereby grants You a world-wide, royalty-free, + non-exclusive license: + + a. under intellectual property rights (other than patent or trademark) + Licensable by such Contributor to use, reproduce, make available, + modify, display, perform, distribute, and otherwise exploit its + Contributions, either on an unmodified basis, with Modifications, or + as part of a Larger Work; and + + b. under Patent Claims of such Contributor to make, use, sell, offer for + sale, have made, import, and otherwise transfer either its + Contributions or its Contributor Version. + +2.2. Effective Date + + The licenses granted in Section 2.1 with respect to any Contribution + become effective for each Contribution on the date the Contributor first + distributes such Contribution. + +2.3. Limitations on Grant Scope + + The licenses granted in this Section 2 are the only rights granted under + this License. No additional rights or licenses will be implied from the + distribution or licensing of Covered Software under this License. + Notwithstanding Section 2.1(b) above, no patent license is granted by a + Contributor: + + a. for any code that a Contributor has removed from Covered Software; or + + b. for infringements caused by: (i) Your and any other third party's + modifications of Covered Software, or (ii) the combination of its + Contributions with other software (except as part of its Contributor + Version); or + + c. under Patent Claims infringed by Covered Software in the absence of + its Contributions. + + This License does not grant any rights in the trademarks, service marks, + or logos of any Contributor (except as may be necessary to comply with + the notice requirements in Section 3.4). + +2.4. Subsequent Licenses + + No Contributor makes additional grants as a result of Your choice to + distribute the Covered Software under a subsequent version of this + License (see Section 10.2) or under the terms of a Secondary License (if + permitted under the terms of Section 3.3). + +2.5. Representation + + Each Contributor represents that the Contributor believes its + Contributions are its original creation(s) or it has sufficient rights to + grant the rights to its Contributions conveyed by this License. + +2.6. Fair Use + + This License is not intended to limit any rights You have under + applicable copyright doctrines of fair use, fair dealing, or other + equivalents. + +2.7. Conditions + + Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted in + Section 2.1. + + +3. Responsibilities + +3.1. Distribution of Source Form + + All distribution of Covered Software in Source Code Form, including any + Modifications that You create or to which You contribute, must be under + the terms of this License. You must inform recipients that the Source + Code Form of the Covered Software is governed by the terms of this + License, and how they can obtain a copy of this License. You may not + attempt to alter or restrict the recipients' rights in the Source Code + Form. + +3.2. Distribution of Executable Form + + If You distribute Covered Software in Executable Form then: + + a. such Covered Software must also be made available in Source Code Form, + as described in Section 3.1, and You must inform recipients of the + Executable Form how they can obtain a copy of such Source Code Form by + reasonable means in a timely manner, at a charge no more than the cost + of distribution to the recipient; and + + b. You may distribute such Executable Form under the terms of this + License, or sublicense it under different terms, provided that the + license for the Executable Form does not attempt to limit or alter the + recipients' rights in the Source Code Form under this License. + +3.3. Distribution of a Larger Work + + You may create and distribute a Larger Work under terms of Your choice, + provided that You also comply with the requirements of this License for + the Covered Software. If the Larger Work is a combination of Covered + Software with a work governed by one or more Secondary Licenses, and the + Covered Software is not Incompatible With Secondary Licenses, this + License permits You to additionally distribute such Covered Software + under the terms of such Secondary License(s), so that the recipient of + the Larger Work may, at their option, further distribute the Covered + Software under the terms of either this License or such Secondary + License(s). + +3.4. Notices + + You may not remove or alter the substance of any license notices + (including copyright notices, patent notices, disclaimers of warranty, or + limitations of liability) contained within the Source Code Form of the + Covered Software, except that You may alter any license notices to the + extent required to remedy known factual inaccuracies. + +3.5. Application of Additional Terms + + You may choose to offer, and to charge a fee for, warranty, support, + indemnity or liability obligations to one or more recipients of Covered + Software. However, You may do so only on Your own behalf, and not on + behalf of any Contributor. You must make it absolutely clear that any + such warranty, support, indemnity, or liability obligation is offered by + You alone, and You hereby agree to indemnify every Contributor for any + liability incurred by such Contributor as a result of warranty, support, + indemnity or liability terms You offer. You may include additional + disclaimers of warranty and limitations of liability specific to any + jurisdiction. + +4. Inability to Comply Due to Statute or Regulation + + If it is impossible for You to comply with any of the terms of this License + with respect to some or all of the Covered Software due to statute, + judicial order, or regulation then You must: (a) comply with the terms of + this License to the maximum extent possible; and (b) describe the + limitations and the code they affect. Such description must be placed in a + text file included with all distributions of the Covered Software under + this License. Except to the extent prohibited by statute or regulation, + such description must be sufficiently detailed for a recipient of ordinary + skill to be able to understand it. + +5. Termination + +5.1. The rights granted under this License will terminate automatically if You + fail to comply with any of its terms. However, if You become compliant, + then the rights granted under this License from a particular Contributor + are reinstated (a) provisionally, unless and until such Contributor + explicitly and finally terminates Your grants, and (b) on an ongoing + basis, if such Contributor fails to notify You of the non-compliance by + some reasonable means prior to 60 days after You have come back into + compliance. Moreover, Your grants from a particular Contributor are + reinstated on an ongoing basis if such Contributor notifies You of the + non-compliance by some reasonable means, this is the first time You have + received notice of non-compliance with this License from such + Contributor, and You become compliant prior to 30 days after Your receipt + of the notice. + +5.2. If You initiate litigation against any entity by asserting a patent + infringement claim (excluding declaratory judgment actions, + counter-claims, and cross-claims) alleging that a Contributor Version + directly or indirectly infringes any patent, then the rights granted to + You by any and all Contributors for the Covered Software under Section + 2.1 of this License shall terminate. + +5.3. In the event of termination under Sections 5.1 or 5.2 above, all end user + license agreements (excluding distributors and resellers) which have been + validly granted by You or Your distributors under this License prior to + termination shall survive termination. + +6. Disclaimer of Warranty + + Covered Software is provided under this License on an "as is" basis, + without warranty of any kind, either expressed, implied, or statutory, + including, without limitation, warranties that the Covered Software is free + of defects, merchantable, fit for a particular purpose or non-infringing. + The entire risk as to the quality and performance of the Covered Software + is with You. Should any Covered Software prove defective in any respect, + You (not any Contributor) assume the cost of any necessary servicing, + repair, or correction. This disclaimer of warranty constitutes an essential + part of this License. No use of any Covered Software is authorized under + this License except under this disclaimer. + +7. Limitation of Liability + + Under no circumstances and under no legal theory, whether tort (including + negligence), contract, or otherwise, shall any Contributor, or anyone who + distributes Covered Software as permitted above, be liable to You for any + direct, indirect, special, incidental, or consequential damages of any + character including, without limitation, damages for lost profits, loss of + goodwill, work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses, even if such party shall have been + informed of the possibility of such damages. This limitation of liability + shall not apply to liability for death or personal injury resulting from + such party's negligence to the extent applicable law prohibits such + limitation. Some jurisdictions do not allow the exclusion or limitation of + incidental or consequential damages, so this exclusion and limitation may + not apply to You. + +8. Litigation + + Any litigation relating to this License may be brought only in the courts + of a jurisdiction where the defendant maintains its principal place of + business and such litigation shall be governed by laws of that + jurisdiction, without reference to its conflict-of-law provisions. Nothing + in this Section shall prevent a party's ability to bring cross-claims or + counter-claims. + +9. Miscellaneous + + This License represents the complete agreement concerning the subject + matter hereof. If any provision of this License is held to be + unenforceable, such provision shall be reformed only to the extent + necessary to make it enforceable. Any law or regulation which provides that + the language of a contract shall be construed against the drafter shall not + be used to construe this License against a Contributor. + + +10. Versions of the License + +10.1. New Versions + + Mozilla Foundation is the license steward. Except as provided in Section + 10.3, no one other than the license steward has the right to modify or + publish new versions of this License. Each version will be given a + distinguishing version number. + +10.2. Effect of New Versions + + You may distribute the Covered Software under the terms of the version + of the License under which You originally received the Covered Software, + or under the terms of any subsequent version published by the license + steward. + +10.3. Modified Versions + + If you create software not governed by this License, and you want to + create a new license for such software, you may create and use a + modified version of this License if you rename the license and remove + any references to the name of the license steward (except to note that + such modified license differs from this License). + +10.4. Distributing Source Code Form that is Incompatible With Secondary + Licenses If You choose to distribute Source Code Form that is + Incompatible With Secondary Licenses under the terms of this version of + the License, the notice described in Exhibit B of this License must be + attached. + +Exhibit A - Source Code Form License Notice + + This Source Code Form is subject to the + terms of the Mozilla Public License, v. + 2.0. If a copy of the MPL was not + distributed with this file, You can + obtain one at + http://mozilla.org/MPL/2.0/. + +If it is not possible or desirable to put the notice in a particular file, +then You may include the notice in a location (such as a LICENSE file in a +relevant directory) where a recipient would be likely to look for such a +notice. + +You may add additional accurate notices of copyright ownership. + +Exhibit B - "Incompatible With Secondary Licenses" Notice + + This Source Code Form is "Incompatible + With Secondary Licenses", as defined by + the Mozilla Public License, v. 2.0. \ No newline at end of file diff --git a/vendor/github.com/hashicorp/yamux/README.md b/vendor/github.com/hashicorp/yamux/README.md new file mode 100644 index 0000000..d4db7fc --- /dev/null +++ b/vendor/github.com/hashicorp/yamux/README.md @@ -0,0 +1,86 @@ +# Yamux + +Yamux (Yet another Multiplexer) is a multiplexing library for Golang. +It relies on an underlying connection to provide reliability +and ordering, such as TCP or Unix domain sockets, and provides +stream-oriented multiplexing. It is inspired by SPDY but is not +interoperable with it. + +Yamux features include: + +* Bi-directional streams + * Streams can be opened by either client or server + * Useful for NAT traversal + * Server-side push support +* Flow control + * Avoid starvation + * Back-pressure to prevent overwhelming a receiver +* Keep Alives + * Enables persistent connections over a load balancer +* Efficient + * Enables thousands of logical streams with low overhead + +## Documentation + +For complete documentation, see the associated [Godoc](http://godoc.org/github.com/hashicorp/yamux). + +## Specification + +The full specification for Yamux is provided in the `spec.md` file. +It can be used as a guide to implementors of interoperable libraries. + +## Usage + +Using Yamux is remarkably simple: + +```go + +func client() { + // Get a TCP connection + conn, err := net.Dial(...) + if err != nil { + panic(err) + } + + // Setup client side of yamux + session, err := yamux.Client(conn, nil) + if err != nil { + panic(err) + } + + // Open a new stream + stream, err := session.Open() + if err != nil { + panic(err) + } + + // Stream implements net.Conn + stream.Write([]byte("ping")) +} + +func server() { + // Accept a TCP connection + conn, err := listener.Accept() + if err != nil { + panic(err) + } + + // Setup server side of yamux + session, err := yamux.Server(conn, nil) + if err != nil { + panic(err) + } + + // Accept a stream + stream, err := session.Accept() + if err != nil { + panic(err) + } + + // Listen for a message + buf := make([]byte, 4) + stream.Read(buf) +} + +``` + diff --git a/vendor/github.com/hashicorp/yamux/addr.go b/vendor/github.com/hashicorp/yamux/addr.go new file mode 100644 index 0000000..be6ebca --- /dev/null +++ b/vendor/github.com/hashicorp/yamux/addr.go @@ -0,0 +1,60 @@ +package yamux + +import ( + "fmt" + "net" +) + +// hasAddr is used to get the address from the underlying connection +type hasAddr interface { + LocalAddr() net.Addr + RemoteAddr() net.Addr +} + +// yamuxAddr is used when we cannot get the underlying address +type yamuxAddr struct { + Addr string +} + +func (*yamuxAddr) Network() string { + return "yamux" +} + +func (y *yamuxAddr) String() string { + return fmt.Sprintf("yamux:%s", y.Addr) +} + +// Addr is used to get the address of the listener. +func (s *Session) Addr() net.Addr { + return s.LocalAddr() +} + +// LocalAddr is used to get the local address of the +// underlying connection. +func (s *Session) LocalAddr() net.Addr { + addr, ok := s.conn.(hasAddr) + if !ok { + return &yamuxAddr{"local"} + } + return addr.LocalAddr() +} + +// RemoteAddr is used to get the address of remote end +// of the underlying connection +func (s *Session) RemoteAddr() net.Addr { + addr, ok := s.conn.(hasAddr) + if !ok { + return &yamuxAddr{"remote"} + } + return addr.RemoteAddr() +} + +// LocalAddr returns the local address +func (s *Stream) LocalAddr() net.Addr { + return s.session.LocalAddr() +} + +// LocalAddr returns the remote address +func (s *Stream) RemoteAddr() net.Addr { + return s.session.RemoteAddr() +} diff --git a/vendor/github.com/hashicorp/yamux/bench_test.go b/vendor/github.com/hashicorp/yamux/bench_test.go new file mode 100644 index 0000000..5fc1c55 --- /dev/null +++ b/vendor/github.com/hashicorp/yamux/bench_test.go @@ -0,0 +1,123 @@ +package yamux + +import ( + "testing" +) + +func BenchmarkPing(b *testing.B) { + client, server := testClientServer() + defer client.Close() + defer server.Close() + + for i := 0; i < b.N; i++ { + rtt, err := client.Ping() + if err != nil { + b.Fatalf("err: %v", err) + } + if rtt == 0 { + b.Fatalf("bad: %v", rtt) + } + } +} + +func BenchmarkAccept(b *testing.B) { + client, server := testClientServer() + defer client.Close() + defer server.Close() + + go func() { + for i := 0; i < b.N; i++ { + stream, err := server.AcceptStream() + if err != nil { + return + } + stream.Close() + } + }() + + for i := 0; i < b.N; i++ { + stream, err := client.Open() + if err != nil { + b.Fatalf("err: %v", err) + } + stream.Close() + } +} + +func BenchmarkSendRecv(b *testing.B) { + client, server := testClientServer() + defer client.Close() + defer server.Close() + + sendBuf := make([]byte, 512) + recvBuf := make([]byte, 512) + + doneCh := make(chan struct{}) + go func() { + stream, err := server.AcceptStream() + if err != nil { + return + } + defer stream.Close() + for i := 0; i < b.N; i++ { + if _, err := stream.Read(recvBuf); err != nil { + b.Fatalf("err: %v", err) + } + } + close(doneCh) + }() + + stream, err := client.Open() + if err != nil { + b.Fatalf("err: %v", err) + } + defer stream.Close() + for i := 0; i < b.N; i++ { + if _, err := stream.Write(sendBuf); err != nil { + b.Fatalf("err: %v", err) + } + } + <-doneCh +} + +func BenchmarkSendRecvLarge(b *testing.B) { + client, server := testClientServer() + defer client.Close() + defer server.Close() + const sendSize = 512 * 1024 * 1024 + const recvSize = 4 * 1024 + + sendBuf := make([]byte, sendSize) + recvBuf := make([]byte, recvSize) + + b.ResetTimer() + recvDone := make(chan struct{}) + + go func() { + stream, err := server.AcceptStream() + if err != nil { + return + } + defer stream.Close() + for i := 0; i < b.N; i++ { + for j := 0; j < sendSize/recvSize; j++ { + if _, err := stream.Read(recvBuf); err != nil { + b.Fatalf("err: %v", err) + } + } + } + close(recvDone) + }() + + stream, err := client.Open() + if err != nil { + b.Fatalf("err: %v", err) + } + defer stream.Close() + for i := 0; i < b.N; i++ { + if _, err := stream.Write(sendBuf); err != nil { + b.Fatalf("err: %v", err) + } + } + <-recvDone +} diff --git a/vendor/github.com/hashicorp/yamux/const.go b/vendor/github.com/hashicorp/yamux/const.go new file mode 100644 index 0000000..4f52938 --- /dev/null +++ b/vendor/github.com/hashicorp/yamux/const.go @@ -0,0 +1,157 @@ +package yamux + +import ( + "encoding/binary" + "fmt" +) + +var ( + // ErrInvalidVersion means we received a frame with an + // invalid version + ErrInvalidVersion = fmt.Errorf("invalid protocol version") + + // ErrInvalidMsgType means we received a frame with an + // invalid message type + ErrInvalidMsgType = fmt.Errorf("invalid msg type") + + // ErrSessionShutdown is used if there is a shutdown during + // an operation + ErrSessionShutdown = fmt.Errorf("session shutdown") + + // ErrStreamsExhausted is returned if we have no more + // stream ids to issue + ErrStreamsExhausted = fmt.Errorf("streams exhausted") + + // ErrDuplicateStream is used if a duplicate stream is + // opened inbound + ErrDuplicateStream = fmt.Errorf("duplicate stream initiated") + + // ErrReceiveWindowExceeded indicates the window was exceeded + ErrRecvWindowExceeded = fmt.Errorf("recv window exceeded") + + // ErrTimeout is used when we reach an IO deadline + ErrTimeout = fmt.Errorf("i/o deadline reached") + + // ErrStreamClosed is returned when using a closed stream + ErrStreamClosed = fmt.Errorf("stream closed") + + // ErrUnexpectedFlag is set when we get an unexpected flag + ErrUnexpectedFlag = fmt.Errorf("unexpected flag") + + // ErrRemoteGoAway is used when we get a go away from the other side + ErrRemoteGoAway = fmt.Errorf("remote end is not accepting connections") + + // ErrConnectionReset is sent if a stream is reset. This can happen + // if the backlog is exceeded, or if there was a remote GoAway. + ErrConnectionReset = fmt.Errorf("connection reset") + + // ErrConnectionWriteTimeout indicates that we hit the "safety valve" + // timeout writing to the underlying stream connection. + ErrConnectionWriteTimeout = fmt.Errorf("connection write timeout") + + // ErrKeepAliveTimeout is sent if a missed keepalive caused the stream close + ErrKeepAliveTimeout = fmt.Errorf("keepalive timeout") +) + +const ( + // protoVersion is the only version we support + protoVersion uint8 = 0 +) + +const ( + // Data is used for data frames. They are followed + // by length bytes worth of payload. + typeData uint8 = iota + + // WindowUpdate is used to change the window of + // a given stream. The length indicates the delta + // update to the window. + typeWindowUpdate + + // Ping is sent as a keep-alive or to measure + // the RTT. The StreamID and Length value are echoed + // back in the response. + typePing + + // GoAway is sent to terminate a session. The StreamID + // should be 0 and the length is an error code. + typeGoAway +) + +const ( + // SYN is sent to signal a new stream. May + // be sent with a data payload + flagSYN uint16 = 1 << iota + + // ACK is sent to acknowledge a new stream. May + // be sent with a data payload + flagACK + + // FIN is sent to half-close the given stream. + // May be sent with a data payload. + flagFIN + + // RST is used to hard close a given stream. + flagRST +) + +const ( + // initialStreamWindow is the initial stream window size + initialStreamWindow uint32 = 256 * 1024 +) + +const ( + // goAwayNormal is sent on a normal termination + goAwayNormal uint32 = iota + + // goAwayProtoErr sent on a protocol error + goAwayProtoErr + + // goAwayInternalErr sent on an internal error + goAwayInternalErr +) + +const ( + sizeOfVersion = 1 + sizeOfType = 1 + sizeOfFlags = 2 + sizeOfStreamID = 4 + sizeOfLength = 4 + headerSize = sizeOfVersion + sizeOfType + sizeOfFlags + + sizeOfStreamID + sizeOfLength +) + +type header []byte + +func (h header) Version() uint8 { + return h[0] +} + +func (h header) MsgType() uint8 { + return h[1] +} + +func (h header) Flags() uint16 { + return binary.BigEndian.Uint16(h[2:4]) +} + +func (h header) StreamID() uint32 { + return binary.BigEndian.Uint32(h[4:8]) +} + +func (h header) Length() uint32 { + return binary.BigEndian.Uint32(h[8:12]) +} + +func (h header) String() string { + return fmt.Sprintf("Vsn:%d Type:%d Flags:%d StreamID:%d Length:%d", + h.Version(), h.MsgType(), h.Flags(), h.StreamID(), h.Length()) +} + +func (h header) encode(msgType uint8, flags uint16, streamID uint32, length uint32) { + h[0] = protoVersion + h[1] = msgType + binary.BigEndian.PutUint16(h[2:4], flags) + binary.BigEndian.PutUint32(h[4:8], streamID) + binary.BigEndian.PutUint32(h[8:12], length) +} diff --git a/vendor/github.com/hashicorp/yamux/const_test.go b/vendor/github.com/hashicorp/yamux/const_test.go new file mode 100644 index 0000000..153da18 --- /dev/null +++ b/vendor/github.com/hashicorp/yamux/const_test.go @@ -0,0 +1,72 @@ +package yamux + +import ( + "testing" +) + +func TestConst(t *testing.T) { + if protoVersion != 0 { + t.Fatalf("bad: %v", protoVersion) + } + + if typeData != 0 { + t.Fatalf("bad: %v", typeData) + } + if typeWindowUpdate != 1 { + t.Fatalf("bad: %v", typeWindowUpdate) + } + if typePing != 2 { + t.Fatalf("bad: %v", typePing) + } + if typeGoAway != 3 { + t.Fatalf("bad: %v", typeGoAway) + } + + if flagSYN != 1 { + t.Fatalf("bad: %v", flagSYN) + } + if flagACK != 2 { + t.Fatalf("bad: %v", flagACK) + } + if flagFIN != 4 { + t.Fatalf("bad: %v", flagFIN) + } + if flagRST != 8 { + t.Fatalf("bad: %v", flagRST) + } + + if goAwayNormal != 0 { + t.Fatalf("bad: %v", goAwayNormal) + } + if goAwayProtoErr != 1 { + t.Fatalf("bad: %v", goAwayProtoErr) + } + if goAwayInternalErr != 2 { + t.Fatalf("bad: %v", goAwayInternalErr) + } + + if headerSize != 12 { + t.Fatalf("bad header size") + } +} + +func TestEncodeDecode(t *testing.T) { + hdr := header(make([]byte, headerSize)) + hdr.encode(typeWindowUpdate, flagACK|flagRST, 1234, 4321) + + if hdr.Version() != protoVersion { + t.Fatalf("bad: %v", hdr) + } + if hdr.MsgType() != typeWindowUpdate { + t.Fatalf("bad: %v", hdr) + } + if hdr.Flags() != flagACK|flagRST { + t.Fatalf("bad: %v", hdr) + } + if hdr.StreamID() != 1234 { + t.Fatalf("bad: %v", hdr) + } + if hdr.Length() != 4321 { + t.Fatalf("bad: %v", hdr) + } +} diff --git a/vendor/github.com/hashicorp/yamux/mux.go b/vendor/github.com/hashicorp/yamux/mux.go new file mode 100644 index 0000000..7abc7c7 --- /dev/null +++ b/vendor/github.com/hashicorp/yamux/mux.go @@ -0,0 +1,87 @@ +package yamux + +import ( + "fmt" + "io" + "os" + "time" +) + +// Config is used to tune the Yamux session +type Config struct { + // AcceptBacklog is used to limit how many streams may be + // waiting an accept. + AcceptBacklog int + + // EnableKeepalive is used to do a period keep alive + // messages using a ping. + EnableKeepAlive bool + + // KeepAliveInterval is how often to perform the keep alive + KeepAliveInterval time.Duration + + // ConnectionWriteTimeout is meant to be a "safety valve" timeout after + // we which will suspect a problem with the underlying connection and + // close it. This is only applied to writes, where's there's generally + // an expectation that things will move along quickly. + ConnectionWriteTimeout time.Duration + + // MaxStreamWindowSize is used to control the maximum + // window size that we allow for a stream. + MaxStreamWindowSize uint32 + + // LogOutput is used to control the log destination + LogOutput io.Writer +} + +// DefaultConfig is used to return a default configuration +func DefaultConfig() *Config { + return &Config{ + AcceptBacklog: 256, + EnableKeepAlive: true, + KeepAliveInterval: 30 * time.Second, + ConnectionWriteTimeout: 10 * time.Second, + MaxStreamWindowSize: initialStreamWindow, + LogOutput: os.Stderr, + } +} + +// VerifyConfig is used to verify the sanity of configuration +func VerifyConfig(config *Config) error { + if config.AcceptBacklog <= 0 { + return fmt.Errorf("backlog must be positive") + } + if config.KeepAliveInterval == 0 { + return fmt.Errorf("keep-alive interval must be positive") + } + if config.MaxStreamWindowSize < initialStreamWindow { + return fmt.Errorf("MaxStreamWindowSize must be larger than %d", initialStreamWindow) + } + return nil +} + +// Server is used to initialize a new server-side connection. +// There must be at most one server-side connection. If a nil config is +// provided, the DefaultConfiguration will be used. +func Server(conn io.ReadWriteCloser, config *Config) (*Session, error) { + if config == nil { + config = DefaultConfig() + } + if err := VerifyConfig(config); err != nil { + return nil, err + } + return newSession(config, conn, false), nil +} + +// Client is used to initialize a new client-side connection. +// There must be at most one client-side connection. +func Client(conn io.ReadWriteCloser, config *Config) (*Session, error) { + if config == nil { + config = DefaultConfig() + } + + if err := VerifyConfig(config); err != nil { + return nil, err + } + return newSession(config, conn, true), nil +} diff --git a/vendor/github.com/hashicorp/yamux/session.go b/vendor/github.com/hashicorp/yamux/session.go new file mode 100644 index 0000000..d8446fa --- /dev/null +++ b/vendor/github.com/hashicorp/yamux/session.go @@ -0,0 +1,646 @@ +package yamux + +import ( + "bufio" + "fmt" + "io" + "io/ioutil" + "log" + "math" + "net" + "strings" + "sync" + "sync/atomic" + "time" +) + +// Session is used to wrap a reliable ordered connection and to +// multiplex it into multiple streams. +type Session struct { + // remoteGoAway indicates the remote side does + // not want futher connections. Must be first for alignment. + remoteGoAway int32 + + // localGoAway indicates that we should stop + // accepting futher connections. Must be first for alignment. + localGoAway int32 + + // nextStreamID is the next stream we should + // send. This depends if we are a client/server. + nextStreamID uint32 + + // config holds our configuration + config *Config + + // logger is used for our logs + logger *log.Logger + + // conn is the underlying connection + conn io.ReadWriteCloser + + // bufRead is a buffered reader + bufRead *bufio.Reader + + // pings is used to track inflight pings + pings map[uint32]chan struct{} + pingID uint32 + pingLock sync.Mutex + + // streams maps a stream id to a stream, and inflight has an entry + // for any outgoing stream that has not yet been established. Both are + // protected by streamLock. + streams map[uint32]*Stream + inflight map[uint32]struct{} + streamLock sync.Mutex + + // synCh acts like a semaphore. It is sized to the AcceptBacklog which + // is assumed to be symmetric between the client and server. This allows + // the client to avoid exceeding the backlog and instead blocks the open. + synCh chan struct{} + + // acceptCh is used to pass ready streams to the client + acceptCh chan *Stream + + // sendCh is used to mark a stream as ready to send, + // or to send a header out directly. + sendCh chan sendReady + + // recvDoneCh is closed when recv() exits to avoid a race + // between stream registration and stream shutdown + recvDoneCh chan struct{} + + // shutdown is used to safely close a session + shutdown bool + shutdownErr error + shutdownCh chan struct{} + shutdownLock sync.Mutex +} + +// sendReady is used to either mark a stream as ready +// or to directly send a header +type sendReady struct { + Hdr []byte + Body io.Reader + Err chan error +} + +// newSession is used to construct a new session +func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session { + s := &Session{ + config: config, + logger: log.New(config.LogOutput, "", log.LstdFlags), + conn: conn, + bufRead: bufio.NewReader(conn), + pings: make(map[uint32]chan struct{}), + streams: make(map[uint32]*Stream), + inflight: make(map[uint32]struct{}), + synCh: make(chan struct{}, config.AcceptBacklog), + acceptCh: make(chan *Stream, config.AcceptBacklog), + sendCh: make(chan sendReady, 64), + recvDoneCh: make(chan struct{}), + shutdownCh: make(chan struct{}), + } + if client { + s.nextStreamID = 1 + } else { + s.nextStreamID = 2 + } + go s.recv() + go s.send() + if config.EnableKeepAlive { + go s.keepalive() + } + return s +} + +// IsClosed does a safe check to see if we have shutdown +func (s *Session) IsClosed() bool { + select { + case <-s.shutdownCh: + return true + default: + return false + } +} + +// CloseChan returns a read-only channel which is closed as +// soon as the session is closed. +func (s *Session) CloseChan() <-chan struct{} { + return s.shutdownCh +} + +// NumStreams returns the number of currently open streams +func (s *Session) NumStreams() int { + s.streamLock.Lock() + num := len(s.streams) + s.streamLock.Unlock() + return num +} + +// Open is used to create a new stream as a net.Conn +func (s *Session) Open() (net.Conn, error) { + conn, err := s.OpenStream() + if err != nil { + return nil, err + } + return conn, nil +} + +// OpenStream is used to create a new stream +func (s *Session) OpenStream() (*Stream, error) { + if s.IsClosed() { + return nil, ErrSessionShutdown + } + if atomic.LoadInt32(&s.remoteGoAway) == 1 { + return nil, ErrRemoteGoAway + } + + // Block if we have too many inflight SYNs + select { + case s.synCh <- struct{}{}: + case <-s.shutdownCh: + return nil, ErrSessionShutdown + } + +GET_ID: + // Get an ID, and check for stream exhaustion + id := atomic.LoadUint32(&s.nextStreamID) + if id >= math.MaxUint32-1 { + return nil, ErrStreamsExhausted + } + if !atomic.CompareAndSwapUint32(&s.nextStreamID, id, id+2) { + goto GET_ID + } + + // Register the stream + stream := newStream(s, id, streamInit) + s.streamLock.Lock() + s.streams[id] = stream + s.inflight[id] = struct{}{} + s.streamLock.Unlock() + + // Send the window update to create + if err := stream.sendWindowUpdate(); err != nil { + select { + case <-s.synCh: + default: + s.logger.Printf("[ERR] yamux: aborted stream open without inflight syn semaphore") + } + return nil, err + } + return stream, nil +} + +// Accept is used to block until the next available stream +// is ready to be accepted. +func (s *Session) Accept() (net.Conn, error) { + conn, err := s.AcceptStream() + if err != nil { + return nil, err + } + return conn, err +} + +// AcceptStream is used to block until the next available stream +// is ready to be accepted. +func (s *Session) AcceptStream() (*Stream, error) { + select { + case stream := <-s.acceptCh: + if err := stream.sendWindowUpdate(); err != nil { + return nil, err + } + return stream, nil + case <-s.shutdownCh: + return nil, s.shutdownErr + } +} + +// Close is used to close the session and all streams. +// Attempts to send a GoAway before closing the connection. +func (s *Session) Close() error { + s.shutdownLock.Lock() + defer s.shutdownLock.Unlock() + + if s.shutdown { + return nil + } + s.shutdown = true + if s.shutdownErr == nil { + s.shutdownErr = ErrSessionShutdown + } + close(s.shutdownCh) + s.conn.Close() + <-s.recvDoneCh + + s.streamLock.Lock() + defer s.streamLock.Unlock() + for _, stream := range s.streams { + stream.forceClose() + } + return nil +} + +// exitErr is used to handle an error that is causing the +// session to terminate. +func (s *Session) exitErr(err error) { + s.shutdownLock.Lock() + if s.shutdownErr == nil { + s.shutdownErr = err + } + s.shutdownLock.Unlock() + s.Close() +} + +// GoAway can be used to prevent accepting further +// connections. It does not close the underlying conn. +func (s *Session) GoAway() error { + return s.waitForSend(s.goAway(goAwayNormal), nil) +} + +// goAway is used to send a goAway message +func (s *Session) goAway(reason uint32) header { + atomic.SwapInt32(&s.localGoAway, 1) + hdr := header(make([]byte, headerSize)) + hdr.encode(typeGoAway, 0, 0, reason) + return hdr +} + +// Ping is used to measure the RTT response time +func (s *Session) Ping() (time.Duration, error) { + // Get a channel for the ping + ch := make(chan struct{}) + + // Get a new ping id, mark as pending + s.pingLock.Lock() + id := s.pingID + s.pingID++ + s.pings[id] = ch + s.pingLock.Unlock() + + // Send the ping request + hdr := header(make([]byte, headerSize)) + hdr.encode(typePing, flagSYN, 0, id) + if err := s.waitForSend(hdr, nil); err != nil { + return 0, err + } + + // Wait for a response + start := time.Now() + select { + case <-ch: + case <-time.After(s.config.ConnectionWriteTimeout): + s.pingLock.Lock() + delete(s.pings, id) // Ignore it if a response comes later. + s.pingLock.Unlock() + return 0, ErrTimeout + case <-s.shutdownCh: + return 0, ErrSessionShutdown + } + + // Compute the RTT + return time.Now().Sub(start), nil +} + +// keepalive is a long running goroutine that periodically does +// a ping to keep the connection alive. +func (s *Session) keepalive() { + for { + select { + case <-time.After(s.config.KeepAliveInterval): + _, err := s.Ping() + if err != nil { + s.logger.Printf("[ERR] yamux: keepalive failed: %v", err) + s.exitErr(ErrKeepAliveTimeout) + return + } + case <-s.shutdownCh: + return + } + } +} + +// waitForSendErr waits to send a header, checking for a potential shutdown +func (s *Session) waitForSend(hdr header, body io.Reader) error { + errCh := make(chan error, 1) + return s.waitForSendErr(hdr, body, errCh) +} + +// waitForSendErr waits to send a header with optional data, checking for a +// potential shutdown. Since there's the expectation that sends can happen +// in a timely manner, we enforce the connection write timeout here. +func (s *Session) waitForSendErr(hdr header, body io.Reader, errCh chan error) error { + t := timerPool.Get() + timer := t.(*time.Timer) + timer.Reset(s.config.ConnectionWriteTimeout) + defer func() { + timer.Stop() + select { + case <-timer.C: + default: + } + timerPool.Put(t) + }() + + ready := sendReady{Hdr: hdr, Body: body, Err: errCh} + select { + case s.sendCh <- ready: + case <-s.shutdownCh: + return ErrSessionShutdown + case <-timer.C: + return ErrConnectionWriteTimeout + } + + select { + case err := <-errCh: + return err + case <-s.shutdownCh: + return ErrSessionShutdown + case <-timer.C: + return ErrConnectionWriteTimeout + } +} + +// sendNoWait does a send without waiting. Since there's the expectation that +// the send happens right here, we enforce the connection write timeout if we +// can't queue the header to be sent. +func (s *Session) sendNoWait(hdr header) error { + t := timerPool.Get() + timer := t.(*time.Timer) + timer.Reset(s.config.ConnectionWriteTimeout) + defer func() { + timer.Stop() + select { + case <-timer.C: + default: + } + timerPool.Put(t) + }() + + select { + case s.sendCh <- sendReady{Hdr: hdr}: + return nil + case <-s.shutdownCh: + return ErrSessionShutdown + case <-timer.C: + return ErrConnectionWriteTimeout + } +} + +// send is a long running goroutine that sends data +func (s *Session) send() { + for { + select { + case ready := <-s.sendCh: + // Send a header if ready + if ready.Hdr != nil { + sent := 0 + for sent < len(ready.Hdr) { + n, err := s.conn.Write(ready.Hdr[sent:]) + if err != nil { + s.logger.Printf("[ERR] yamux: Failed to write header: %v", err) + asyncSendErr(ready.Err, err) + s.exitErr(err) + return + } + sent += n + } + } + + // Send data from a body if given + if ready.Body != nil { + _, err := io.Copy(s.conn, ready.Body) + if err != nil { + s.logger.Printf("[ERR] yamux: Failed to write body: %v", err) + asyncSendErr(ready.Err, err) + s.exitErr(err) + return + } + } + + // No error, successful send + asyncSendErr(ready.Err, nil) + case <-s.shutdownCh: + return + } + } +} + +// recv is a long running goroutine that accepts new data +func (s *Session) recv() { + if err := s.recvLoop(); err != nil { + s.exitErr(err) + } +} + +// Ensure that the index of the handler (typeData/typeWindowUpdate/etc) matches the message type +var ( + handlers = []func(*Session, header) error{ + typeData: (*Session).handleStreamMessage, + typeWindowUpdate: (*Session).handleStreamMessage, + typePing: (*Session).handlePing, + typeGoAway: (*Session).handleGoAway, + } +) + +// recvLoop continues to receive data until a fatal error is encountered +func (s *Session) recvLoop() error { + defer close(s.recvDoneCh) + hdr := header(make([]byte, headerSize)) + for { + // Read the header + if _, err := io.ReadFull(s.bufRead, hdr); err != nil { + if err != io.EOF && !strings.Contains(err.Error(), "closed") && !strings.Contains(err.Error(), "reset by peer") { + s.logger.Printf("[ERR] yamux: Failed to read header: %v", err) + } + return err + } + + // Verify the version + if hdr.Version() != protoVersion { + s.logger.Printf("[ERR] yamux: Invalid protocol version: %d", hdr.Version()) + return ErrInvalidVersion + } + + mt := hdr.MsgType() + if mt < typeData || mt > typeGoAway { + return ErrInvalidMsgType + } + + if err := handlers[mt](s, hdr); err != nil { + return err + } + } +} + +// handleStreamMessage handles either a data or window update frame +func (s *Session) handleStreamMessage(hdr header) error { + // Check for a new stream creation + id := hdr.StreamID() + flags := hdr.Flags() + if flags&flagSYN == flagSYN { + if err := s.incomingStream(id); err != nil { + return err + } + } + + // Get the stream + s.streamLock.Lock() + stream := s.streams[id] + s.streamLock.Unlock() + + // If we do not have a stream, likely we sent a RST + if stream == nil { + // Drain any data on the wire + if hdr.MsgType() == typeData && hdr.Length() > 0 { + s.logger.Printf("[WARN] yamux: Discarding data for stream: %d", id) + if _, err := io.CopyN(ioutil.Discard, s.bufRead, int64(hdr.Length())); err != nil { + s.logger.Printf("[ERR] yamux: Failed to discard data: %v", err) + return nil + } + } else { + s.logger.Printf("[WARN] yamux: frame for missing stream: %v", hdr) + } + return nil + } + + // Check if this is a window update + if hdr.MsgType() == typeWindowUpdate { + if err := stream.incrSendWindow(hdr, flags); err != nil { + if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil { + s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr) + } + return err + } + return nil + } + + // Read the new data + if err := stream.readData(hdr, flags, s.bufRead); err != nil { + if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil { + s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr) + } + return err + } + return nil +} + +// handlePing is invokde for a typePing frame +func (s *Session) handlePing(hdr header) error { + flags := hdr.Flags() + pingID := hdr.Length() + + // Check if this is a query, respond back in a separate context so we + // don't interfere with the receiving thread blocking for the write. + if flags&flagSYN == flagSYN { + go func() { + hdr := header(make([]byte, headerSize)) + hdr.encode(typePing, flagACK, 0, pingID) + if err := s.sendNoWait(hdr); err != nil { + s.logger.Printf("[WARN] yamux: failed to send ping reply: %v", err) + } + }() + return nil + } + + // Handle a response + s.pingLock.Lock() + ch := s.pings[pingID] + if ch != nil { + delete(s.pings, pingID) + close(ch) + } + s.pingLock.Unlock() + return nil +} + +// handleGoAway is invokde for a typeGoAway frame +func (s *Session) handleGoAway(hdr header) error { + code := hdr.Length() + switch code { + case goAwayNormal: + atomic.SwapInt32(&s.remoteGoAway, 1) + case goAwayProtoErr: + s.logger.Printf("[ERR] yamux: received protocol error go away") + return fmt.Errorf("yamux protocol error") + case goAwayInternalErr: + s.logger.Printf("[ERR] yamux: received internal error go away") + return fmt.Errorf("remote yamux internal error") + default: + s.logger.Printf("[ERR] yamux: received unexpected go away") + return fmt.Errorf("unexpected go away received") + } + return nil +} + +// incomingStream is used to create a new incoming stream +func (s *Session) incomingStream(id uint32) error { + // Reject immediately if we are doing a go away + if atomic.LoadInt32(&s.localGoAway) == 1 { + hdr := header(make([]byte, headerSize)) + hdr.encode(typeWindowUpdate, flagRST, id, 0) + return s.sendNoWait(hdr) + } + + // Allocate a new stream + stream := newStream(s, id, streamSYNReceived) + + s.streamLock.Lock() + defer s.streamLock.Unlock() + + // Check if stream already exists + if _, ok := s.streams[id]; ok { + s.logger.Printf("[ERR] yamux: duplicate stream declared") + if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil { + s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr) + } + return ErrDuplicateStream + } + + // Register the stream + s.streams[id] = stream + + // Check if we've exceeded the backlog + select { + case s.acceptCh <- stream: + return nil + default: + // Backlog exceeded! RST the stream + s.logger.Printf("[WARN] yamux: backlog exceeded, forcing connection reset") + delete(s.streams, id) + stream.sendHdr.encode(typeWindowUpdate, flagRST, id, 0) + return s.sendNoWait(stream.sendHdr) + } +} + +// closeStream is used to close a stream once both sides have +// issued a close. If there was an in-flight SYN and the stream +// was not yet established, then this will give the credit back. +func (s *Session) closeStream(id uint32) { + s.streamLock.Lock() + if _, ok := s.inflight[id]; ok { + select { + case <-s.synCh: + default: + s.logger.Printf("[ERR] yamux: SYN tracking out of sync") + } + } + delete(s.streams, id) + s.streamLock.Unlock() +} + +// establishStream is used to mark a stream that was in the +// SYN Sent state as established. +func (s *Session) establishStream(id uint32) { + s.streamLock.Lock() + if _, ok := s.inflight[id]; ok { + delete(s.inflight, id) + } else { + s.logger.Printf("[ERR] yamux: established stream without inflight SYN (no tracking entry)") + } + select { + case <-s.synCh: + default: + s.logger.Printf("[ERR] yamux: established stream without inflight SYN (didn't have semaphore)") + } + s.streamLock.Unlock() +} diff --git a/vendor/github.com/hashicorp/yamux/session_test.go b/vendor/github.com/hashicorp/yamux/session_test.go new file mode 100644 index 0000000..1645e2b --- /dev/null +++ b/vendor/github.com/hashicorp/yamux/session_test.go @@ -0,0 +1,1256 @@ +package yamux + +import ( + "bytes" + "fmt" + "io" + "io/ioutil" + "log" + "reflect" + "runtime" + "strings" + "sync" + "testing" + "time" +) + +type logCapture struct{ bytes.Buffer } + +func (l *logCapture) logs() []string { + return strings.Split(strings.TrimSpace(l.String()), "\n") +} + +func (l *logCapture) match(expect []string) bool { + return reflect.DeepEqual(l.logs(), expect) +} + +func captureLogs(s *Session) *logCapture { + buf := new(logCapture) + s.logger = log.New(buf, "", 0) + return buf +} + +type pipeConn struct { + reader *io.PipeReader + writer *io.PipeWriter + writeBlocker sync.Mutex +} + +func (p *pipeConn) Read(b []byte) (int, error) { + return p.reader.Read(b) +} + +func (p *pipeConn) Write(b []byte) (int, error) { + p.writeBlocker.Lock() + defer p.writeBlocker.Unlock() + return p.writer.Write(b) +} + +func (p *pipeConn) Close() error { + p.reader.Close() + return p.writer.Close() +} + +func testConn() (io.ReadWriteCloser, io.ReadWriteCloser) { + read1, write1 := io.Pipe() + read2, write2 := io.Pipe() + conn1 := &pipeConn{reader: read1, writer: write2} + conn2 := &pipeConn{reader: read2, writer: write1} + return conn1, conn2 +} + +func testConf() *Config { + conf := DefaultConfig() + conf.AcceptBacklog = 64 + conf.KeepAliveInterval = 100 * time.Millisecond + conf.ConnectionWriteTimeout = 250 * time.Millisecond + return conf +} + +func testConfNoKeepAlive() *Config { + conf := testConf() + conf.EnableKeepAlive = false + return conf +} + +func testClientServer() (*Session, *Session) { + return testClientServerConfig(testConf()) +} + +func testClientServerConfig(conf *Config) (*Session, *Session) { + conn1, conn2 := testConn() + client, _ := Client(conn1, conf) + server, _ := Server(conn2, conf) + return client, server +} + +func TestPing(t *testing.T) { + client, server := testClientServer() + defer client.Close() + defer server.Close() + + rtt, err := client.Ping() + if err != nil { + t.Fatalf("err: %v", err) + } + if rtt == 0 { + t.Fatalf("bad: %v", rtt) + } + + rtt, err = server.Ping() + if err != nil { + t.Fatalf("err: %v", err) + } + if rtt == 0 { + t.Fatalf("bad: %v", rtt) + } +} + +func TestPing_Timeout(t *testing.T) { + client, server := testClientServerConfig(testConfNoKeepAlive()) + defer client.Close() + defer server.Close() + + // Prevent the client from responding + clientConn := client.conn.(*pipeConn) + clientConn.writeBlocker.Lock() + + errCh := make(chan error, 1) + go func() { + _, err := server.Ping() // Ping via the server session + errCh <- err + }() + + select { + case err := <-errCh: + if err != ErrTimeout { + t.Fatalf("err: %v", err) + } + case <-time.After(client.config.ConnectionWriteTimeout * 2): + t.Fatalf("failed to timeout within expected %v", client.config.ConnectionWriteTimeout) + } + + // Verify that we recover, even if we gave up + clientConn.writeBlocker.Unlock() + + go func() { + _, err := server.Ping() // Ping via the server session + errCh <- err + }() + + select { + case err := <-errCh: + if err != nil { + t.Fatalf("err: %v", err) + } + case <-time.After(client.config.ConnectionWriteTimeout): + t.Fatalf("timeout") + } +} + +func TestCloseBeforeAck(t *testing.T) { + cfg := testConf() + cfg.AcceptBacklog = 8 + client, server := testClientServerConfig(cfg) + + defer client.Close() + defer server.Close() + + for i := 0; i < 8; i++ { + s, err := client.OpenStream() + if err != nil { + t.Fatal(err) + } + s.Close() + } + + for i := 0; i < 8; i++ { + s, err := server.AcceptStream() + if err != nil { + t.Fatal(err) + } + s.Close() + } + + done := make(chan struct{}) + go func() { + defer close(done) + s, err := client.OpenStream() + if err != nil { + t.Fatal(err) + } + s.Close() + }() + + select { + case <-done: + case <-time.After(time.Second * 5): + t.Fatal("timed out trying to open stream") + } +} + +func TestAccept(t *testing.T) { + client, server := testClientServer() + defer client.Close() + defer server.Close() + + if client.NumStreams() != 0 { + t.Fatalf("bad") + } + if server.NumStreams() != 0 { + t.Fatalf("bad") + } + + wg := &sync.WaitGroup{} + wg.Add(4) + + go func() { + defer wg.Done() + stream, err := server.AcceptStream() + if err != nil { + t.Fatalf("err: %v", err) + } + if id := stream.StreamID(); id != 1 { + t.Fatalf("bad: %v", id) + } + if err := stream.Close(); err != nil { + t.Fatalf("err: %v", err) + } + }() + + go func() { + defer wg.Done() + stream, err := client.AcceptStream() + if err != nil { + t.Fatalf("err: %v", err) + } + if id := stream.StreamID(); id != 2 { + t.Fatalf("bad: %v", id) + } + if err := stream.Close(); err != nil { + t.Fatalf("err: %v", err) + } + }() + + go func() { + defer wg.Done() + stream, err := server.OpenStream() + if err != nil { + t.Fatalf("err: %v", err) + } + if id := stream.StreamID(); id != 2 { + t.Fatalf("bad: %v", id) + } + if err := stream.Close(); err != nil { + t.Fatalf("err: %v", err) + } + }() + + go func() { + defer wg.Done() + stream, err := client.OpenStream() + if err != nil { + t.Fatalf("err: %v", err) + } + if id := stream.StreamID(); id != 1 { + t.Fatalf("bad: %v", id) + } + if err := stream.Close(); err != nil { + t.Fatalf("err: %v", err) + } + }() + + doneCh := make(chan struct{}) + go func() { + wg.Wait() + close(doneCh) + }() + + select { + case <-doneCh: + case <-time.After(time.Second): + panic("timeout") + } +} + +func TestNonNilInterface(t *testing.T) { + _, server := testClientServer() + server.Close() + + conn, err := server.Accept() + if err != nil && conn != nil { + t.Error("bad: accept should return a connection of nil value") + } + + conn, err = server.Open() + if err != nil && conn != nil { + t.Error("bad: open should return a connection of nil value") + } +} + +func TestSendData_Small(t *testing.T) { + client, server := testClientServer() + defer client.Close() + defer server.Close() + + wg := &sync.WaitGroup{} + wg.Add(2) + + go func() { + defer wg.Done() + stream, err := server.AcceptStream() + if err != nil { + t.Fatalf("err: %v", err) + } + + if server.NumStreams() != 1 { + t.Fatalf("bad") + } + + buf := make([]byte, 4) + for i := 0; i < 1000; i++ { + n, err := stream.Read(buf) + if err != nil { + t.Fatalf("err: %v", err) + } + if n != 4 { + t.Fatalf("short read: %d", n) + } + if string(buf) != "test" { + t.Fatalf("bad: %s", buf) + } + } + + if err := stream.Close(); err != nil { + t.Fatalf("err: %v", err) + } + }() + + go func() { + defer wg.Done() + stream, err := client.Open() + if err != nil { + t.Fatalf("err: %v", err) + } + + if client.NumStreams() != 1 { + t.Fatalf("bad") + } + + for i := 0; i < 1000; i++ { + n, err := stream.Write([]byte("test")) + if err != nil { + t.Fatalf("err: %v", err) + } + if n != 4 { + t.Fatalf("short write %d", n) + } + } + + if err := stream.Close(); err != nil { + t.Fatalf("err: %v", err) + } + }() + + doneCh := make(chan struct{}) + go func() { + wg.Wait() + close(doneCh) + }() + select { + case <-doneCh: + case <-time.After(time.Second): + panic("timeout") + } + + if client.NumStreams() != 0 { + t.Fatalf("bad") + } + if server.NumStreams() != 0 { + t.Fatalf("bad") + } +} + +func TestSendData_Large(t *testing.T) { + client, server := testClientServer() + defer client.Close() + defer server.Close() + + const ( + sendSize = 250 * 1024 * 1024 + recvSize = 4 * 1024 + ) + + data := make([]byte, sendSize) + for idx := range data { + data[idx] = byte(idx % 256) + } + + wg := &sync.WaitGroup{} + wg.Add(2) + + go func() { + defer wg.Done() + stream, err := server.AcceptStream() + if err != nil { + t.Fatalf("err: %v", err) + } + var sz int + buf := make([]byte, recvSize) + for i := 0; i < sendSize/recvSize; i++ { + n, err := stream.Read(buf) + if err != nil { + t.Fatalf("err: %v", err) + } + if n != recvSize { + t.Fatalf("short read: %d", n) + } + sz += n + for idx := range buf { + if buf[idx] != byte(idx%256) { + t.Fatalf("bad: %v %v %v", i, idx, buf[idx]) + } + } + } + + if err := stream.Close(); err != nil { + t.Fatalf("err: %v", err) + } + + t.Logf("cap=%d, n=%d\n", stream.recvBuf.Cap(), sz) + }() + + go func() { + defer wg.Done() + stream, err := client.Open() + if err != nil { + t.Fatalf("err: %v", err) + } + + n, err := stream.Write(data) + if err != nil { + t.Fatalf("err: %v", err) + } + if n != len(data) { + t.Fatalf("short write %d", n) + } + + if err := stream.Close(); err != nil { + t.Fatalf("err: %v", err) + } + }() + + doneCh := make(chan struct{}) + go func() { + wg.Wait() + close(doneCh) + }() + select { + case <-doneCh: + case <-time.After(5 * time.Second): + panic("timeout") + } +} + +func TestGoAway(t *testing.T) { + client, server := testClientServer() + defer client.Close() + defer server.Close() + + if err := server.GoAway(); err != nil { + t.Fatalf("err: %v", err) + } + + _, err := client.Open() + if err != ErrRemoteGoAway { + t.Fatalf("err: %v", err) + } +} + +func TestManyStreams(t *testing.T) { + client, server := testClientServer() + defer client.Close() + defer server.Close() + + wg := &sync.WaitGroup{} + + acceptor := func(i int) { + defer wg.Done() + stream, err := server.AcceptStream() + if err != nil { + t.Fatalf("err: %v", err) + } + defer stream.Close() + + buf := make([]byte, 512) + for { + n, err := stream.Read(buf) + if err == io.EOF { + return + } + if err != nil { + t.Fatalf("err: %v", err) + } + if n == 0 { + t.Fatalf("err: %v", err) + } + } + } + sender := func(i int) { + defer wg.Done() + stream, err := client.Open() + if err != nil { + t.Fatalf("err: %v", err) + } + defer stream.Close() + + msg := fmt.Sprintf("%08d", i) + for i := 0; i < 1000; i++ { + n, err := stream.Write([]byte(msg)) + if err != nil { + t.Fatalf("err: %v", err) + } + if n != len(msg) { + t.Fatalf("short write %d", n) + } + } + } + + for i := 0; i < 50; i++ { + wg.Add(2) + go acceptor(i) + go sender(i) + } + + wg.Wait() +} + +func TestManyStreams_PingPong(t *testing.T) { + client, server := testClientServer() + defer client.Close() + defer server.Close() + + wg := &sync.WaitGroup{} + + ping := []byte("ping") + pong := []byte("pong") + + acceptor := func(i int) { + defer wg.Done() + stream, err := server.AcceptStream() + if err != nil { + t.Fatalf("err: %v", err) + } + defer stream.Close() + + buf := make([]byte, 4) + for { + // Read the 'ping' + n, err := stream.Read(buf) + if err == io.EOF { + return + } + if err != nil { + t.Fatalf("err: %v", err) + } + if n != 4 { + t.Fatalf("err: %v", err) + } + if !bytes.Equal(buf, ping) { + t.Fatalf("bad: %s", buf) + } + + // Shrink the internal buffer! + stream.Shrink() + + // Write out the 'pong' + n, err = stream.Write(pong) + if err != nil { + t.Fatalf("err: %v", err) + } + if n != 4 { + t.Fatalf("err: %v", err) + } + } + } + sender := func(i int) { + defer wg.Done() + stream, err := client.OpenStream() + if err != nil { + t.Fatalf("err: %v", err) + } + defer stream.Close() + + buf := make([]byte, 4) + for i := 0; i < 1000; i++ { + // Send the 'ping' + n, err := stream.Write(ping) + if err != nil { + t.Fatalf("err: %v", err) + } + if n != 4 { + t.Fatalf("short write %d", n) + } + + // Read the 'pong' + n, err = stream.Read(buf) + if err != nil { + t.Fatalf("err: %v", err) + } + if n != 4 { + t.Fatalf("err: %v", err) + } + if !bytes.Equal(buf, pong) { + t.Fatalf("bad: %s", buf) + } + + // Shrink the buffer + stream.Shrink() + } + } + + for i := 0; i < 50; i++ { + wg.Add(2) + go acceptor(i) + go sender(i) + } + + wg.Wait() +} + +func TestHalfClose(t *testing.T) { + client, server := testClientServer() + defer client.Close() + defer server.Close() + + stream, err := client.Open() + if err != nil { + t.Fatalf("err: %v", err) + } + if _, err = stream.Write([]byte("a")); err != nil { + t.Fatalf("err: %v", err) + } + + stream2, err := server.Accept() + if err != nil { + t.Fatalf("err: %v", err) + } + stream2.Close() // Half close + + buf := make([]byte, 4) + n, err := stream2.Read(buf) + if err != nil { + t.Fatalf("err: %v", err) + } + if n != 1 { + t.Fatalf("bad: %v", n) + } + + // Send more + if _, err = stream.Write([]byte("bcd")); err != nil { + t.Fatalf("err: %v", err) + } + stream.Close() + + // Read after close + n, err = stream2.Read(buf) + if err != nil { + t.Fatalf("err: %v", err) + } + if n != 3 { + t.Fatalf("bad: %v", n) + } + + // EOF after close + n, err = stream2.Read(buf) + if err != io.EOF { + t.Fatalf("err: %v", err) + } + if n != 0 { + t.Fatalf("bad: %v", n) + } +} + +func TestReadDeadline(t *testing.T) { + client, server := testClientServer() + defer client.Close() + defer server.Close() + + stream, err := client.Open() + if err != nil { + t.Fatalf("err: %v", err) + } + defer stream.Close() + + stream2, err := server.Accept() + if err != nil { + t.Fatalf("err: %v", err) + } + defer stream2.Close() + + if err := stream.SetReadDeadline(time.Now().Add(5 * time.Millisecond)); err != nil { + t.Fatalf("err: %v", err) + } + + buf := make([]byte, 4) + if _, err := stream.Read(buf); err != ErrTimeout { + t.Fatalf("err: %v", err) + } +} + +func TestWriteDeadline(t *testing.T) { + client, server := testClientServer() + defer client.Close() + defer server.Close() + + stream, err := client.Open() + if err != nil { + t.Fatalf("err: %v", err) + } + defer stream.Close() + + stream2, err := server.Accept() + if err != nil { + t.Fatalf("err: %v", err) + } + defer stream2.Close() + + if err := stream.SetWriteDeadline(time.Now().Add(50 * time.Millisecond)); err != nil { + t.Fatalf("err: %v", err) + } + + buf := make([]byte, 512) + for i := 0; i < int(initialStreamWindow); i++ { + _, err := stream.Write(buf) + if err != nil && err == ErrTimeout { + return + } else if err != nil { + t.Fatalf("err: %v", err) + } + } + t.Fatalf("Expected timeout") +} + +func TestBacklogExceeded(t *testing.T) { + client, server := testClientServer() + defer client.Close() + defer server.Close() + + // Fill the backlog + max := client.config.AcceptBacklog + for i := 0; i < max; i++ { + stream, err := client.Open() + if err != nil { + t.Fatalf("err: %v", err) + } + defer stream.Close() + + if _, err := stream.Write([]byte("foo")); err != nil { + t.Fatalf("err: %v", err) + } + } + + // Attempt to open a new stream + errCh := make(chan error, 1) + go func() { + _, err := client.Open() + errCh <- err + }() + + // Shutdown the server + go func() { + time.Sleep(10 * time.Millisecond) + server.Close() + }() + + select { + case err := <-errCh: + if err == nil { + t.Fatalf("open should fail") + } + case <-time.After(time.Second): + t.Fatalf("timeout") + } +} + +func TestKeepAlive(t *testing.T) { + client, server := testClientServer() + defer client.Close() + defer server.Close() + + time.Sleep(200 * time.Millisecond) + + // Ping value should increase + client.pingLock.Lock() + defer client.pingLock.Unlock() + if client.pingID == 0 { + t.Fatalf("should ping") + } + + server.pingLock.Lock() + defer server.pingLock.Unlock() + if server.pingID == 0 { + t.Fatalf("should ping") + } +} + +func TestKeepAlive_Timeout(t *testing.T) { + conn1, conn2 := testConn() + + clientConf := testConf() + clientConf.ConnectionWriteTimeout = time.Hour // We're testing keep alives, not connection writes + clientConf.EnableKeepAlive = false // Just test one direction, so it's deterministic who hangs up on whom + client, _ := Client(conn1, clientConf) + defer client.Close() + + server, _ := Server(conn2, testConf()) + defer server.Close() + + _ = captureLogs(client) // Client logs aren't part of the test + serverLogs := captureLogs(server) + + errCh := make(chan error, 1) + go func() { + _, err := server.Accept() // Wait until server closes + errCh <- err + }() + + // Prevent the client from responding + clientConn := client.conn.(*pipeConn) + clientConn.writeBlocker.Lock() + + select { + case err := <-errCh: + if err != ErrKeepAliveTimeout { + t.Fatalf("unexpected error: %v", err) + } + case <-time.After(1 * time.Second): + t.Fatalf("timeout waiting for timeout") + } + + if !server.IsClosed() { + t.Fatalf("server should have closed") + } + + if !serverLogs.match([]string{"[ERR] yamux: keepalive failed: i/o deadline reached"}) { + t.Fatalf("server log incorect: %v", serverLogs.logs()) + } +} + +func TestLargeWindow(t *testing.T) { + conf := DefaultConfig() + conf.MaxStreamWindowSize *= 2 + + client, server := testClientServerConfig(conf) + defer client.Close() + defer server.Close() + + stream, err := client.Open() + if err != nil { + t.Fatalf("err: %v", err) + } + defer stream.Close() + + stream2, err := server.Accept() + if err != nil { + t.Fatalf("err: %v", err) + } + defer stream2.Close() + + stream.SetWriteDeadline(time.Now().Add(10 * time.Millisecond)) + buf := make([]byte, conf.MaxStreamWindowSize) + n, err := stream.Write(buf) + if err != nil { + t.Fatalf("err: %v", err) + } + if n != len(buf) { + t.Fatalf("short write: %d", n) + } +} + +type UnlimitedReader struct{} + +func (u *UnlimitedReader) Read(p []byte) (int, error) { + runtime.Gosched() + return len(p), nil +} + +func TestSendData_VeryLarge(t *testing.T) { + client, server := testClientServer() + defer client.Close() + defer server.Close() + + var n int64 = 1 * 1024 * 1024 * 1024 + var workers int = 16 + + wg := &sync.WaitGroup{} + wg.Add(workers * 2) + + for i := 0; i < workers; i++ { + go func() { + defer wg.Done() + stream, err := server.AcceptStream() + if err != nil { + t.Fatalf("err: %v", err) + } + defer stream.Close() + + buf := make([]byte, 4) + _, err = stream.Read(buf) + if err != nil { + t.Fatalf("err: %v", err) + } + if !bytes.Equal(buf, []byte{0, 1, 2, 3}) { + t.Fatalf("bad header") + } + + recv, err := io.Copy(ioutil.Discard, stream) + if err != nil { + t.Fatalf("err: %v", err) + } + if recv != n { + t.Fatalf("bad: %v", recv) + } + }() + } + for i := 0; i < workers; i++ { + go func() { + defer wg.Done() + stream, err := client.Open() + if err != nil { + t.Fatalf("err: %v", err) + } + defer stream.Close() + + _, err = stream.Write([]byte{0, 1, 2, 3}) + if err != nil { + t.Fatalf("err: %v", err) + } + + unlimited := &UnlimitedReader{} + sent, err := io.Copy(stream, io.LimitReader(unlimited, n)) + if err != nil { + t.Fatalf("err: %v", err) + } + if sent != n { + t.Fatalf("bad: %v", sent) + } + }() + } + + doneCh := make(chan struct{}) + go func() { + wg.Wait() + close(doneCh) + }() + select { + case <-doneCh: + case <-time.After(20 * time.Second): + panic("timeout") + } +} + +func TestBacklogExceeded_Accept(t *testing.T) { + client, server := testClientServer() + defer client.Close() + defer server.Close() + + max := 5 * client.config.AcceptBacklog + go func() { + for i := 0; i < max; i++ { + stream, err := server.Accept() + if err != nil { + t.Fatalf("err: %v", err) + } + defer stream.Close() + } + }() + + // Fill the backlog + for i := 0; i < max; i++ { + stream, err := client.Open() + if err != nil { + t.Fatalf("err: %v", err) + } + defer stream.Close() + + if _, err := stream.Write([]byte("foo")); err != nil { + t.Fatalf("err: %v", err) + } + } +} + +func TestSession_WindowUpdateWriteDuringRead(t *testing.T) { + client, server := testClientServerConfig(testConfNoKeepAlive()) + defer client.Close() + defer server.Close() + + var wg sync.WaitGroup + wg.Add(2) + + // Choose a huge flood size that we know will result in a window update. + flood := int64(client.config.MaxStreamWindowSize) - 1 + + // The server will accept a new stream and then flood data to it. + go func() { + defer wg.Done() + + stream, err := server.AcceptStream() + if err != nil { + t.Fatalf("err: %v", err) + } + defer stream.Close() + + n, err := stream.Write(make([]byte, flood)) + if err != nil { + t.Fatalf("err: %v", err) + } + if int64(n) != flood { + t.Fatalf("short write: %d", n) + } + }() + + // The client will open a stream, block outbound writes, and then + // listen to the flood from the server, which should time out since + // it won't be able to send the window update. + go func() { + defer wg.Done() + + stream, err := client.OpenStream() + if err != nil { + t.Fatalf("err: %v", err) + } + defer stream.Close() + + conn := client.conn.(*pipeConn) + conn.writeBlocker.Lock() + + _, err = stream.Read(make([]byte, flood)) + if err != ErrConnectionWriteTimeout { + t.Fatalf("err: %v", err) + } + }() + + wg.Wait() +} + +func TestSession_PartialReadWindowUpdate(t *testing.T) { + client, server := testClientServerConfig(testConfNoKeepAlive()) + defer client.Close() + defer server.Close() + + var wg sync.WaitGroup + wg.Add(1) + + // Choose a huge flood size that we know will result in a window update. + flood := int64(client.config.MaxStreamWindowSize) + var wr *Stream + + // The server will accept a new stream and then flood data to it. + go func() { + defer wg.Done() + + var err error + wr, err = server.AcceptStream() + if err != nil { + t.Fatalf("err: %v", err) + } + defer wr.Close() + + if wr.sendWindow != client.config.MaxStreamWindowSize { + t.Fatalf("sendWindow: exp=%d, got=%d", client.config.MaxStreamWindowSize, wr.sendWindow) + } + + n, err := wr.Write(make([]byte, flood)) + if err != nil { + t.Fatalf("err: %v", err) + } + if int64(n) != flood { + t.Fatalf("short write: %d", n) + } + if wr.sendWindow != 0 { + t.Fatalf("sendWindow: exp=%d, got=%d", 0, wr.sendWindow) + } + }() + + stream, err := client.OpenStream() + if err != nil { + t.Fatalf("err: %v", err) + } + defer stream.Close() + + wg.Wait() + + _, err = stream.Read(make([]byte, flood/2+1)) + + if exp := uint32(flood/2 + 1); wr.sendWindow != exp { + t.Errorf("sendWindow: exp=%d, got=%d", exp, wr.sendWindow) + } +} + +func TestSession_sendNoWait_Timeout(t *testing.T) { + client, server := testClientServerConfig(testConfNoKeepAlive()) + defer client.Close() + defer server.Close() + + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + + stream, err := server.AcceptStream() + if err != nil { + t.Fatalf("err: %v", err) + } + defer stream.Close() + }() + + // The client will open the stream and then block outbound writes, we'll + // probe sendNoWait once it gets into that state. + go func() { + defer wg.Done() + + stream, err := client.OpenStream() + if err != nil { + t.Fatalf("err: %v", err) + } + defer stream.Close() + + conn := client.conn.(*pipeConn) + conn.writeBlocker.Lock() + + hdr := header(make([]byte, headerSize)) + hdr.encode(typePing, flagACK, 0, 0) + for { + err = client.sendNoWait(hdr) + if err == nil { + continue + } else if err == ErrConnectionWriteTimeout { + break + } else { + t.Fatalf("err: %v", err) + } + } + }() + + wg.Wait() +} + +func TestSession_PingOfDeath(t *testing.T) { + client, server := testClientServerConfig(testConfNoKeepAlive()) + defer client.Close() + defer server.Close() + + var wg sync.WaitGroup + wg.Add(2) + + var doPingOfDeath sync.Mutex + doPingOfDeath.Lock() + + // This is used later to block outbound writes. + conn := server.conn.(*pipeConn) + + // The server will accept a stream, block outbound writes, and then + // flood its send channel so that no more headers can be queued. + go func() { + defer wg.Done() + + stream, err := server.AcceptStream() + if err != nil { + t.Fatalf("err: %v", err) + } + defer stream.Close() + + conn.writeBlocker.Lock() + for { + hdr := header(make([]byte, headerSize)) + hdr.encode(typePing, 0, 0, 0) + err = server.sendNoWait(hdr) + if err == nil { + continue + } else if err == ErrConnectionWriteTimeout { + break + } else { + t.Fatalf("err: %v", err) + } + } + + doPingOfDeath.Unlock() + }() + + // The client will open a stream and then send the server a ping once it + // can no longer write. This makes sure the server doesn't deadlock reads + // while trying to reply to the ping with no ability to write. + go func() { + defer wg.Done() + + stream, err := client.OpenStream() + if err != nil { + t.Fatalf("err: %v", err) + } + defer stream.Close() + + // This ping will never unblock because the ping id will never + // show up in a response. + doPingOfDeath.Lock() + go func() { client.Ping() }() + + // Wait for a while to make sure the previous ping times out, + // then turn writes back on and make sure a ping works again. + time.Sleep(2 * server.config.ConnectionWriteTimeout) + conn.writeBlocker.Unlock() + if _, err = client.Ping(); err != nil { + t.Fatalf("err: %v", err) + } + }() + + wg.Wait() +} + +func TestSession_ConnectionWriteTimeout(t *testing.T) { + client, server := testClientServerConfig(testConfNoKeepAlive()) + defer client.Close() + defer server.Close() + + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + + stream, err := server.AcceptStream() + if err != nil { + t.Fatalf("err: %v", err) + } + defer stream.Close() + }() + + // The client will open the stream and then block outbound writes, we'll + // tee up a write and make sure it eventually times out. + go func() { + defer wg.Done() + + stream, err := client.OpenStream() + if err != nil { + t.Fatalf("err: %v", err) + } + defer stream.Close() + + conn := client.conn.(*pipeConn) + conn.writeBlocker.Lock() + + // Since the write goroutine is blocked then this will return a + // timeout since it can't get feedback about whether the write + // worked. + n, err := stream.Write([]byte("hello")) + if err != ErrConnectionWriteTimeout { + t.Fatalf("err: %v", err) + } + if n != 0 { + t.Fatalf("lied about writes: %d", n) + } + }() + + wg.Wait() +} diff --git a/vendor/github.com/hashicorp/yamux/spec.md b/vendor/github.com/hashicorp/yamux/spec.md new file mode 100644 index 0000000..183d797 --- /dev/null +++ b/vendor/github.com/hashicorp/yamux/spec.md @@ -0,0 +1,140 @@ +# Specification + +We use this document to detail the internal specification of Yamux. +This is used both as a guide for implementing Yamux, but also for +alternative interoperable libraries to be built. + +# Framing + +Yamux uses a streaming connection underneath, but imposes a message +framing so that it can be shared between many logical streams. Each +frame contains a header like: + +* Version (8 bits) +* Type (8 bits) +* Flags (16 bits) +* StreamID (32 bits) +* Length (32 bits) + +This means that each header has a 12 byte overhead. +All fields are encoded in network order (big endian). +Each field is described below: + +## Version Field + +The version field is used for future backward compatibility. At the +current time, the field is always set to 0, to indicate the initial +version. + +## Type Field + +The type field is used to switch the frame message type. The following +message types are supported: + +* 0x0 Data - Used to transmit data. May transmit zero length payloads + depending on the flags. + +* 0x1 Window Update - Used to updated the senders receive window size. + This is used to implement per-session flow control. + +* 0x2 Ping - Used to measure RTT. It can also be used to heart-beat + and do keep-alives over TCP. + +* 0x3 Go Away - Used to close a session. + +## Flag Field + +The flags field is used to provide additional information related +to the message type. The following flags are supported: + +* 0x1 SYN - Signals the start of a new stream. May be sent with a data or + window update message. Also sent with a ping to indicate outbound. + +* 0x2 ACK - Acknowledges the start of a new stream. May be sent with a data + or window update message. Also sent with a ping to indicate response. + +* 0x4 FIN - Performs a half-close of a stream. May be sent with a data + message or window update. + +* 0x8 RST - Reset a stream immediately. May be sent with a data or + window update message. + +## StreamID Field + +The StreamID field is used to identify the logical stream the frame +is addressing. The client side should use odd ID's, and the server even. +This prevents any collisions. Additionally, the 0 ID is reserved to represent +the session. + +Both Ping and Go Away messages should always use the 0 StreamID. + +## Length Field + +The meaning of the length field depends on the message type: + +* Data - provides the length of bytes following the header +* Window update - provides a delta update to the window size +* Ping - Contains an opaque value, echoed back +* Go Away - Contains an error code + +# Message Flow + +There is no explicit connection setup, as Yamux relies on an underlying +transport to be provided. However, there is a distinction between client +and server side of the connection. + +## Opening a stream + +To open a stream, an initial data or window update frame is sent +with a new StreamID. The SYN flag should be set to signal a new stream. + +The receiver must then reply with either a data or window update frame +with the StreamID along with the ACK flag to accept the stream or with +the RST flag to reject the stream. + +Because we are relying on the reliable stream underneath, a connection +can begin sending data once the SYN flag is sent. The corresponding +ACK does not need to be received. This is particularly well suited +for an RPC system where a client wants to open a stream and immediately +fire a request without waiting for the RTT of the ACK. + +This does introduce the possibility of a connection being rejected +after data has been sent already. This is a slight semantic difference +from TCP, where the conection cannot be refused after it is opened. +Clients should be prepared to handle this by checking for an error +that indicates a RST was received. + +## Closing a stream + +To close a stream, either side sends a data or window update frame +along with the FIN flag. This does a half-close indicating the sender +will send no further data. + +Once both sides have closed the connection, the stream is closed. + +Alternatively, if an error occurs, the RST flag can be used to +hard close a stream immediately. + +## Flow Control + +When Yamux is initially starts each stream with a 256KB window size. +There is no window size for the session. + +To prevent the streams from stalling, window update frames should be +sent regularly. Yamux can be configured to provide a larger limit for +windows sizes. Both sides assume the initial 256KB window, but can +immediately send a window update as part of the SYN/ACK indicating a +larger window. + +Both sides should track the number of bytes sent in Data frames +only, as only they are tracked as part of the window size. + +## Session termination + +When a session is being terminated, the Go Away message should +be sent. The Length should be set to one of the following to +provide an error code: + +* 0x0 Normal termination +* 0x1 Protocol error +* 0x2 Internal error diff --git a/vendor/github.com/hashicorp/yamux/stream.go b/vendor/github.com/hashicorp/yamux/stream.go new file mode 100644 index 0000000..aa23919 --- /dev/null +++ b/vendor/github.com/hashicorp/yamux/stream.go @@ -0,0 +1,470 @@ +package yamux + +import ( + "bytes" + "io" + "sync" + "sync/atomic" + "time" +) + +type streamState int + +const ( + streamInit streamState = iota + streamSYNSent + streamSYNReceived + streamEstablished + streamLocalClose + streamRemoteClose + streamClosed + streamReset +) + +// Stream is used to represent a logical stream +// within a session. +type Stream struct { + recvWindow uint32 + sendWindow uint32 + + id uint32 + session *Session + + state streamState + stateLock sync.Mutex + + recvBuf *bytes.Buffer + recvLock sync.Mutex + + controlHdr header + controlErr chan error + controlHdrLock sync.Mutex + + sendHdr header + sendErr chan error + sendLock sync.Mutex + + recvNotifyCh chan struct{} + sendNotifyCh chan struct{} + + readDeadline atomic.Value // time.Time + writeDeadline atomic.Value // time.Time +} + +// newStream is used to construct a new stream within +// a given session for an ID +func newStream(session *Session, id uint32, state streamState) *Stream { + s := &Stream{ + id: id, + session: session, + state: state, + controlHdr: header(make([]byte, headerSize)), + controlErr: make(chan error, 1), + sendHdr: header(make([]byte, headerSize)), + sendErr: make(chan error, 1), + recvWindow: initialStreamWindow, + sendWindow: initialStreamWindow, + recvNotifyCh: make(chan struct{}, 1), + sendNotifyCh: make(chan struct{}, 1), + } + s.readDeadline.Store(time.Time{}) + s.writeDeadline.Store(time.Time{}) + return s +} + +// Session returns the associated stream session +func (s *Stream) Session() *Session { + return s.session +} + +// StreamID returns the ID of this stream +func (s *Stream) StreamID() uint32 { + return s.id +} + +// Read is used to read from the stream +func (s *Stream) Read(b []byte) (n int, err error) { + defer asyncNotify(s.recvNotifyCh) +START: + s.stateLock.Lock() + switch s.state { + case streamLocalClose: + fallthrough + case streamRemoteClose: + fallthrough + case streamClosed: + s.recvLock.Lock() + if s.recvBuf == nil || s.recvBuf.Len() == 0 { + s.recvLock.Unlock() + s.stateLock.Unlock() + return 0, io.EOF + } + s.recvLock.Unlock() + case streamReset: + s.stateLock.Unlock() + return 0, ErrConnectionReset + } + s.stateLock.Unlock() + + // If there is no data available, block + s.recvLock.Lock() + if s.recvBuf == nil || s.recvBuf.Len() == 0 { + s.recvLock.Unlock() + goto WAIT + } + + // Read any bytes + n, _ = s.recvBuf.Read(b) + s.recvLock.Unlock() + + // Send a window update potentially + err = s.sendWindowUpdate() + return n, err + +WAIT: + var timeout <-chan time.Time + var timer *time.Timer + readDeadline := s.readDeadline.Load().(time.Time) + if !readDeadline.IsZero() { + delay := readDeadline.Sub(time.Now()) + timer = time.NewTimer(delay) + timeout = timer.C + } + select { + case <-s.recvNotifyCh: + if timer != nil { + timer.Stop() + } + goto START + case <-timeout: + return 0, ErrTimeout + } +} + +// Write is used to write to the stream +func (s *Stream) Write(b []byte) (n int, err error) { + s.sendLock.Lock() + defer s.sendLock.Unlock() + total := 0 + for total < len(b) { + n, err := s.write(b[total:]) + total += n + if err != nil { + return total, err + } + } + return total, nil +} + +// write is used to write to the stream, may return on +// a short write. +func (s *Stream) write(b []byte) (n int, err error) { + var flags uint16 + var max uint32 + var body io.Reader +START: + s.stateLock.Lock() + switch s.state { + case streamLocalClose: + fallthrough + case streamClosed: + s.stateLock.Unlock() + return 0, ErrStreamClosed + case streamReset: + s.stateLock.Unlock() + return 0, ErrConnectionReset + } + s.stateLock.Unlock() + + // If there is no data available, block + window := atomic.LoadUint32(&s.sendWindow) + if window == 0 { + goto WAIT + } + + // Determine the flags if any + flags = s.sendFlags() + + // Send up to our send window + max = min(window, uint32(len(b))) + body = bytes.NewReader(b[:max]) + + // Send the header + s.sendHdr.encode(typeData, flags, s.id, max) + if err = s.session.waitForSendErr(s.sendHdr, body, s.sendErr); err != nil { + return 0, err + } + + // Reduce our send window + atomic.AddUint32(&s.sendWindow, ^uint32(max-1)) + + // Unlock + return int(max), err + +WAIT: + var timeout <-chan time.Time + writeDeadline := s.writeDeadline.Load().(time.Time) + if !writeDeadline.IsZero() { + delay := writeDeadline.Sub(time.Now()) + timeout = time.After(delay) + } + select { + case <-s.sendNotifyCh: + goto START + case <-timeout: + return 0, ErrTimeout + } + return 0, nil +} + +// sendFlags determines any flags that are appropriate +// based on the current stream state +func (s *Stream) sendFlags() uint16 { + s.stateLock.Lock() + defer s.stateLock.Unlock() + var flags uint16 + switch s.state { + case streamInit: + flags |= flagSYN + s.state = streamSYNSent + case streamSYNReceived: + flags |= flagACK + s.state = streamEstablished + } + return flags +} + +// sendWindowUpdate potentially sends a window update enabling +// further writes to take place. Must be invoked with the lock. +func (s *Stream) sendWindowUpdate() error { + s.controlHdrLock.Lock() + defer s.controlHdrLock.Unlock() + + // Determine the delta update + max := s.session.config.MaxStreamWindowSize + var bufLen uint32 + s.recvLock.Lock() + if s.recvBuf != nil { + bufLen = uint32(s.recvBuf.Len()) + } + delta := (max - bufLen) - s.recvWindow + + // Determine the flags if any + flags := s.sendFlags() + + // Check if we can omit the update + if delta < (max/2) && flags == 0 { + s.recvLock.Unlock() + return nil + } + + // Update our window + s.recvWindow += delta + s.recvLock.Unlock() + + // Send the header + s.controlHdr.encode(typeWindowUpdate, flags, s.id, delta) + if err := s.session.waitForSendErr(s.controlHdr, nil, s.controlErr); err != nil { + return err + } + return nil +} + +// sendClose is used to send a FIN +func (s *Stream) sendClose() error { + s.controlHdrLock.Lock() + defer s.controlHdrLock.Unlock() + + flags := s.sendFlags() + flags |= flagFIN + s.controlHdr.encode(typeWindowUpdate, flags, s.id, 0) + if err := s.session.waitForSendErr(s.controlHdr, nil, s.controlErr); err != nil { + return err + } + return nil +} + +// Close is used to close the stream +func (s *Stream) Close() error { + closeStream := false + s.stateLock.Lock() + switch s.state { + // Opened means we need to signal a close + case streamSYNSent: + fallthrough + case streamSYNReceived: + fallthrough + case streamEstablished: + s.state = streamLocalClose + goto SEND_CLOSE + + case streamLocalClose: + case streamRemoteClose: + s.state = streamClosed + closeStream = true + goto SEND_CLOSE + + case streamClosed: + case streamReset: + default: + panic("unhandled state") + } + s.stateLock.Unlock() + return nil +SEND_CLOSE: + s.stateLock.Unlock() + s.sendClose() + s.notifyWaiting() + if closeStream { + s.session.closeStream(s.id) + } + return nil +} + +// forceClose is used for when the session is exiting +func (s *Stream) forceClose() { + s.stateLock.Lock() + s.state = streamClosed + s.stateLock.Unlock() + s.notifyWaiting() +} + +// processFlags is used to update the state of the stream +// based on set flags, if any. Lock must be held +func (s *Stream) processFlags(flags uint16) error { + // Close the stream without holding the state lock + closeStream := false + defer func() { + if closeStream { + s.session.closeStream(s.id) + } + }() + + s.stateLock.Lock() + defer s.stateLock.Unlock() + if flags&flagACK == flagACK { + if s.state == streamSYNSent { + s.state = streamEstablished + } + s.session.establishStream(s.id) + } + if flags&flagFIN == flagFIN { + switch s.state { + case streamSYNSent: + fallthrough + case streamSYNReceived: + fallthrough + case streamEstablished: + s.state = streamRemoteClose + s.notifyWaiting() + case streamLocalClose: + s.state = streamClosed + closeStream = true + s.notifyWaiting() + default: + s.session.logger.Printf("[ERR] yamux: unexpected FIN flag in state %d", s.state) + return ErrUnexpectedFlag + } + } + if flags&flagRST == flagRST { + s.state = streamReset + closeStream = true + s.notifyWaiting() + } + return nil +} + +// notifyWaiting notifies all the waiting channels +func (s *Stream) notifyWaiting() { + asyncNotify(s.recvNotifyCh) + asyncNotify(s.sendNotifyCh) +} + +// incrSendWindow updates the size of our send window +func (s *Stream) incrSendWindow(hdr header, flags uint16) error { + if err := s.processFlags(flags); err != nil { + return err + } + + // Increase window, unblock a sender + atomic.AddUint32(&s.sendWindow, hdr.Length()) + asyncNotify(s.sendNotifyCh) + return nil +} + +// readData is used to handle a data frame +func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error { + if err := s.processFlags(flags); err != nil { + return err + } + + // Check that our recv window is not exceeded + length := hdr.Length() + if length == 0 { + return nil + } + + // Wrap in a limited reader + conn = &io.LimitedReader{R: conn, N: int64(length)} + + // Copy into buffer + s.recvLock.Lock() + + if length > s.recvWindow { + s.session.logger.Printf("[ERR] yamux: receive window exceeded (stream: %d, remain: %d, recv: %d)", s.id, s.recvWindow, length) + return ErrRecvWindowExceeded + } + + if s.recvBuf == nil { + // Allocate the receive buffer just-in-time to fit the full data frame. + // This way we can read in the whole packet without further allocations. + s.recvBuf = bytes.NewBuffer(make([]byte, 0, length)) + } + if _, err := io.Copy(s.recvBuf, conn); err != nil { + s.session.logger.Printf("[ERR] yamux: Failed to read stream data: %v", err) + s.recvLock.Unlock() + return err + } + + // Decrement the receive window + s.recvWindow -= length + s.recvLock.Unlock() + + // Unblock any readers + asyncNotify(s.recvNotifyCh) + return nil +} + +// SetDeadline sets the read and write deadlines +func (s *Stream) SetDeadline(t time.Time) error { + if err := s.SetReadDeadline(t); err != nil { + return err + } + if err := s.SetWriteDeadline(t); err != nil { + return err + } + return nil +} + +// SetReadDeadline sets the deadline for future Read calls. +func (s *Stream) SetReadDeadline(t time.Time) error { + s.readDeadline.Store(t) + return nil +} + +// SetWriteDeadline sets the deadline for future Write calls +func (s *Stream) SetWriteDeadline(t time.Time) error { + s.writeDeadline.Store(t) + return nil +} + +// Shrink is used to compact the amount of buffers utilized +// This is useful when using Yamux in a connection pool to reduce +// the idle memory utilization. +func (s *Stream) Shrink() { + s.recvLock.Lock() + if s.recvBuf != nil && s.recvBuf.Len() == 0 { + s.recvBuf = nil + } + s.recvLock.Unlock() +} diff --git a/vendor/github.com/hashicorp/yamux/util.go b/vendor/github.com/hashicorp/yamux/util.go new file mode 100644 index 0000000..8a73e92 --- /dev/null +++ b/vendor/github.com/hashicorp/yamux/util.go @@ -0,0 +1,43 @@ +package yamux + +import ( + "sync" + "time" +) + +var ( + timerPool = &sync.Pool{ + New: func() interface{} { + timer := time.NewTimer(time.Hour * 1e6) + timer.Stop() + return timer + }, + } +) + +// asyncSendErr is used to try an async send of an error +func asyncSendErr(ch chan error, err error) { + if ch == nil { + return + } + select { + case ch <- err: + default: + } +} + +// asyncNotify is used to signal a waiting goroutine +func asyncNotify(ch chan struct{}) { + select { + case ch <- struct{}{}: + default: + } +} + +// min computes the minimum of two values +func min(a, b uint32) uint32 { + if a < b { + return a + } + return b +} diff --git a/vendor/github.com/hashicorp/yamux/util_test.go b/vendor/github.com/hashicorp/yamux/util_test.go new file mode 100644 index 0000000..dd14623 --- /dev/null +++ b/vendor/github.com/hashicorp/yamux/util_test.go @@ -0,0 +1,50 @@ +package yamux + +import ( + "testing" +) + +func TestAsyncSendErr(t *testing.T) { + ch := make(chan error) + asyncSendErr(ch, ErrTimeout) + select { + case <-ch: + t.Fatalf("should not get") + default: + } + + ch = make(chan error, 1) + asyncSendErr(ch, ErrTimeout) + select { + case <-ch: + default: + t.Fatalf("should get") + } +} + +func TestAsyncNotify(t *testing.T) { + ch := make(chan struct{}) + asyncNotify(ch) + select { + case <-ch: + t.Fatalf("should not get") + default: + } + + ch = make(chan struct{}, 1) + asyncNotify(ch) + select { + case <-ch: + default: + t.Fatalf("should get") + } +} + +func TestMin(t *testing.T) { + if min(1, 2) != 1 { + t.Fatalf("bad") + } + if min(2, 1) != 1 { + t.Fatalf("bad") + } +} diff --git a/vendor/github.com/spf13/cobra/cobra/cmd/init.go b/vendor/github.com/spf13/cobra/cobra/cmd/init.go index d65e6c8..2441370 100644 --- a/vendor/github.com/spf13/cobra/cobra/cmd/init.go +++ b/vendor/github.com/spf13/cobra/cobra/cmd/init.go @@ -65,7 +65,7 @@ Init will not use an existing directory with contents.`, initializeProject(project) fmt.Fprintln(cmd.OutOrStdout(), `Your Cobra application is ready at -`+project.AbsPath()+` +`+project.AbsPath()+`. Give it a try by going there and running `+"`go run main.go`."+` Add commands to it by running `+"`cobra add [cmdname]`.") diff --git a/vendor/github.com/spf13/cobra/command.go b/vendor/github.com/spf13/cobra/command.go index 34d1bf3..15b8112 100644 --- a/vendor/github.com/spf13/cobra/command.go +++ b/vendor/github.com/spf13/cobra/command.go @@ -27,9 +27,6 @@ import ( flag "github.com/spf13/pflag" ) -// FParseErrWhitelist configures Flag parse errors to be ignored -type FParseErrWhitelist flag.ParseErrorsWhitelist - // Command is just that, a command for your application. // E.g. 'go run ...' - 'run' is the command. Cobra requires // you to define the usage and description as part of your command @@ -140,9 +137,6 @@ type Command struct { // TraverseChildren parses flags on all parents before executing child command. TraverseChildren bool - //FParseErrWhitelist flag parse errors to be ignored - FParseErrWhitelist FParseErrWhitelist - // commands is the list of commands supported by this program. commands []*Command // parent is a parent command for this command. @@ -1469,10 +1463,6 @@ func (c *Command) ParseFlags(args []string) error { } beforeErrorBufLen := c.flagErrorBuf.Len() c.mergePersistentFlags() - - //do it here after merging all flags and just before parse - c.Flags().ParseErrorsWhitelist = flag.ParseErrorsWhitelist(c.FParseErrWhitelist) - err := c.Flags().Parse(args) // Print warnings if they occurred (e.g. deprecated flag messages). if c.flagErrorBuf.Len()-beforeErrorBufLen > 0 && err == nil { diff --git a/vendor/github.com/spf13/cobra/command_test.go b/vendor/github.com/spf13/cobra/command_test.go index ccee031..d874a9a 100644 --- a/vendor/github.com/spf13/cobra/command_test.go +++ b/vendor/github.com/spf13/cobra/command_test.go @@ -1626,108 +1626,3 @@ func TestCalledAs(t *testing.T) { t.Run(name, tc.test) } } - -func TestFParseErrWhitelistBackwardCompatibility(t *testing.T) { - c := &Command{Use: "c", Run: emptyRun} - c.Flags().BoolP("boola", "a", false, "a boolean flag") - - output, err := executeCommand(c, "c", "-a", "--unknown", "flag") - if err == nil { - t.Error("expected unknown flag error") - } - checkStringContains(t, output, "unknown flag: --unknown") -} - -func TestFParseErrWhitelistSameCommand(t *testing.T) { - c := &Command{ - Use: "c", - Run: emptyRun, - FParseErrWhitelist: FParseErrWhitelist{ - UnknownFlags: true, - }, - } - c.Flags().BoolP("boola", "a", false, "a boolean flag") - - _, err := executeCommand(c, "c", "-a", "--unknown", "flag") - if err != nil { - t.Error("unexpected error: ", err) - } -} - -func TestFParseErrWhitelistParentCommand(t *testing.T) { - root := &Command{ - Use: "root", - Run: emptyRun, - FParseErrWhitelist: FParseErrWhitelist{ - UnknownFlags: true, - }, - } - - c := &Command{ - Use: "child", - Run: emptyRun, - } - c.Flags().BoolP("boola", "a", false, "a boolean flag") - - root.AddCommand(c) - - output, err := executeCommand(root, "child", "-a", "--unknown", "flag") - if err == nil { - t.Error("expected unknown flag error") - } - checkStringContains(t, output, "unknown flag: --unknown") -} - -func TestFParseErrWhitelistChildCommand(t *testing.T) { - root := &Command{ - Use: "root", - Run: emptyRun, - } - - c := &Command{ - Use: "child", - Run: emptyRun, - FParseErrWhitelist: FParseErrWhitelist{ - UnknownFlags: true, - }, - } - c.Flags().BoolP("boola", "a", false, "a boolean flag") - - root.AddCommand(c) - - _, err := executeCommand(root, "child", "-a", "--unknown", "flag") - if err != nil { - t.Error("unexpected error: ", err.Error()) - } -} - -func TestFParseErrWhitelistSiblingCommand(t *testing.T) { - root := &Command{ - Use: "root", - Run: emptyRun, - } - - c := &Command{ - Use: "child", - Run: emptyRun, - FParseErrWhitelist: FParseErrWhitelist{ - UnknownFlags: true, - }, - } - c.Flags().BoolP("boola", "a", false, "a boolean flag") - - s := &Command{ - Use: "sibling", - Run: emptyRun, - } - s.Flags().BoolP("boolb", "b", false, "a boolean flag") - - root.AddCommand(c) - root.AddCommand(s) - - output, err := executeCommand(root, "sibling", "-b", "--unknown", "flag") - if err == nil { - t.Error("expected unknown flag error") - } - checkStringContains(t, output, "unknown flag: --unknown") -} diff --git a/vendor/github.com/xtaci/smux/.travis.yml b/vendor/github.com/xtaci/smux/.travis.yml deleted file mode 100644 index 1ad083c..0000000 --- a/vendor/github.com/xtaci/smux/.travis.yml +++ /dev/null @@ -1,15 +0,0 @@ -language: go -go: - - tip - -before_install: - - go get -t -v ./... - -install: - - go get github.com/xtaci/smux - -script: - - go test -coverprofile=coverage.txt -covermode=atomic -bench . - -after_success: - - bash <(curl -s https://codecov.io/bash) diff --git a/vendor/github.com/xtaci/smux/LICENSE b/vendor/github.com/xtaci/smux/LICENSE deleted file mode 100644 index eed41ac..0000000 --- a/vendor/github.com/xtaci/smux/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2016-2017 Daniel Fu - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/vendor/github.com/xtaci/smux/README.md b/vendor/github.com/xtaci/smux/README.md deleted file mode 100644 index c01c346..0000000 --- a/vendor/github.com/xtaci/smux/README.md +++ /dev/null @@ -1,99 +0,0 @@ -smux - -[![GoDoc][1]][2] [![MIT licensed][3]][4] [![Build Status][5]][6] [![Go Report Card][7]][8] [![Coverage Statusd][9]][10] - -smux - -[1]: https://godoc.org/github.com/xtaci/smux?status.svg -[2]: https://godoc.org/github.com/xtaci/smux -[3]: https://img.shields.io/badge/license-MIT-blue.svg -[4]: LICENSE -[5]: https://travis-ci.org/xtaci/smux.svg?branch=master -[6]: https://travis-ci.org/xtaci/smux -[7]: https://goreportcard.com/badge/github.com/xtaci/smux -[8]: https://goreportcard.com/report/github.com/xtaci/smux -[9]: https://codecov.io/gh/xtaci/smux/branch/master/graph/badge.svg -[10]: https://codecov.io/gh/xtaci/smux - -## Introduction - -Smux ( **S**imple **MU**ltiple**X**ing) is a multiplexing library for Golang. It relies on an underlying connection to provide reliability and ordering, such as TCP or [KCP](https://github.com/xtaci/kcp-go), and provides stream-oriented multiplexing. The original intention of this library is to power the connection management for [kcp-go](https://github.com/xtaci/kcp-go). - -## Features - -1. Tiny, less than 600 LOC. -2. ***Token bucket*** controlled receiving, which provides smoother bandwidth graph(see picture below). -3. Session-wide receive buffer, shared among streams, tightly controlled overall memory usage. -4. Minimized header(8Bytes), maximized payload. -5. Well-tested on millions of devices in [kcptun](https://github.com/xtaci/kcptun). - -![smooth bandwidth curve](curve.jpg) - -## Documentation - -For complete documentation, see the associated [Godoc](https://godoc.org/github.com/xtaci/smux). - -## Specification - -``` -VERSION(1B) | CMD(1B) | LENGTH(2B) | STREAMID(4B) | DATA(LENGTH) -``` - -## Usage - -The API of smux are mostly taken from [yamux](https://github.com/hashicorp/yamux) - -```go - -func client() { - // Get a TCP connection - conn, err := net.Dial(...) - if err != nil { - panic(err) - } - - // Setup client side of smux - session, err := smux.Client(conn, nil) - if err != nil { - panic(err) - } - - // Open a new stream - stream, err := session.OpenStream() - if err != nil { - panic(err) - } - - // Stream implements io.ReadWriteCloser - stream.Write([]byte("ping")) -} - -func server() { - // Accept a TCP connection - conn, err := listener.Accept() - if err != nil { - panic(err) - } - - // Setup server side of smux - session, err := smux.Server(conn, nil) - if err != nil { - panic(err) - } - - // Accept a stream - stream, err := session.AcceptStream() - if err != nil { - panic(err) - } - - // Listen for a message - buf := make([]byte, 4) - stream.Read(buf) -} - -``` - -## Status - -Stable diff --git a/vendor/github.com/xtaci/smux/curve.jpg b/vendor/github.com/xtaci/smux/curve.jpg deleted file mode 100644 index 3fc4863..0000000 Binary files a/vendor/github.com/xtaci/smux/curve.jpg and /dev/null differ diff --git a/vendor/github.com/xtaci/smux/frame.go b/vendor/github.com/xtaci/smux/frame.go deleted file mode 100644 index 36062d7..0000000 --- a/vendor/github.com/xtaci/smux/frame.go +++ /dev/null @@ -1,60 +0,0 @@ -package smux - -import ( - "encoding/binary" - "fmt" -) - -const ( - version = 1 -) - -const ( // cmds - cmdSYN byte = iota // stream open - cmdFIN // stream close, a.k.a EOF mark - cmdPSH // data push - cmdNOP // no operation -) - -const ( - sizeOfVer = 1 - sizeOfCmd = 1 - sizeOfLength = 2 - sizeOfSid = 4 - headerSize = sizeOfVer + sizeOfCmd + sizeOfSid + sizeOfLength -) - -// Frame defines a packet from or to be multiplexed into a single connection -type Frame struct { - ver byte - cmd byte - sid uint32 - data []byte -} - -func newFrame(cmd byte, sid uint32) Frame { - return Frame{ver: version, cmd: cmd, sid: sid} -} - -type rawHeader []byte - -func (h rawHeader) Version() byte { - return h[0] -} - -func (h rawHeader) Cmd() byte { - return h[1] -} - -func (h rawHeader) Length() uint16 { - return binary.LittleEndian.Uint16(h[2:]) -} - -func (h rawHeader) StreamID() uint32 { - return binary.LittleEndian.Uint32(h[4:]) -} - -func (h rawHeader) String() string { - return fmt.Sprintf("Version:%d Cmd:%d StreamID:%d Length:%d", - h.Version(), h.Cmd(), h.StreamID(), h.Length()) -} diff --git a/vendor/github.com/xtaci/smux/mux.go b/vendor/github.com/xtaci/smux/mux.go deleted file mode 100644 index afcf58b..0000000 --- a/vendor/github.com/xtaci/smux/mux.go +++ /dev/null @@ -1,80 +0,0 @@ -package smux - -import ( - "fmt" - "io" - "time" - - "github.com/pkg/errors" -) - -// Config is used to tune the Smux session -type Config struct { - // KeepAliveInterval is how often to send a NOP command to the remote - KeepAliveInterval time.Duration - - // KeepAliveTimeout is how long the session - // will be closed if no data has arrived - KeepAliveTimeout time.Duration - - // MaxFrameSize is used to control the maximum - // frame size to sent to the remote - MaxFrameSize int - - // MaxReceiveBuffer is used to control the maximum - // number of data in the buffer pool - MaxReceiveBuffer int -} - -// DefaultConfig is used to return a default configuration -func DefaultConfig() *Config { - return &Config{ - KeepAliveInterval: 10 * time.Second, - KeepAliveTimeout: 30 * time.Second, - MaxFrameSize: 4096, - MaxReceiveBuffer: 4194304, - } -} - -// VerifyConfig is used to verify the sanity of configuration -func VerifyConfig(config *Config) error { - if config.KeepAliveInterval == 0 { - return errors.New("keep-alive interval must be positive") - } - if config.KeepAliveTimeout < config.KeepAliveInterval { - return fmt.Errorf("keep-alive timeout must be larger than keep-alive interval") - } - if config.MaxFrameSize <= 0 { - return errors.New("max frame size must be positive") - } - if config.MaxFrameSize > 65535 { - return errors.New("max frame size must not be larger than 65535") - } - if config.MaxReceiveBuffer <= 0 { - return errors.New("max receive buffer must be positive") - } - return nil -} - -// Server is used to initialize a new server-side connection. -func Server(conn io.ReadWriteCloser, config *Config) (*Session, error) { - if config == nil { - config = DefaultConfig() - } - if err := VerifyConfig(config); err != nil { - return nil, err - } - return newSession(config, conn, false), nil -} - -// Client is used to initialize a new client-side connection. -func Client(conn io.ReadWriteCloser, config *Config) (*Session, error) { - if config == nil { - config = DefaultConfig() - } - - if err := VerifyConfig(config); err != nil { - return nil, err - } - return newSession(config, conn, true), nil -} diff --git a/vendor/github.com/xtaci/smux/mux.jpg b/vendor/github.com/xtaci/smux/mux.jpg deleted file mode 100644 index dde2e11..0000000 Binary files a/vendor/github.com/xtaci/smux/mux.jpg and /dev/null differ diff --git a/vendor/github.com/xtaci/smux/mux_test.go b/vendor/github.com/xtaci/smux/mux_test.go deleted file mode 100644 index 638e67c..0000000 --- a/vendor/github.com/xtaci/smux/mux_test.go +++ /dev/null @@ -1,69 +0,0 @@ -package smux - -import ( - "bytes" - "testing" -) - -type buffer struct { - bytes.Buffer -} - -func (b *buffer) Close() error { - b.Buffer.Reset() - return nil -} - -func TestConfig(t *testing.T) { - VerifyConfig(DefaultConfig()) - - config := DefaultConfig() - config.KeepAliveInterval = 0 - err := VerifyConfig(config) - t.Log(err) - if err == nil { - t.Fatal(err) - } - - config = DefaultConfig() - config.KeepAliveInterval = 10 - config.KeepAliveTimeout = 5 - err = VerifyConfig(config) - t.Log(err) - if err == nil { - t.Fatal(err) - } - - config = DefaultConfig() - config.MaxFrameSize = 0 - err = VerifyConfig(config) - t.Log(err) - if err == nil { - t.Fatal(err) - } - - config = DefaultConfig() - config.MaxFrameSize = 65536 - err = VerifyConfig(config) - t.Log(err) - if err == nil { - t.Fatal(err) - } - - config = DefaultConfig() - config.MaxReceiveBuffer = 0 - err = VerifyConfig(config) - t.Log(err) - if err == nil { - t.Fatal(err) - } - - var bts buffer - if _, err := Server(&bts, config); err == nil { - t.Fatal("server started with wrong config") - } - - if _, err := Client(&bts, config); err == nil { - t.Fatal("client started with wrong config") - } -} diff --git a/vendor/github.com/xtaci/smux/session.go b/vendor/github.com/xtaci/smux/session.go deleted file mode 100644 index 12fc4cb..0000000 --- a/vendor/github.com/xtaci/smux/session.go +++ /dev/null @@ -1,353 +0,0 @@ -package smux - -import ( - "encoding/binary" - "io" - "sync" - "sync/atomic" - "time" - - "github.com/pkg/errors" -) - -const ( - defaultAcceptBacklog = 1024 -) - -const ( - errBrokenPipe = "broken pipe" - errInvalidProtocol = "invalid protocol version" - errGoAway = "stream id overflows, should start a new connection" -) - -type writeRequest struct { - frame Frame - result chan writeResult -} - -type writeResult struct { - n int - err error -} - -// Session defines a multiplexed connection for streams -type Session struct { - conn io.ReadWriteCloser - - config *Config - nextStreamID uint32 // next stream identifier - nextStreamIDLock sync.Mutex - - bucket int32 // token bucket - bucketNotify chan struct{} // used for waiting for tokens - - streams map[uint32]*Stream // all streams in this session - streamLock sync.Mutex // locks streams - - die chan struct{} // flag session has died - dieLock sync.Mutex - chAccepts chan *Stream - - dataReady int32 // flag data has arrived - - goAway int32 // flag id exhausted - - deadline atomic.Value - - writes chan writeRequest -} - -func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session { - s := new(Session) - s.die = make(chan struct{}) - s.conn = conn - s.config = config - s.streams = make(map[uint32]*Stream) - s.chAccepts = make(chan *Stream, defaultAcceptBacklog) - s.bucket = int32(config.MaxReceiveBuffer) - s.bucketNotify = make(chan struct{}, 1) - s.writes = make(chan writeRequest) - - if client { - s.nextStreamID = 1 - } else { - s.nextStreamID = 0 - } - go s.recvLoop() - go s.sendLoop() - go s.keepalive() - return s -} - -// OpenStream is used to create a new stream -func (s *Session) OpenStream() (*Stream, error) { - if s.IsClosed() { - return nil, errors.New(errBrokenPipe) - } - - // generate stream id - s.nextStreamIDLock.Lock() - if s.goAway > 0 { - s.nextStreamIDLock.Unlock() - return nil, errors.New(errGoAway) - } - - s.nextStreamID += 2 - sid := s.nextStreamID - if sid == sid%2 { // stream-id overflows - s.goAway = 1 - s.nextStreamIDLock.Unlock() - return nil, errors.New(errGoAway) - } - s.nextStreamIDLock.Unlock() - - stream := newStream(sid, s.config.MaxFrameSize, s) - - if _, err := s.writeFrame(newFrame(cmdSYN, sid)); err != nil { - return nil, errors.Wrap(err, "writeFrame") - } - - s.streamLock.Lock() - s.streams[sid] = stream - s.streamLock.Unlock() - return stream, nil -} - -// AcceptStream is used to block until the next available stream -// is ready to be accepted. -func (s *Session) AcceptStream() (*Stream, error) { - var deadline <-chan time.Time - if d, ok := s.deadline.Load().(time.Time); ok && !d.IsZero() { - timer := time.NewTimer(d.Sub(time.Now())) - defer timer.Stop() - deadline = timer.C - } - select { - case stream := <-s.chAccepts: - return stream, nil - case <-deadline: - return nil, errTimeout - case <-s.die: - return nil, errors.New(errBrokenPipe) - } -} - -// Close is used to close the session and all streams. -func (s *Session) Close() (err error) { - s.dieLock.Lock() - - select { - case <-s.die: - s.dieLock.Unlock() - return errors.New(errBrokenPipe) - default: - close(s.die) - s.dieLock.Unlock() - s.streamLock.Lock() - for k := range s.streams { - s.streams[k].sessionClose() - } - s.streamLock.Unlock() - s.notifyBucket() - return s.conn.Close() - } -} - -// notifyBucket notifies recvLoop that bucket is available -func (s *Session) notifyBucket() { - select { - case s.bucketNotify <- struct{}{}: - default: - } -} - -// IsClosed does a safe check to see if we have shutdown -func (s *Session) IsClosed() bool { - select { - case <-s.die: - return true - default: - return false - } -} - -// NumStreams returns the number of currently open streams -func (s *Session) NumStreams() int { - if s.IsClosed() { - return 0 - } - s.streamLock.Lock() - defer s.streamLock.Unlock() - return len(s.streams) -} - -// SetDeadline sets a deadline used by Accept* calls. -// A zero time value disables the deadline. -func (s *Session) SetDeadline(t time.Time) error { - s.deadline.Store(t) - return nil -} - -// notify the session that a stream has closed -func (s *Session) streamClosed(sid uint32) { - s.streamLock.Lock() - if n := s.streams[sid].recycleTokens(); n > 0 { // return remaining tokens to the bucket - if atomic.AddInt32(&s.bucket, int32(n)) > 0 { - s.notifyBucket() - } - } - delete(s.streams, sid) - s.streamLock.Unlock() -} - -// returnTokens is called by stream to return token after read -func (s *Session) returnTokens(n int) { - if atomic.AddInt32(&s.bucket, int32(n)) > 0 { - s.notifyBucket() - } -} - -// session read a frame from underlying connection -// it's data is pointed to the input buffer -func (s *Session) readFrame(buffer []byte) (f Frame, err error) { - if _, err := io.ReadFull(s.conn, buffer[:headerSize]); err != nil { - return f, errors.Wrap(err, "readFrame") - } - - dec := rawHeader(buffer) - if dec.Version() != version { - return f, errors.New(errInvalidProtocol) - } - - f.ver = dec.Version() - f.cmd = dec.Cmd() - f.sid = dec.StreamID() - if length := dec.Length(); length > 0 { - if _, err := io.ReadFull(s.conn, buffer[headerSize:headerSize+length]); err != nil { - return f, errors.Wrap(err, "readFrame") - } - f.data = buffer[headerSize : headerSize+length] - } - return f, nil -} - -// recvLoop keeps on reading from underlying connection if tokens are available -func (s *Session) recvLoop() { - buffer := make([]byte, (1<<16)+headerSize) - for { - for atomic.LoadInt32(&s.bucket) <= 0 && !s.IsClosed() { - <-s.bucketNotify - } - - if f, err := s.readFrame(buffer); err == nil { - atomic.StoreInt32(&s.dataReady, 1) - - switch f.cmd { - case cmdNOP: - case cmdSYN: - s.streamLock.Lock() - if _, ok := s.streams[f.sid]; !ok { - stream := newStream(f.sid, s.config.MaxFrameSize, s) - s.streams[f.sid] = stream - select { - case s.chAccepts <- stream: - case <-s.die: - } - } - s.streamLock.Unlock() - case cmdFIN: - s.streamLock.Lock() - if stream, ok := s.streams[f.sid]; ok { - stream.markRST() - stream.notifyReadEvent() - } - s.streamLock.Unlock() - case cmdPSH: - s.streamLock.Lock() - if stream, ok := s.streams[f.sid]; ok { - atomic.AddInt32(&s.bucket, -int32(len(f.data))) - stream.pushBytes(f.data) - stream.notifyReadEvent() - } - s.streamLock.Unlock() - default: - s.Close() - return - } - } else { - s.Close() - return - } - } -} - -func (s *Session) keepalive() { - tickerPing := time.NewTicker(s.config.KeepAliveInterval) - tickerTimeout := time.NewTicker(s.config.KeepAliveTimeout) - defer tickerPing.Stop() - defer tickerTimeout.Stop() - for { - select { - case <-tickerPing.C: - s.writeFrame(newFrame(cmdNOP, 0)) - s.notifyBucket() // force a signal to the recvLoop - case <-tickerTimeout.C: - if !atomic.CompareAndSwapInt32(&s.dataReady, 1, 0) { - s.Close() - return - } - case <-s.die: - return - } - } -} - -func (s *Session) sendLoop() { - buf := make([]byte, (1<<16)+headerSize) - for { - select { - case <-s.die: - return - case request, ok := <-s.writes: - if !ok { - continue - } - buf[0] = request.frame.ver - buf[1] = request.frame.cmd - binary.LittleEndian.PutUint16(buf[2:], uint16(len(request.frame.data))) - binary.LittleEndian.PutUint32(buf[4:], request.frame.sid) - copy(buf[headerSize:], request.frame.data) - n, err := s.conn.Write(buf[:headerSize+len(request.frame.data)]) - - n -= headerSize - if n < 0 { - n = 0 - } - - result := writeResult{ - n: n, - err: err, - } - - request.result <- result - close(request.result) - } - } -} - -// writeFrame writes the frame to the underlying connection -// and returns the number of bytes written if successful -func (s *Session) writeFrame(f Frame) (n int, err error) { - req := writeRequest{ - frame: f, - result: make(chan writeResult, 1), - } - select { - case <-s.die: - return 0, errors.New(errBrokenPipe) - case s.writes <- req: - } - - result := <-req.result - return result.n, result.err -} diff --git a/vendor/github.com/xtaci/smux/session_test.go b/vendor/github.com/xtaci/smux/session_test.go deleted file mode 100644 index 2147d8f..0000000 --- a/vendor/github.com/xtaci/smux/session_test.go +++ /dev/null @@ -1,667 +0,0 @@ -package smux - -import ( - crand "crypto/rand" - "encoding/binary" - "fmt" - "io" - "log" - "math/rand" - "net" - "net/http" - _ "net/http/pprof" - "strings" - "sync" - "testing" - "time" -) - -func init() { - go func() { - log.Println(http.ListenAndServe("localhost:6060", nil)) - }() - log.SetFlags(log.LstdFlags | log.Lshortfile) - ln, err := net.Listen("tcp", "127.0.0.1:19999") - if err != nil { - // handle error - panic(err) - } - go func() { - for { - conn, err := ln.Accept() - if err != nil { - // handle error - } - go handleConnection(conn) - } - }() -} - -func handleConnection(conn net.Conn) { - session, _ := Server(conn, nil) - for { - if stream, err := session.AcceptStream(); err == nil { - go func(s io.ReadWriteCloser) { - buf := make([]byte, 65536) - for { - n, err := s.Read(buf) - if err != nil { - return - } - s.Write(buf[:n]) - } - }(stream) - } else { - return - } - } -} - -func TestEcho(t *testing.T) { - cli, err := net.Dial("tcp", "127.0.0.1:19999") - if err != nil { - t.Fatal(err) - } - session, _ := Client(cli, nil) - stream, _ := session.OpenStream() - const N = 100 - buf := make([]byte, 10) - var sent string - var received string - for i := 0; i < N; i++ { - msg := fmt.Sprintf("hello%v", i) - stream.Write([]byte(msg)) - sent += msg - if n, err := stream.Read(buf); err != nil { - t.Fatal(err) - } else { - received += string(buf[:n]) - } - } - if sent != received { - t.Fatal("data mimatch") - } - session.Close() -} - -func TestSpeed(t *testing.T) { - cli, err := net.Dial("tcp", "127.0.0.1:19999") - if err != nil { - t.Fatal(err) - } - session, _ := Client(cli, nil) - stream, _ := session.OpenStream() - t.Log(stream.LocalAddr(), stream.RemoteAddr()) - - start := time.Now() - var wg sync.WaitGroup - wg.Add(1) - go func() { - buf := make([]byte, 1024*1024) - nrecv := 0 - for { - n, err := stream.Read(buf) - if err != nil { - t.Fatal(err) - break - } else { - nrecv += n - if nrecv == 4096*4096 { - break - } - } - } - stream.Close() - t.Log("time for 16MB rtt", time.Since(start)) - wg.Done() - }() - msg := make([]byte, 8192) - for i := 0; i < 2048; i++ { - stream.Write(msg) - } - wg.Wait() - session.Close() -} - -func TestParallel(t *testing.T) { - cli, err := net.Dial("tcp", "127.0.0.1:19999") - if err != nil { - t.Fatal(err) - } - session, _ := Client(cli, nil) - - par := 1000 - messages := 100 - var wg sync.WaitGroup - wg.Add(par) - for i := 0; i < par; i++ { - stream, _ := session.OpenStream() - go func(s *Stream) { - buf := make([]byte, 20) - for j := 0; j < messages; j++ { - msg := fmt.Sprintf("hello%v", j) - s.Write([]byte(msg)) - if _, err := s.Read(buf); err != nil { - break - } - } - s.Close() - wg.Done() - }(stream) - } - t.Log("created", session.NumStreams(), "streams") - wg.Wait() - session.Close() -} - -func TestCloseThenOpen(t *testing.T) { - cli, err := net.Dial("tcp", "127.0.0.1:19999") - if err != nil { - t.Fatal(err) - } - session, _ := Client(cli, nil) - session.Close() - if _, err := session.OpenStream(); err == nil { - t.Fatal("opened after close") - } -} - -func TestStreamDoubleClose(t *testing.T) { - cli, err := net.Dial("tcp", "127.0.0.1:19999") - if err != nil { - t.Fatal(err) - } - session, _ := Client(cli, nil) - stream, _ := session.OpenStream() - stream.Close() - if err := stream.Close(); err == nil { - t.Log("double close doesn't return error") - } - session.Close() -} - -func TestConcurrentClose(t *testing.T) { - cli, err := net.Dial("tcp", "127.0.0.1:19999") - if err != nil { - t.Fatal(err) - } - session, _ := Client(cli, nil) - numStreams := 100 - streams := make([]*Stream, 0, numStreams) - var wg sync.WaitGroup - wg.Add(numStreams) - for i := 0; i < 100; i++ { - stream, _ := session.OpenStream() - streams = append(streams, stream) - } - for _, s := range streams { - stream := s - go func() { - stream.Close() - wg.Done() - }() - } - session.Close() - wg.Wait() -} - -func TestTinyReadBuffer(t *testing.T) { - cli, err := net.Dial("tcp", "127.0.0.1:19999") - if err != nil { - t.Fatal(err) - } - session, _ := Client(cli, nil) - stream, _ := session.OpenStream() - const N = 100 - tinybuf := make([]byte, 6) - var sent string - var received string - for i := 0; i < N; i++ { - msg := fmt.Sprintf("hello%v", i) - sent += msg - nsent, err := stream.Write([]byte(msg)) - if err != nil { - t.Fatal("cannot write") - } - nrecv := 0 - for nrecv < nsent { - if n, err := stream.Read(tinybuf); err == nil { - nrecv += n - received += string(tinybuf[:n]) - } else { - t.Fatal("cannot read with tiny buffer") - } - } - } - - if sent != received { - t.Fatal("data mimatch") - } - session.Close() -} - -func TestIsClose(t *testing.T) { - cli, err := net.Dial("tcp", "127.0.0.1:19999") - if err != nil { - t.Fatal(err) - } - session, _ := Client(cli, nil) - session.Close() - if session.IsClosed() != true { - t.Fatal("still open after close") - } -} - -func TestKeepAliveTimeout(t *testing.T) { - ln, err := net.Listen("tcp", "127.0.0.1:29999") - if err != nil { - // handle error - panic(err) - } - go func() { - ln.Accept() - }() - - cli, err := net.Dial("tcp", "127.0.0.1:29999") - if err != nil { - t.Fatal(err) - } - - config := DefaultConfig() - config.KeepAliveInterval = time.Second - config.KeepAliveTimeout = 2 * time.Second - session, _ := Client(cli, config) - <-time.After(3 * time.Second) - if session.IsClosed() != true { - t.Fatal("keepalive-timeout failed") - } -} - -func TestServerEcho(t *testing.T) { - ln, err := net.Listen("tcp", "127.0.0.1:39999") - if err != nil { - // handle error - panic(err) - } - go func() { - if conn, err := ln.Accept(); err == nil { - session, _ := Server(conn, nil) - if stream, err := session.OpenStream(); err == nil { - const N = 100 - buf := make([]byte, 10) - for i := 0; i < N; i++ { - msg := fmt.Sprintf("hello%v", i) - stream.Write([]byte(msg)) - if n, err := stream.Read(buf); err != nil { - t.Fatal(err) - } else if string(buf[:n]) != msg { - t.Fatal(err) - } - } - stream.Close() - } else { - t.Fatal(err) - } - } else { - t.Fatal(err) - } - }() - - cli, err := net.Dial("tcp", "127.0.0.1:39999") - if err != nil { - t.Fatal(err) - } - if session, err := Client(cli, nil); err == nil { - if stream, err := session.AcceptStream(); err == nil { - buf := make([]byte, 65536) - for { - n, err := stream.Read(buf) - if err != nil { - break - } - stream.Write(buf[:n]) - } - } else { - t.Fatal(err) - } - } else { - t.Fatal(err) - } -} - -func TestSendWithoutRecv(t *testing.T) { - cli, err := net.Dial("tcp", "127.0.0.1:19999") - if err != nil { - t.Fatal(err) - } - session, _ := Client(cli, nil) - stream, _ := session.OpenStream() - const N = 100 - for i := 0; i < N; i++ { - msg := fmt.Sprintf("hello%v", i) - stream.Write([]byte(msg)) - } - buf := make([]byte, 1) - if _, err := stream.Read(buf); err != nil { - t.Fatal(err) - } - stream.Close() -} - -func TestWriteAfterClose(t *testing.T) { - cli, err := net.Dial("tcp", "127.0.0.1:19999") - if err != nil { - t.Fatal(err) - } - session, _ := Client(cli, nil) - stream, _ := session.OpenStream() - stream.Close() - if _, err := stream.Write([]byte("write after close")); err == nil { - t.Fatal("write after close failed") - } -} - -func TestReadStreamAfterSessionClose(t *testing.T) { - cli, err := net.Dial("tcp", "127.0.0.1:19999") - if err != nil { - t.Fatal(err) - } - session, _ := Client(cli, nil) - stream, _ := session.OpenStream() - session.Close() - buf := make([]byte, 10) - if _, err := stream.Read(buf); err != nil { - t.Log(err) - } else { - t.Fatal("read stream after session close succeeded") - } -} - -func TestWriteStreamAfterConnectionClose(t *testing.T) { - cli, err := net.Dial("tcp", "127.0.0.1:19999") - if err != nil { - t.Fatal(err) - } - session, _ := Client(cli, nil) - stream, _ := session.OpenStream() - session.conn.Close() - if _, err := stream.Write([]byte("write after connection close")); err == nil { - t.Fatal("write after connection close failed") - } -} - -func TestNumStreamAfterClose(t *testing.T) { - cli, err := net.Dial("tcp", "127.0.0.1:19999") - if err != nil { - t.Fatal(err) - } - session, _ := Client(cli, nil) - if _, err := session.OpenStream(); err == nil { - if session.NumStreams() != 1 { - t.Fatal("wrong number of streams after opened") - } - session.Close() - if session.NumStreams() != 0 { - t.Fatal("wrong number of streams after session closed") - } - } else { - t.Fatal(err) - } - cli.Close() -} - -func TestRandomFrame(t *testing.T) { - // pure random - cli, err := net.Dial("tcp", "127.0.0.1:19999") - if err != nil { - t.Fatal(err) - } - session, _ := Client(cli, nil) - for i := 0; i < 100; i++ { - rnd := make([]byte, rand.Uint32()%1024) - io.ReadFull(crand.Reader, rnd) - session.conn.Write(rnd) - } - cli.Close() - - // double syn - cli, err = net.Dial("tcp", "127.0.0.1:19999") - if err != nil { - t.Fatal(err) - } - session, _ = Client(cli, nil) - for i := 0; i < 100; i++ { - f := newFrame(cmdSYN, 1000) - session.writeFrame(f) - } - cli.Close() - - // random cmds - cli, err = net.Dial("tcp", "127.0.0.1:19999") - if err != nil { - t.Fatal(err) - } - allcmds := []byte{cmdSYN, cmdFIN, cmdPSH, cmdNOP} - session, _ = Client(cli, nil) - for i := 0; i < 100; i++ { - f := newFrame(allcmds[rand.Int()%len(allcmds)], rand.Uint32()) - session.writeFrame(f) - } - cli.Close() - - // random cmds & sids - cli, err = net.Dial("tcp", "127.0.0.1:19999") - if err != nil { - t.Fatal(err) - } - session, _ = Client(cli, nil) - for i := 0; i < 100; i++ { - f := newFrame(byte(rand.Uint32()), rand.Uint32()) - session.writeFrame(f) - } - cli.Close() - - // random version - cli, err = net.Dial("tcp", "127.0.0.1:19999") - if err != nil { - t.Fatal(err) - } - session, _ = Client(cli, nil) - for i := 0; i < 100; i++ { - f := newFrame(byte(rand.Uint32()), rand.Uint32()) - f.ver = byte(rand.Uint32()) - session.writeFrame(f) - } - cli.Close() - - // incorrect size - cli, err = net.Dial("tcp", "127.0.0.1:19999") - if err != nil { - t.Fatal(err) - } - session, _ = Client(cli, nil) - - f := newFrame(byte(rand.Uint32()), rand.Uint32()) - rnd := make([]byte, rand.Uint32()%1024) - io.ReadFull(crand.Reader, rnd) - f.data = rnd - - buf := make([]byte, headerSize+len(f.data)) - buf[0] = f.ver - buf[1] = f.cmd - binary.LittleEndian.PutUint16(buf[2:], uint16(len(rnd)+1)) /// incorrect size - binary.LittleEndian.PutUint32(buf[4:], f.sid) - copy(buf[headerSize:], f.data) - - session.conn.Write(buf) - t.Log(rawHeader(buf)) - cli.Close() -} - -func TestReadDeadline(t *testing.T) { - cli, err := net.Dial("tcp", "127.0.0.1:19999") - if err != nil { - t.Fatal(err) - } - session, _ := Client(cli, nil) - stream, _ := session.OpenStream() - const N = 100 - buf := make([]byte, 10) - var readErr error - for i := 0; i < N; i++ { - msg := fmt.Sprintf("hello%v", i) - stream.Write([]byte(msg)) - stream.SetReadDeadline(time.Now().Add(-1 * time.Minute)) - if _, readErr = stream.Read(buf); readErr != nil { - break - } - } - if readErr != nil { - if !strings.Contains(readErr.Error(), "i/o timeout") { - t.Fatalf("Wrong error: %v", readErr) - } - } else { - t.Fatal("No error when reading with past deadline") - } - session.Close() -} - -func TestWriteDeadline(t *testing.T) { - cli, err := net.Dial("tcp", "127.0.0.1:19999") - if err != nil { - t.Fatal(err) - } - session, _ := Client(cli, nil) - stream, _ := session.OpenStream() - buf := make([]byte, 10) - var writeErr error - for { - stream.SetWriteDeadline(time.Now().Add(-1 * time.Minute)) - if _, writeErr = stream.Write(buf); writeErr != nil { - if !strings.Contains(writeErr.Error(), "i/o timeout") { - t.Fatalf("Wrong error: %v", writeErr) - } - break - } - } - session.Close() -} - -func BenchmarkAcceptClose(b *testing.B) { - cli, err := net.Dial("tcp", "127.0.0.1:19999") - if err != nil { - b.Fatal(err) - } - session, _ := Client(cli, nil) - for i := 0; i < b.N; i++ { - if stream, err := session.OpenStream(); err == nil { - stream.Close() - } else { - b.Fatal(err) - } - } -} -func BenchmarkConnSmux(b *testing.B) { - cs, ss, err := getSmuxStreamPair() - if err != nil { - b.Fatal(err) - } - defer cs.Close() - defer ss.Close() - bench(b, cs, ss) -} - -func BenchmarkConnTCP(b *testing.B) { - cs, ss, err := getTCPConnectionPair() - if err != nil { - b.Fatal(err) - } - defer cs.Close() - defer ss.Close() - bench(b, cs, ss) -} - -func getSmuxStreamPair() (*Stream, *Stream, error) { - c1, c2, err := getTCPConnectionPair() - if err != nil { - return nil, nil, err - } - - s, err := Server(c2, nil) - if err != nil { - return nil, nil, err - } - c, err := Client(c1, nil) - if err != nil { - return nil, nil, err - } - var ss *Stream - done := make(chan error) - go func() { - var rerr error - ss, rerr = s.AcceptStream() - done <- rerr - close(done) - }() - cs, err := c.OpenStream() - if err != nil { - return nil, nil, err - } - err = <-done - if err != nil { - return nil, nil, err - } - - return cs, ss, nil -} - -func getTCPConnectionPair() (net.Conn, net.Conn, error) { - lst, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - return nil, nil, err - } - - var conn0 net.Conn - var err0 error - done := make(chan struct{}) - go func() { - conn0, err0 = lst.Accept() - close(done) - }() - - conn1, err := net.Dial("tcp", lst.Addr().String()) - if err != nil { - return nil, nil, err - } - - <-done - if err0 != nil { - return nil, nil, err0 - } - return conn0, conn1, nil -} - -func bench(b *testing.B, rd io.Reader, wr io.Writer) { - buf := make([]byte, 128*1024) - buf2 := make([]byte, 128*1024) - b.SetBytes(128 * 1024) - b.ResetTimer() - - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - count := 0 - for { - n, _ := rd.Read(buf2) - count += n - if count == 128*1024*b.N { - return - } - } - }() - for i := 0; i < b.N; i++ { - wr.Write(buf) - } - wg.Wait() -} diff --git a/vendor/github.com/xtaci/smux/smux.png b/vendor/github.com/xtaci/smux/smux.png deleted file mode 100644 index 26aba3b..0000000 Binary files a/vendor/github.com/xtaci/smux/smux.png and /dev/null differ diff --git a/vendor/github.com/xtaci/smux/stream.go b/vendor/github.com/xtaci/smux/stream.go deleted file mode 100644 index 8b8a52a..0000000 --- a/vendor/github.com/xtaci/smux/stream.go +++ /dev/null @@ -1,261 +0,0 @@ -package smux - -import ( - "bytes" - "io" - "net" - "sync" - "sync/atomic" - "time" - - "github.com/pkg/errors" -) - -// Stream implements net.Conn -type Stream struct { - id uint32 - rstflag int32 - sess *Session - buffer bytes.Buffer - bufferLock sync.Mutex - frameSize int - chReadEvent chan struct{} // notify a read event - die chan struct{} // flag the stream has closed - dieLock sync.Mutex - readDeadline atomic.Value - writeDeadline atomic.Value -} - -// newStream initiates a Stream struct -func newStream(id uint32, frameSize int, sess *Session) *Stream { - s := new(Stream) - s.id = id - s.chReadEvent = make(chan struct{}, 1) - s.frameSize = frameSize - s.sess = sess - s.die = make(chan struct{}) - return s -} - -// ID returns the unique stream ID. -func (s *Stream) ID() uint32 { - return s.id -} - -// Read implements net.Conn -func (s *Stream) Read(b []byte) (n int, err error) { - var deadline <-chan time.Time - if d, ok := s.readDeadline.Load().(time.Time); ok && !d.IsZero() { - timer := time.NewTimer(d.Sub(time.Now())) - defer timer.Stop() - deadline = timer.C - } - -READ: - select { - case <-s.die: - return 0, errors.New(errBrokenPipe) - case <-deadline: - return n, errTimeout - default: - } - - s.bufferLock.Lock() - n, err = s.buffer.Read(b) - s.bufferLock.Unlock() - - if n > 0 { - s.sess.returnTokens(n) - return n, nil - } else if atomic.LoadInt32(&s.rstflag) == 1 { - _ = s.Close() - return 0, io.EOF - } - - select { - case <-s.chReadEvent: - goto READ - case <-deadline: - return n, errTimeout - case <-s.die: - return 0, errors.New(errBrokenPipe) - } -} - -// Write implements net.Conn -func (s *Stream) Write(b []byte) (n int, err error) { - var deadline <-chan time.Time - if d, ok := s.writeDeadline.Load().(time.Time); ok && !d.IsZero() { - timer := time.NewTimer(d.Sub(time.Now())) - defer timer.Stop() - deadline = timer.C - } - - select { - case <-s.die: - return 0, errors.New(errBrokenPipe) - default: - } - - frames := s.split(b, cmdPSH, s.id) - sent := 0 - for k := range frames { - req := writeRequest{ - frame: frames[k], - result: make(chan writeResult, 1), - } - - select { - case s.sess.writes <- req: - case <-s.die: - return sent, errors.New(errBrokenPipe) - case <-deadline: - return sent, errTimeout - } - - select { - case result := <-req.result: - sent += result.n - if result.err != nil { - return sent, result.err - } - case <-s.die: - return sent, errors.New(errBrokenPipe) - case <-deadline: - return sent, errTimeout - } - } - return sent, nil -} - -// Close implements net.Conn -func (s *Stream) Close() error { - s.dieLock.Lock() - - select { - case <-s.die: - s.dieLock.Unlock() - return errors.New(errBrokenPipe) - default: - close(s.die) - s.dieLock.Unlock() - s.sess.streamClosed(s.id) - _, err := s.sess.writeFrame(newFrame(cmdFIN, s.id)) - return err - } -} - -// SetReadDeadline sets the read deadline as defined by -// net.Conn.SetReadDeadline. -// A zero time value disables the deadline. -func (s *Stream) SetReadDeadline(t time.Time) error { - s.readDeadline.Store(t) - return nil -} - -// SetWriteDeadline sets the write deadline as defined by -// net.Conn.SetWriteDeadline. -// A zero time value disables the deadline. -func (s *Stream) SetWriteDeadline(t time.Time) error { - s.writeDeadline.Store(t) - return nil -} - -// SetDeadline sets both read and write deadlines as defined by -// net.Conn.SetDeadline. -// A zero time value disables the deadlines. -func (s *Stream) SetDeadline(t time.Time) error { - if err := s.SetReadDeadline(t); err != nil { - return err - } - if err := s.SetWriteDeadline(t); err != nil { - return err - } - return nil -} - -// session closes the stream -func (s *Stream) sessionClose() { - s.dieLock.Lock() - defer s.dieLock.Unlock() - - select { - case <-s.die: - default: - close(s.die) - } -} - -// LocalAddr satisfies net.Conn interface -func (s *Stream) LocalAddr() net.Addr { - if ts, ok := s.sess.conn.(interface { - LocalAddr() net.Addr - }); ok { - return ts.LocalAddr() - } - return nil -} - -// RemoteAddr satisfies net.Conn interface -func (s *Stream) RemoteAddr() net.Addr { - if ts, ok := s.sess.conn.(interface { - RemoteAddr() net.Addr - }); ok { - return ts.RemoteAddr() - } - return nil -} - -// pushBytes a slice into buffer -func (s *Stream) pushBytes(p []byte) { - s.bufferLock.Lock() - s.buffer.Write(p) - s.bufferLock.Unlock() -} - -// recycleTokens transform remaining bytes to tokens(will truncate buffer) -func (s *Stream) recycleTokens() (n int) { - s.bufferLock.Lock() - n = s.buffer.Len() - s.buffer.Reset() - s.bufferLock.Unlock() - return -} - -// split large byte buffer into smaller frames, reference only -func (s *Stream) split(bts []byte, cmd byte, sid uint32) []Frame { - frames := make([]Frame, 0, len(bts)/s.frameSize+1) - for len(bts) > s.frameSize { - frame := newFrame(cmd, sid) - frame.data = bts[:s.frameSize] - bts = bts[s.frameSize:] - frames = append(frames, frame) - } - if len(bts) > 0 { - frame := newFrame(cmd, sid) - frame.data = bts - frames = append(frames, frame) - } - return frames -} - -// notify read event -func (s *Stream) notifyReadEvent() { - select { - case s.chReadEvent <- struct{}{}: - default: - } -} - -// mark this stream has been reset -func (s *Stream) markRST() { - atomic.StoreInt32(&s.rstflag, 1) -} - -var errTimeout error = &timeoutError{} - -type timeoutError struct{} - -func (e *timeoutError) Error() string { return "i/o timeout" } -func (e *timeoutError) Timeout() bool { return true } -func (e *timeoutError) Temporary() bool { return true }