diff --git a/models/proto/tcp/tcp.go b/models/proto/tcp/tcp.go index e1e22e2..e718714 100644 --- a/models/proto/tcp/tcp.go +++ b/models/proto/tcp/tcp.go @@ -27,23 +27,29 @@ func WithEncryption(rwc io.ReadWriteCloser, key []byte) (io.ReadWriteCloser, err if err != nil { return nil, err } - return WrapReadWriteCloser(crypto.NewReader(rwc, key), w), nil + return WrapReadWriteCloser(crypto.NewReader(rwc, key), w, func() error { + return rwc.Close() + }), nil } func WithCompression(rwc io.ReadWriteCloser) io.ReadWriteCloser { - return WrapReadWriteCloser(snappy.NewReader(rwc), snappy.NewWriter(rwc)) -} - -func WrapReadWriteCloser(r io.Reader, w io.Writer) io.ReadWriteCloser { - return &ReadWriteCloser{ - r: r, - w: w, - } + return WrapReadWriteCloser(snappy.NewReader(rwc), snappy.NewWriter(rwc), func() error { + return rwc.Close() + }) } type ReadWriteCloser struct { - r io.Reader - w io.Writer + r io.Reader + w io.Writer + closeFn func() error +} + +func WrapReadWriteCloser(r io.Reader, w io.Writer, closeFn func() error) io.ReadWriteCloser { + return &ReadWriteCloser{ + r: r, + w: w, + closeFn: closeFn, + } } func (rwc *ReadWriteCloser) Read(p []byte) (n int, err error) { @@ -69,5 +75,12 @@ func (rwc *ReadWriteCloser) Close() (errRet error) { errRet = err } } + + if rwc.closeFn != nil { + err = rwc.closeFn() + if err != nil { + errRet = err + } + } return }