Skip to content

Commit

Permalink
Optional packet filter (#46)
Browse files Browse the repository at this point in the history
* Bytewise crc32

* Bump go-astikit to 0.30. Make crc32 generator. Remove old crc32 calculation func and corresponding tests/benchmarks.

* Replace OpenFile with Create in crc32 generator. Some minor changes

* Implement optional packet filter

* Revert packet reference creation

* Formatting and tests

* Fix comments

Co-authored-by: Danil Korymov <[email protected]>
  • Loading branch information
k-danil and k-danil authored Jan 23, 2023
1 parent db51df8 commit f2825ee
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 26 deletions.
30 changes: 22 additions & 8 deletions demuxer.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,28 @@ var (
// http://seidl.cs.vsb.cz/download/dvb/DVB_Poster.pdf
// http://www.etsi.org/deliver/etsi_en/300400_300499/300468/01.13.01_40/en_300468v011301o.pdf
type Demuxer struct {
ctx context.Context
dataBuffer []*DemuxerData
l astikit.CompleteLogger
ctx context.Context
dataBuffer []*DemuxerData
l astikit.CompleteLogger

optPacketSize int
optPacketsParser PacketsParser
packetBuffer *packetBuffer
packetPool *packetPool
programMap *programMap
r io.Reader
optPacketSkipper PacketSkipper

packetBuffer *packetBuffer
packetPool *packetPool
programMap *programMap
r io.Reader
}

// PacketsParser represents an object capable of parsing a set of packets containing a unique payload spanning over those packets
// Use the skip returned argument to indicate whether the default process should still be executed on the set of packets
type PacketsParser func(ps []*Packet) (ds []*DemuxerData, skip bool, err error)

// PacketSkipper represents an object capable of skipping a packet before parsing its payload. Its header and adaptation field is parsed and provided to the object.
// Use this option if you need to filter out unwanted packets from your pipeline. NextPacket() will return the next unskipped packet if any.
type PacketSkipper func(p *Packet) (skip bool)

// NewDemuxer creates a new transport stream based on a reader
func NewDemuxer(ctx context.Context, r io.Reader, opts ...func(*Demuxer)) (d *Demuxer) {
// Init
Expand Down Expand Up @@ -78,6 +85,13 @@ func DemuxerOptPacketsParser(p PacketsParser) func(*Demuxer) {
}
}

// DemuxerOptPacketSkipper returns the option to set the packet skipper
func DemuxerOptPacketSkipper(s PacketSkipper) func(*Demuxer) {
return func(d *Demuxer) {
d.optPacketSkipper = s
}
}

// NextPacket retrieves the next packet
func (dmx *Demuxer) NextPacket() (p *Packet, err error) {
// Check ctx error
Expand All @@ -89,7 +103,7 @@ func (dmx *Demuxer) NextPacket() (p *Packet, err error) {

// Create packet buffer if not exists
if dmx.packetBuffer == nil {
if dmx.packetBuffer, err = newPacketBuffer(dmx.r, dmx.optPacketSize); err != nil {
if dmx.packetBuffer, err = newPacketBuffer(dmx.r, dmx.optPacketSize, dmx.optPacketSkipper); err != nil {
err = fmt.Errorf("astits: creating packet buffer failed: %w", err)
return
}
Expand Down
4 changes: 3 additions & 1 deletion demuxer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,11 @@ func hexToBytes(in string) []byte {
func TestDemuxerNew(t *testing.T) {
ps := 1
pp := func(ps []*Packet) (ds []*DemuxerData, skip bool, err error) { return }
dmx := NewDemuxer(context.Background(), nil, DemuxerOptPacketSize(ps), DemuxerOptPacketsParser(pp))
sp := func(p *Packet) bool { return true }
dmx := NewDemuxer(context.Background(), nil, DemuxerOptPacketSize(ps), DemuxerOptPacketsParser(pp), DemuxerOptPacketSkipper(sp))
assert.Equal(t, ps, dmx.optPacketSize)
assert.Equal(t, fmt.Sprintf("%p", pp), fmt.Sprintf("%p", dmx.optPacketsParser))
assert.Equal(t, fmt.Sprintf("%p", sp), fmt.Sprintf("%p", dmx.optPacketSkipper))
}

func TestDemuxerNextPacket(t *testing.T) {
Expand Down
10 changes: 9 additions & 1 deletion packet.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package astits

import (
"errors"
"fmt"
"github.com/asticode/go-astikit"
)
Expand All @@ -19,6 +20,8 @@ const (
pcrBytesSize = 6
)

var errSkippedPacket = errors.New("astits: skipped packet")

// Packet represents a packet
// https://en.wikipedia.org/wiki/MPEG_transport_stream
type Packet struct {
Expand Down Expand Up @@ -74,7 +77,7 @@ type PacketAdaptationExtensionField struct {
}

// parsePacket parses a packet
func parsePacket(i *astikit.BytesIterator) (p *Packet, err error) {
func parsePacket(i *astikit.BytesIterator, s PacketSkipper) (p *Packet, err error) {
// Get next byte
var b byte
if b, err = i.NextByte(); err != nil {
Expand Down Expand Up @@ -109,6 +112,11 @@ func parsePacket(i *astikit.BytesIterator) (p *Packet, err error) {
}
}

// Skip packet
if s != nil && s(p) {
return nil, errSkippedPacket
}

// Build payload
if p.Header.HasPayload {
i.Seek(payloadOffset(offsetStart, p.Header, p.AdaptationField))
Expand Down
33 changes: 21 additions & 12 deletions packet_buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package astits

import (
"bufio"
"errors"
"fmt"
"io"

Expand All @@ -11,15 +12,17 @@ import (
// packetBuffer represents a packet buffer
type packetBuffer struct {
packetSize int
s PacketSkipper
r io.Reader
packetReadBuffer []byte
}

// newPacketBuffer creates a new packet buffer
func newPacketBuffer(r io.Reader, packetSize int) (pb *packetBuffer, err error) {
func newPacketBuffer(r io.Reader, packetSize int, s PacketSkipper) (pb *packetBuffer, err error) {
// Init
pb = &packetBuffer{
packetSize: packetSize,
s: s,
r: r,
}

Expand Down Expand Up @@ -121,19 +124,25 @@ func (pb *packetBuffer) next() (p *Packet, err error) {
pb.packetReadBuffer = make([]byte, pb.packetSize)
}

if _, err = io.ReadFull(pb.r, pb.packetReadBuffer); err != nil {
if err == io.EOF || err == io.ErrUnexpectedEOF {
err = ErrNoMorePackets
} else {
err = fmt.Errorf("astits: reading %d bytes failed: %w", pb.packetSize, err)
// Loop to make sure we return a packet even if first packets are skipped
for p == nil {
if _, err = io.ReadFull(pb.r, pb.packetReadBuffer); err != nil {
if err == io.EOF || err == io.ErrUnexpectedEOF {
err = ErrNoMorePackets
} else {
err = fmt.Errorf("astits: reading %d bytes failed: %w", pb.packetSize, err)
}
return
}
return
}

// Parse packet
if p, err = parsePacket(astikit.NewBytesIterator(pb.packetReadBuffer)); err != nil {
err = fmt.Errorf("astits: building packet failed: %w", err)
return
// Parse packet
if p, err = parsePacket(astikit.NewBytesIterator(pb.packetReadBuffer), pb.s); err != nil {
if !errors.Is(err, errSkippedPacket) {
err = fmt.Errorf("astits: building packet failed: %w", err)
return
}
}
}

return
}
12 changes: 8 additions & 4 deletions packet_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,18 @@ func TestParsePacket(t *testing.T) {
buf := &bytes.Buffer{}
w := astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: buf})
w.Write(uint16(1)) // Invalid sync byte
_, err := parsePacket(astikit.NewBytesIterator(buf.Bytes()))
_, err := parsePacket(astikit.NewBytesIterator(buf.Bytes()), nil)
assert.EqualError(t, err, ErrPacketMustStartWithASyncByte.Error())

// Valid
b, ep := packet(packetHeader, *packetAdaptationField, []byte("payload"), true)
p, err := parsePacket(astikit.NewBytesIterator(b))
p, err := parsePacket(astikit.NewBytesIterator(b), nil)
assert.NoError(t, err)
assert.Equal(t, p, ep)

// Skip
_, err = parsePacket(astikit.NewBytesIterator(b), func(p *Packet) bool { return true })
assert.EqualError(t, err, errSkippedPacket.Error())
}

func TestPayloadOffset(t *testing.T) {
Expand Down Expand Up @@ -89,7 +93,7 @@ func TestWritePacket_HeaderOnly(t *testing.T) {
// we can't just compare bytes returned by packetShort since they're not completely correct,
// so we just cross-check writePacket with parsePacket
i := astikit.NewBytesIterator(buf.Bytes())
p, err := parsePacket(i)
p, err := parsePacket(i, nil)
assert.NoError(t, err)
assert.Equal(t, ep, p)
}
Expand Down Expand Up @@ -258,6 +262,6 @@ func BenchmarkParsePacket(b *testing.B) {

for i := 0; i < b.N; i++ {
b.ReportAllocs()
parsePacket(astikit.NewBytesIterator(bs))
parsePacket(astikit.NewBytesIterator(bs), nil)
}
}

0 comments on commit f2825ee

Please sign in to comment.