Skip to content

Commit

Permalink
Temporary object pools. Fixes and tuning (#47)
Browse files Browse the repository at this point in the history
* Bytewise crc32

* Bump go-astikit to 0.30. Make crc32 generator. Remove old crc32 calculation func and corresponding tests/benchmarks.

* Replace OpenFile with Create in crc32 generator. Some minor changes

* Add pooling for packet slices and raw data payload. Replace map[uint16] with map[uint32] see runtime/map_fast32.go . Cut out mutexes. Make DemuxerData slices of known size. Bump GO to 1.19. Fix BenchmarkParsePSIData and BenchmarkDemuxer_NextData. Copy FirstPacket without payload to DemuxerData.

* Move pools to separate file. Rollback to GO 1.13

* Comments and naming

* Some formatting and esContexts map[uint32]

* Remove packetSlice pool. Wrap tempPayload in object to reduce allocations.

* Mark packetPool and programMap methods as Unlocked

* Naming and comments

* Naming

---------

Co-authored-by: Danil Korymov <[email protected]>
  • Loading branch information
k-danil and k-danil authored Feb 15, 2023
1 parent f2825ee commit d8a24c5
Show file tree
Hide file tree
Showing 12 changed files with 153 additions and 97 deletions.
43 changes: 28 additions & 15 deletions data.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package astits
import (
"encoding/binary"
"fmt"

"github.com/asticode/go-astikit"
)

Expand Down Expand Up @@ -54,19 +53,28 @@ func parseData(ps []*Packet, prs PacketsParser, pm *programMap) (ds []*DemuxerDa
l += len(p.Payload)
}

// Get the slice for payload from pool
payload := bytesPool.get(l)
defer bytesPool.put(payload)

// Append payload
var payload = make([]byte, l)
var c int
for _, p := range ps {
c += copy(payload[c:], p.Payload)
c += copy(payload.s[c:], p.Payload)
}

// Create reader
i := astikit.NewBytesIterator(payload)
i := astikit.NewBytesIterator(payload.s)

// Parse PID
pid := ps[0].Header.PID

// Copy first packet headers, so we can safely deallocate original payload
fp := &Packet{
Header: ps[0].Header,
AdaptationField: ps[0].AdaptationField,
}

// Parse payload
if pid == PIDCAT {
// Information in a CAT payload is private and dependent on the CA system. Use the PacketsParser
Expand All @@ -80,8 +88,8 @@ func parseData(ps []*Packet, prs PacketsParser, pm *programMap) (ds []*DemuxerDa
}

// Append data
ds = psiData.toData(ps[0], pid)
} else if isPESPayload(payload) {
ds = psiData.toData(fp, pid)
} else if isPESPayload(payload.s) {
// Parse PES data
var pesData *PESData
if pesData, err = parsePESData(i); err != nil {
Expand All @@ -90,19 +98,21 @@ func parseData(ps []*Packet, prs PacketsParser, pm *programMap) (ds []*DemuxerDa
}

// Append data
ds = append(ds, &DemuxerData{
FirstPacket: ps[0],
PES: pesData,
PID: pid,
})
ds = []*DemuxerData{
{
FirstPacket: fp,
PES: pesData,
PID: pid,
},
}
}
return
}

// isPSIPayload checks whether the payload is a PSI one
func isPSIPayload(pid uint16, pm *programMap) bool {
return pid == PIDPAT || // PAT
pm.exists(pid) || // PMT
pm.existsUnlocked(pid) || // PMT
((pid >= 0x10 && pid <= 0x14) || (pid >= 0x1e && pid <= 0x1f)) //DVB
}

Expand All @@ -125,15 +135,18 @@ func isPSIComplete(ps []*Packet) bool {
l += len(p.Payload)
}

// Get the slice for payload from pool
payload := bytesPool.get(l)
defer bytesPool.put(payload)

// Append payload
var payload = make([]byte, l)
var o int
for _, p := range ps {
o += copy(payload[o:], p.Payload)
o += copy(payload.s[o:], p.Payload)
}

// Create reader
i := astikit.NewBytesIterator(payload)
i := astikit.NewBytesIterator(payload.s)

// Get next byte
b, err := i.NextByte()
Expand Down
1 change: 1 addition & 0 deletions data_psi.go
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,7 @@ func parsePSISectionSyntaxData(i *astikit.BytesIterator, h *PSISectionHeader, sh
// toData parses the PSI tables and returns a set of DemuxerData
func (d *PSIData) toData(firstPacket *Packet, pid uint16) (ds []*DemuxerData) {
// Loop through sections
ds = make([]*DemuxerData, 0, len(d.Sections))
for _, s := range d.Sections {
// No data
if s.Syntax == nil || s.Syntax.Data == nil {
Expand Down
3 changes: 2 additions & 1 deletion data_psi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -376,8 +376,9 @@ func TestWritePSIData(t *testing.T) {
}

func BenchmarkParsePSIData(b *testing.B) {
pb := psiBytes()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
parsePSIData(astikit.NewBytesIterator(psiBytes()))
parsePSIData(astikit.NewBytesIterator(pb))
}
}
16 changes: 12 additions & 4 deletions data_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,15 @@ func TestParseData(t *testing.T) {
}
ds, err = parseData(ps, nil, pm)
assert.NoError(t, err)
assert.Equal(t, []*DemuxerData{{FirstPacket: ps[0], PES: pesWithHeader(), PID: uint16(256)}}, ds)
assert.Equal(t, []*DemuxerData{
{
FirstPacket: &Packet{Header: ps[0].Header, AdaptationField: ps[0].AdaptationField},
PES: pesWithHeader(),
PID: uint16(256),
}}, ds)

// PSI
pm.set(uint16(256), uint16(1))
pm.setUnlocked(uint16(256), uint16(1))
p = psiBytes()
ps = []*Packet{
{
Expand All @@ -61,7 +66,10 @@ func TestParseData(t *testing.T) {
}
ds, err = parseData(ps, nil, pm)
assert.NoError(t, err)
assert.Equal(t, psi.toData(ps[0], uint16(256)), ds)
assert.Equal(t, psi.toData(
&Packet{Header: ps[0].Header, AdaptationField: ps[0].AdaptationField},
uint16(256),
), ds)
}

func TestIsPSIPayload(t *testing.T) {
Expand All @@ -73,7 +81,7 @@ func TestIsPSIPayload(t *testing.T) {
}
}
assert.Equal(t, []int{0, 16, 17, 18, 19, 20, 30, 31}, pids)
pm.set(uint16(1), uint16(0))
pm.setUnlocked(uint16(1), uint16(0))
assert.True(t, isPSIPayload(uint16(1), pm))
}

Expand Down
6 changes: 3 additions & 3 deletions demuxer.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ func (dmx *Demuxer) NextData() (d *DemuxerData, err error) {
if err == ErrNoMorePackets {
for {
// Dump packet pool
if ps = dmx.packetPool.dump(); len(ps) == 0 {
if ps = dmx.packetPool.dumpUnlocked(); len(ps) == 0 {
break
}

Expand All @@ -165,7 +165,7 @@ func (dmx *Demuxer) NextData() (d *DemuxerData, err error) {
}

// Add packet to the pool
if ps = dmx.packetPool.add(p); len(ps) == 0 {
if ps = dmx.packetPool.addUnlocked(p); len(ps) == 0 {
continue
}

Expand Down Expand Up @@ -195,7 +195,7 @@ func (dmx *Demuxer) updateData(ds []*DemuxerData) (d *DemuxerData) {
for _, pgm := range v.PAT.Programs {
// Program number 0 is reserved to NIT
if pgm.ProgramNumber > 0 {
dmx.programMap.set(pgm.ProgramMapID, pgm.ProgramNumber)
dmx.programMap.setUnlocked(pgm.ProgramMapID, pgm.ProgramNumber)
}
}
}
Expand Down
12 changes: 7 additions & 5 deletions demuxer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,11 @@ func TestDemuxerNextData(t *testing.T) {
ds = append(ds, d)
}
}
assert.Equal(t, psi.toData(p, PIDPAT), ds)
assert.Equal(t, map[uint16]uint16{0x3: 0x2, 0x5: 0x4}, dmx.programMap.p)
assert.Equal(t, psi.toData(
&Packet{Header: p.Header, AdaptationField: p.AdaptationField},
PIDPAT,
), ds)
assert.Equal(t, map[uint32]uint16{0x3: 0x2, 0x5: 0x4}, dmx.programMap.p)

// No more packets
_, err = dmx.NextData()
Expand Down Expand Up @@ -158,7 +161,7 @@ func TestDemuxerNextDataPATPMT(t *testing.T) {
func TestDemuxerRewind(t *testing.T) {
r := bytes.NewReader([]byte("content"))
dmx := NewDemuxer(context.Background(), r)
dmx.packetPool.add(&Packet{Header: PacketHeader{PID: 1}})
dmx.packetPool.addUnlocked(&Packet{Header: PacketHeader{PID: 1}})
dmx.dataBuffer = append(dmx.dataBuffer, &DemuxerData{})
b := make([]byte, 2)
_, err := r.Read(b)
Expand All @@ -184,11 +187,10 @@ func BenchmarkDemuxer_NextData(b *testing.B) {
w.Write(b2)

r := bytes.NewReader(buf.Bytes())
dmx := NewDemuxer(context.Background(), r)

for i := 0; i < b.N; i++ {
r.Seek(0, io.SeekStart)
dmx := NewDemuxer(context.Background(), r)

for _, s := range psi.Sections {
if !s.Header.TableID.isUnknown() {
dmx.NextData()
Expand Down
15 changes: 8 additions & 7 deletions muxer.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ type Muxer struct {
buf bytes.Buffer
bufWriter *astikit.BitsWriter

esContexts map[uint16]*esContext
// We use map[uint32] instead map[uint16] as go runtime provide optimized hash functions for (u)int32/64 keys
esContexts map[uint32]*esContext
tablesRetransmitCounter int
}

Expand Down Expand Up @@ -90,14 +91,14 @@ func NewMuxer(ctx context.Context, w io.Writer, opts ...func(*Muxer)) *Muxer {
patCC: newWrappingCounter(0b1111),
pmtCC: newWrappingCounter(0b1111),

esContexts: map[uint16]*esContext{},
esContexts: map[uint32]*esContext{},
}

m.bufWriter = astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: &m.buf})
m.bitsWriter = astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: m.w})

// TODO multiple programs support
m.pm.set(pmtStartPID, programNumberStart)
m.pm.setUnlocked(pmtStartPID, programNumberStart)
m.pmUpdated = true

for _, opt := range opts {
Expand Down Expand Up @@ -125,7 +126,7 @@ func (m *Muxer) AddElementaryStream(es PMTElementaryStream) error {

m.pmt.ElementaryStreams = append(m.pmt.ElementaryStreams, &es)

m.esContexts[es.ElementaryPID] = newEsContext(&es)
m.esContexts[uint32(es.ElementaryPID)] = newEsContext(&es)
// invalidate pmt cache
m.pmtBytes.Reset()
m.pmtUpdated = true
Expand All @@ -146,7 +147,7 @@ func (m *Muxer) RemoveElementaryStream(pid uint16) error {
}

m.pmt.ElementaryStreams = append(m.pmt.ElementaryStreams[:foundIdx], m.pmt.ElementaryStreams[foundIdx+1:]...)
delete(m.esContexts, pid)
delete(m.esContexts, uint32(pid))
m.pmtBytes.Reset()
m.pmtUpdated = true
return nil
Expand All @@ -162,7 +163,7 @@ func (m *Muxer) SetPCRPID(pid uint16) {
// Currently only PES packets are supported
// Be aware that after successful call WriteData will set d.AdaptationField.StuffingLength value to zero
func (m *Muxer) WriteData(d *MuxerData) (int, error) {
ctx, ok := m.esContexts[d.PID]
ctx, ok := m.esContexts[uint32(d.PID)]
if !ok {
return 0, ErrPIDNotFound
}
Expand Down Expand Up @@ -320,7 +321,7 @@ func (m *Muxer) WriteTables() (int, error) {
}

func (m *Muxer) generatePAT() error {
d := m.pm.toPATData()
d := m.pm.toPATDataUnlocked()

versionNumber := m.patVersion.get()
if m.pmUpdated {
Expand Down
39 changes: 18 additions & 21 deletions packet_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package astits

import (
"sort"
"sync"
)

// packetAccumulator keeps track of packets for a single PID and decides when to flush them
Expand All @@ -26,7 +25,12 @@ func (b *packetAccumulator) add(p *Packet) (ps []*Packet) {

// Empty buffer if we detect a discontinuity
if hasDiscontinuity(mps, p) {
mps = make([]*Packet, 0, cap(mps))
// Reset current slice or make new
if cap(mps) > 0 {
mps = mps[:0]
} else {
mps = make([]*Packet, 0, 10)
}
}

// Throw away packet if it's the same as the previous one
Expand All @@ -44,7 +48,7 @@ func (b *packetAccumulator) add(p *Packet) (ps []*Packet) {

// Check if PSI payload is complete
if b.programMap != nil &&
(b.pid == PIDPAT || b.programMap.exists(b.pid)) &&
(b.pid == PIDPAT || b.programMap.existsUnlocked(b.pid)) &&
isPSIComplete(mps) {
ps = mps
mps = nil
Expand All @@ -56,24 +60,23 @@ func (b *packetAccumulator) add(p *Packet) (ps []*Packet) {

// packetPool represents a queue of packets for each PID in the stream
type packetPool struct {
b map[uint16]*packetAccumulator // Indexed by PID
m *sync.Mutex
// We use map[uint32] instead map[uint16] as go runtime provide optimized hash functions for (u)int32/64 keys
b map[uint32]*packetAccumulator // Indexed by PID

programMap *programMap
}

// newPacketPool creates a new packet pool with an optional parser and programMap
func newPacketPool(programMap *programMap) *packetPool {
return &packetPool{
b: make(map[uint16]*packetAccumulator),
m: &sync.Mutex{},
b: make(map[uint32]*packetAccumulator),

programMap: programMap,
}
}

// add adds a new packet to the pool
func (b *packetPool) add(p *Packet) (ps []*Packet) {
// addUnlocked adds a new packet to the pool
func (b *packetPool) addUnlocked(p *Packet) (ps []*Packet) {
// Throw away packet if error indicator
if p.Header.TransportErrorIndicator {
return
Expand All @@ -85,33 +88,27 @@ func (b *packetPool) add(p *Packet) (ps []*Packet) {
return
}

// Lock
b.m.Lock()
defer b.m.Unlock()

// Make sure accumulator exists
acc, ok := b.b[p.Header.PID]
acc, ok := b.b[uint32(p.Header.PID)]
if !ok {
acc = newPacketAccumulator(p.Header.PID, b.programMap)
b.b[p.Header.PID] = acc
b.b[uint32(p.Header.PID)] = acc
}

// Add to the accumulator
return acc.add(p)
}

// dump dumps the packet pool by looking for the first item with packets inside
func (b *packetPool) dump() (ps []*Packet) {
b.m.Lock()
defer b.m.Unlock()
// dumpUnlocked dumps the packet pool by looking for the first item with packets inside
func (b *packetPool) dumpUnlocked() (ps []*Packet) {
var keys []int
for k := range b.b {
keys = append(keys, int(k))
}
sort.Ints(keys)
for _, k := range keys {
ps = b.b[uint16(k)].q
delete(b.b, uint16(k))
ps = b.b[uint32(k)].q
delete(b.b, uint32(k))
if len(ps) > 0 {
return
}
Expand Down
Loading

0 comments on commit d8a24c5

Please sign in to comment.