frp/vendor/github.com/fatedier/kcp-go/fec.go

309 lines
7.7 KiB
Go
Raw Normal View History

2017-10-24 22:53:20 +08:00
package kcp
import (
"encoding/binary"
"sync/atomic"
2019-03-17 17:09:54 +08:00
"github.com/klauspost/reedsolomon"
2017-10-24 22:53:20 +08:00
)
const (
fecHeaderSize = 6
fecHeaderSizePlus2 = fecHeaderSize + 2 // plus 2B data size
typeData = 0xf1
2019-08-03 18:49:55 +08:00
typeParity = 0xf2
2017-10-24 22:53:20 +08:00
)
2019-08-03 18:49:55 +08:00
// fecPacket is a decoded FEC packet
type fecPacket []byte
2017-10-24 22:53:20 +08:00
2019-08-03 18:49:55 +08:00
func (bts fecPacket) seqid() uint32 { return binary.LittleEndian.Uint32(bts) }
func (bts fecPacket) flag() uint16 { return binary.LittleEndian.Uint16(bts[4:]) }
func (bts fecPacket) data() []byte { return bts[6:] }
2017-10-24 22:53:20 +08:00
2019-08-03 18:49:55 +08:00
// fecDecoder for decoding incoming packets
type fecDecoder struct {
rxlimit int // queue size limit
dataShards int
parityShards int
shardSize int
rx []fecPacket // ordered receive queue
2017-10-24 22:53:20 +08:00
2019-08-03 18:49:55 +08:00
// caches
decodeCache [][]byte
flagCache []bool
2019-03-17 17:09:54 +08:00
2019-08-03 18:49:55 +08:00
// zeros
zeros []byte
// RS decoder
codec reedsolomon.Encoder
}
2017-10-24 22:53:20 +08:00
func newFECDecoder(rxlimit, dataShards, parityShards int) *fecDecoder {
if dataShards <= 0 || parityShards <= 0 {
return nil
}
if rxlimit < dataShards+parityShards {
return nil
}
2019-03-17 17:09:54 +08:00
dec := new(fecDecoder)
dec.rxlimit = rxlimit
dec.dataShards = dataShards
dec.parityShards = parityShards
dec.shardSize = dataShards + parityShards
codec, err := reedsolomon.New(dataShards, parityShards)
2017-10-24 22:53:20 +08:00
if err != nil {
return nil
}
2019-03-17 17:09:54 +08:00
dec.codec = codec
dec.decodeCache = make([][]byte, dec.shardSize)
dec.flagCache = make([]bool, dec.shardSize)
dec.zeros = make([]byte, mtuLimit)
return dec
2017-10-24 22:53:20 +08:00
}
// decode a fec packet
2019-08-03 18:49:55 +08:00
func (dec *fecDecoder) decode(in fecPacket) (recovered [][]byte) {
2017-10-24 22:53:20 +08:00
// insertion
n := len(dec.rx) - 1
insertIdx := 0
for i := n; i >= 0; i-- {
2019-08-03 18:49:55 +08:00
if in.seqid() == dec.rx[i].seqid() { // de-duplicate
2017-10-24 22:53:20 +08:00
return nil
2019-08-03 18:49:55 +08:00
} else if _itimediff(in.seqid(), dec.rx[i].seqid()) > 0 { // insertion
2017-10-24 22:53:20 +08:00
insertIdx = i + 1
break
}
}
2019-08-03 18:49:55 +08:00
// make a copy
pkt := fecPacket(xmitBuf.Get().([]byte)[:len(in)])
copy(pkt, in)
2017-10-24 22:53:20 +08:00
// insert into ordered rx queue
if insertIdx == n+1 {
dec.rx = append(dec.rx, pkt)
} else {
dec.rx = append(dec.rx, fecPacket{})
copy(dec.rx[insertIdx+1:], dec.rx[insertIdx:]) // shift right
dec.rx[insertIdx] = pkt
}
// shard range for current packet
2019-08-03 18:49:55 +08:00
shardBegin := pkt.seqid() - pkt.seqid()%uint32(dec.shardSize)
2017-10-24 22:53:20 +08:00
shardEnd := shardBegin + uint32(dec.shardSize) - 1
// max search range in ordered queue for current shard
2019-08-03 18:49:55 +08:00
searchBegin := insertIdx - int(pkt.seqid()%uint32(dec.shardSize))
2017-10-24 22:53:20 +08:00
if searchBegin < 0 {
searchBegin = 0
}
searchEnd := searchBegin + dec.shardSize - 1
if searchEnd >= len(dec.rx) {
searchEnd = len(dec.rx) - 1
}
// re-construct datashards
if searchEnd-searchBegin+1 >= dec.dataShards {
var numshard, numDataShard, first, maxlen int
2019-03-17 17:09:54 +08:00
// zero caches
2017-10-24 22:53:20 +08:00
shards := dec.decodeCache
shardsflag := dec.flagCache
for k := range dec.decodeCache {
shards[k] = nil
shardsflag[k] = false
}
// shard assembly
for i := searchBegin; i <= searchEnd; i++ {
2019-08-03 18:49:55 +08:00
seqid := dec.rx[i].seqid()
2017-10-24 22:53:20 +08:00
if _itimediff(seqid, shardEnd) > 0 {
break
} else if _itimediff(seqid, shardBegin) >= 0 {
2019-08-03 18:49:55 +08:00
shards[seqid%uint32(dec.shardSize)] = dec.rx[i].data()
2017-10-24 22:53:20 +08:00
shardsflag[seqid%uint32(dec.shardSize)] = true
numshard++
2019-08-03 18:49:55 +08:00
if dec.rx[i].flag() == typeData {
2017-10-24 22:53:20 +08:00
numDataShard++
}
if numshard == 1 {
first = i
}
2019-08-03 18:49:55 +08:00
if len(dec.rx[i].data()) > maxlen {
maxlen = len(dec.rx[i].data())
2017-10-24 22:53:20 +08:00
}
}
}
if numDataShard == dec.dataShards {
2019-03-17 17:09:54 +08:00
// case 1: no loss on data shards
2017-10-24 22:53:20 +08:00
dec.rx = dec.freeRange(first, numshard, dec.rx)
} else if numshard >= dec.dataShards {
2019-03-17 17:09:54 +08:00
// case 2: loss on data shards, but it's recoverable from parity shards
2017-10-24 22:53:20 +08:00
for k := range shards {
if shards[k] != nil {
dlen := len(shards[k])
shards[k] = shards[k][:maxlen]
2019-03-17 17:09:54 +08:00
copy(shards[k][dlen:], dec.zeros)
2019-08-03 18:49:55 +08:00
} else {
shards[k] = xmitBuf.Get().([]byte)[:0]
2017-10-24 22:53:20 +08:00
}
}
if err := dec.codec.ReconstructData(shards); err == nil {
for k := range shards[:dec.dataShards] {
if !shardsflag[k] {
2019-08-03 18:49:55 +08:00
// recovered data should be recycled
2017-10-24 22:53:20 +08:00
recovered = append(recovered, shards[k])
}
}
}
dec.rx = dec.freeRange(first, numshard, dec.rx)
}
}
// keep rxlimit
if len(dec.rx) > dec.rxlimit {
2019-08-03 18:49:55 +08:00
if dec.rx[0].flag() == typeData { // track the unrecoverable data
2017-10-24 22:53:20 +08:00
atomic.AddUint64(&DefaultSnmp.FECShortShards, 1)
}
dec.rx = dec.freeRange(0, 1, dec.rx)
}
return
}
2019-08-03 18:49:55 +08:00
// free a range of fecPacket
2017-10-24 22:53:20 +08:00
func (dec *fecDecoder) freeRange(first, n int, q []fecPacket) []fecPacket {
2019-03-17 17:09:54 +08:00
for i := first; i < first+n; i++ { // recycle buffer
2019-08-03 18:49:55 +08:00
xmitBuf.Put([]byte(q[i]))
2017-10-24 22:53:20 +08:00
}
2019-08-03 18:49:55 +08:00
if first == 0 && n < cap(q)/2 {
return q[n:]
2017-10-24 22:53:20 +08:00
}
2019-08-03 18:49:55 +08:00
copy(q[first:], q[first+n:])
2017-10-24 22:53:20 +08:00
return q[:len(q)-n]
}
type (
// fecEncoder for encoding outgoing packets
fecEncoder struct {
dataShards int
parityShards int
shardSize int
paws uint32 // Protect Against Wrapped Sequence numbers
next uint32 // next seqid
shardCount int // count the number of datashards collected
2019-03-17 17:09:54 +08:00
maxSize int // track maximum data length in datashard
2017-10-24 22:53:20 +08:00
headerOffset int // FEC header offset
payloadOffset int // FEC payload offset
// caches
shardCache [][]byte
encodeCache [][]byte
2019-03-17 17:09:54 +08:00
// zeros
zeros []byte
2017-10-24 22:53:20 +08:00
// RS encoder
codec reedsolomon.Encoder
}
)
func newFECEncoder(dataShards, parityShards, offset int) *fecEncoder {
if dataShards <= 0 || parityShards <= 0 {
return nil
}
2019-03-17 17:09:54 +08:00
enc := new(fecEncoder)
enc.dataShards = dataShards
enc.parityShards = parityShards
enc.shardSize = dataShards + parityShards
2019-08-03 18:49:55 +08:00
enc.paws = 0xffffffff / uint32(enc.shardSize) * uint32(enc.shardSize)
2019-03-17 17:09:54 +08:00
enc.headerOffset = offset
enc.payloadOffset = enc.headerOffset + fecHeaderSize
codec, err := reedsolomon.New(dataShards, parityShards)
2017-10-24 22:53:20 +08:00
if err != nil {
return nil
}
2019-03-17 17:09:54 +08:00
enc.codec = codec
2017-10-24 22:53:20 +08:00
// caches
2019-03-17 17:09:54 +08:00
enc.encodeCache = make([][]byte, enc.shardSize)
enc.shardCache = make([][]byte, enc.shardSize)
for k := range enc.shardCache {
enc.shardCache[k] = make([]byte, mtuLimit)
2017-10-24 22:53:20 +08:00
}
2019-03-17 17:09:54 +08:00
enc.zeros = make([]byte, mtuLimit)
return enc
2017-10-24 22:53:20 +08:00
}
2019-03-17 17:09:54 +08:00
// encodes the packet, outputs parity shards if we have collected quorum datashards
// notice: the contents of 'ps' will be re-written in successive calling
2017-10-24 22:53:20 +08:00
func (enc *fecEncoder) encode(b []byte) (ps [][]byte) {
2019-08-03 18:49:55 +08:00
// The header format:
// | FEC SEQID(4B) | FEC TYPE(2B) | SIZE (2B) | PAYLOAD(SIZE-2) |
// |<-headerOffset |<-payloadOffset
2017-10-24 22:53:20 +08:00
enc.markData(b[enc.headerOffset:])
binary.LittleEndian.PutUint16(b[enc.payloadOffset:], uint16(len(b[enc.payloadOffset:])))
2019-08-03 18:49:55 +08:00
// copy data from payloadOffset to fec shard cache
2017-10-24 22:53:20 +08:00
sz := len(b)
enc.shardCache[enc.shardCount] = enc.shardCache[enc.shardCount][:sz]
2019-08-03 18:49:55 +08:00
copy(enc.shardCache[enc.shardCount][enc.payloadOffset:], b[enc.payloadOffset:])
2017-10-24 22:53:20 +08:00
enc.shardCount++
2019-03-17 17:09:54 +08:00
// track max datashard length
2017-10-24 22:53:20 +08:00
if sz > enc.maxSize {
enc.maxSize = sz
}
2019-03-17 17:09:54 +08:00
// Generation of Reed-Solomon Erasure Code
2017-10-24 22:53:20 +08:00
if enc.shardCount == enc.dataShards {
2019-03-17 17:09:54 +08:00
// fill '0' into the tail of each datashard
2017-10-24 22:53:20 +08:00
for i := 0; i < enc.dataShards; i++ {
shard := enc.shardCache[i]
slen := len(shard)
2019-03-17 17:09:54 +08:00
copy(shard[slen:enc.maxSize], enc.zeros)
2017-10-24 22:53:20 +08:00
}
// construct equal-sized slice with stripped header
cache := enc.encodeCache
for k := range cache {
cache[k] = enc.shardCache[k][enc.payloadOffset:enc.maxSize]
}
2019-03-17 17:09:54 +08:00
// encoding
2017-10-24 22:53:20 +08:00
if err := enc.codec.Encode(cache); err == nil {
ps = enc.shardCache[enc.dataShards:]
for k := range ps {
2019-08-03 18:49:55 +08:00
enc.markParity(ps[k][enc.headerOffset:])
2017-10-24 22:53:20 +08:00
ps[k] = ps[k][:enc.maxSize]
}
}
2019-03-17 17:09:54 +08:00
// counters resetting
2017-10-24 22:53:20 +08:00
enc.shardCount = 0
enc.maxSize = 0
}
return
}
func (enc *fecEncoder) markData(data []byte) {
binary.LittleEndian.PutUint32(data, enc.next)
binary.LittleEndian.PutUint16(data[4:], typeData)
enc.next++
}
2019-08-03 18:49:55 +08:00
func (enc *fecEncoder) markParity(data []byte) {
2017-10-24 22:53:20 +08:00
binary.LittleEndian.PutUint32(data, enc.next)
2019-08-03 18:49:55 +08:00
binary.LittleEndian.PutUint16(data[4:], typeParity)
// sequence wrap will only happen at parity shard
2017-10-24 22:53:20 +08:00
enc.next = (enc.next + 1) % enc.paws
}