From 5caf397df9bf9fc81e49146d7032f111d9bd5584 Mon Sep 17 00:00:00 2001 From: Ilya Barbashov Date: Wed, 24 Mar 2021 17:44:06 +0300 Subject: [PATCH] Add muxer (#16) * starting * pcr write * go-astikit replacement until PR is merged * packet header, adaptation field, dts write * dependency cleanup * get rid of panic/recover for writing * make writePCR code more adequate * writePacket and adaptation field length calculation * writePacket and adaptation field length calculation * s/188/MpegTsPacketSize/ * writePATSection * Try* -> BitsWriterBatch * descriptors WIP * go-astikit version bump * descriptors WIP * descriptor parsing tests refactored to allow reuse for descriptor writing * descriptors writing tested * descriptors writing refactoring * descriptors done, PMT WIP * write PMT section fix * writePSIData works * WIP * PES WIP * PES functions pass tests * minor fix * muxer: pat & pmt * muxer: more tests * muxer: payload writing * es-split WIP * es-split seems to work * es-split PCR sync; some style fixes * Data.Kind cleanup * comment update * cleanup * cleanup * go-astikit dep replace removed * comment fix * minor fix * flush on pid change removed as it seems to be unnecessary * added some streamtype info funcs * StreamType and PSITableTypeID are special types now * comment cleanup * use PSITableTypeId more instead of comparing strings * comment cleanup * PSITableTypeId -> PSITableTypeID * PESIsVideoStreamId -> PESIsVideoStreamID * PSITableTypeID.String() -> PSITableTypeID.Type() * review fixes: first pack: constants and stuff * packet_buffer read() renamed to peek() * tools are moved to cmd directory * correct prefixes for muxer and demuxer opts * SetPCRPID instead of isPCRPID of AddElementaryStream * test fix and some comments * muxer WritePacket export * `astits.New` renamed to `NewDemuxer` * astikit version bump; writeBytesN removed in favor of one in BitsWriter * WritePayload -> WriteData with MuxerData Co-authored-by: Ilya Barbashov --- cmd/astits-es-split/main.go | 205 ++++++ {astits => cmd/astits-probe}/main.go | 14 +- crc32.go | 25 + data.go | 23 +- data_eit.go | 1 + data_nit.go | 1 + data_pat.go | 20 + data_pat_test.go | 10 + data_pes.go | 339 ++++++++- data_pes_test.go | 534 ++++++++++---- data_pmt.go | 182 ++++- data_pmt_test.go | 9 + data_psi.go | 359 +++++++--- data_psi_test.go | 134 +++- data_sdt.go | 1 + data_test.go | 6 +- data_tot.go | 1 + demuxer.go | 25 +- demuxer_test.go | 26 +- descriptor.go | 725 ++++++++++++++++++- descriptor_test.go | 998 +++++++++++++++++---------- dvb.go | 56 ++ dvb_test.go | 28 + go.mod | 2 +- go.sum | 4 +- muxer.go | 422 +++++++++++ muxer_test.go | 321 +++++++++ packet.go | 243 ++++++- packet_buffer.go | 45 +- packet_buffer_test.go | 2 +- packet_test.go | 117 +++- program_map.go | 25 + program_map_test.go | 2 + wrapping_counter.go | 22 + 34 files changed, 4223 insertions(+), 704 deletions(-) create mode 100644 cmd/astits-es-split/main.go rename {astits => cmd/astits-probe}/main.go (97%) create mode 100644 crc32.go create mode 100644 muxer.go create mode 100644 muxer_test.go create mode 100644 wrapping_counter.go diff --git a/cmd/astits-es-split/main.go b/cmd/astits-es-split/main.go new file mode 100644 index 0000000..420e128 --- /dev/null +++ b/cmd/astits-es-split/main.go @@ -0,0 +1,205 @@ +package main + +import ( + "bufio" + "context" + "flag" + "fmt" + "github.com/asticode/go-astikit" + "github.com/asticode/go-astits" + "log" + "os" + "path" + "time" +) + +const ( + ioBufSize = 10 * 1024 * 1024 +) + +type muxerOut struct { + f *os.File + w *bufio.Writer +} + +func main() { + flag.Usage = func() { + fmt.Fprintf(flag.CommandLine.Output(), "Split TS file into multiple files each holding one elementary stream") + fmt.Fprintf(flag.CommandLine.Output(), "%s [FLAGS] INPUT_FILE:\n", os.Args[0]) + flag.PrintDefaults() + } + outDir := flag.String("o", "out", "Output dir, 'out' by default") + inputFile := astikit.FlagCmd() + flag.Parse() + + infile, err := os.Open(inputFile) + if err != nil { + log.Fatalf("%v", err) + } + defer infile.Close() + + _, err = os.Stat(*outDir) + if !os.IsNotExist(err) { + log.Fatalf("can't write to `%s': already exists", *outDir) + } + + if err = os.MkdirAll(*outDir, os.ModePerm); err != nil { + log.Fatalf("%v", err) + } + + demux := astits.NewDemuxer( + context.Background(), + bufio.NewReaderSize(infile, ioBufSize), + ) + + var pat *astits.PATData + // key is program number + pmts := map[uint16]*astits.PMTData{} + gotAllPMTs := false + // key is pid + muxers := map[uint16]*astits.Muxer{} + outfiles := map[uint16]muxerOut{} + + pmtsPrinted := false + + timeStarted := time.Now() + bytesWritten := 0 + + for { + d, err := demux.NextData() + if err != nil { + if err == astits.ErrNoMorePackets { + break + } + log.Fatalf("%v", err) + } + + if d.PAT != nil { + pat = d.PAT + gotAllPMTs = false + continue + } + + if d.PMT != nil { + pmts[d.PMT.ProgramNumber] = d.PMT + + gotAllPMTs = true + for _, p := range pat.Programs { + _, ok := pmts[p.ProgramNumber] + if !ok { + gotAllPMTs = false + break + } + } + + if !gotAllPMTs { + continue + } + + if !pmtsPrinted { + log.Printf("Got all PMTs") + } + for _, pmt := range pmts { + if !pmtsPrinted { + log.Printf("\tProgram %d PCR PID %d", pmt.ProgramNumber, pmt.PCRPID) + } + for _, es := range pmt.ElementaryStreams { + _, ok := muxers[es.ElementaryPID] + if ok { + continue + } + + esFilename := path.Join(*outDir, fmt.Sprintf("%d.ts", es.ElementaryPID)) + outfile, err := os.Create(esFilename) + if err != nil { + log.Fatalf("%v", err) + } + + bufWriter := bufio.NewWriterSize(outfile, ioBufSize) + mux := astits.NewMuxer(context.Background(), bufWriter) + err = mux.AddElementaryStream(*es) + if err != nil { + log.Fatalf("%v", err) + } + mux.SetPCRPID(es.ElementaryPID) + + outfiles[es.ElementaryPID] = muxerOut{ + f: outfile, + w: bufWriter, + } + muxers[es.ElementaryPID] = mux + + if !pmtsPrinted { + log.Printf("\t\tES PID %d type %s", + es.ElementaryPID, es.StreamType.String(), + ) + } + } + } + + pmtsPrinted = true + continue + } + + if !gotAllPMTs { + continue + } + + if d.PES == nil { + continue + } + + pid := d.FirstPacket.Header.PID + mux, ok := muxers[pid] + if !ok { + log.Printf("Got payload for unknown PID %d", pid) + continue + } + + af := d.FirstPacket.AdaptationField + + if af != nil && af.HasPCR { + af.HasPCR = false + } + + var pcr *astits.ClockReference + if d.PES.Header.OptionalHeader.PTSDTSIndicator == astits.PTSDTSIndicatorBothPresent { + pcr = d.PES.Header.OptionalHeader.DTS + } else if d.PES.Header.OptionalHeader.PTSDTSIndicator == astits.PTSDTSIndicatorOnlyPTS { + pcr = d.PES.Header.OptionalHeader.PTS + } + + if pcr != nil { + if af == nil { + af = &astits.PacketAdaptationField{} + } + af.HasPCR = true + af.PCR = pcr + } + + n, err := mux.WriteData(&astits.MuxerData{ + PID: pid, + AdaptationField: af, + PES: d.PES, + }) + if err != nil { + log.Fatalf("%v", err) + } + + bytesWritten += n + } + + timeDiff := time.Since(timeStarted) + log.Printf("%d bytes written at rate %.02f mb/s", bytesWritten, (float64(bytesWritten)/1024.0/1024.0)/timeDiff.Seconds()) + + for _, f := range outfiles { + if err = f.w.Flush(); err != nil { + log.Printf("Error flushing %s: %v", f.f.Name(), err) + } + if err = f.f.Close(); err != nil { + log.Printf("Error closing %s: %v", f.f.Name(), err) + } + } + + log.Printf("Done") +} diff --git a/astits/main.go b/cmd/astits-probe/main.go similarity index 97% rename from astits/main.go rename to cmd/astits-probe/main.go index 07cd7eb..8591e24 100644 --- a/astits/main.go +++ b/cmd/astits-probe/main.go @@ -63,7 +63,7 @@ func main() { } // Create the demuxer - var dmx = astits.New(ctx, r) + var dmx = astits.NewDemuxer(ctx, r) // Switch on command switch cmd { @@ -219,7 +219,7 @@ func data(dmx *astits.Demuxer) (err error) { } // Loop through data - var d *astits.Data + var d *astits.DemuxerData log.Println("Fetching data...") for { // Get next data @@ -272,7 +272,7 @@ func data(dmx *astits.Demuxer) (err error) { func programs(dmx *astits.Demuxer) (o []*Program, err error) { // Loop through data - var d *astits.Data + var d *astits.DemuxerData var pgmsToProcess = make(map[uint16]bool) var pgms = make(map[uint16]*Program) log.Println("Fetching data...") @@ -347,9 +347,9 @@ type Program struct { // Stream represents a stream type Stream struct { - Descriptors []string `json:"descriptors,omitempty"` - ID uint16 `json:"id,omitempty"` - Type uint8 `json:"type,omitempty"` + Descriptors []string `json:"descriptors,omitempty"` + ID uint16 `json:"id,omitempty"` + Type astits.StreamType `json:"type,omitempty"` } func newProgram(id, mapID uint16) *Program { @@ -359,7 +359,7 @@ func newProgram(id, mapID uint16) *Program { } } -func newStream(id uint16, _type uint8) *Stream { +func newStream(id uint16, _type astits.StreamType) *Stream { return &Stream{ ID: id, Type: _type, diff --git a/crc32.go b/crc32.go new file mode 100644 index 0000000..5a3f601 --- /dev/null +++ b/crc32.go @@ -0,0 +1,25 @@ +package astits + +const ( + crc32Polynomial = uint32(0xffffffff) +) + +// computeCRC32 computes a CRC32 +// https://stackoverflow.com/questions/35034042/how-to-calculate-crc32-in-psi-si-packet +func computeCRC32(bs []byte) uint32 { + return updateCRC32(crc32Polynomial, bs) +} + +func updateCRC32(crc32 uint32, bs []byte) uint32 { + for _, b := range bs { + for i := 0; i < 8; i++ { + if (crc32 >= uint32(0x80000000)) != (b >= uint8(0x80)) { + crc32 = (crc32 << 1) ^ 0x04C11DB7 + } else { + crc32 = crc32 << 1 + } + b <<= 1 + } + } + return crc32 +} diff --git a/data.go b/data.go index 0cadf9a..f06bdab 100644 --- a/data.go +++ b/data.go @@ -8,14 +8,14 @@ import ( // PIDs const ( - PIDPAT = 0x0 // Program Association Table (PAT) contains a directory listing of all Program Map Tables. - PIDCAT = 0x1 // Conditional Access Table (CAT) contains a directory listing of all ITU-T Rec. H.222 entitlement management message streams used by Program Map Tables. - PIDTSDT = 0x2 // Transport Stream Description Table (TSDT) contains descriptors related to the overall transport stream - PIDNull = 0x1fff // Null Packet (used for fixed bandwidth padding) + PIDPAT uint16 = 0x0 // Program Association Table (PAT) contains a directory listing of all Program Map Tables. + PIDCAT uint16 = 0x1 // Conditional Access Table (CAT) contains a directory listing of all ITU-T Rec. H.222 entitlement management message streams used by Program Map Tables. + PIDTSDT uint16 = 0x2 // Transport Stream Description Table (TSDT) contains descriptors related to the overall transport stream + PIDNull uint16 = 0x1fff // Null Packet (used for fixed bandwidth padding) ) -// Data represents a data -type Data struct { +// DemuxerData represents a data parsed by Demuxer +type DemuxerData struct { EIT *EITData FirstPacket *Packet NIT *NITData @@ -27,8 +27,15 @@ type Data struct { TOT *TOTData } +// MuxerData represents a data to be written by Muxer +type MuxerData struct { + PID uint16 + AdaptationField *PacketAdaptationField + PES *PESData +} + // parseData parses a payload spanning over multiple packets and returns a set of data -func parseData(ps []*Packet, prs PacketsParser, pm programMap) (ds []*Data, err error) { +func parseData(ps []*Packet, prs PacketsParser, pm programMap) (ds []*DemuxerData, err error) { // Use custom parser first if prs != nil { var skip bool @@ -82,7 +89,7 @@ func parseData(ps []*Packet, prs PacketsParser, pm programMap) (ds []*Data, err } // Append data - ds = append(ds, &Data{ + ds = append(ds, &DemuxerData{ FirstPacket: ps[0], PES: pesData, PID: pid, diff --git a/data_eit.go b/data_eit.go index a8d4135..a81b9e0 100644 --- a/data_eit.go +++ b/data_eit.go @@ -9,6 +9,7 @@ import ( // EITData represents an EIT data // Page: 36 | Chapter: 5.2.4 | Link: https://www.dvb.org/resources/public/standards/a38_dvb-si_specification.pdf +// (barbashov) the link above can be broken, alternative: https://dvb.org/wp-content/uploads/2019/12/a038_tm1217r37_en300468v1_17_1_-_rev-134_-_si_specification.pdf type EITData struct { Events []*EITDataEvent LastTableID uint8 diff --git a/data_nit.go b/data_nit.go index 33d571f..7fdcb0f 100644 --- a/data_nit.go +++ b/data_nit.go @@ -8,6 +8,7 @@ import ( // NITData represents a NIT data // Page: 29 | Chapter: 5.2.1 | Link: https://www.dvb.org/resources/public/standards/a38_dvb-si_specification.pdf +// (barbashov) the link above can be broken, alternative: https://dvb.org/wp-content/uploads/2019/12/a038_tm1217r37_en300468v1_17_1_-_rev-134_-_si_specification.pdf type NITData struct { NetworkDescriptors []*Descriptor NetworkID uint16 diff --git a/data_pat.go b/data_pat.go index 1c32836..69a50f1 100644 --- a/data_pat.go +++ b/data_pat.go @@ -6,6 +6,10 @@ import ( "github.com/asticode/go-astikit" ) +const ( + patSectionEntryBytesSize = 4 // 16 bits + 3 reserved + 13 bits = 32 bits +) + // PATData represents a PAT data // https://en.wikipedia.org/wiki/Program-specific_information type PATData struct { @@ -41,3 +45,19 @@ func parsePATSection(i *astikit.BytesIterator, offsetSectionsEnd int, tableIDExt } return } + +func calcPATSectionLength(d *PATData) uint16 { + return uint16(4 * len(d.Programs)) +} + +func writePATSection(w *astikit.BitsWriter, d *PATData) (int, error) { + b := astikit.NewBitsWriterBatch(w) + + for _, p := range d.Programs { + b.Write(p.ProgramNumber) + b.WriteN(uint8(0xff), 3) + b.WriteN(p.ProgramMapID, 13) + } + + return len(d.Programs) * patSectionEntryBytesSize, b.Err() +} diff --git a/data_pat_test.go b/data_pat_test.go index 21a5db4..db89cb2 100644 --- a/data_pat_test.go +++ b/data_pat_test.go @@ -34,3 +34,13 @@ func TestParsePATSection(t *testing.T) { assert.Equal(t, d, pat) assert.NoError(t, err) } + +func TestWritePatSection(t *testing.T) { + bw := &bytes.Buffer{} + w := astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: bw}) + n, err := writePATSection(w, pat) + assert.NoError(t, err) + assert.Equal(t, n, 8) + assert.Equal(t, n, bw.Len()) + assert.Equal(t, patBytes(), bw.Bytes()) +} diff --git a/data_pes.go b/data_pes.go index 858d911..f0bdc7d 100644 --- a/data_pes.go +++ b/data_pes.go @@ -36,6 +36,13 @@ const ( TrickModeControlSlowReverse = 4 ) +const ( + pesHeaderLength = 6 + ptsOrDTSByteLength = 5 + escrLength = 6 + dsmTrickModeLength = 1 +) + // PESData represents a PES data // https://en.wikipedia.org/wiki/Packetized_elementary_stream // http://dvd.sourceforge.net/dvdinfo/pes-hdr.html @@ -102,6 +109,11 @@ type DSMTrickMode struct { TrickModeControl uint8 } +func (h *PESHeader) IsVideoStream() bool { + return h.StreamID == 0xe0 || + h.StreamID == 0xfd +} + // parsePESData parses a PES data func parsePESData(i *astikit.BytesIterator) (d *PESData, err error) { // Create data @@ -328,6 +340,7 @@ func parsePESOptionalHeader(i *astikit.BytesIterator) (h *PESOptionalHeader, dat err = fmt.Errorf("astits: fetching next byte failed: %w", err) return } + // TODO it's only a length of pack_header, should read it all. now it's wrong h.PackField = uint8(b) } @@ -357,12 +370,11 @@ func parsePESOptionalHeader(i *astikit.BytesIterator) (h *PESOptionalHeader, dat // Extension 2 if h.HasExtension2 { // Length - var bs []byte - if bs, err = i.NextBytes(2); err != nil { + if b, err = i.NextByte(); err != nil { err = fmt.Errorf("astits: fetching next bytes failed: %w", err) return } - h.Extension2Length = uint8(bs[0]) & 0x7f + h.Extension2Length = uint8(b) & 0x7f // Data if h.Extension2Data, err = i.NextBytes(int(h.Extension2Length)); err != nil { @@ -412,3 +424,324 @@ func parseESCR(i *astikit.BytesIterator) (cr *ClockReference, err error) { cr = newClockReference(int64(escr>>9), int64(escr&0x1ff)) return } + +// will count how many total bytes and payload bytes will be written when writePESData is called with the same arguments +// should be used by the caller of writePESData to determine AF stuffing size needed to be applied +// since the length of video PES packets are often zero, we can't just stuff it with 0xff-s at the end +func calcPESDataLength(h *PESHeader, payloadLeft []byte, isPayloadStart bool, bytesAvailable int) (totalBytes, payloadBytes int) { + totalBytes += pesHeaderLength + if isPayloadStart { + totalBytes += int(calcPESOptionalHeaderLength(h.OptionalHeader)) + } + bytesAvailable -= totalBytes + + if len(payloadLeft) < bytesAvailable { + payloadBytes = len(payloadLeft) + } else { + payloadBytes = bytesAvailable + } + + return +} + +// first packet will contain PES header with optional PES header and payload, if possible +// all consequential packets will contain just payload +// for the last packet caller must add AF with stuffing, see calcPESDataLength +func writePESData(w *astikit.BitsWriter, h *PESHeader, payloadLeft []byte, isPayloadStart bool, bytesAvailable int) (totalBytesWritten, payloadBytesWritten int, err error) { + if isPayloadStart { + var n int + n, err = writePESHeader(w, h, len(payloadLeft)) + if err != nil { + return + } + totalBytesWritten += n + } + + payloadBytesWritten = bytesAvailable - totalBytesWritten + if payloadBytesWritten > len(payloadLeft) { + payloadBytesWritten = len(payloadLeft) + } + + err = w.Write(payloadLeft[:payloadBytesWritten]) + if err != nil { + return + } + + totalBytesWritten += payloadBytesWritten + return +} + +func writePESHeader(w *astikit.BitsWriter, h *PESHeader, payloadSize int) (int, error) { + b := astikit.NewBitsWriterBatch(w) + + b.WriteN(uint32(0x000001), 24) // packet_start_code_prefix + b.Write(h.StreamID) + + pesPacketLength := 0 + + if !h.IsVideoStream() { + pesPacketLength = payloadSize + if hasPESOptionalHeader(h.StreamID) { + pesPacketLength += int(calcPESOptionalHeaderLength(h.OptionalHeader)) + } + if pesPacketLength > 0xffff { + pesPacketLength = 0 + } + } + + b.Write(uint16(pesPacketLength)) + + bytesWritten := pesHeaderLength + + if hasPESOptionalHeader(h.StreamID) { + n, err := writePESOptionalHeader(w, h.OptionalHeader) + if err != nil { + return 0, err + } + bytesWritten += n + } + + return bytesWritten, b.Err() +} + +func calcPESOptionalHeaderLength(h *PESOptionalHeader) uint8 { + if h == nil { + return 0 + } + return 3 + calcPESOptionalHeaderDataLength(h) +} + +func calcPESOptionalHeaderDataLength(h *PESOptionalHeader) (length uint8) { + if h.PTSDTSIndicator == PTSDTSIndicatorOnlyPTS { + length += ptsOrDTSByteLength + } else if h.PTSDTSIndicator == PTSDTSIndicatorBothPresent { + length += 2 * ptsOrDTSByteLength + } + + if h.HasESCR { + length += escrLength + } + + if h.HasESRate { + length += 3 + } + + if h.HasDSMTrickMode { + length += dsmTrickModeLength + } + + if h.HasAdditionalCopyInfo { + length++ + } + + if h.HasCRC { + //length += 4 // TODO + } + + if h.HasExtension { + length++ + + if h.HasPrivateData { + length += 16 + } + + if h.HasPackHeaderField { + // TODO + } + + if h.HasProgramPacketSequenceCounter { + length += 2 + } + + if h.HasPSTDBuffer { + length += 2 + } + + if h.HasExtension2 { + length += 1 + uint8(len(h.Extension2Data)) + } + } + + return +} + +func writePESOptionalHeader(w *astikit.BitsWriter, h *PESOptionalHeader) (int, error) { + if h == nil { + return 0, nil + } + + b := astikit.NewBitsWriterBatch(w) + + b.WriteN(uint8(0b10), 2) // marker bits + b.WriteN(h.ScramblingControl, 2) + b.Write(h.Priority) + b.Write(h.DataAlignmentIndicator) + b.Write(h.IsCopyrighted) + b.Write(h.IsOriginal) + + b.WriteN(h.PTSDTSIndicator, 2) + b.Write(h.HasESCR) + b.Write(h.HasESRate) + b.Write(h.HasDSMTrickMode) + b.Write(h.HasAdditionalCopyInfo) + b.Write(false) // CRC of previous PES packet. not supported yet + //b.Write(h.HasCRC) + b.Write(h.HasExtension) + + pesOptionalHeaderDataLength := calcPESOptionalHeaderDataLength(h) + b.Write(pesOptionalHeaderDataLength) + + bytesWritten := 3 + + if h.PTSDTSIndicator == PTSDTSIndicatorOnlyPTS { + n, err := writePTSOrDTS(w, 0b0010, h.PTS) + if err != nil { + return 0, err + } + bytesWritten += n + } + + if h.PTSDTSIndicator == PTSDTSIndicatorBothPresent { + n, err := writePTSOrDTS(w, 0b0011, h.PTS) + if err != nil { + return 0, err + } + bytesWritten += n + + n, err = writePTSOrDTS(w, 0b0001, h.DTS) + if err != nil { + return 0, err + } + bytesWritten += n + } + + if h.HasESCR { + n, err := writeESCR(w, h.ESCR) + if err != nil { + return 0, err + } + bytesWritten += n + } + + if h.HasESRate { + b.Write(true) + b.WriteN(h.ESRate, 22) + b.Write(true) + bytesWritten += 3 + } + + if h.HasDSMTrickMode { + n, err := writeDSMTrickMode(w, h.DSMTrickMode) + if err != nil { + return 0, err + } + bytesWritten += n + } + + if h.HasAdditionalCopyInfo { + b.Write(true) // marker_bit + b.WriteN(h.AdditionalCopyInfo, 7) + bytesWritten++ + } + + if h.HasCRC { + // TODO, not supported + } + + if h.HasExtension { + // exp 10110001 + // act 10111111 + b.Write(h.HasPrivateData) + b.Write(false) // TODO pack_header_field_flag, not implemented + //b.Write(h.HasPackHeaderField) + b.Write(h.HasProgramPacketSequenceCounter) + b.Write(h.HasPSTDBuffer) + b.WriteN(uint8(0xff), 3) // reserved + b.Write(h.HasExtension2) + bytesWritten++ + + if h.HasPrivateData { + b.WriteBytesN(h.PrivateData, 16, 0) + bytesWritten += 16 + } + + if h.HasPackHeaderField { + // TODO (see parsePESOptionalHeader) + } + + if h.HasProgramPacketSequenceCounter { + b.Write(true) // marker_bit + b.WriteN(h.PacketSequenceCounter, 7) + b.Write(true) // marker_bit + b.WriteN(h.MPEG1OrMPEG2ID, 1) + b.WriteN(h.OriginalStuffingLength, 6) + bytesWritten += 2 + } + + if h.HasPSTDBuffer { + b.WriteN(uint8(0b01), 2) + b.WriteN(h.PSTDBufferScale, 1) + b.WriteN(h.PSTDBufferSize, 13) + bytesWritten += 2 + } + + if h.HasExtension2 { + b.Write(true) // marker_bit + b.WriteN(uint8(len(h.Extension2Data)), 7) + b.Write(h.Extension2Data) + bytesWritten += 1 + len(h.Extension2Data) + } + } + + return bytesWritten, b.Err() +} + +func writeDSMTrickMode(w *astikit.BitsWriter, m *DSMTrickMode) (int, error) { + b := astikit.NewBitsWriterBatch(w) + + b.WriteN(m.TrickModeControl, 3) + if m.TrickModeControl == TrickModeControlFastForward || m.TrickModeControl == TrickModeControlFastReverse { + b.WriteN(m.FieldID, 2) + b.Write(m.IntraSliceRefresh == 1) // it should be boolean + b.WriteN(m.FrequencyTruncation, 2) + } else if m.TrickModeControl == TrickModeControlFreezeFrame { + b.WriteN(m.FieldID, 2) + b.WriteN(uint8(0xff), 3) // reserved + } else if m.TrickModeControl == TrickModeControlSlowMotion || m.TrickModeControl == TrickModeControlSlowReverse { + b.WriteN(m.RepeatControl, 5) + } else { + b.WriteN(uint8(0xff), 5) // reserved + } + + return dsmTrickModeLength, b.Err() +} + +func writeESCR(w *astikit.BitsWriter, cr *ClockReference) (int, error) { + b := astikit.NewBitsWriterBatch(w) + + b.WriteN(uint8(0xff), 2) + b.WriteN(uint64(cr.Base>>30), 3) + b.Write(true) + b.WriteN(uint64(cr.Base>>15), 15) + b.Write(true) + b.WriteN(uint64(cr.Base), 15) + b.Write(true) + b.WriteN(uint64(cr.Extension), 9) + b.Write(true) + + return escrLength, b.Err() +} + +func writePTSOrDTS(w *astikit.BitsWriter, flag uint8, cr *ClockReference) (bytesWritten int, retErr error) { + b := astikit.NewBitsWriterBatch(w) + + b.WriteN(flag, 4) + b.WriteN(uint64(cr.Base>>30), 3) + b.Write(true) + b.WriteN(uint64(cr.Base>>15), 15) + b.Write(true) + b.WriteN(uint64(cr.Base), 15) + b.Write(true) + + return ptsOrDTSByteLength, b.Err() +} diff --git a/data_pes_test.go b/data_pes_test.go index f241904..4004438 100644 --- a/data_pes_test.go +++ b/data_pes_test.go @@ -31,75 +31,169 @@ func dsmTrickModeSlowBytes() []byte { return buf.Bytes() } +type dsmTrickModeTestCase struct { + name string + bytesFunc func(w *astikit.BitsWriter) + trickMode *DSMTrickMode +} + +var dsmTrickModeTestCases = []dsmTrickModeTestCase{ + { + "fast_forward", + func(w *astikit.BitsWriter) { + w.Write("000") // Control + w.Write("10") // Field ID + w.Write("1") // Intra slice refresh + w.Write("11") // Frequency truncation + }, + &DSMTrickMode{ + FieldID: 2, + FrequencyTruncation: 3, + IntraSliceRefresh: 1, + TrickModeControl: TrickModeControlFastForward, + }, + }, + { + "slow_motion", + func(w *astikit.BitsWriter) { + w.Write("001") + w.Write("10101") + }, + &DSMTrickMode{ + RepeatControl: 0b10101, + TrickModeControl: TrickModeControlSlowMotion, + }, + }, + { + "freeze_frame", + func(w *astikit.BitsWriter) { + w.Write("010") // Control + w.Write("10") // Field ID + w.Write("111") // Reserved + }, + &DSMTrickMode{ + FieldID: 2, + TrickModeControl: TrickModeControlFreezeFrame, + }, + }, + { + "fast_reverse", + func(w *astikit.BitsWriter) { + w.Write("011") // Control + w.Write("10") // Field ID + w.Write("1") // Intra slice refresh + w.Write("11") // Frequency truncation + }, + &DSMTrickMode{ + FieldID: 2, + FrequencyTruncation: 3, + IntraSliceRefresh: 1, + TrickModeControl: TrickModeControlFastReverse, + }, + }, + { + "slow_reverse", + func(w *astikit.BitsWriter) { + w.Write("100") + w.Write("01010") + }, + &DSMTrickMode{ + RepeatControl: 0b01010, + TrickModeControl: TrickModeControlSlowReverse, + }, + }, + { + "reserved", + func(w *astikit.BitsWriter) { + w.Write("101") + w.Write("11111") + }, + &DSMTrickMode{ + TrickModeControl: 5, // reserved + }, + }, +} + func TestParseDSMTrickMode(t *testing.T) { - // Fast - buf := &bytes.Buffer{} - w := astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: buf}) - w.Write("011") // Control - w.Write("10") // Field ID - w.Write("1") // Intra slice refresh - w.Write("11") // Frequency truncation - assert.Equal(t, parseDSMTrickMode(buf.Bytes()[0]), &DSMTrickMode{ - FieldID: 2, - FrequencyTruncation: 3, - IntraSliceRefresh: 1, - TrickModeControl: TrickModeControlFastReverse, - }) - - // Freeze - buf.Reset() - w.Write("010") // Control - w.Write("10") // Field ID - w.Write("000") // Reserved - assert.Equal(t, parseDSMTrickMode(buf.Bytes()[0]), &DSMTrickMode{ - FieldID: 2, - TrickModeControl: TrickModeControlFreezeFrame, - }) - - // Slow - assert.Equal(t, parseDSMTrickMode(dsmTrickModeSlowBytes()[0]), dsmTrickModeSlow) + for _, tc := range dsmTrickModeTestCases { + t.Run(tc.name, func(t *testing.T) { + buf := &bytes.Buffer{} + w := astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: buf}) + tc.bytesFunc(w) + assert.Equal(t, parseDSMTrickMode(buf.Bytes()[0]), tc.trickMode) + }) + } +} + +func TestWriteDSMTrickMode(t *testing.T) { + for _, tc := range dsmTrickModeTestCases { + t.Run(tc.name, func(t *testing.T) { + bufExpected := &bytes.Buffer{} + wExpected := astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: bufExpected}) + tc.bytesFunc(wExpected) + + bufActual := &bytes.Buffer{} + wActual := astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: bufActual}) + + n, err := writeDSMTrickMode(wActual, tc.trickMode) + assert.NoError(t, err) + assert.Equal(t, 1, n) + assert.Equal(t, n, bufActual.Len()) + assert.Equal(t, bufExpected.Bytes(), bufActual.Bytes()) + }) + } } var ptsClockReference = &ClockReference{Base: 5726623061} -func ptsBytes() []byte { +func ptsBytes(flag string) []byte { buf := &bytes.Buffer{} w := astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: buf}) - w.Write("0010") // Flag + w.Write(flag) // Flag w.Write("101") // 32...30 - w.Write("0") // Dummy + w.Write("1") // Dummy w.Write("010101010101010") // 29...15 - w.Write("0") // Dummy + w.Write("1") // Dummy w.Write("101010101010101") // 14...0 - w.Write("0") // Dummy + w.Write("1") // Dummy return buf.Bytes() } var dtsClockReference = &ClockReference{Base: 5726623060} -func dtsBytes() []byte { +func dtsBytes(flag string) []byte { buf := &bytes.Buffer{} w := astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: buf}) - w.Write("0010") // Flag + w.Write(flag) // Flag w.Write("101") // 32...30 - w.Write("0") // Dummy + w.Write("1") // Dummy w.Write("010101010101010") // 29...15 - w.Write("0") // Dummy + w.Write("1") // Dummy w.Write("101010101010100") // 14...0 - w.Write("0") // Dummy + w.Write("1") // Dummy return buf.Bytes() } func TestParsePTSOrDTS(t *testing.T) { - v, err := parsePTSOrDTS(astikit.NewBytesIterator(ptsBytes())) + v, err := parsePTSOrDTS(astikit.NewBytesIterator(ptsBytes("0010"))) assert.Equal(t, v, ptsClockReference) assert.NoError(t, err) } +func TestWritePTSOrDTS(t *testing.T) { + buf := &bytes.Buffer{} + w := astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: buf}) + n, err := writePTSOrDTS(w, uint8(0b0010), dtsClockReference) + assert.NoError(t, err) + assert.Equal(t, n, 5) + assert.Equal(t, n, buf.Len()) + assert.Equal(t, dtsBytes("0010"), buf.Bytes()) +} + func escrBytes() []byte { buf := &bytes.Buffer{} w := astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: buf}) - w.Write("00") // Dummy + w.Write("11") // Dummy w.Write("011") // 32...30 w.Write("1") // Dummy w.Write("000010111110000") // 29...15 @@ -117,121 +211,277 @@ func TestParseESCR(t *testing.T) { assert.NoError(t, err) } -var pesWithoutHeader = &PESData{ - Data: []byte("data"), - Header: &PESHeader{ - PacketLength: 4, - StreamID: StreamIDPaddingStream, - }, -} - -func pesWithoutHeaderBytes() []byte { +func TestWriteESCR(t *testing.T) { buf := &bytes.Buffer{} w := astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: buf}) - w.Write("000000000000000000000001") // Prefix - w.Write(uint8(StreamIDPaddingStream)) // Stream ID - w.Write(uint16(4)) // Packet length - w.Write([]byte("datastuff")) // Data - return buf.Bytes() + n, err := writeESCR(w, clockReference) + assert.NoError(t, err) + assert.Equal(t, n, 6) + assert.Equal(t, n, buf.Len()) + assert.Equal(t, escrBytes(), buf.Bytes()) +} + +type pesTestCase struct { + name string + headerBytesFunc func(w *astikit.BitsWriter, withStuffing bool, withCRC bool) + optionalHeaderBytesFunc func(w *astikit.BitsWriter, withStuffing bool, withCRC bool) + bytesFunc func(w *astikit.BitsWriter, withStuffing bool, withCRC bool) + pesData *PESData } -var pesWithHeader = &PESData{ - Data: []byte("data"), - Header: &PESHeader{ - OptionalHeader: &PESOptionalHeader{ - AdditionalCopyInfo: 127, - CRC: 4, - DataAlignmentIndicator: true, - DSMTrickMode: dsmTrickModeSlow, - DTS: dtsClockReference, - ESCR: clockReference, - ESRate: 1398101, - Extension2Data: []byte("extension2"), - Extension2Length: 10, - HasAdditionalCopyInfo: true, - HasCRC: true, - HasDSMTrickMode: true, - HasESCR: true, - HasESRate: true, - HasExtension: true, - HasExtension2: true, - HasPackHeaderField: true, - HasPrivateData: true, - HasProgramPacketSequenceCounter: true, - HasPSTDBuffer: true, - HeaderLength: 62, - IsCopyrighted: true, - IsOriginal: true, - MarkerBits: 2, - MPEG1OrMPEG2ID: 1, - OriginalStuffingLength: 21, - PacketSequenceCounter: 85, - PackField: 5, - Priority: true, - PrivateData: []byte("1234567890123456"), - PSTDBufferScale: 1, - PSTDBufferSize: 5461, - PTSDTSIndicator: 3, - PTS: ptsClockReference, - ScramblingControl: 1, - }, - PacketLength: 69, - StreamID: 1, +var pesTestCases = []pesTestCase{ + { + "without_header", + func(w *astikit.BitsWriter, withStuffing bool, withCRC bool) { + w.Write("000000000000000000000001") // Prefix + w.Write(uint8(StreamIDPaddingStream)) // Stream ID + w.Write(uint16(4)) // Packet length + }, + func(w *astikit.BitsWriter, withStuffing bool, withCRC bool) { + // do nothing here + }, + func(w *astikit.BitsWriter, withStuffing bool, withCRC bool) { + w.Write([]byte("data")) // Data + if withStuffing { + w.Write([]byte("stuff")) // Stuffing + } + }, + &PESData{ + Data: []byte("data"), + Header: &PESHeader{ + PacketLength: 4, + StreamID: StreamIDPaddingStream, + }, + }, + }, + { + "with_header", + func(w *astikit.BitsWriter, withStuffing bool, withCRC bool) { + packetLength := 67 + stuffing := []byte("stuff") + + if !withStuffing { + packetLength -= len(stuffing) + } + + if !withCRC { + packetLength -= 2 + } + + w.Write("000000000000000000000001") // Prefix + w.Write(uint8(1)) // Stream ID + w.Write(uint16(packetLength)) // Packet length + + }, + func(w *astikit.BitsWriter, withStuffing bool, withCRC bool) { + optionalHeaderLength := 60 + stuffing := []byte("stuff") + + if !withStuffing { + optionalHeaderLength -= len(stuffing) + } + + if !withCRC { + optionalHeaderLength -= 2 + } + + w.Write("10") // Marker bits + w.Write("01") // Scrambling control + w.Write("1") // Priority + w.Write("1") // Data alignment indicator + w.Write("1") // Copyright + w.Write("1") // Original or copy + w.Write("11") // PTS/DTS indicator + w.Write("1") // ESCR flag + w.Write("1") // ES rate flag + w.Write("1") // DSM trick mode flag + w.Write("1") // Additional copy flag + w.Write(withCRC) // CRC flag + w.Write("1") // Extension flag + w.Write(uint8(optionalHeaderLength)) // Header length + w.Write(ptsBytes("0011")) // PTS + w.Write(dtsBytes("0001")) // DTS + w.Write(escrBytes()) // ESCR + w.Write("101010101010101010101011") // ES rate + w.Write(dsmTrickModeSlowBytes()) // DSM trick mode + w.Write("11111111") // Additional copy info + if withCRC { + w.Write(uint16(4)) // CRC + } + // Extension starts here + w.Write("1") // Private data flag + w.Write("0") // Pack header field flag + w.Write("1") // Program packet sequence counter flag + w.Write("1") // PSTD buffer flag + w.Write("111") // Dummy + w.Write("1") // Extension 2 flag + w.Write([]byte("1234567890123456")) // Private data + //w.Write(uint8(5)) // Pack field + w.Write("1101010111010101") // Packet sequence counter + w.Write("0111010101010101") // PSTD buffer + w.Write("10001010") // Extension 2 header + w.Write([]byte("extension2")) // Extension 2 data + if withStuffing { + w.Write(stuffing) // Optional header stuffing bytes + } + }, + func(w *astikit.BitsWriter, withStuffing bool, withCRC bool) { + stuffing := []byte("stuff") + w.Write([]byte("data")) // Data + if withStuffing { + w.Write(stuffing) // Stuffing + } + }, + &PESData{ + Data: []byte("data"), + Header: &PESHeader{ + OptionalHeader: &PESOptionalHeader{ + AdditionalCopyInfo: 127, + CRC: 4, + DataAlignmentIndicator: true, + DSMTrickMode: dsmTrickModeSlow, + DTS: dtsClockReference, + ESCR: clockReference, + ESRate: 1398101, + Extension2Data: []byte("extension2"), + Extension2Length: 10, + HasAdditionalCopyInfo: true, + HasCRC: true, + HasDSMTrickMode: true, + HasESCR: true, + HasESRate: true, + HasExtension: true, + HasExtension2: true, + HasPackHeaderField: false, + HasPrivateData: true, + HasProgramPacketSequenceCounter: true, + HasPSTDBuffer: true, + HeaderLength: 60, + IsCopyrighted: true, + IsOriginal: true, + MarkerBits: 2, + MPEG1OrMPEG2ID: 1, + OriginalStuffingLength: 21, + PacketSequenceCounter: 85, + //PackField: 5, + Priority: true, + PrivateData: []byte("1234567890123456"), + PSTDBufferScale: 1, + PSTDBufferSize: 5461, + PTSDTSIndicator: 3, + PTS: ptsClockReference, + ScramblingControl: 1, + }, + PacketLength: 67, + StreamID: 1, + }, + }, }, } +// used by TestParseData func pesWithHeaderBytes() []byte { - buf := &bytes.Buffer{} - w := astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: buf}) - w.Write("000000000000000000000001") // Prefix - w.Write(uint8(1)) // Stream ID - w.Write(uint16(69)) // Packet length - w.Write("10") // Marker bits - w.Write("01") // Scrambling control - w.Write("1") // Priority - w.Write("1") // Data alignment indicator - w.Write("1") // Copyright - w.Write("1") // Original or copy - w.Write("11") // PTS/DTS indicator - w.Write("1") // ESCR flag - w.Write("1") // ES rate flag - w.Write("1") // DSM trick mode flag - w.Write("1") // Additional copy flag - w.Write("1") // CRC flag - w.Write("1") // Extension flag - w.Write(uint8(62)) // Header length - w.Write(ptsBytes()) // PTS - w.Write(dtsBytes()) // DTS - w.Write(escrBytes()) // ESCR - w.Write("101010101010101010101010") // ES rate - w.Write(dsmTrickModeSlowBytes()) // DSM trick mode - w.Write("11111111") // Additional copy info - w.Write(uint16(4)) // CRC - w.Write("1") // Private data flag - w.Write("1") // Pack header field flag - w.Write("1") // Program packet sequence counter flag - w.Write("1") // PSTD buffer flag - w.Write("000") // Dummy - w.Write("1") // Extension 2 flag - w.Write([]byte("1234567890123456")) // Private data - w.Write(uint8(5)) // Pack field - w.Write("0101010101010101") // Packet sequence counter - w.Write("0111010101010101") // PSTD buffer - w.Write("0000101000000000") // Extension 2 header - w.Write([]byte("extension2")) // Extension 2 data - w.Write([]byte("stuff")) // Optional header stuffing bytes - w.Write([]byte("datastuff")) // Data + buf := bytes.Buffer{} + w := astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: &buf}) + pesTestCases[1].headerBytesFunc(w, true, true) + pesTestCases[1].optionalHeaderBytesFunc(w, true, true) + pesTestCases[1].bytesFunc(w, true, true) return buf.Bytes() } +// used by TestParseData +func pesWithHeader() *PESData { + return pesTestCases[1].pesData +} + func TestParsePESData(t *testing.T) { - // No optional header and specific packet length - d, err := parsePESData(astikit.NewBytesIterator(pesWithoutHeaderBytes())) - assert.NoError(t, err) - assert.Equal(t, pesWithoutHeader, d) + for _, tc := range pesTestCases { + t.Run(tc.name, func(t *testing.T) { + buf := bytes.Buffer{} + w := astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: &buf}) + tc.headerBytesFunc(w, true, true) + tc.optionalHeaderBytesFunc(w, true, true) + tc.bytesFunc(w, true, true) + d, err := parsePESData(astikit.NewBytesIterator(buf.Bytes())) + assert.NoError(t, err) + assert.Equal(t, tc.pesData, d) + }) + } +} - // Optional header and no specific header length - d, err = parsePESData(astikit.NewBytesIterator(pesWithHeaderBytes())) - assert.NoError(t, err) - assert.Equal(t, pesWithHeader, d) +func TestWritePESData(t *testing.T) { + for _, tc := range pesTestCases { + t.Run(tc.name, func(t *testing.T) { + bufExpected := bytes.Buffer{} + wExpected := astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: &bufExpected}) + tc.headerBytesFunc(wExpected, false, false) + tc.optionalHeaderBytesFunc(wExpected, false, false) + tc.bytesFunc(wExpected, false, false) + + bufActual := bytes.Buffer{} + wActual := astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: &bufActual}) + + start := true + totalBytes := 0 + payloadPos := 0 + + for payloadPos+1 < len(tc.pesData.Data) { + n, payloadN, err := writePESData( + wActual, + tc.pesData.Header, + tc.pesData.Data[payloadPos:], + start, + MpegTsPacketSize-mpegTsPacketHeaderSize, + ) + assert.NoError(t, err) + start = false + + totalBytes += n + payloadPos += payloadN + } + + assert.Equal(t, totalBytes, bufActual.Len()) + assert.Equal(t, bufExpected.Len(), bufActual.Len()) + assert.Equal(t, bufExpected.Bytes(), bufActual.Bytes()) + }) + } +} + +func TestWritePESHeader(t *testing.T) { + for _, tc := range pesTestCases { + t.Run(tc.name, func(t *testing.T) { + bufExpected := bytes.Buffer{} + wExpected := astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: &bufExpected}) + tc.headerBytesFunc(wExpected, false, false) + tc.optionalHeaderBytesFunc(wExpected, false, false) + + bufActual := bytes.Buffer{} + wActual := astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: &bufActual}) + + n, err := writePESHeader(wActual, tc.pesData.Header, len(tc.pesData.Data)) + assert.NoError(t, err) + assert.Equal(t, n, bufActual.Len()) + assert.Equal(t, bufExpected.Len(), bufActual.Len()) + assert.Equal(t, bufExpected.Bytes(), bufActual.Bytes()) + }) + } +} + +func TestWritePESOptionalHeader(t *testing.T) { + for _, tc := range pesTestCases { + t.Run(tc.name, func(t *testing.T) { + bufExpected := bytes.Buffer{} + wExpected := astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: &bufExpected}) + tc.optionalHeaderBytesFunc(wExpected, false, false) + + bufActual := bytes.Buffer{} + wActual := astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: &bufActual}) + + n, err := writePESOptionalHeader(wActual, tc.pesData.Header.OptionalHeader) + assert.NoError(t, err) + assert.Equal(t, n, bufActual.Len()) + assert.Equal(t, bufExpected.Len(), bufActual.Len()) + assert.Equal(t, bufExpected.Bytes(), bufActual.Bytes()) + }) + } } diff --git a/data_pmt.go b/data_pmt.go index 3cbbfeb..13e21d2 100644 --- a/data_pmt.go +++ b/data_pmt.go @@ -6,14 +6,33 @@ import ( "github.com/asticode/go-astikit" ) +type StreamType uint8 + // Stream types const ( - StreamTypeMPEG1Audio = 0x03 // ISO/IEC 11172-3 - StreamTypeMPEG2HalvedSampleRateAudio = 0x04 // ISO/IEC 13818-3 - StreamTypeMPEG2PacketizedData = 0x06 // Rec. ITU-T H.222 | ISO/IEC 13818-1 i.e., DVB subtitles/VBI and AC-3 - StreamTypeADTS = 0x0F // ISO/IEC 13818-7 Audio with ADTS transport syntax - StreamTypeH264Video = 0x1B // Rec. ITU-T H.264 | ISO/IEC 14496-10 - StreamTypeH265Video = 0x24 // Rec. ITU-T H.265 | ISO/IEC 23008-2 + StreamTypeMPEG1Video StreamType = 0x01 + StreamTypeMPEG2Video StreamType = 0x02 + StreamTypeMPEG1Audio StreamType = 0x03 // ISO/IEC 11172-3 + StreamTypeMPEG2HalvedSampleRateAudio StreamType = 0x04 // ISO/IEC 13818-3 + StreamTypeMPEG2Audio StreamType = 0x04 + StreamTypePrivateSection StreamType = 0x05 + StreamTypePrivateData StreamType = 0x06 + StreamTypeMPEG2PacketizedData StreamType = 0x06 // Rec. ITU-T H.222 | ISO/IEC 13818-1 i.e., DVB subtitles/VBI and AC-3 + StreamTypeADTS StreamType = 0x0F // ISO/IEC 13818-7 Audio with ADTS transport syntax + StreamTypeAACAudio StreamType = 0x0f + StreamTypeMPEG4Video StreamType = 0x10 + StreamTypeAACLATMAudio StreamType = 0x11 + StreamTypeMetadata StreamType = 0x15 + StreamTypeH264Video StreamType = 0x1B // Rec. ITU-T H.264 | ISO/IEC 14496-10 + StreamTypeH265Video StreamType = 0x24 // Rec. ITU-T H.265 | ISO/IEC 23008-2 + StreamTypeHEVCVideo StreamType = 0x24 + StreamTypeCAVSVideo StreamType = 0x42 + StreamTypeVC1Video StreamType = 0xea + StreamTypeDIRACVideo StreamType = 0xd1 + StreamTypeAC3Audio StreamType = 0x81 + StreamTypeDTSAudio StreamType = 0x82 + StreamTypeTRUEHDAudio StreamType = 0x83 + StreamTypeEAC3Audio StreamType = 0x87 ) // PMTData represents a PMT data @@ -29,7 +48,7 @@ type PMTData struct { type PMTElementaryStream struct { ElementaryPID uint16 // The packet identifier that contains the stream type data. ElementaryStreamDescriptors []*Descriptor // Elementary stream descriptors - StreamType uint8 // This defines the structure of the data contained within the elementary packet identifier. + StreamType StreamType // This defines the structure of the data contained within the elementary packet identifier. } // parsePMTSection parses a PMT section @@ -66,7 +85,7 @@ func parsePMTSection(i *astikit.BytesIterator, offsetSectionsEnd int, tableIDExt } // Stream type - e.StreamType = uint8(b) + e.StreamType = StreamType(b) // Get next bytes if bs, err = i.NextBytes(2); err != nil { @@ -88,3 +107,150 @@ func parsePMTSection(i *astikit.BytesIterator, offsetSectionsEnd int, tableIDExt } return } + +func calcPMTProgramInfoLength(d *PMTData) uint16 { + ret := uint16(2) // program_info_length + ret += calcDescriptorsLength(d.ProgramDescriptors) + + for _, es := range d.ElementaryStreams { + ret += 5 // stream_type, elementary_pid, es_info_length + ret += calcDescriptorsLength(es.ElementaryStreamDescriptors) + } + + return ret +} + +func calcPMTSectionLength(d *PMTData) uint16 { + ret := uint16(4) + ret += calcDescriptorsLength(d.ProgramDescriptors) + + for _, es := range d.ElementaryStreams { + ret += 5 + ret += calcDescriptorsLength(es.ElementaryStreamDescriptors) + } + + return ret +} + +func writePMTSection(w *astikit.BitsWriter, d *PMTData) (int, error) { + b := astikit.NewBitsWriterBatch(w) + + // TODO split into sections + + b.WriteN(uint8(0xff), 3) + b.WriteN(d.PCRPID, 13) + bytesWritten := 2 + + n, err := writeDescriptorsWithLength(w, d.ProgramDescriptors) + if err != nil { + return 0, err + } + bytesWritten += n + + for _, es := range d.ElementaryStreams { + b.Write(uint8(es.StreamType)) + b.WriteN(uint8(0xff), 3) + b.WriteN(es.ElementaryPID, 13) + bytesWritten += 3 + + n, err = writeDescriptorsWithLength(w, es.ElementaryStreamDescriptors) + if err != nil { + return 0, err + } + bytesWritten += n + } + + return bytesWritten, b.Err() +} + +func (t StreamType) IsVideo() bool { + switch t { + case StreamTypeMPEG1Video, + StreamTypeMPEG2Video, + StreamTypeMPEG4Video, + StreamTypeH264Video, + StreamTypeH265Video, + StreamTypeCAVSVideo, + StreamTypeVC1Video, + StreamTypeDIRACVideo: + return true + } + return false +} + +func (t StreamType) IsAudio() bool { + switch t { + case StreamTypeMPEG1Audio, + StreamTypeMPEG2Audio, + StreamTypeAACAudio, + StreamTypeAACLATMAudio, + StreamTypeAC3Audio, + StreamTypeDTSAudio, + StreamTypeTRUEHDAudio, + StreamTypeEAC3Audio: + return true + } + return false +} + +func (t StreamType) String() string { + switch t { + case StreamTypeMPEG1Video: + return "MPEG1 Video" + case StreamTypeMPEG2Video: + return "MPEG2 Video" + case StreamTypeMPEG1Audio: + return "MPEG1 Audio" + case StreamTypeMPEG2Audio: + return "MPEG2 Audio" + case StreamTypePrivateSection: + return "Private Section" + case StreamTypePrivateData: + return "Private Data" + case StreamTypeAACAudio: + return "AAC Audio" + case StreamTypeMPEG4Video: + return "MPEG4 Video" + case StreamTypeAACLATMAudio: + return "AAC LATM Audio" + case StreamTypeMetadata: + return "Metadata" + case StreamTypeH264Video: + return "H264 Video" + case StreamTypeH265Video: + return "H265 Video" + case StreamTypeCAVSVideo: + return "CAVS Video" + case StreamTypeVC1Video: + return "VC1 Video" + case StreamTypeDIRACVideo: + return "DIRAC Video" + case StreamTypeAC3Audio: + return "AC3 Audio" + case StreamTypeDTSAudio: + return "DTS Audio" + case StreamTypeTRUEHDAudio: + return "TRUEHD Audio" + case StreamTypeEAC3Audio: + return "EAC3 Audio" + } + return "Unknown" +} + +func (t StreamType) ToPESStreamID() uint8 { + switch t { + case StreamTypeMPEG1Video, StreamTypeMPEG2Video, StreamTypeMPEG4Video, StreamTypeH264Video, + StreamTypeH265Video, StreamTypeCAVSVideo, StreamTypeVC1Video: + return 0xe0 + case StreamTypeDIRACVideo: + return 0xfd + case StreamTypeMPEG2Audio, StreamTypeAACAudio, StreamTypeAACLATMAudio: + return 0xc0 + case StreamTypeAC3Audio, StreamTypeEAC3Audio: // m2ts_mode??? + return 0xfd + case StreamTypePrivateSection, StreamTypePrivateData, StreamTypeMetadata: + return 0xfc + default: + return 0xbd + } +} diff --git a/data_pmt_test.go b/data_pmt_test.go index 15b8423..064a1a0 100644 --- a/data_pmt_test.go +++ b/data_pmt_test.go @@ -40,3 +40,12 @@ func TestParsePMTSection(t *testing.T) { assert.Equal(t, d, pmt) assert.NoError(t, err) } + +func TestWritePMTSection(t *testing.T) { + buf := bytes.Buffer{} + w := astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: &buf}) + n, err := writePMTSection(w, pmt) + assert.NoError(t, err) + assert.Equal(t, n, buf.Len()) + assert.Equal(t, pmtBytes(), buf.Bytes()) +} diff --git a/data_psi.go b/data_psi.go index f9146a4..242fd53 100644 --- a/data_psi.go +++ b/data_psi.go @@ -24,6 +24,28 @@ const ( PSITableTypeUnknown = "Unknown" ) +type PSITableID uint16 + +const ( + PSITableIDPAT PSITableID = 0x00 + PSITableIDPMT PSITableID = 0x02 + PSITableIDBAT PSITableID = 0x4a + PSITableIDDIT PSITableID = 0x7e + PSITableIDRST PSITableID = 0x71 + PSITableIDSIT PSITableID = 0x7f + PSITableIDST PSITableID = 0x72 + PSITableIDTDT PSITableID = 0x70 + PSITableIDTOT PSITableID = 0x73 + PSITableIDNull PSITableID = 0xff + + PSITableIDEITStart PSITableID = 0x4e + PSITableIDEITEnd PSITableID = 0x6f + PSITableIDSDTVariant1 PSITableID = 0x42 + PSITableIDSDTVariant2 PSITableID = 0x46 + PSITableIDNITVariant1 PSITableID = 0x40 + PSITableIDNITVariant2 PSITableID = 0x41 +) + // PSIData represents a PSI data // https://en.wikipedia.org/wiki/Program-specific_information type PSIData struct { @@ -40,10 +62,10 @@ type PSISection struct { // PSISectionHeader represents a PSI section header type PSISectionHeader struct { - PrivateBit bool // The PAT, PMT, and CAT all set this to 0. Other tables set this to 1. - SectionLength uint16 // The number of bytes that follow for the syntax section (with CRC value) and/or table data. These bytes must not exceed a value of 1021. - SectionSyntaxIndicator bool // A flag that indicates if the syntax section follows the section length. The PAT, PMT, and CAT all set this to 1. - TableID int // Table Identifier, that defines the structure of the syntax section and other contained data. As an exception, if this is the byte that immediately follow previous table section and is set to 0xFF, then it indicates that the repeat of table section end here and the rest of TS data payload shall be stuffed with 0xFF. Consequently the value 0xFF shall not be used for the Table Identifier. + PrivateBit bool // The PAT, PMT, and CAT all set this to 0. Other tables set this to 1. + SectionLength uint16 // The number of bytes that follow for the syntax section (with CRC value) and/or table data. These bytes must not exceed a value of 1021. + SectionSyntaxIndicator bool // A flag that indicates if the syntax section follows the section length. The PAT, PMT, and CAT all set this to 1. + TableID PSITableID // Table Identifier, that defines the structure of the syntax section and other contained data. As an exception, if this is the byte that immediately follow previous table section and is set to 0xFF, then it indicates that the repeat of table section end here and the rest of TS data payload shall be stuffed with 0xFF. Consequently the value 0xFF shall not be used for the Table Identifier. TableType string } @@ -116,7 +138,7 @@ func parsePSISection(i *astikit.BytesIterator) (s *PSISection, stop bool, err er } // Check whether we need to stop the parsing - if shouldStopPSIParsing(s.Header.TableType) { + if shouldStopPSIParsing(s.Header.TableID) { stop = true return } @@ -130,7 +152,7 @@ func parsePSISection(i *astikit.BytesIterator) (s *PSISection, stop bool, err er } // Process CRC32 - if hasCRC32(s.Header.TableType) { + if s.Header.TableID.hasCRC32() { // Seek to the end of the sections i.Seek(offsetSectionsEnd) @@ -149,11 +171,7 @@ func parsePSISection(i *astikit.BytesIterator) (s *PSISection, stop bool, err er } // Compute CRC32 - var crc32 uint32 - if crc32, err = computeCRC32(crc32Data); err != nil { - err = fmt.Errorf("astits: computing CRC32 failed: %w", err) - return - } + crc32 := computeCRC32(crc32Data) // Check CRC32 if crc32 != s.CRC32 { @@ -179,26 +197,10 @@ func parseCRC32(i *astikit.BytesIterator) (c uint32, err error) { return } -// computeCRC32 computes a CRC32 -// https://stackoverflow.com/questions/35034042/how-to-calculate-crc32-in-psi-si-packet -func computeCRC32(bs []byte) (o uint32, err error) { - o = uint32(0xffffffff) - for _, b := range bs { - for i := 0; i < 8; i++ { - if (o >= uint32(0x80000000)) != (b >= uint8(0x80)) { - o = (o << 1) ^ 0x04C11DB7 - } else { - o = o << 1 - } - b <<= 1 - } - } - return -} - // shouldStopPSIParsing checks whether the PSI parsing should be stopped -func shouldStopPSIParsing(tableType string) bool { - return tableType == PSITableTypeNull || tableType == PSITableTypeUnknown +func shouldStopPSIParsing(tableID PSITableID) bool { + return tableID == PSITableIDNull || + tableID.isUnknown() } // parsePSISectionHeader parses a PSI section header @@ -215,13 +217,13 @@ func parsePSISectionHeader(i *astikit.BytesIterator) (h *PSISectionHeader, offse } // Table ID - h.TableID = int(b) + h.TableID = PSITableID(b) // Table type - h.TableType = psiTableType(h.TableID) + h.TableType = h.TableID.Type() // Check whether we need to stop the parsing - if shouldStopPSIParsing(h.TableType) { + if shouldStopPSIParsing(h.TableID) { return } @@ -245,64 +247,96 @@ func parsePSISectionHeader(i *astikit.BytesIterator) (h *PSISectionHeader, offse offsetSectionsStart = i.Offset() offsetEnd = offsetSectionsStart + int(h.SectionLength) offsetSectionsEnd = offsetEnd - if hasCRC32(h.TableType) { + if h.TableID.hasCRC32() { offsetSectionsEnd -= 4 } return } -// hasCRC32 checks whether the table has a CRC32 -func hasCRC32(tableType string) bool { - return tableType == PSITableTypePAT || - tableType == PSITableTypePMT || - tableType == PSITableTypeEIT || - tableType == PSITableTypeNIT || - tableType == PSITableTypeTOT || - tableType == PSITableTypeSDT -} - -// psiTableType returns the psi table type based on the table id +// PSITableID.Type() returns the psi table type based on the table id // Page: 28 | https://www.dvb.org/resources/public/standards/a38_dvb-si_specification.pdf -func psiTableType(tableID int) string { +// (barbashov) the link above can be broken, alternative: https://dvb.org/wp-content/uploads/2019/12/a038_tm1217r37_en300468v1_17_1_-_rev-134_-_si_specification.pdf +func (t PSITableID) Type() string { switch { - case tableID == 0x4a: + case t == PSITableIDBAT: return PSITableTypeBAT - case tableID >= 0x4e && tableID <= 0x6f: + case t >= PSITableIDEITStart && t <= PSITableIDEITEnd: return PSITableTypeEIT - case tableID == 0x7e: + case t == PSITableIDDIT: return PSITableTypeDIT - case tableID == 0x40, tableID == 0x41: + case t == PSITableIDNITVariant1, t == PSITableIDNITVariant2: return PSITableTypeNIT - case tableID == 0xff: + case t == PSITableIDNull: return PSITableTypeNull - case tableID == 0: + case t == PSITableIDPAT: return PSITableTypePAT - case tableID == 2: + case t == PSITableIDPMT: return PSITableTypePMT - case tableID == 0x71: + case t == PSITableIDRST: return PSITableTypeRST - case tableID == 0x42, tableID == 0x46: + case t == PSITableIDSDTVariant1, t == PSITableIDSDTVariant2: return PSITableTypeSDT - case tableID == 0x7f: + case t == PSITableIDSIT: return PSITableTypeSIT - case tableID == 0x72: + case t == PSITableIDST: return PSITableTypeST - case tableID == 0x70: + case t == PSITableIDTDT: return PSITableTypeTDT - case tableID == 0x73: + case t == PSITableIDTOT: return PSITableTypeTOT default: return PSITableTypeUnknown } } +// hasPSISyntaxHeader checks whether the section has a syntax header +func (t PSITableID) hasPSISyntaxHeader() bool { + return t == PSITableIDPAT || + t == PSITableIDPMT || + t == PSITableIDNITVariant1 || t == PSITableIDNITVariant2 || + t == PSITableIDSDTVariant1 || t == PSITableIDSDTVariant2 || + (t >= PSITableIDEITStart && t <= PSITableIDEITEnd) +} + +// hasCRC32 checks whether the table has a CRC32 +func (t PSITableID) hasCRC32() bool { + return t == PSITableIDPAT || + t == PSITableIDPMT || + t == PSITableIDTOT || + t == PSITableIDNITVariant1 || t == PSITableIDNITVariant2 || + t == PSITableIDSDTVariant1 || t == PSITableIDSDTVariant2 || + (t >= PSITableIDEITStart && t <= PSITableIDEITEnd) +} + +func (t PSITableID) isUnknown() bool { + switch t { + case PSITableIDBAT, + PSITableIDDIT, + PSITableIDNITVariant1, PSITableIDNITVariant2, + PSITableIDNull, + PSITableIDPAT, + PSITableIDPMT, + PSITableIDRST, + PSITableIDSDTVariant1, PSITableIDSDTVariant2, + PSITableIDSIT, + PSITableIDST, + PSITableIDTDT, + PSITableIDTOT: + return false + } + if t >= PSITableIDEITStart && t <= PSITableIDEITEnd { + return false + } + return true +} + // parsePSISectionSyntax parses a PSI section syntax func parsePSISectionSyntax(i *astikit.BytesIterator, h *PSISectionHeader, offsetSectionsEnd int) (s *PSISectionSyntax, err error) { // Init s = &PSISectionSyntax{} // Header - if hasPSISyntaxHeader(h.TableType) { + if h.TableID.hasPSISyntaxHeader() { if s.Header, err = parsePSISectionSyntaxHeader(i); err != nil { err = fmt.Errorf("astits: parsing PSI section syntax header failed: %w", err) return @@ -317,15 +351,6 @@ func parsePSISectionSyntax(i *astikit.BytesIterator, h *PSISectionHeader, offset return } -// hasPSISyntaxHeader checks whether the section has a syntax header -func hasPSISyntaxHeader(tableType string) bool { - return tableType == PSITableTypeEIT || - tableType == PSITableTypeNIT || - tableType == PSITableTypePAT || - tableType == PSITableTypePMT || - tableType == PSITableTypeSDT -} - // parsePSISectionSyntaxHeader parses a PSI section syntax header func parsePSISectionSyntaxHeader(i *astikit.BytesIterator) (h *PSISectionSyntaxHeader, err error) { // Init @@ -380,72 +405,204 @@ func parsePSISectionSyntaxData(i *astikit.BytesIterator, h *PSISectionHeader, sh d = &PSISectionSyntaxData{} // Switch on table type - switch h.TableType { - case PSITableTypeBAT: + switch h.TableID { + case PSITableIDBAT: // TODO Parse BAT - case PSITableTypeDIT: + case PSITableIDDIT: // TODO Parse DIT - case PSITableTypeEIT: - if d.EIT, err = parseEITSection(i, offsetSectionsEnd, sh.TableIDExtension); err != nil { - err = fmt.Errorf("astits: parsing EIT section failed: %w", err) - return - } - case PSITableTypeNIT: + case PSITableIDNITVariant1, PSITableIDNITVariant2: if d.NIT, err = parseNITSection(i, sh.TableIDExtension); err != nil { err = fmt.Errorf("astits: parsing NIT section failed: %w", err) return } - case PSITableTypePAT: + case PSITableIDPAT: if d.PAT, err = parsePATSection(i, offsetSectionsEnd, sh.TableIDExtension); err != nil { err = fmt.Errorf("astits: parsing PAT section failed: %w", err) return } - case PSITableTypePMT: + case PSITableIDPMT: if d.PMT, err = parsePMTSection(i, offsetSectionsEnd, sh.TableIDExtension); err != nil { err = fmt.Errorf("astits: parsing PMT section failed: %w", err) return } - case PSITableTypeRST: + case PSITableIDRST: // TODO Parse RST - case PSITableTypeSDT: + case PSITableIDSDTVariant1, PSITableIDSDTVariant2: if d.SDT, err = parseSDTSection(i, offsetSectionsEnd, sh.TableIDExtension); err != nil { err = fmt.Errorf("astits: parsing PMT section failed: %w", err) return } - case PSITableTypeSIT: + case PSITableIDSIT: // TODO Parse SIT - case PSITableTypeST: + case PSITableIDST: // TODO Parse ST - case PSITableTypeTOT: + case PSITableIDTOT: if d.TOT, err = parseTOTSection(i); err != nil { err = fmt.Errorf("astits: parsing TOT section failed: %w", err) return } - case PSITableTypeTDT: + case PSITableIDTDT: // TODO Parse TDT } + + if h.TableID >= PSITableIDEITStart && h.TableID <= PSITableIDEITEnd { + if d.EIT, err = parseEITSection(i, offsetSectionsEnd, sh.TableIDExtension); err != nil { + err = fmt.Errorf("astits: parsing EIT section failed: %w", err) + return + } + } + return } -// toData parses the PSI tables and returns a set of Data -func (d *PSIData) toData(firstPacket *Packet, pid uint16) (ds []*Data) { +// toData parses the PSI tables and returns a set of DemuxerData +func (d *PSIData) toData(firstPacket *Packet, pid uint16) (ds []*DemuxerData) { // Loop through sections for _, s := range d.Sections { // Switch on table type - switch s.Header.TableType { - case PSITableTypeEIT: - ds = append(ds, &Data{EIT: s.Syntax.Data.EIT, FirstPacket: firstPacket, PID: pid}) - case PSITableTypeNIT: - ds = append(ds, &Data{FirstPacket: firstPacket, NIT: s.Syntax.Data.NIT, PID: pid}) - case PSITableTypePAT: - ds = append(ds, &Data{FirstPacket: firstPacket, PAT: s.Syntax.Data.PAT, PID: pid}) - case PSITableTypePMT: - ds = append(ds, &Data{FirstPacket: firstPacket, PID: pid, PMT: s.Syntax.Data.PMT}) - case PSITableTypeSDT: - ds = append(ds, &Data{FirstPacket: firstPacket, PID: pid, SDT: s.Syntax.Data.SDT}) - case PSITableTypeTOT: - ds = append(ds, &Data{FirstPacket: firstPacket, PID: pid, TOT: s.Syntax.Data.TOT}) + switch s.Header.TableID { + case PSITableIDNITVariant1, PSITableIDNITVariant2: + ds = append(ds, &DemuxerData{FirstPacket: firstPacket, NIT: s.Syntax.Data.NIT, PID: pid}) + case PSITableIDPAT: + ds = append(ds, &DemuxerData{FirstPacket: firstPacket, PAT: s.Syntax.Data.PAT, PID: pid}) + case PSITableIDPMT: + ds = append(ds, &DemuxerData{FirstPacket: firstPacket, PID: pid, PMT: s.Syntax.Data.PMT}) + case PSITableIDSDTVariant1, PSITableIDSDTVariant2: + ds = append(ds, &DemuxerData{FirstPacket: firstPacket, PID: pid, SDT: s.Syntax.Data.SDT}) + case PSITableIDTOT: + ds = append(ds, &DemuxerData{FirstPacket: firstPacket, PID: pid, TOT: s.Syntax.Data.TOT}) + } + if s.Header.TableID >= PSITableIDEITStart && s.Header.TableID <= PSITableIDEITEnd { + ds = append(ds, &DemuxerData{EIT: s.Syntax.Data.EIT, FirstPacket: firstPacket, PID: pid}) } } return } + +func writePSIData(w *astikit.BitsWriter, d *PSIData) (int, error) { + b := astikit.NewBitsWriterBatch(w) + b.Write(uint8(d.PointerField)) + for i := 0; i < d.PointerField; i++ { + b.Write(uint8(0x00)) + } + + bytesWritten := 1 + d.PointerField + + if err := b.Err(); err != nil { + return 0, err + } + + for _, s := range d.Sections { + n, err := writePSISection(w, s) + if err != nil { + return 0, err + } + bytesWritten += n + } + + return bytesWritten, nil +} + +func calcPSISectionLength(s *PSISection) uint16 { + ret := uint16(0) + if s.Header.TableID.hasPSISyntaxHeader() { + ret += 5 // PSI syntax header length + } + + switch s.Header.TableID { + case PSITableIDPAT: + ret += calcPATSectionLength(s.Syntax.Data.PAT) + case PSITableIDPMT: + ret += calcPMTSectionLength(s.Syntax.Data.PMT) + } + + if s.Header.TableID.hasCRC32() { + ret += 4 + } + + return ret +} + +func writePSISection(w *astikit.BitsWriter, s *PSISection) (int, error) { + if s.Header.TableID != PSITableIDPAT && s.Header.TableID != PSITableIDPMT { + return 0, fmt.Errorf("writePSISection: table %s is not implemented", s.Header.TableID.Type()) + } + + b := astikit.NewBitsWriterBatch(w) + + sectionLength := calcPSISectionLength(s) + sectionCRC32 := crc32Polynomial + + if s.Header.TableID.hasCRC32() { + w.SetWriteCallback(func(bs []byte) { + sectionCRC32 = updateCRC32(sectionCRC32, bs) + }) + defer w.SetWriteCallback(nil) + } + + b.Write(uint8(s.Header.TableID)) + b.Write(s.Header.SectionSyntaxIndicator) + b.Write(s.Header.PrivateBit) + b.WriteN(uint8(0xff), 2) + b.WriteN(sectionLength, 12) + bytesWritten := 3 + + if s.Header.SectionLength > 0 { + n, err := writePSISectionSyntax(w, s) + if err != nil { + return 0, err + } + bytesWritten += n + + if s.Header.TableID.hasCRC32() { + b.Write(sectionCRC32) + bytesWritten += 4 + } + } + + return bytesWritten, b.Err() +} + +func writePSISectionSyntax(w *astikit.BitsWriter, s *PSISection) (int, error) { + bytesWritten := 0 + if s.Header.TableID.hasPSISyntaxHeader() { + n, err := writePSISectionSyntaxHeader(w, s.Syntax.Header) + if err != nil { + return 0, err + } + bytesWritten += n + } + + n, err := writePSISectionSyntaxData(w, s.Syntax.Data, s.Header.TableID) + if err != nil { + return 0, err + } + bytesWritten += n + + return bytesWritten, nil +} + +func writePSISectionSyntaxHeader(w *astikit.BitsWriter, h *PSISectionSyntaxHeader) (int, error) { + b := astikit.NewBitsWriterBatch(w) + + b.Write(h.TableIDExtension) + b.WriteN(uint8(0xff), 2) + b.WriteN(h.VersionNumber, 5) + b.Write(h.CurrentNextIndicator) + b.Write(h.SectionNumber) + b.Write(h.LastSectionNumber) + + return 5, b.Err() +} + +func writePSISectionSyntaxData(w *astikit.BitsWriter, d *PSISectionSyntaxData, tableID PSITableID) (int, error) { + switch tableID { + // TODO write other table types + case PSITableIDPAT: + return writePATSection(w, d.PAT) + case PSITableIDPMT: + return writePMTSection(w, d.PMT) + } + + return 0, nil +} diff --git a/data_psi_test.go b/data_psi_test.go index 187c710..e63898b 100644 --- a/data_psi_test.go +++ b/data_psi_test.go @@ -220,25 +220,25 @@ func TestParsePSISectionHeader(t *testing.T) { } func TestPSITableType(t *testing.T) { - assert.Equal(t, PSITableTypeBAT, psiTableType(74)) - for i := 78; i <= 111; i++ { - assert.Equal(t, PSITableTypeEIT, psiTableType(i)) + for i := PSITableIDEITStart; i <= PSITableIDEITEnd; i++ { + assert.Equal(t, PSITableTypeEIT, i.Type()) } - assert.Equal(t, PSITableTypeDIT, psiTableType(126)) - for i := 64; i <= 65; i++ { - assert.Equal(t, PSITableTypeNIT, psiTableType(i)) - } - assert.Equal(t, PSITableTypeNull, psiTableType(255)) - assert.Equal(t, PSITableTypePAT, psiTableType(0)) - assert.Equal(t, PSITableTypePMT, psiTableType(2)) - assert.Equal(t, PSITableTypeRST, psiTableType(113)) - assert.Equal(t, PSITableTypeSDT, psiTableType(66)) - assert.Equal(t, PSITableTypeSDT, psiTableType(70)) - assert.Equal(t, PSITableTypeSIT, psiTableType(127)) - assert.Equal(t, PSITableTypeST, psiTableType(114)) - assert.Equal(t, PSITableTypeTDT, psiTableType(112)) - assert.Equal(t, PSITableTypeTOT, psiTableType(115)) - assert.Equal(t, PSITableTypeUnknown, psiTableType(1)) + assert.Equal(t, PSITableTypeDIT, PSITableIDDIT.Type()) + assert.Equal(t, PSITableTypeNIT, PSITableIDNITVariant1.Type()) + assert.Equal(t, PSITableTypeNIT, PSITableIDNITVariant2.Type()) + assert.Equal(t, PSITableTypeSDT, PSITableIDSDTVariant1.Type()) + assert.Equal(t, PSITableTypeSDT, PSITableIDSDTVariant2.Type()) + + assert.Equal(t, PSITableTypeBAT, PSITableIDBAT.Type()) + assert.Equal(t, PSITableTypeNull, PSITableIDNull.Type()) + assert.Equal(t, PSITableTypePAT, PSITableIDPAT.Type()) + assert.Equal(t, PSITableTypePMT, PSITableIDPMT.Type()) + assert.Equal(t, PSITableTypeRST, PSITableIDRST.Type()) + assert.Equal(t, PSITableTypeSIT, PSITableIDSIT.Type()) + assert.Equal(t, PSITableTypeST, PSITableIDST.Type()) + assert.Equal(t, PSITableTypeTDT, PSITableIDTDT.Type()) + assert.Equal(t, PSITableTypeTOT, PSITableIDTOT.Type()) + assert.Equal(t, PSITableTypeUnknown, PSITableID(1).Type()) } var psiSectionSyntaxHeader = &PSISectionSyntaxHeader{ @@ -269,7 +269,7 @@ func TestParsePSISectionSyntaxHeader(t *testing.T) { func TestPSIToData(t *testing.T) { p := &Packet{} - assert.Equal(t, []*Data{ + assert.Equal(t, []*DemuxerData{ {EIT: eit, FirstPacket: p, PID: 2}, {FirstPacket: p, NIT: nit, PID: 2}, {FirstPacket: p, PAT: pat, PID: 2}, @@ -278,3 +278,99 @@ func TestPSIToData(t *testing.T) { {FirstPacket: p, TOT: tot, PID: 2}, }, psi.toData(p, uint16(2))) } + +type psiDataTestCase struct { + name string + bytesFunc func(*astikit.BitsWriter) + data *PSIData +} + +var psiDataTestCases = []psiDataTestCase{ + { + "PAT", + func(w *astikit.BitsWriter) { + w.Write(uint8(4)) // Pointer field + w.Write([]byte{0, 0, 0, 0}) // Pointer field bytes + w.Write(uint8(0)) // PAT table ID + w.Write("1") // PAT syntax section indicator + w.Write("1") // PAT private bit + w.Write("11") // PAT reserved + w.Write("000000010001") // PAT section length + w.Write(psiSectionSyntaxHeaderBytes()) // PAT syntax section header + w.Write(patBytes()) // PAT data + w.Write(uint32(0x60739f61)) // PAT CRC32 + }, + &PSIData{ + PointerField: 4, + Sections: []*PSISection{ + { + CRC32: uint32(0x60739f61), + Header: &PSISectionHeader{ + PrivateBit: true, + SectionLength: 17, + SectionSyntaxIndicator: true, + TableID: 0, + TableType: PSITableTypePAT, + }, + Syntax: &PSISectionSyntax{ + Data: &PSISectionSyntaxData{PAT: pat}, + Header: psiSectionSyntaxHeader, + }, + }, + }, + }, + }, + { + "PMT", + func(w *astikit.BitsWriter) { + w.Write(uint8(4)) // Pointer field + w.Write([]byte{0, 0, 0, 0}) // Pointer field bytes + w.Write(uint8(2)) // PMT table ID + w.Write("1") // PMT syntax section indicator + w.Write("1") // PMT private bit + w.Write("11") // PMT reserved + w.Write("000000011000") // PMT section length + w.Write(psiSectionSyntaxHeaderBytes()) // PMT syntax section header + w.Write(pmtBytes()) // PMT data + w.Write(uint32(0xc68442e8)) // PMT CRC32 + }, + &PSIData{ + PointerField: 4, + Sections: []*PSISection{ + { + CRC32: uint32(0xc68442e8), + Header: &PSISectionHeader{ + PrivateBit: true, + SectionLength: 24, + SectionSyntaxIndicator: true, + TableID: 2, + TableType: PSITableTypePMT, + }, + Syntax: &PSISectionSyntax{ + Data: &PSISectionSyntaxData{PMT: pmt}, + Header: psiSectionSyntaxHeader, + }, + }, + }, + }, + }, +} + +func TestWritePSIData(t *testing.T) { + for _, tc := range psiDataTestCases { + t.Run(tc.name, func(t *testing.T) { + bufExpected := bytes.Buffer{} + wExpected := astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: &bufExpected}) + bufActual := bytes.Buffer{} + wActual := astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: &bufActual}) + + tc.bytesFunc(wExpected) + + n, err := writePSIData(wActual, tc.data) + assert.NoError(t, err) + assert.Equal(t, bufExpected.Len(), n) + assert.Equal(t, n, bufActual.Len()) + assert.Equal(t, bufExpected.Bytes(), bufActual.Bytes()) + }) + } +} diff --git a/data_sdt.go b/data_sdt.go index ce48af1..24eb51a 100644 --- a/data_sdt.go +++ b/data_sdt.go @@ -18,6 +18,7 @@ const ( // SDTData represents an SDT data // Page: 33 | Chapter: 5.2.3 | Link: https://www.dvb.org/resources/public/standards/a38_dvb-si_specification.pdf +// (barbashov) the link above can be broken, alternative: https://dvb.org/wp-content/uploads/2019/12/a038_tm1217r37_en300468v1_17_1_-_rev-134_-_si_specification.pdf type SDTData struct { OriginalNetworkID uint16 Services []*SDTDataService diff --git a/data_test.go b/data_test.go index 428671b..72e0ea0 100644 --- a/data_test.go +++ b/data_test.go @@ -14,8 +14,8 @@ func TestParseData(t *testing.T) { ps := []*Packet{} // Custom parser - cds := []*Data{{PID: 1}} - var c = func(ps []*Packet) (o []*Data, skip bool, err error) { + cds := []*DemuxerData{{PID: 1}} + var c = func(ps []*Packet) (o []*DemuxerData, skip bool, err error) { o = cds skip = true return @@ -44,7 +44,7 @@ func TestParseData(t *testing.T) { } ds, err = parseData(ps, nil, pm) assert.NoError(t, err) - assert.Equal(t, []*Data{{FirstPacket: ps[0], PES: pesWithHeader, PID: uint16(256)}}, ds) + assert.Equal(t, []*DemuxerData{{FirstPacket: ps[0], PES: pesWithHeader(), PID: uint16(256)}}, ds) // PSI pm.set(uint16(256), uint16(1)) diff --git a/data_tot.go b/data_tot.go index 4f569d2..0bd64d2 100644 --- a/data_tot.go +++ b/data_tot.go @@ -9,6 +9,7 @@ import ( // TOTData represents a TOT data // Page: 39 | Chapter: 5.2.6 | Link: https://www.dvb.org/resources/public/standards/a38_dvb-si_specification.pdf +// (barbashov) the link above can be broken, alternative: https://dvb.org/wp-content/uploads/2019/12/a038_tm1217r37_en300468v1_17_1_-_rev-134_-_si_specification.pdf type TOTData struct { Descriptors []*Descriptor UTCTime time.Time diff --git a/demuxer.go b/demuxer.go index 4db168f..b17e82f 100644 --- a/demuxer.go +++ b/demuxer.go @@ -22,7 +22,7 @@ var ( // http://www.etsi.org/deliver/etsi_en/300400_300499/300468/01.13.01_40/en_300468v011301o.pdf type Demuxer struct { ctx context.Context - dataBuffer []*Data + dataBuffer []*DemuxerData optPacketSize int optPacketsParser PacketsParser packetBuffer *packetBuffer @@ -33,10 +33,10 @@ type Demuxer struct { // PacketsParser represents an object capable of parsing a set of packets containing a unique payload spanning over those packets // Use the skip returned argument to indicate whether the default process should still be executed on the set of packets -type PacketsParser func(ps []*Packet) (ds []*Data, skip bool, err error) +type PacketsParser func(ps []*Packet) (ds []*DemuxerData, skip bool, err error) -// New creates a new transport stream based on a reader -func New(ctx context.Context, r io.Reader, opts ...func(*Demuxer)) (d *Demuxer) { +// NewDemuxer creates a new transport stream based on a reader +func NewDemuxer(ctx context.Context, r io.Reader, opts ...func(*Demuxer)) (d *Demuxer) { // Init d = &Demuxer{ ctx: ctx, @@ -49,18 +49,19 @@ func New(ctx context.Context, r io.Reader, opts ...func(*Demuxer)) (d *Demuxer) for _, opt := range opts { opt(d) } + return } -// OptPacketSize returns the option to set the packet size -func OptPacketSize(packetSize int) func(*Demuxer) { +// DemuxerOptPacketSize returns the option to set the packet size +func DemuxerOptPacketSize(packetSize int) func(*Demuxer) { return func(d *Demuxer) { d.optPacketSize = packetSize } } -// OptPacketsParser returns the option to set the packets parser -func OptPacketsParser(p PacketsParser) func(*Demuxer) { +// DemuxerOptPacketsParser returns the option to set the packets parser +func DemuxerOptPacketsParser(p PacketsParser) func(*Demuxer) { return func(d *Demuxer) { d.optPacketsParser = p } @@ -94,7 +95,7 @@ func (dmx *Demuxer) NextPacket() (p *Packet, err error) { } // NextData retrieves the next data -func (dmx *Demuxer) NextData() (d *Data, err error) { +func (dmx *Demuxer) NextData() (d *DemuxerData, err error) { // Check data buffer if len(dmx.dataBuffer) > 0 { d = dmx.dataBuffer[0] @@ -105,7 +106,7 @@ func (dmx *Demuxer) NextData() (d *Data, err error) { // Loop through packets var p *Packet var ps []*Packet - var ds []*Data + var ds []*DemuxerData for { // Get next packet if p, err = dmx.NextPacket(); err != nil { @@ -153,7 +154,7 @@ func (dmx *Demuxer) NextData() (d *Data, err error) { } } -func (dmx *Demuxer) updateData(ds []*Data) (d *Data) { +func (dmx *Demuxer) updateData(ds []*DemuxerData) (d *DemuxerData) { // Check whether there is data to be processed if len(ds) > 0 { // Process data @@ -177,7 +178,7 @@ func (dmx *Demuxer) updateData(ds []*Data) (d *Data) { // Rewind rewinds the demuxer reader func (dmx *Demuxer) Rewind() (n int64, err error) { - dmx.dataBuffer = []*Data{} + dmx.dataBuffer = []*DemuxerData{} dmx.packetBuffer = nil dmx.packetPool = newPacketPool() if n, err = rewind(dmx.r); err != nil { diff --git a/demuxer_test.go b/demuxer_test.go index d5dfee5..7249b7b 100644 --- a/demuxer_test.go +++ b/demuxer_test.go @@ -12,8 +12,8 @@ import ( func TestDemuxerNew(t *testing.T) { ps := 1 - pp := func(ps []*Packet) (ds []*Data, skip bool, err error) { return } - dmx := New(context.Background(), nil, OptPacketSize(ps), OptPacketsParser(pp)) + pp := func(ps []*Packet) (ds []*DemuxerData, skip bool, err error) { return } + dmx := NewDemuxer(context.Background(), nil, DemuxerOptPacketSize(ps), DemuxerOptPacketsParser(pp)) assert.Equal(t, ps, dmx.optPacketSize) assert.Equal(t, fmt.Sprintf("%p", pp), fmt.Sprintf("%p", dmx.optPacketsParser)) } @@ -21,7 +21,7 @@ func TestDemuxerNew(t *testing.T) { func TestDemuxerNextPacket(t *testing.T) { // Ctx error ctx, cancel := context.WithCancel(context.Background()) - dmx := New(ctx, bytes.NewReader([]byte{})) + dmx := NewDemuxer(ctx, bytes.NewReader([]byte{})) cancel() _, err := dmx.NextPacket() assert.Error(t, err) @@ -29,11 +29,11 @@ func TestDemuxerNextPacket(t *testing.T) { // Valid buf := &bytes.Buffer{} w := astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: buf}) - b1, p1 := packet(*packetHeader, *packetAdaptationField, []byte("1")) + b1, p1 := packet(*packetHeader, *packetAdaptationField, []byte("1"), true) w.Write(b1) - b2, p2 := packet(*packetHeader, *packetAdaptationField, []byte("2")) + b2, p2 := packet(*packetHeader, *packetAdaptationField, []byte("2"), true) w.Write(b2) - dmx = New(context.Background(), bytes.NewReader(buf.Bytes())) + dmx = NewDemuxer(context.Background(), bytes.NewReader(buf.Bytes())) // First packet p, err := dmx.NextPacket() @@ -56,20 +56,20 @@ func TestDemuxerNextData(t *testing.T) { buf := &bytes.Buffer{} w := astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: buf}) b := psiBytes() - b1, _ := packet(PacketHeader{ContinuityCounter: uint8(0), PayloadUnitStartIndicator: true, PID: PIDPAT}, PacketAdaptationField{}, b[:147]) + b1, _ := packet(PacketHeader{ContinuityCounter: uint8(0), PayloadUnitStartIndicator: true, PID: PIDPAT}, PacketAdaptationField{}, b[:147], true) w.Write(b1) - b2, _ := packet(PacketHeader{ContinuityCounter: uint8(1), PID: PIDPAT}, PacketAdaptationField{}, b[147:]) + b2, _ := packet(PacketHeader{ContinuityCounter: uint8(1), PID: PIDPAT}, PacketAdaptationField{}, b[147:], true) w.Write(b2) - dmx := New(context.Background(), bytes.NewReader(buf.Bytes())) + dmx := NewDemuxer(context.Background(), bytes.NewReader(buf.Bytes())) p, err := dmx.NextPacket() assert.NoError(t, err) _, err = dmx.Rewind() assert.NoError(t, err) // Next data - var ds []*Data + var ds []*DemuxerData for _, s := range psi.Sections { - if s.Header.TableType != PSITableTypeUnknown { + if !s.Header.TableID.isUnknown() { d, err := dmx.NextData() assert.NoError(t, err) ds = append(ds, d) @@ -85,9 +85,9 @@ func TestDemuxerNextData(t *testing.T) { func TestDemuxerRewind(t *testing.T) { r := bytes.NewReader([]byte("content")) - dmx := New(context.Background(), r) + dmx := NewDemuxer(context.Background(), r) dmx.packetPool.add(&Packet{Header: &PacketHeader{PID: 1}}) - dmx.dataBuffer = append(dmx.dataBuffer, &Data{}) + dmx.dataBuffer = append(dmx.dataBuffer, &DemuxerData{}) b := make([]byte, 2) _, err := r.Read(b) assert.NoError(t, err) diff --git a/descriptor.go b/descriptor.go index 62d2d2c..6a0e993 100644 --- a/descriptor.go +++ b/descriptor.go @@ -693,6 +693,7 @@ func newDescriptorExtensionSupplementaryAudio(i *astikit.BytesIterator, offsetEn // DescriptorISO639LanguageAndAudioType represents an ISO639 language descriptor // https://github.com/gfto/bitstream/blob/master/mpeg/psi/desc_0a.h +// FIXME (barbashov) according to Chapter 2.6.18 ISO/IEC 13818-1:2015 there could be not one, but multiple such descriptors type DescriptorISO639LanguageAndAudioType struct { Language []byte Type uint8 @@ -1032,7 +1033,9 @@ func newDescriptorShortEvent(i *astikit.BytesIterator) (d *DescriptorShortEvent, // DescriptorStreamIdentifier represents a stream identifier descriptor // Chapter: 6.2.39 | Link: https://www.etsi.org/deliver/etsi_en/300400_300499/300468/01.15.01_60/en_300468v011501p.pdf -type DescriptorStreamIdentifier struct{ ComponentTag uint8 } +type DescriptorStreamIdentifier struct { + ComponentTag uint8 +} func newDescriptorStreamIdentifier(i *astikit.BytesIterator) (d *DescriptorStreamIdentifier, err error) { var b byte @@ -1437,3 +1440,723 @@ func parseDescriptors(i *astikit.BytesIterator) (o []*Descriptor, err error) { } return } + +func calcDescriptorUserDefinedLength(d []byte) uint8 { + return uint8(len(d)) +} + +func writeDescriptorUserDefined(w *astikit.BitsWriter, d []byte) error { + b := astikit.NewBitsWriterBatch(w) + + b.Write(d) + + return b.Err() +} + +func calcDescriptorAC3Length(d *DescriptorAC3) uint8 { + ret := 1 // flags + + if d.HasComponentType { + ret++ + } + if d.HasBSID { + ret++ + } + if d.HasMainID { + ret++ + } + if d.HasASVC { + ret++ + } + + ret += len(d.AdditionalInfo) + + return uint8(ret) +} + +func writeDescriptorAC3(w *astikit.BitsWriter, d *DescriptorAC3) error { + b := astikit.NewBitsWriterBatch(w) + + b.Write(d.HasComponentType) + b.Write(d.HasBSID) + b.Write(d.HasMainID) + b.Write(d.HasASVC) + b.WriteN(uint8(0xff), 4) + + if d.HasComponentType { + b.Write(d.ComponentType) + } + if d.HasBSID { + b.Write(d.BSID) + } + if d.HasMainID { + b.Write(d.MainID) + } + if d.HasASVC { + b.Write(d.ASVC) + } + b.Write(d.AdditionalInfo) + + return b.Err() +} + +func calcDescriptorAVCVideoLength(d *DescriptorAVCVideo) uint8 { + return 4 +} + +func writeDescriptorAVCVideo(w *astikit.BitsWriter, d *DescriptorAVCVideo) error { + b := astikit.NewBitsWriterBatch(w) + + b.Write(d.ProfileIDC) + + b.Write(d.ConstraintSet0Flag) + b.Write(d.ConstraintSet1Flag) + b.Write(d.ConstraintSet2Flag) + b.WriteN(d.CompatibleFlags, 5) + + b.Write(d.LevelIDC) + + b.Write(d.AVCStillPresent) + b.Write(d.AVC24HourPictureFlag) + b.WriteN(uint8(0xff), 6) + + return b.Err() +} + +func calcDescriptorComponentLength(d *DescriptorComponent) uint8 { + return uint8(6 + len(d.Text)) +} + +func writeDescriptorComponent(w *astikit.BitsWriter, d *DescriptorComponent) error { + b := astikit.NewBitsWriterBatch(w) + + b.WriteN(d.StreamContentExt, 4) + b.WriteN(d.StreamContent, 4) + + b.Write(d.ComponentType) + b.Write(d.ComponentTag) + + b.WriteBytesN(d.ISO639LanguageCode, 3, 0) + + b.Write(d.Text) + + return b.Err() +} + +func calcDescriptorContentLength(d *DescriptorContent) uint8 { + return uint8(2 * len(d.Items)) +} + +func writeDescriptorContent(w *astikit.BitsWriter, d *DescriptorContent) error { + b := astikit.NewBitsWriterBatch(w) + + for _, item := range d.Items { + b.WriteN(item.ContentNibbleLevel1, 4) + b.WriteN(item.ContentNibbleLevel2, 4) + b.Write(item.UserByte) + } + + return b.Err() +} + +func calcDescriptorDataStreamAlignmentLength(d *DescriptorDataStreamAlignment) uint8 { + return 1 +} + +func writeDescriptorDataStreamAlignment(w *astikit.BitsWriter, d *DescriptorDataStreamAlignment) error { + b := astikit.NewBitsWriterBatch(w) + + b.Write(d.Type) + + return b.Err() +} + +func calcDescriptorEnhancedAC3Length(d *DescriptorEnhancedAC3) uint8 { + ret := 1 // flags + + if d.HasComponentType { + ret++ + } + if d.HasBSID { + ret++ + } + if d.HasMainID { + ret++ + } + if d.HasASVC { + ret++ + } + if d.HasSubStream1 { + ret++ + } + if d.HasSubStream2 { + ret++ + } + if d.HasSubStream3 { + ret++ + } + + ret += len(d.AdditionalInfo) + + return uint8(ret) +} + +func writeDescriptorEnhancedAC3(w *astikit.BitsWriter, d *DescriptorEnhancedAC3) error { + b := astikit.NewBitsWriterBatch(w) + + b.Write(d.HasComponentType) + b.Write(d.HasBSID) + b.Write(d.HasMainID) + b.Write(d.HasASVC) + b.Write(d.MixInfoExists) + b.Write(d.HasSubStream1) + b.Write(d.HasSubStream2) + b.Write(d.HasSubStream3) + + if d.HasComponentType { + b.Write(d.ComponentType) + } + if d.HasBSID { + b.Write(d.BSID) + } + if d.HasMainID { + b.Write(d.MainID) + } + if d.HasASVC { + b.Write(d.ASVC) + } + if d.HasSubStream1 { + b.Write(d.SubStream1) + } + if d.HasSubStream2 { + b.Write(d.SubStream2) + } + if d.HasSubStream3 { + b.Write(d.SubStream3) + } + + b.Write(d.AdditionalInfo) + + return b.Err() +} + +func calcDescriptorExtendedEventLength(d *DescriptorExtendedEvent) (descriptorLength, lengthOfItems uint8) { + ret := 1 + 3 + 1 // numbers, language and items length + + itemsRet := 0 + for _, item := range d.Items { + itemsRet += 1 // description length + itemsRet += len(item.Description) + itemsRet += 1 // content length + itemsRet += len(item.Content) + } + + ret += itemsRet + + ret += 1 // text length + ret += len(d.Text) + + return uint8(ret), uint8(itemsRet) +} + +func writeDescriptorExtendedEvent(w *astikit.BitsWriter, d *DescriptorExtendedEvent) error { + b := astikit.NewBitsWriterBatch(w) + + var lengthOfItems uint8 + + _, lengthOfItems = calcDescriptorExtendedEventLength(d) + + b.WriteN(d.Number, 4) + b.WriteN(d.LastDescriptorNumber, 4) + + b.WriteBytesN(d.ISO639LanguageCode, 3, 0) + + b.Write(lengthOfItems) + for _, item := range d.Items { + b.Write(uint8(len(item.Description))) + b.Write(item.Description) + b.Write(uint8(len(item.Content))) + b.Write(item.Content) + } + + b.Write(uint8(len(d.Text))) + b.Write(d.Text) + + return b.Err() +} + +func calcDescriptorExtensionSupplementaryAudioLength(d *DescriptorExtensionSupplementaryAudio) int { + ret := 1 + if d.HasLanguageCode { + ret += 3 + } + ret += len(d.PrivateData) + return ret +} + +func calcDescriptorExtensionLength(d *DescriptorExtension) uint8 { + ret := 1 // tag + + switch d.Tag { + case DescriptorTagExtensionSupplementaryAudio: + ret += calcDescriptorExtensionSupplementaryAudioLength(d.SupplementaryAudio) + default: + if d.Unknown != nil { + ret += len(*d.Unknown) + } + } + + return uint8(ret) +} + +func writeDescriptorExtensionSupplementaryAudio(w *astikit.BitsWriter, d *DescriptorExtensionSupplementaryAudio) error { + b := astikit.NewBitsWriterBatch(w) + + b.Write(d.MixType) + b.WriteN(d.EditorialClassification, 5) + b.Write(true) // reserved + b.Write(d.HasLanguageCode) + + if d.HasLanguageCode { + b.WriteBytesN(d.LanguageCode, 3, 0) + } + + b.Write(d.PrivateData) + + return b.Err() +} + +func writeDescriptorExtension(w *astikit.BitsWriter, d *DescriptorExtension) error { + b := astikit.NewBitsWriterBatch(w) + + b.Write(d.Tag) + + switch d.Tag { + case DescriptorTagExtensionSupplementaryAudio: + err := writeDescriptorExtensionSupplementaryAudio(w, d.SupplementaryAudio) + if err != nil { + return err + } + default: + if d.Unknown != nil { + b.Write(*d.Unknown) + } + } + + return b.Err() +} + +func calcDescriptorISO639LanguageAndAudioTypeLength(d *DescriptorISO639LanguageAndAudioType) uint8 { + return 3 + 1 // language code + type +} + +func writeDescriptorISO639LanguageAndAudioType(w *astikit.BitsWriter, d *DescriptorISO639LanguageAndAudioType) error { + b := astikit.NewBitsWriterBatch(w) + + b.WriteBytesN(d.Language, 3, 0) + b.Write(d.Type) + + return b.Err() +} + +func calcDescriptorLocalTimeOffsetLength(d *DescriptorLocalTimeOffset) uint8 { + return uint8(13 * len(d.Items)) +} + +func writeDescriptorLocalTimeOffset(w *astikit.BitsWriter, d *DescriptorLocalTimeOffset) error { + b := astikit.NewBitsWriterBatch(w) + + for _, item := range d.Items { + b.WriteBytesN(item.CountryCode, 3, 0) + + b.WriteN(item.CountryRegionID, 6) + b.WriteN(uint8(0xff), 1) + b.Write(item.LocalTimeOffsetPolarity) + + if _, err := writeDVBDurationMinutes(w, item.LocalTimeOffset); err != nil { + return err + } + if _, err := writeDVBTime(w, item.TimeOfChange); err != nil { + return err + } + if _, err := writeDVBDurationMinutes(w, item.NextTimeOffset); err != nil { + return err + } + } + + return b.Err() +} + +func calcDescriptorMaximumBitrateLength(d *DescriptorMaximumBitrate) uint8 { + return 3 +} + +func writeDescriptorMaximumBitrate(w *astikit.BitsWriter, d *DescriptorMaximumBitrate) error { + b := astikit.NewBitsWriterBatch(w) + + b.WriteN(uint8(0xff), 2) + b.WriteN(uint32(d.Bitrate/50), 22) + + return b.Err() +} + +func calcDescriptorNetworkNameLength(d *DescriptorNetworkName) uint8 { + return uint8(len(d.Name)) +} + +func writeDescriptorNetworkName(w *astikit.BitsWriter, d *DescriptorNetworkName) error { + b := astikit.NewBitsWriterBatch(w) + + b.Write(d.Name) + + return b.Err() +} + +func calcDescriptorParentalRatingLength(d *DescriptorParentalRating) uint8 { + return uint8(4 * len(d.Items)) +} + +func writeDescriptorParentalRating(w *astikit.BitsWriter, d *DescriptorParentalRating) error { + b := astikit.NewBitsWriterBatch(w) + + for _, item := range d.Items { + b.WriteBytesN(item.CountryCode, 3, 0) + b.Write(item.Rating) + } + + return b.Err() +} + +func calcDescriptorPrivateDataIndicatorLength(d *DescriptorPrivateDataIndicator) uint8 { + return 4 +} + +func writeDescriptorPrivateDataIndicator(w *astikit.BitsWriter, d *DescriptorPrivateDataIndicator) error { + b := astikit.NewBitsWriterBatch(w) + + b.Write(d.Indicator) + + return b.Err() +} + +func calcDescriptorPrivateDataSpecifierLength(d *DescriptorPrivateDataSpecifier) uint8 { + return 4 +} + +func writeDescriptorPrivateDataSpecifier(w *astikit.BitsWriter, d *DescriptorPrivateDataSpecifier) error { + b := astikit.NewBitsWriterBatch(w) + + b.Write(d.Specifier) + + return b.Err() +} + +func calcDescriptorRegistrationLength(d *DescriptorRegistration) uint8 { + return uint8(4 + len(d.AdditionalIdentificationInfo)) +} + +func writeDescriptorRegistration(w *astikit.BitsWriter, d *DescriptorRegistration) error { + b := astikit.NewBitsWriterBatch(w) + + b.Write(d.FormatIdentifier) + b.Write(d.AdditionalIdentificationInfo) + + return b.Err() +} + +func calcDescriptorServiceLength(d *DescriptorService) uint8 { + ret := 3 // type and lengths + ret += len(d.Name) + ret += len(d.Provider) + return uint8(ret) +} + +func writeDescriptorService(w *astikit.BitsWriter, d *DescriptorService) error { + b := astikit.NewBitsWriterBatch(w) + + b.Write(d.Type) + b.Write(uint8(len(d.Provider))) + b.Write(d.Provider) + b.Write(uint8(len(d.Name))) + b.Write(d.Name) + + return b.Err() +} + +func calcDescriptorShortEventLength(d *DescriptorShortEvent) uint8 { + ret := 3 + 1 + 1 // language code and lengths + ret += len(d.EventName) + ret += len(d.Text) + return uint8(ret) +} + +func writeDescriptorShortEvent(w *astikit.BitsWriter, d *DescriptorShortEvent) error { + b := astikit.NewBitsWriterBatch(w) + + b.WriteBytesN(d.Language, 3, 0) + + b.Write(uint8(len(d.EventName))) + b.Write(d.EventName) + + b.Write(uint8(len(d.Text))) + b.Write(d.Text) + + return b.Err() +} + +func calcDescriptorStreamIdentifierLength(d *DescriptorStreamIdentifier) uint8 { + return 1 +} + +func writeDescriptorStreamIdentifier(w *astikit.BitsWriter, d *DescriptorStreamIdentifier) error { + b := astikit.NewBitsWriterBatch(w) + + b.Write(d.ComponentTag) + + return b.Err() +} + +func calcDescriptorSubtitlingLength(d *DescriptorSubtitling) uint8 { + return uint8(8 * len(d.Items)) +} + +func writeDescriptorSubtitling(w *astikit.BitsWriter, d *DescriptorSubtitling) error { + b := astikit.NewBitsWriterBatch(w) + + for _, item := range d.Items { + b.WriteBytesN(item.Language, 3, 0) + b.Write(item.Type) + b.Write(item.CompositionPageID) + b.Write(item.AncillaryPageID) + } + + return b.Err() +} + +func calcDescriptorTeletextLength(d *DescriptorTeletext) uint8 { + return uint8(5 * len(d.Items)) +} + +func writeDescriptorTeletext(w *astikit.BitsWriter, d *DescriptorTeletext) error { + b := astikit.NewBitsWriterBatch(w) + + for _, item := range d.Items { + b.WriteBytesN(item.Language, 3, 0) + b.WriteN(item.Type, 5) + b.WriteN(item.Magazine, 3) + b.WriteN(item.Page/10, 4) + b.WriteN(item.Page%10, 4) + } + + return b.Err() +} + +func calcDescriptorVBIDataLength(d *DescriptorVBIData) uint8 { + return uint8(3 * len(d.Services)) +} + +func writeDescriptorVBIData(w *astikit.BitsWriter, d *DescriptorVBIData) error { + b := astikit.NewBitsWriterBatch(w) + + for _, item := range d.Services { + b.Write(item.DataServiceID) + + if item.DataServiceID == VBIDataServiceIDClosedCaptioning || + item.DataServiceID == VBIDataServiceIDEBUTeletext || + item.DataServiceID == VBIDataServiceIDInvertedTeletext || + item.DataServiceID == VBIDataServiceIDMonochrome442Samples || + item.DataServiceID == VBIDataServiceIDVPS || + item.DataServiceID == VBIDataServiceIDWSS { + + b.Write(uint8(len(item.Descriptors))) // each descriptor is 1 byte + for _, desc := range item.Descriptors { + b.WriteN(uint8(0xff), 2) + b.Write(desc.FieldParity) + b.WriteN(desc.LineOffset, 5) + } + } else { + // let's put one reserved byte + b.Write(uint8(1)) + b.Write(uint8(0xff)) + } + } + + return b.Err() +} + +func calcDescriptorUnknownLength(d *DescriptorUnknown) uint8 { + return uint8(len(d.Content)) +} + +func writeDescriptorUnknown(w *astikit.BitsWriter, d *DescriptorUnknown) error { + b := astikit.NewBitsWriterBatch(w) + + b.Write(d.Content) + + return b.Err() +} + +func calcDescriptorLength(d *Descriptor) uint8 { + if d.Tag >= 0x80 && d.Tag <= 0xfe { + return calcDescriptorUserDefinedLength(d.UserDefined) + } + + switch d.Tag { + case DescriptorTagAC3: + return calcDescriptorAC3Length(d.AC3) + case DescriptorTagAVCVideo: + return calcDescriptorAVCVideoLength(d.AVCVideo) + case DescriptorTagComponent: + return calcDescriptorComponentLength(d.Component) + case DescriptorTagContent: + return calcDescriptorContentLength(d.Content) + case DescriptorTagDataStreamAlignment: + return calcDescriptorDataStreamAlignmentLength(d.DataStreamAlignment) + case DescriptorTagEnhancedAC3: + return calcDescriptorEnhancedAC3Length(d.EnhancedAC3) + case DescriptorTagExtendedEvent: + ret, _ := calcDescriptorExtendedEventLength(d.ExtendedEvent) + return ret + case DescriptorTagExtension: + return calcDescriptorExtensionLength(d.Extension) + case DescriptorTagISO639LanguageAndAudioType: + return calcDescriptorISO639LanguageAndAudioTypeLength(d.ISO639LanguageAndAudioType) + case DescriptorTagLocalTimeOffset: + return calcDescriptorLocalTimeOffsetLength(d.LocalTimeOffset) + case DescriptorTagMaximumBitrate: + return calcDescriptorMaximumBitrateLength(d.MaximumBitrate) + case DescriptorTagNetworkName: + return calcDescriptorNetworkNameLength(d.NetworkName) + case DescriptorTagParentalRating: + return calcDescriptorParentalRatingLength(d.ParentalRating) + case DescriptorTagPrivateDataIndicator: + return calcDescriptorPrivateDataIndicatorLength(d.PrivateDataIndicator) + case DescriptorTagPrivateDataSpecifier: + return calcDescriptorPrivateDataSpecifierLength(d.PrivateDataSpecifier) + case DescriptorTagRegistration: + return calcDescriptorRegistrationLength(d.Registration) + case DescriptorTagService: + return calcDescriptorServiceLength(d.Service) + case DescriptorTagShortEvent: + return calcDescriptorShortEventLength(d.ShortEvent) + case DescriptorTagStreamIdentifier: + return calcDescriptorStreamIdentifierLength(d.StreamIdentifier) + case DescriptorTagSubtitling: + return calcDescriptorSubtitlingLength(d.Subtitling) + case DescriptorTagTeletext: + return calcDescriptorTeletextLength(d.Teletext) + case DescriptorTagVBIData: + return calcDescriptorVBIDataLength(d.VBIData) + case DescriptorTagVBITeletext: + return calcDescriptorTeletextLength(d.VBITeletext) + } + + return calcDescriptorUnknownLength(d.Unknown) +} + +func writeDescriptor(w *astikit.BitsWriter, d *Descriptor) (int, error) { + b := astikit.NewBitsWriterBatch(w) + length := calcDescriptorLength(d) + + b.Write(d.Tag) + b.Write(length) + + if err := b.Err(); err != nil { + return 0, err + } + + written := int(length) + 2 + + if d.Tag >= 0x80 && d.Tag <= 0xfe { + return written, writeDescriptorUserDefined(w, d.UserDefined) + } + + switch d.Tag { + case DescriptorTagAC3: + return written, writeDescriptorAC3(w, d.AC3) + case DescriptorTagAVCVideo: + return written, writeDescriptorAVCVideo(w, d.AVCVideo) + case DescriptorTagComponent: + return written, writeDescriptorComponent(w, d.Component) + case DescriptorTagContent: + return written, writeDescriptorContent(w, d.Content) + case DescriptorTagDataStreamAlignment: + return written, writeDescriptorDataStreamAlignment(w, d.DataStreamAlignment) + case DescriptorTagEnhancedAC3: + return written, writeDescriptorEnhancedAC3(w, d.EnhancedAC3) + case DescriptorTagExtendedEvent: + return written, writeDescriptorExtendedEvent(w, d.ExtendedEvent) + case DescriptorTagExtension: + return written, writeDescriptorExtension(w, d.Extension) + case DescriptorTagISO639LanguageAndAudioType: + return written, writeDescriptorISO639LanguageAndAudioType(w, d.ISO639LanguageAndAudioType) + case DescriptorTagLocalTimeOffset: + return written, writeDescriptorLocalTimeOffset(w, d.LocalTimeOffset) + case DescriptorTagMaximumBitrate: + return written, writeDescriptorMaximumBitrate(w, d.MaximumBitrate) + case DescriptorTagNetworkName: + return written, writeDescriptorNetworkName(w, d.NetworkName) + case DescriptorTagParentalRating: + return written, writeDescriptorParentalRating(w, d.ParentalRating) + case DescriptorTagPrivateDataIndicator: + return written, writeDescriptorPrivateDataIndicator(w, d.PrivateDataIndicator) + case DescriptorTagPrivateDataSpecifier: + return written, writeDescriptorPrivateDataSpecifier(w, d.PrivateDataSpecifier) + case DescriptorTagRegistration: + return written, writeDescriptorRegistration(w, d.Registration) + case DescriptorTagService: + return written, writeDescriptorService(w, d.Service) + case DescriptorTagShortEvent: + return written, writeDescriptorShortEvent(w, d.ShortEvent) + case DescriptorTagStreamIdentifier: + return written, writeDescriptorStreamIdentifier(w, d.StreamIdentifier) + case DescriptorTagSubtitling: + return written, writeDescriptorSubtitling(w, d.Subtitling) + case DescriptorTagTeletext: + return written, writeDescriptorTeletext(w, d.Teletext) + case DescriptorTagVBIData: + return written, writeDescriptorVBIData(w, d.VBIData) + case DescriptorTagVBITeletext: + return written, writeDescriptorTeletext(w, d.VBITeletext) + } + + return written, writeDescriptorUnknown(w, d.Unknown) +} + +func calcDescriptorsLength(ds []*Descriptor) uint16 { + length := uint16(0) + for _, d := range ds { + length += 2 // tag and length + length += uint16(calcDescriptorLength(d)) + } + return length +} + +func writeDescriptors(w *astikit.BitsWriter, ds []*Descriptor) (int, error) { + written := 0 + + for _, d := range ds { + n, err := writeDescriptor(w, d) + if err != nil { + return 0, err + } + written += n + } + + return written, nil +} + +func writeDescriptorsWithLength(w *astikit.BitsWriter, ds []*Descriptor) (int, error) { + length := calcDescriptorsLength(ds) + b := astikit.NewBitsWriterBatch(w) + + b.WriteN(uint8(0xff), 4) // reserved + b.WriteN(length, 12) // program_info_length + + if err := b.Err(); err != nil { + return 0, err + } + + written, err := writeDescriptors(w, ds) + return written + 2, err // 2 for length +} diff --git a/descriptor_test.go b/descriptor_test.go index 700bab7..27a3a62 100644 --- a/descriptor_test.go +++ b/descriptor_test.go @@ -2,10 +2,9 @@ package astits import ( "bytes" - "testing" - "github.com/asticode/go-astikit" "github.com/stretchr/testify/assert" + "testing" ) var descriptors = []*Descriptor{{ @@ -21,368 +20,645 @@ func descriptorsBytes(w *astikit.BitsWriter) { w.Write(uint8(7)) // Component tag } -func TestParseDescriptor(t *testing.T) { - // Init - buf := &bytes.Buffer{} - w := astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: buf}) - w.Write(uint16(255)) // Descriptors length - // AC3 - w.Write(uint8(DescriptorTagAC3)) // Tag - w.Write(uint8(9)) // Length - w.Write("1") // Component type flag - w.Write("1") // BSID flag - w.Write("1") // MainID flag - w.Write("1") // ASVC flag - w.Write("0000") // Reserved flags - w.Write(uint8(1)) // Component type - w.Write(uint8(2)) // BSID - w.Write(uint8(3)) // MainID - w.Write(uint8(4)) // ASVC - w.Write([]byte("info")) // Additional info - // ISO639 language and audio type - w.Write(uint8(DescriptorTagISO639LanguageAndAudioType)) // Tag - w.Write(uint8(4)) // Length - w.Write([]byte("eng")) // Language - w.Write(uint8(AudioTypeCleanEffects)) // Audio type - // Maximum bitrate - w.Write(uint8(DescriptorTagMaximumBitrate)) // Tag - w.Write(uint8(3)) // Length - w.Write("000000000000000000000001") // Maximum bitrate - // Network name - w.Write(uint8(DescriptorTagNetworkName)) // Tag - w.Write(uint8(4)) // Length - w.Write([]byte("name")) // Name - // Service - w.Write(uint8(DescriptorTagService)) // Tag - w.Write(uint8(18)) // Length - w.Write(uint8(ServiceTypeDigitalTelevisionService)) // Type - w.Write(uint8(8)) // Provider name length - w.Write([]byte("provider")) // Provider name - w.Write(uint8(7)) // Service name length - w.Write([]byte("service")) // Service name - // Short event - w.Write(uint8(DescriptorTagShortEvent)) // Tag - w.Write(uint8(14)) // Length - w.Write([]byte("eng")) // Language code - w.Write(uint8(5)) // Event name length - w.Write([]byte("event")) // Event name - w.Write(uint8(4)) // Text length - w.Write([]byte("text")) - // Stream identifier - w.Write(uint8(DescriptorTagStreamIdentifier)) // Tag - w.Write(uint8(1)) // Length - w.Write(uint8(2)) // Component tag - // Subtitling - w.Write(uint8(DescriptorTagSubtitling)) // Tag - w.Write(uint8(16)) // Length - w.Write([]byte("lg1")) // Item #1 language - w.Write(uint8(1)) // Item #1 type - w.Write(uint16(2)) // Item #1 composition page - w.Write(uint16(3)) // Item #1 ancillary page - w.Write([]byte("lg2")) // Item #2 language - w.Write(uint8(4)) // Item #2 type - w.Write(uint16(5)) // Item #2 composition page - w.Write(uint16(6)) // Item #2 ancillary page - // Teletext - w.Write(uint8(DescriptorTagTeletext)) // Tag - w.Write(uint8(10)) // Length - w.Write([]byte("lg1")) // Item #1 language - w.Write("00001") // Item #1 type - w.Write("010") // Item #1 magazine - w.Write("00010010") // Item #1 page number - w.Write([]byte("lg2")) // Item #2 language - w.Write("00011") // Item #2 type - w.Write("100") // Item #2 magazine - w.Write("00100011") // Item #2 page number - // Extended event - w.Write(uint8(DescriptorTagExtendedEvent)) // Tag - w.Write(uint8(30)) // Length - w.Write("0001") // Number - w.Write("0010") // Last descriptor number - w.Write([]byte("lan")) // ISO 639 language code - w.Write(uint8(20)) // Length of items - w.Write(uint8(11)) // Item #1 description length - w.Write([]byte("description")) // Item #1 description - w.Write(uint8(7)) // Item #1 content length - w.Write([]byte("content")) // Item #1 content - w.Write(uint8(4)) // Text length - w.Write([]byte("text")) // Text - // Enhanced AC3 - w.Write(uint8(DescriptorTagEnhancedAC3)) // Tag - w.Write(uint8(12)) // Length - w.Write("1") // Component type flag - w.Write("1") // BSID flag - w.Write("1") // MainID flag - w.Write("1") // ASVC flag - w.Write("1") // Mix info exists - w.Write("1") // SubStream1 flag - w.Write("1") // SubStream2 flag - w.Write("1") // SubStream3 flag - w.Write(uint8(1)) // Component type - w.Write(uint8(2)) // BSID - w.Write(uint8(3)) // MainID - w.Write(uint8(4)) // ASVC - w.Write(uint8(5)) // SubStream1 - w.Write(uint8(6)) // SubStream2 - w.Write(uint8(7)) // SubStream3 - w.Write([]byte("info")) // Additional info - // Extension supplementary audio - w.Write(uint8(DescriptorTagExtension)) // Tag - w.Write(uint8(12)) // Length - w.Write(uint8(DescriptorTagExtensionSupplementaryAudio)) // Extension tag - w.Write("1") // Mix type - w.Write("10101") // Editorial classification - w.Write("1") // Reserved - w.Write("1") // Language code flag - w.Write([]byte("lan")) // Language code - w.Write([]byte("private")) // Private data - // Component - w.Write(uint8(DescriptorTagComponent)) // Tag - w.Write(uint8(10)) // Length - w.Write("1010") // Stream content ext - w.Write("0101") // Stream content - w.Write(uint8(1)) // Component type - w.Write(uint8(2)) // Component tag - w.Write([]byte("lan")) // ISO639 language code - w.Write([]byte("text")) // Text - // Content - w.Write(uint8(DescriptorTagContent)) // Tag - w.Write(uint8(2)) // Length - w.Write("0001") // Item #1 content nibble level 1 - w.Write("0010") // Item #1 content nibble level 2 - w.Write(uint8(3)) // Item #1 user byte - // Parental rating - w.Write(uint8(DescriptorTagParentalRating)) // Tag - w.Write(uint8(4)) // Length - w.Write([]byte("cou")) // Item #1 country code - w.Write(uint8(2)) // Item #1 rating - // Local time offset - w.Write(uint8(DescriptorTagLocalTimeOffset)) // Tag - w.Write(uint8(13)) // Length - w.Write([]byte("cou")) // Country code - w.Write("101010") // Country region ID - w.Write("1") // Reserved - w.Write("1") // Local time offset polarity - w.Write(dvbDurationMinutesBytes) // Local time offset - w.Write(dvbTimeBytes) // Time of change - w.Write(dvbDurationMinutesBytes) // Next time offset - // VBI data - w.Write(uint8(DescriptorTagVBIData)) // Tag - w.Write(uint8(3)) // Length - w.Write(uint8(VBIDataServiceIDEBUTeletext)) // Service #1 id - w.Write(uint8(1)) // Service #1 descriptor length - w.Write("00") // Service #1 descriptor reserved - w.Write("1") // Service #1 descriptor field polarity - w.Write("10101") // Service #1 descriptor line offset - // VBI Teletext - w.Write(uint8(DescriptorTagVBITeletext)) // Tag - w.Write(uint8(5)) // Length - w.Write([]byte("lan")) // Item #1 language - w.Write("00001") // Item #1 type - w.Write("010") // Item #1 magazine - w.Write("00010010") // Item #1 page number - // AVC video - w.Write(uint8(DescriptorTagAVCVideo)) // Tag - w.Write(uint8(4)) // Length - w.Write(uint8(1)) // Profile idc - w.Write("1") // Constraint set0 flag - w.Write("1") // Constraint set1 flag - w.Write("1") // Constraint set1 flag - w.Write("10101") // Compatible flags - w.Write(uint8(2)) // Level idc - w.Write("1") // AVC still present - w.Write("1") // AVC 24 hour picture flag - w.Write("000000") // Reserved - // Private data specifier - w.Write(uint8(DescriptorTagPrivateDataSpecifier)) // Tag - w.Write(uint8(4)) // Length - w.Write(uint32(128)) // Private data specifier - // Data stream alignment - w.Write(uint8(DescriptorTagDataStreamAlignment)) // Tag - w.Write(uint8(1)) // Length - w.Write(uint8(2)) // Type - // Private data indicator - w.Write(uint8(DescriptorTagPrivateDataIndicator)) // Tag - w.Write(uint8(4)) // Length - w.Write(uint32(127)) // Private data indicator - // User defined - w.Write(uint8(0x80)) // Tag - w.Write(uint8(4)) // Length - w.Write([]byte("test")) // User defined - // Registration - w.Write(uint8(0x5)) // Tag - w.Write(uint8(8)) // Length - w.Write(uint32(1)) // Format identifier - w.Write([]byte("test")) // Additional identification info - // Unknown - w.Write(uint8(0x1)) // Tag - w.Write(uint8(4)) // Length - w.Write([]byte("test")) // Content - // Extension unknown - w.Write(uint8(DescriptorTagExtension)) // Tag - w.Write(uint8(5)) // Length - w.Write(uint8(0)) // Extension tag - w.Write([]byte("test")) // Content +type descriptorTest struct { + name string + bytesFunc func(w *astikit.BitsWriter) + desc Descriptor +} - // Assert - ds, err := parseDescriptors(astikit.NewBytesIterator(buf.Bytes())) - assert.NoError(t, err) - assert.Equal(t, *ds[0].AC3, DescriptorAC3{ - AdditionalInfo: []byte("info"), - ASVC: uint8(4), - BSID: uint8(2), - ComponentType: uint8(1), - HasASVC: true, - HasBSID: true, - HasComponentType: true, - HasMainID: true, - MainID: uint8(3), - }) - assert.Equal(t, *ds[1].ISO639LanguageAndAudioType, DescriptorISO639LanguageAndAudioType{ - Language: []byte("eng"), - Type: AudioTypeCleanEffects, - }) - assert.Equal(t, *ds[2].MaximumBitrate, DescriptorMaximumBitrate{Bitrate: uint32(50)}) - assert.Equal(t, *ds[3].NetworkName, DescriptorNetworkName{Name: []byte("name")}) - assert.Equal(t, *ds[4].Service, DescriptorService{ - Name: []byte("service"), - Provider: []byte("provider"), - Type: ServiceTypeDigitalTelevisionService, - }) - assert.Equal(t, *ds[5].ShortEvent, DescriptorShortEvent{ - EventName: []byte("event"), - Language: []byte("eng"), - Text: []byte("text"), - }) - assert.Equal(t, *ds[6].StreamIdentifier, DescriptorStreamIdentifier{ComponentTag: 0x2}) - assert.Equal(t, *ds[7].Subtitling, DescriptorSubtitling{Items: []*DescriptorSubtitlingItem{ - { - AncillaryPageID: 3, - CompositionPageID: 2, - Language: []byte("lg1"), - Type: 1, +var descriptorTestTable = []descriptorTest{ + { + "AC3", + func(w *astikit.BitsWriter) { + w.Write(uint8(DescriptorTagAC3)) // Tag + w.Write(uint8(9)) // Length + w.Write("1") // Component type flag + w.Write("1") // BSID flag + w.Write("1") // MainID flag + w.Write("1") // ASVC flag + w.Write("1111") // Reserved flags + w.Write(uint8(1)) // Component type + w.Write(uint8(2)) // BSID + w.Write(uint8(3)) // MainID + w.Write(uint8(4)) // ASVC + w.Write([]byte("info")) // Additional info + }, + Descriptor{ + Tag: DescriptorTagAC3, + Length: 9, + AC3: &DescriptorAC3{ + AdditionalInfo: []byte("info"), + ASVC: uint8(4), + BSID: uint8(2), + ComponentType: uint8(1), + HasASVC: true, + HasBSID: true, + HasComponentType: true, + HasMainID: true, + MainID: uint8(3), + }}, + }, + { + "ISO639LanguageAndAudioType", + func(w *astikit.BitsWriter) { + w.Write(uint8(DescriptorTagISO639LanguageAndAudioType)) // Tag + w.Write(uint8(4)) // Length + w.Write([]byte("eng")) // Language + w.Write(uint8(AudioTypeCleanEffects)) // Audio type + }, + Descriptor{ + Tag: DescriptorTagISO639LanguageAndAudioType, + Length: 4, + ISO639LanguageAndAudioType: &DescriptorISO639LanguageAndAudioType{ + Language: []byte("eng"), + Type: AudioTypeCleanEffects, + }}, + }, + { + "MaximumBitrate", + func(w *astikit.BitsWriter) { + w.Write(uint8(DescriptorTagMaximumBitrate)) // Tag + w.Write(uint8(3)) // Length + w.Write("110000000000000000000001") // Maximum bitrate + }, + Descriptor{ + Tag: DescriptorTagMaximumBitrate, + Length: 3, + MaximumBitrate: &DescriptorMaximumBitrate{Bitrate: uint32(50)}}, + }, + { + "NetworkName", + func(w *astikit.BitsWriter) { + w.Write(uint8(DescriptorTagNetworkName)) // Tag + w.Write(uint8(4)) // Length + w.Write([]byte("name")) // Name + }, + Descriptor{ + Tag: DescriptorTagNetworkName, + Length: 4, + NetworkName: &DescriptorNetworkName{Name: []byte("name")}}, + }, + { + "Service", + func(w *astikit.BitsWriter) { + w.Write(uint8(DescriptorTagService)) // Tag + w.Write(uint8(18)) // Length + w.Write(uint8(ServiceTypeDigitalTelevisionService)) // Type + w.Write(uint8(8)) // Provider name length + w.Write([]byte("provider")) // Provider name + w.Write(uint8(7)) // Service name length + w.Write([]byte("service")) // Service name + }, + Descriptor{ + Tag: DescriptorTagService, + Length: 18, + Service: &DescriptorService{ + Name: []byte("service"), + Provider: []byte("provider"), + Type: ServiceTypeDigitalTelevisionService, + }}, + }, + { + "ShortEvent", + func(w *astikit.BitsWriter) { + w.Write(uint8(DescriptorTagShortEvent)) // Tag + w.Write(uint8(14)) // Length + w.Write([]byte("eng")) // Language code + w.Write(uint8(5)) // Event name length + w.Write([]byte("event")) // Event name + w.Write(uint8(4)) // Text length + w.Write([]byte("text")) + }, + Descriptor{ + Tag: DescriptorTagShortEvent, + Length: 14, + ShortEvent: &DescriptorShortEvent{ + EventName: []byte("event"), + Language: []byte("eng"), + Text: []byte("text"), + }}, + }, + { + "StreamIdentifier", + func(w *astikit.BitsWriter) { + w.Write(uint8(DescriptorTagStreamIdentifier)) // Tag + w.Write(uint8(1)) // Length + w.Write(uint8(2)) // Component tag + }, + Descriptor{ + Tag: DescriptorTagStreamIdentifier, + Length: 1, + StreamIdentifier: &DescriptorStreamIdentifier{ComponentTag: 0x2}}, + }, + { + "Subtitling", + func(w *astikit.BitsWriter) { + w.Write(uint8(DescriptorTagSubtitling)) // Tag + w.Write(uint8(16)) // Length + w.Write([]byte("lg1")) // Item #1 language + w.Write(uint8(1)) // Item #1 type + w.Write(uint16(2)) // Item #1 composition page + w.Write(uint16(3)) // Item #1 ancillary page + w.Write([]byte("lg2")) // Item #2 language + w.Write(uint8(4)) // Item #2 type + w.Write(uint16(5)) // Item #2 composition page + w.Write(uint16(6)) // Item #2 ancillary page + }, + Descriptor{ + Tag: DescriptorTagSubtitling, + Length: 16, + Subtitling: &DescriptorSubtitling{Items: []*DescriptorSubtitlingItem{ + { + AncillaryPageID: 3, + CompositionPageID: 2, + Language: []byte("lg1"), + Type: 1, + }, + { + AncillaryPageID: 6, + CompositionPageID: 5, + Language: []byte("lg2"), + Type: 4, + }, + }}}, + }, + { + "Teletext", + func(w *astikit.BitsWriter) { + w.Write(uint8(DescriptorTagTeletext)) // Tag + w.Write(uint8(10)) // Length + w.Write([]byte("lg1")) // Item #1 language + w.Write("00001") // Item #1 type + w.Write("010") // Item #1 magazine + w.Write("00010010") // Item #1 page number + w.Write([]byte("lg2")) // Item #2 language + w.Write("00011") // Item #2 type + w.Write("100") // Item #2 magazine + w.Write("00100011") // Item #2 page number + }, + Descriptor{ + Tag: DescriptorTagTeletext, + Length: 10, + Teletext: &DescriptorTeletext{Items: []*DescriptorTeletextItem{ + { + Language: []byte("lg1"), + Magazine: uint8(2), + Page: uint8(12), + Type: uint8(1), + }, + { + Language: []byte("lg2"), + Magazine: uint8(4), + Page: uint8(23), + Type: uint8(3), + }, + }}}, + }, + { + "ExtendedEvent", + func(w *astikit.BitsWriter) { + w.Write(uint8(DescriptorTagExtendedEvent)) // Tag + w.Write(uint8(30)) // Length + w.Write("0001") // Number + w.Write("0010") // Last descriptor number + w.Write([]byte("lan")) // ISO 639 language code + w.Write(uint8(20)) // Length of items + w.Write(uint8(11)) // Item #1 description length + w.Write([]byte("description")) // Item #1 description + w.Write(uint8(7)) // Item #1 content length + w.Write([]byte("content")) // Item #1 content + w.Write(uint8(4)) // Text length + w.Write([]byte("text")) // Text + }, + Descriptor{ + Tag: DescriptorTagExtendedEvent, + Length: 30, + ExtendedEvent: &DescriptorExtendedEvent{ + ISO639LanguageCode: []byte("lan"), + Items: []*DescriptorExtendedEventItem{{ + Content: []byte("content"), + Description: []byte("description"), + }}, + LastDescriptorNumber: 0x2, + Number: 0x1, + Text: []byte("text"), + }}, + }, + { + "EnhancedAC3", + func(w *astikit.BitsWriter) { + w.Write(uint8(DescriptorTagEnhancedAC3)) // Tag + w.Write(uint8(12)) // Length + w.Write("1") // Component type flag + w.Write("1") // BSID flag + w.Write("1") // MainID flag + w.Write("1") // ASVC flag + w.Write("1") // Mix info exists + w.Write("1") // SubStream1 flag + w.Write("1") // SubStream2 flag + w.Write("1") // SubStream3 flag + w.Write(uint8(1)) // Component type + w.Write(uint8(2)) // BSID + w.Write(uint8(3)) // MainID + w.Write(uint8(4)) // ASVC + w.Write(uint8(5)) // SubStream1 + w.Write(uint8(6)) // SubStream2 + w.Write(uint8(7)) // SubStream3 + w.Write([]byte("info")) // Additional info + }, + Descriptor{ + Tag: DescriptorTagEnhancedAC3, + Length: 12, + EnhancedAC3: &DescriptorEnhancedAC3{ + AdditionalInfo: []byte("info"), + ASVC: uint8(4), + BSID: uint8(2), + ComponentType: uint8(1), + HasASVC: true, + HasBSID: true, + HasComponentType: true, + HasMainID: true, + HasSubStream1: true, + HasSubStream2: true, + HasSubStream3: true, + MainID: uint8(3), + MixInfoExists: true, + SubStream1: 5, + SubStream2: 6, + SubStream3: 7, + }}, + }, + { + "Extension", + func(w *astikit.BitsWriter) { + w.Write(uint8(DescriptorTagExtension)) // Tag + w.Write(uint8(12)) // Length + w.Write(uint8(DescriptorTagExtensionSupplementaryAudio)) // Extension tag + w.Write("1") // Mix type + w.Write("10101") // Editorial classification + w.Write("1") // Reserved + w.Write("1") // Language code flag + w.Write([]byte("lan")) // Language code + w.Write([]byte("private")) // Private data + }, + Descriptor{ + Tag: DescriptorTagExtension, + Length: 12, + Extension: &DescriptorExtension{ + SupplementaryAudio: &DescriptorExtensionSupplementaryAudio{ + EditorialClassification: 21, + HasLanguageCode: true, + LanguageCode: []byte("lan"), + MixType: true, + PrivateData: []byte("private"), + }, + Tag: DescriptorTagExtensionSupplementaryAudio, + Unknown: nil, + }}, + }, + { + "Component", + func(w *astikit.BitsWriter) { + w.Write(uint8(DescriptorTagComponent)) // Tag + w.Write(uint8(10)) // Length + w.Write("1010") // Stream content ext + w.Write("0101") // Stream content + w.Write(uint8(1)) // Component type + w.Write(uint8(2)) // Component tag + w.Write([]byte("lan")) // ISO639 language code + w.Write([]byte("text")) // Text + }, + Descriptor{ + Tag: DescriptorTagComponent, + Length: 10, + Component: &DescriptorComponent{ + ComponentTag: 2, + ComponentType: 1, + ISO639LanguageCode: []byte("lan"), + StreamContentExt: 10, + StreamContent: 5, + Text: []byte("text"), + }}, + }, + { + "Content", + func(w *astikit.BitsWriter) { + w.Write(uint8(DescriptorTagContent)) // Tag + w.Write(uint8(2)) // Length + w.Write("0001") // Item #1 content nibble level 1 + w.Write("0010") // Item #1 content nibble level 2 + w.Write(uint8(3)) // Item #1 user byte + }, + Descriptor{ + Tag: DescriptorTagContent, + Length: 2, + Content: &DescriptorContent{Items: []*DescriptorContentItem{{ + ContentNibbleLevel1: 1, + ContentNibbleLevel2: 2, + UserByte: 3, + }}}}, + }, + { + "ParentalRating", + func(w *astikit.BitsWriter) { + w.Write(uint8(DescriptorTagParentalRating)) // Tag + w.Write(uint8(4)) // Length + w.Write([]byte("cou")) // Item #1 country code + w.Write(uint8(2)) // Item #1 rating + }, + Descriptor{ + Tag: DescriptorTagParentalRating, + Length: 4, + ParentalRating: &DescriptorParentalRating{Items: []*DescriptorParentalRatingItem{{ + CountryCode: []byte("cou"), + Rating: 2, + }}}}, + }, + { + "LocalTimeOffset", + func(w *astikit.BitsWriter) { + w.Write(uint8(DescriptorTagLocalTimeOffset)) // Tag + w.Write(uint8(13)) // Length + w.Write([]byte("cou")) // Country code + w.Write("101010") // Country region ID + w.Write("1") // Reserved + w.Write("1") // Local time offset polarity + w.Write(dvbDurationMinutesBytes) // Local time offset + w.Write(dvbTimeBytes) // Time of change + w.Write(dvbDurationMinutesBytes) // Next time offset + }, + Descriptor{ + Tag: DescriptorTagLocalTimeOffset, + Length: 13, + LocalTimeOffset: &DescriptorLocalTimeOffset{Items: []*DescriptorLocalTimeOffsetItem{{ + CountryCode: []byte("cou"), + CountryRegionID: 42, + LocalTimeOffset: dvbDurationMinutes, + LocalTimeOffsetPolarity: true, + NextTimeOffset: dvbDurationMinutes, + TimeOfChange: dvbTime, + }}}}, + }, + { + "VBIData", + func(w *astikit.BitsWriter) { + w.Write(uint8(DescriptorTagVBIData)) // Tag + w.Write(uint8(3)) // Length + w.Write(uint8(VBIDataServiceIDEBUTeletext)) // Service #1 id + w.Write(uint8(1)) // Service #1 descriptor length + w.Write("11") // Service #1 descriptor reserved + w.Write("1") // Service #1 descriptor field polarity + w.Write("10101") // Service #1 descriptor line offset + }, + Descriptor{ + Tag: DescriptorTagVBIData, + Length: 3, + VBIData: &DescriptorVBIData{Services: []*DescriptorVBIDataService{{ + DataServiceID: VBIDataServiceIDEBUTeletext, + Descriptors: []*DescriptorVBIDataDescriptor{{ + FieldParity: true, + LineOffset: 21, + }}, + }}}}, + }, + { + "VBITeletext", + func(w *astikit.BitsWriter) { + w.Write(uint8(DescriptorTagVBITeletext)) // Tag + w.Write(uint8(5)) // Length + w.Write([]byte("lan")) // Item #1 language + w.Write("00001") // Item #1 type + w.Write("010") // Item #1 magazine + w.Write("00010010") // Item #1 page number + }, + Descriptor{ + Tag: DescriptorTagVBITeletext, + Length: 5, + VBITeletext: &DescriptorTeletext{Items: []*DescriptorTeletextItem{{ + Language: []byte("lan"), + Magazine: uint8(2), + Page: uint8(12), + Type: uint8(1), + }}}}, + }, + { + "AVCVideo", + func(w *astikit.BitsWriter) { + w.Write(uint8(DescriptorTagAVCVideo)) // Tag + w.Write(uint8(4)) // Length + w.Write(uint8(1)) // Profile idc + w.Write("1") // Constraint set0 flag + w.Write("1") // Constraint set1 flag + w.Write("1") // Constraint set1 flag + w.Write("10101") // Compatible flags + w.Write(uint8(2)) // Level idc + w.Write("1") // AVC still present + w.Write("1") // AVC 24 hour picture flag + w.Write("111111") // Reserved }, - { - AncillaryPageID: 6, - CompositionPageID: 5, - Language: []byte("lg2"), - Type: 4, + Descriptor{ + Tag: DescriptorTagAVCVideo, + Length: 4, + AVCVideo: &DescriptorAVCVideo{ + AVC24HourPictureFlag: true, + AVCStillPresent: true, + CompatibleFlags: 21, + ConstraintSet0Flag: true, + ConstraintSet1Flag: true, + ConstraintSet2Flag: true, + LevelIDC: 2, + ProfileIDC: 1, + }}, + }, + { + "PrivateDataSpecifier", + func(w *astikit.BitsWriter) { + w.Write(uint8(DescriptorTagPrivateDataSpecifier)) // Tag + w.Write(uint8(4)) // Length + w.Write(uint32(128)) // Private data specifier }, - }}) - assert.Equal(t, *ds[8].Teletext, DescriptorTeletext{Items: []*DescriptorTeletextItem{ - { - Language: []byte("lg1"), - Magazine: uint8(2), - Page: uint8(12), - Type: uint8(1), + Descriptor{ + Tag: DescriptorTagPrivateDataSpecifier, + Length: 4, + PrivateDataSpecifier: &DescriptorPrivateDataSpecifier{ + Specifier: 128, + }}, + }, + { + "DataStreamAlignment", + func(w *astikit.BitsWriter) { + w.Write(uint8(DescriptorTagDataStreamAlignment)) // Tag + w.Write(uint8(1)) // Length + w.Write(uint8(2)) // Type }, - { - Language: []byte("lg2"), - Magazine: uint8(4), - Page: uint8(23), - Type: uint8(3), + Descriptor{ + Tag: DescriptorTagDataStreamAlignment, + Length: 1, + DataStreamAlignment: &DescriptorDataStreamAlignment{ + Type: 2, + }}, + }, + { + "PrivateDataIndicator", + func(w *astikit.BitsWriter) { + w.Write(uint8(DescriptorTagPrivateDataIndicator)) // Tag + w.Write(uint8(4)) // Length + w.Write(uint32(127)) // Private data indicator }, - }}) - assert.Equal(t, *ds[9].ExtendedEvent, DescriptorExtendedEvent{ - ISO639LanguageCode: []byte("lan"), - Items: []*DescriptorExtendedEventItem{{ - Content: []byte("content"), - Description: []byte("description"), - }}, - LastDescriptorNumber: 0x2, - Number: 0x1, - Text: []byte("text"), - }) - assert.Equal(t, *ds[10].EnhancedAC3, DescriptorEnhancedAC3{ - AdditionalInfo: []byte("info"), - ASVC: uint8(4), - BSID: uint8(2), - ComponentType: uint8(1), - HasASVC: true, - HasBSID: true, - HasComponentType: true, - HasMainID: true, - HasSubStream1: true, - HasSubStream2: true, - HasSubStream3: true, - MainID: uint8(3), - MixInfoExists: true, - SubStream1: 5, - SubStream2: 6, - SubStream3: 7, - }) - assert.Equal(t, *ds[11].Extension.SupplementaryAudio, DescriptorExtensionSupplementaryAudio{ - EditorialClassification: 21, - HasLanguageCode: true, - LanguageCode: []byte("lan"), - MixType: true, - PrivateData: []byte("private"), - }) - assert.Equal(t, *ds[12].Component, DescriptorComponent{ - ComponentTag: 2, - ComponentType: 1, - ISO639LanguageCode: []byte("lan"), - StreamContentExt: 10, - StreamContent: 5, - Text: []byte("text"), - }) - assert.Equal(t, *ds[13].Content, DescriptorContent{Items: []*DescriptorContentItem{{ - ContentNibbleLevel1: 1, - ContentNibbleLevel2: 2, - UserByte: 3, - }}}) - assert.Equal(t, *ds[14].ParentalRating, DescriptorParentalRating{Items: []*DescriptorParentalRatingItem{{ - CountryCode: []byte("cou"), - Rating: 2, - }}}) - assert.Equal(t, *ds[15].LocalTimeOffset, DescriptorLocalTimeOffset{Items: []*DescriptorLocalTimeOffsetItem{{ - CountryCode: []byte("cou"), - CountryRegionID: 42, - LocalTimeOffset: dvbDurationMinutes, - LocalTimeOffsetPolarity: true, - NextTimeOffset: dvbDurationMinutes, - TimeOfChange: dvbTime, - }}}) - assert.Equal(t, *ds[16].VBIData, DescriptorVBIData{Services: []*DescriptorVBIDataService{{ - DataServiceID: VBIDataServiceIDEBUTeletext, - Descriptors: []*DescriptorVBIDataDescriptor{{ - FieldParity: true, - LineOffset: 21, - }}, - }}}) - assert.Equal(t, *ds[17].VBITeletext, DescriptorTeletext{Items: []*DescriptorTeletextItem{{ - Language: []byte("lan"), - Magazine: uint8(2), - Page: uint8(12), - Type: uint8(1), - }}}) - assert.Equal(t, *ds[18].AVCVideo, DescriptorAVCVideo{ - AVC24HourPictureFlag: true, - AVCStillPresent: true, - CompatibleFlags: 21, - ConstraintSet0Flag: true, - ConstraintSet1Flag: true, - ConstraintSet2Flag: true, - LevelIDC: 2, - ProfileIDC: 1, - }) - assert.Equal(t, *ds[19].PrivateDataSpecifier, DescriptorPrivateDataSpecifier{ - Specifier: 128, - }) - assert.Equal(t, *ds[20].DataStreamAlignment, DescriptorDataStreamAlignment{ - Type: 2, - }) - assert.Equal(t, *ds[21].PrivateDataIndicator, DescriptorPrivateDataIndicator{ - Indicator: 127, - }) - assert.Equal(t, ds[22].UserDefined, []byte("test")) - assert.Equal(t, *ds[23].Registration, DescriptorRegistration{ - AdditionalIdentificationInfo: []byte("test"), - FormatIdentifier: uint32(1), - }) - assert.Equal(t, *ds[24].Unknown, DescriptorUnknown{ - Content: []byte("test"), - Tag: 0x1, - }) - assert.Equal(t, *ds[25].Extension.Unknown, []byte("test")) + Descriptor{ + Tag: DescriptorTagPrivateDataIndicator, + Length: 4, + PrivateDataIndicator: &DescriptorPrivateDataIndicator{ + Indicator: 127, + }}, + }, + { + "UserDefined", + func(w *astikit.BitsWriter) { + w.Write(uint8(0x80)) // Tag + w.Write(uint8(4)) // Length + w.Write([]byte("test")) // User defined + }, + Descriptor{ + Tag: 0x80, + Length: 4, + UserDefined: []byte("test")}, + }, + { + "Registration", + func(w *astikit.BitsWriter) { + w.Write(uint8(DescriptorTagRegistration)) // Tag + w.Write(uint8(8)) // Length + w.Write(uint32(1)) // Format identifier + w.Write([]byte("test")) // Additional identification info + }, + Descriptor{ + Tag: DescriptorTagRegistration, + Length: 8, + Registration: &DescriptorRegistration{ + AdditionalIdentificationInfo: []byte("test"), + FormatIdentifier: uint32(1), + }}, + }, + { + "Unknown", + func(w *astikit.BitsWriter) { + w.Write(uint8(0x1)) // Tag + w.Write(uint8(4)) // Length + w.Write([]byte("test")) // Content + }, + Descriptor{ + Tag: 0x1, + Length: 4, + Unknown: &DescriptorUnknown{ + Content: []byte("test"), + Tag: 0x1, + }}, + }, + { + "Extension", + func(w *astikit.BitsWriter) { + w.Write(uint8(DescriptorTagExtension)) // Tag + w.Write(uint8(5)) // Length + w.Write(uint8(0)) // Extension tag + w.Write([]byte("test")) // Content + }, + Descriptor{ + Tag: DescriptorTagExtension, + Length: 5, + Extension: &DescriptorExtension{ + Tag: 0, + Unknown: &[]byte{'t', 'e', 's', 't'}, + }}, + }, +} + +func TestParseDescriptorOneByOne(t *testing.T) { + for _, tc := range descriptorTestTable { + t.Run(tc.name, func(t *testing.T) { + // idea is following: + // 1. get descriptor bytes and update its length + // 2. parse bytes and get a Descriptor instance + // 3. compare expected descriptor value and actual + buf := bytes.Buffer{} + buf.Write([]byte{0x00, 0x00}) // reserve two bytes for length + w := astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: &buf}) + tc.bytesFunc(w) + descLen := uint16(buf.Len() - 2) + descBytes := buf.Bytes() + descBytes[0] = byte(descLen >> 8) + descBytes[1] = byte(descLen & 0xff) + + ds, err := parseDescriptors(astikit.NewBytesIterator(descBytes)) + assert.NoError(t, err) + assert.Equal(t, tc.desc, *ds[0]) + }) + } +} + +func TestParseDescriptorAll(t *testing.T) { + buf := bytes.Buffer{} + buf.Write([]byte{0x00, 0x00}) // reserve two bytes for length + w := astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: &buf}) + + for _, tc := range descriptorTestTable { + tc.bytesFunc(w) + } + + descLen := uint16(buf.Len() - 2) + descBytes := buf.Bytes() + descBytes[0] = byte(descLen >> 8) + descBytes[1] = byte(descLen & 0xff) + + ds, err := parseDescriptors(astikit.NewBytesIterator(descBytes)) + assert.NoError(t, err) + + for i, tc := range descriptorTestTable { + assert.Equal(t, tc.desc, *ds[i]) + } +} + +func TestWriteDescriptorOneByOne(t *testing.T) { + for _, tc := range descriptorTestTable { + t.Run(tc.name, func(t *testing.T) { + bufExpected := bytes.Buffer{} + wExpected := astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: &bufExpected}) + tc.bytesFunc(wExpected) + + bufActual := bytes.Buffer{} + wActual := astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: &bufActual}) + n, err := writeDescriptor(wActual, &tc.desc) + assert.NoError(t, err) + assert.Equal(t, n, bufActual.Len()) + assert.Equal(t, bufExpected.Bytes(), bufActual.Bytes()) + }) + } +} + +func TestWriteDescriptorAll(t *testing.T) { + bufExpected := bytes.Buffer{} + bufExpected.Write([]byte{0x00, 0x00}) // reserve two bytes for length + wExpected := astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: &bufExpected}) + + dss := []*Descriptor{} + + for _, tc := range descriptorTestTable { + tc.bytesFunc(wExpected) + tcc := tc + dss = append(dss, &tcc.desc) + } + + descLen := uint16(bufExpected.Len() - 2) + descBytes := bufExpected.Bytes() + descBytes[0] = byte(descLen>>8) | 0b11110000 // program_info_length is preceded by 4 reserved bits + descBytes[1] = byte(descLen & 0xff) + + bufActual := bytes.Buffer{} + wActual := astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: &bufActual}) + + n, err := writeDescriptorsWithLength(wActual, dss) + assert.NoError(t, err) + assert.Equal(t, n, bufActual.Len()) + assert.Equal(t, bufExpected.Len(), bufActual.Len()) + assert.Equal(t, bufExpected.Bytes(), bufActual.Bytes()) } diff --git a/dvb.go b/dvb.go index ec8e9c9..1cbaa97 100644 --- a/dvb.go +++ b/dvb.go @@ -13,6 +13,7 @@ import ( // field are set to "1". // I apologize for the computation which is really messy but details are given in the documentation // Page: 160 | Annex C | Link: https://www.dvb.org/resources/public/standards/a38_dvb-si_specification.pdf +// (barbashov) the link above can be broken, alternative: https://dvb.org/wp-content/uploads/2019/12/a038_tm1217r37_en300468v1_17_1_-_rev-134_-_si_specification.pdf func parseDVBTime(i *astikit.BytesIterator) (t time.Time, err error) { // Get next 2 bytes var bs []byte @@ -72,3 +73,58 @@ func parseDVBDurationSeconds(i *astikit.BytesIterator) (d time.Duration, err err func parseDVBDurationByte(i byte) time.Duration { return time.Duration(uint8(i)>>4*10 + uint8(i)&0xf) } + +func writeDVBTime(w *astikit.BitsWriter, t time.Time) (int, error) { + year := t.Year() - 1900 + month := t.Month() + day := t.Day() + + l := 0 + if month <= time.February { + l = 1 + } + + mjd := 14956 + day + int(float64(year-l)*365.25) + int(float64(int(month)+1+l*12)*30.6001) + + d := t.Sub(t.Truncate(24 * time.Hour)) + + b := astikit.NewBitsWriterBatch(w) + + b.Write(uint16(mjd)) + bytesWritten, err := writeDVBDurationSeconds(w, d) + if err != nil { + return 2, err + } + + return bytesWritten + 2, b.Err() +} + +func writeDVBDurationMinutes(w *astikit.BitsWriter, d time.Duration) (int, error) { + b := astikit.NewBitsWriterBatch(w) + + hours := uint8(d.Hours()) + minutes := uint8(int(d.Minutes()) % 60) + + b.Write(dvbDurationByteRepresentation(hours)) + b.Write(dvbDurationByteRepresentation(minutes)) + + return 2, b.Err() +} + +func writeDVBDurationSeconds(w *astikit.BitsWriter, d time.Duration) (int, error) { + b := astikit.NewBitsWriterBatch(w) + + hours := uint8(d.Hours()) + minutes := uint8(int(d.Minutes()) % 60) + seconds := uint8(int(d.Seconds()) % 60) + + b.Write(dvbDurationByteRepresentation(hours)) + b.Write(dvbDurationByteRepresentation(minutes)) + b.Write(dvbDurationByteRepresentation(seconds)) + + return 3, b.Err() +} + +func dvbDurationByteRepresentation(n uint8) uint8 { + return (n/10)<<4 | n%10 +} diff --git a/dvb_test.go b/dvb_test.go index 8e71ef2..da09c49 100644 --- a/dvb_test.go +++ b/dvb_test.go @@ -1,6 +1,7 @@ package astits import ( + "bytes" "testing" "time" @@ -34,3 +35,30 @@ func TestParseDVBDurationSeconds(t *testing.T) { assert.Equal(t, dvbDurationSeconds, d) assert.NoError(t, err) } + +func TestWriteDVBTime(t *testing.T) { + buf := &bytes.Buffer{} + w := astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: buf}) + n, err := writeDVBTime(w, dvbTime) + assert.NoError(t, err) + assert.Equal(t, n, buf.Len()) + assert.Equal(t, dvbTimeBytes, buf.Bytes()) +} + +func TestWriteDVBDurationMinutes(t *testing.T) { + buf := &bytes.Buffer{} + w := astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: buf}) + n, err := writeDVBDurationMinutes(w, dvbDurationMinutes) + assert.NoError(t, err) + assert.Equal(t, n, buf.Len()) + assert.Equal(t, dvbDurationMinutesBytes, buf.Bytes()) +} + +func TestWriteDVBDurationSeconds(t *testing.T) { + buf := &bytes.Buffer{} + w := astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: buf}) + n, err := writeDVBDurationSeconds(w, dvbDurationSeconds) + assert.NoError(t, err) + assert.Equal(t, n, buf.Len()) + assert.Equal(t, dvbDurationSecondsBytes, buf.Bytes()) +} diff --git a/go.mod b/go.mod index 0ba7d11..77f4f3b 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/asticode/go-astits go 1.13 require ( - github.com/asticode/go-astikit v0.2.0 + github.com/asticode/go-astikit v0.19.0 github.com/pkg/profile v1.4.0 github.com/stretchr/testify v1.4.0 ) diff --git a/go.sum b/go.sum index 6e232b2..440b1af 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -github.com/asticode/go-astikit v0.2.0 h1:QonRVJKQB2btMYZGW+YkibMDOXje2F49RLW4UCnyjns= -github.com/asticode/go-astikit v0.2.0/go.mod h1:h4ly7idim1tNhaVkdVBeXQZEE3L0xblP7fCWbgwipF0= +github.com/asticode/go-astikit v0.19.0 h1:NEeyjodbwGTZN7Pn8IYXFw1Occl59okjg6wYfDce7uM= +github.com/asticode/go-astikit v0.19.0/go.mod h1:h4ly7idim1tNhaVkdVBeXQZEE3L0xblP7fCWbgwipF0= github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/pkg/profile v1.4.0 h1:uCmaf4vVbWAOZz36k1hrQD7ijGRzLwaME8Am/7a4jZI= diff --git a/muxer.go b/muxer.go new file mode 100644 index 0000000..0312169 --- /dev/null +++ b/muxer.go @@ -0,0 +1,422 @@ +package astits + +import ( + "bytes" + "context" + "errors" + "github.com/asticode/go-astikit" + "io" +) + +const ( + startPID uint16 = 0x0100 + pmtStartPID uint16 = 0x1000 + programNumberStart uint16 = 1 +) + +var ( + ErrPIDNotFound = errors.New("astits: PID not found") + ErrPIDAlreadyExists = errors.New("astits: PID already exists") + ErrPCRPIDInvalid = errors.New("astits: PCR PID invalid") +) + +type Muxer struct { + ctx context.Context + w io.Writer + bitsWriter *astikit.BitsWriter + + packetSize int + tablesRetransmitPeriod int // period in PES packets + + pm programMap // pid -> programNumber + pmt PMTData + nextPID uint16 + patVersion wrappingCounter + pmtVersion wrappingCounter + + patBytes bytes.Buffer + pmtBytes bytes.Buffer + + buf bytes.Buffer + bufWriter *astikit.BitsWriter + + esContexts map[uint16]*esContext + tablesRetransmitCounter int +} + +type esContext struct { + es *PMTElementaryStream + cc wrappingCounter +} + +func newEsContext(es *PMTElementaryStream) *esContext { + return &esContext{ + es: es, + cc: newWrappingCounter(0b1111), // CC is 4 bits + } +} + +func MuxerOptTablesRetransmitPeriod(newPeriod int) func(*Muxer) { + return func(m *Muxer) { + m.tablesRetransmitPeriod = newPeriod + } +} + +// TODO MuxerOptAutodetectPCRPID selecting first video PID for each PMT, falling back to first audio, falling back to any other + +func NewMuxer(ctx context.Context, w io.Writer, opts ...func(*Muxer)) *Muxer { + m := &Muxer{ + ctx: ctx, + w: w, + + packetSize: MpegTsPacketSize, // no 192-byte packet support yet + tablesRetransmitPeriod: 40, + + pm: newProgramMap(), + pmt: PMTData{ + ElementaryStreams: []*PMTElementaryStream{}, + ProgramNumber: programNumberStart, + }, + + // table version is 5-bit field + patVersion: newWrappingCounter(0b11111), + pmtVersion: newWrappingCounter(0b11111), + + esContexts: map[uint16]*esContext{}, + } + + m.bufWriter = astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: &m.buf}) + m.bitsWriter = astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: m.w}) + + // TODO multiple programs support + m.pm.set(pmtStartPID, programNumberStart) + + for _, opt := range opts { + opt(m) + } + + // to output tables at the very start + m.tablesRetransmitCounter = m.tablesRetransmitPeriod + + return m +} + +// if es.ElementaryPID is zero, it will be generated automatically +func (m *Muxer) AddElementaryStream(es PMTElementaryStream) error { + if es.ElementaryPID != 0 { + for _, oes := range m.pmt.ElementaryStreams { + if oes.ElementaryPID == es.ElementaryPID { + return ErrPIDAlreadyExists + } + } + } else { + es.ElementaryPID = m.nextPID + m.nextPID++ + } + + m.pmt.ElementaryStreams = append(m.pmt.ElementaryStreams, &es) + + m.esContexts[es.ElementaryPID] = newEsContext(&es) + // invalidate pmt cache + m.pmtBytes.Reset() + return nil +} + +func (m *Muxer) RemoveElementaryStream(pid uint16) error { + foundIdx := -1 + for i, oes := range m.pmt.ElementaryStreams { + if oes.ElementaryPID == pid { + foundIdx = i + break + } + } + + if foundIdx == -1 { + return ErrPIDNotFound + } + + m.pmt.ElementaryStreams = append(m.pmt.ElementaryStreams[:foundIdx], m.pmt.ElementaryStreams[foundIdx+1:]...) + delete(m.esContexts, pid) + m.pmtBytes.Reset() + return nil +} + +// SetPCRPID marks pid as one to look PCRs in +func (m *Muxer) SetPCRPID(pid uint16) { + m.pmt.PCRPID = pid +} + +// WriteData writes MuxerData to TS stream +// Currently only PES packets are supported +// Be aware that after successful call WriteData will set d.AdaptationField.StuffingLength value to zero +func (m *Muxer) WriteData(d *MuxerData) (int, error) { + ctx, ok := m.esContexts[d.PID] + if !ok { + return 0, ErrPIDNotFound + } + + bytesWritten := 0 + + forceTables := d.AdaptationField != nil && + d.AdaptationField.RandomAccessIndicator && + d.PID == m.pmt.PCRPID + + n, err := m.retransmitTables(forceTables) + if err != nil { + return n, err + } + + bytesWritten += n + + payloadStart := true + writeAf := d.AdaptationField != nil + payloadBytesWritten := 0 + for payloadBytesWritten < len(d.PES.Data) { + pktLen := 1 + mpegTsPacketHeaderSize // sync byte + header + pkt := Packet{ + Header: &PacketHeader{ + ContinuityCounter: uint8(ctx.cc.get()), + HasAdaptationField: writeAf, + HasPayload: false, + PayloadUnitStartIndicator: false, + PID: d.PID, + }, + } + + if writeAf { + pkt.AdaptationField = d.AdaptationField + // one byte for adaptation field length field + pktLen += 1 + int(calcPacketAdaptationFieldLength(d.AdaptationField)) + writeAf = false + } + + bytesAvailable := m.packetSize - pktLen + if payloadStart { + pesHeaderLengthCurrent := pesHeaderLength + int(calcPESOptionalHeaderLength(d.PES.Header.OptionalHeader)) + // d.AdaptationField with pes header are too big, we don't have space to write pes header + if bytesAvailable < pesHeaderLengthCurrent { + pkt.Header.HasAdaptationField = true + if pkt.AdaptationField == nil { + pkt.AdaptationField = newStuffingAdaptationField(bytesAvailable) + } else { + pkt.AdaptationField.StuffingLength = bytesAvailable + } + } else { + pkt.Header.HasPayload = true + pkt.Header.PayloadUnitStartIndicator = true + } + } else { + pkt.Header.HasPayload = true + } + + if pkt.Header.HasPayload { + m.buf.Reset() + if d.PES.Header.StreamID == 0 { + d.PES.Header.StreamID = ctx.es.StreamType.ToPESStreamID() + } + + ntot, npayload, err := writePESData( + m.bufWriter, + d.PES.Header, + d.PES.Data[payloadBytesWritten:], + payloadStart, + bytesAvailable, + ) + if err != nil { + return bytesWritten, err + } + + payloadBytesWritten += npayload + + pkt.Payload = m.buf.Bytes() + + bytesAvailable -= ntot + // if we still have some space in packet, we should stuff it with adaptation field stuffing + // we can't stuff packets with 0xff at the end of a packet since it's not uncommon for PES payloads to have length unspecified + if bytesAvailable > 0 { + pkt.Header.HasAdaptationField = true + if pkt.AdaptationField == nil { + pkt.AdaptationField = newStuffingAdaptationField(bytesAvailable) + } else { + pkt.AdaptationField.StuffingLength = bytesAvailable + } + } + + n, err = writePacket(m.bitsWriter, &pkt, m.packetSize) + if err != nil { + return bytesWritten, err + } + + bytesWritten += n + + payloadStart = false + } + } + + if d.AdaptationField != nil { + d.AdaptationField.StuffingLength = 0 + } + + return bytesWritten, nil +} + +// Writes given packet to MPEG-TS stream +// Stuffs with 0xffs if packet turns out to be shorter than target packet length +func (m *Muxer) WritePacket(p *Packet) (int, error) { + return writePacket(m.bitsWriter, p, m.packetSize) +} + +func (m *Muxer) retransmitTables(force bool) (int, error) { + m.tablesRetransmitCounter++ + if !force && m.tablesRetransmitCounter < m.tablesRetransmitPeriod { + return 0, nil + } + + n, err := m.WriteTables() + if err != nil { + return n, err + } + + m.tablesRetransmitCounter = 0 + return n, nil +} + +func (m *Muxer) WriteTables() (int, error) { + bytesWritten := 0 + + if m.patBytes.Len() != m.packetSize { + if err := m.generatePAT(); err != nil { + return bytesWritten, err + } + } + + if m.pmtBytes.Len() != m.packetSize { + if err := m.generatePMT(); err != nil { + return bytesWritten, err + } + } + + n, err := m.w.Write(m.patBytes.Bytes()) + if err != nil { + return bytesWritten, err + } + bytesWritten += n + + n, err = m.w.Write(m.pmtBytes.Bytes()) + if err != nil { + return bytesWritten, err + } + bytesWritten += n + + return bytesWritten, nil +} + +func (m *Muxer) generatePAT() error { + d := m.pm.toPATData() + syntax := &PSISectionSyntax{ + Data: &PSISectionSyntaxData{PAT: d}, + Header: &PSISectionSyntaxHeader{ + CurrentNextIndicator: true, + // TODO support for PAT tables longer than 1 TS packet + //LastSectionNumber: 0, + //SectionNumber: 0, + TableIDExtension: d.TransportStreamID, + VersionNumber: uint8(m.patVersion.get()), + }, + } + section := PSISection{ + Header: &PSISectionHeader{ + SectionLength: calcPATSectionLength(d), + SectionSyntaxIndicator: true, + TableID: PSITableID(d.TransportStreamID), + }, + Syntax: syntax, + } + psiData := PSIData{ + Sections: []*PSISection{§ion}, + } + + m.buf.Reset() + w := astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: &m.buf}) + if _, err := writePSIData(w, &psiData); err != nil { + return err + } + + m.patBytes.Reset() + wPacket := astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: &m.patBytes}) + + pkt := Packet{ + Header: &PacketHeader{ + HasPayload: true, + PayloadUnitStartIndicator: true, + PID: PIDPAT, + }, + Payload: m.buf.Bytes(), + } + if _, err := writePacket(wPacket, &pkt, m.packetSize); err != nil { + // FIXME save old PAT and rollback to it here maybe? + return err + } + + return nil +} + +func (m *Muxer) generatePMT() error { + hasPCRPID := false + for _, es := range m.pmt.ElementaryStreams { + if es.ElementaryPID == m.pmt.PCRPID { + hasPCRPID = true + break + } + } + if !hasPCRPID { + return ErrPCRPIDInvalid + } + + syntax := &PSISectionSyntax{ + Data: &PSISectionSyntaxData{PMT: &m.pmt}, + Header: &PSISectionSyntaxHeader{ + CurrentNextIndicator: true, + // TODO support for PMT tables longer than 1 TS packet + //LastSectionNumber: 0, + //SectionNumber: 0, + TableIDExtension: m.pmt.ProgramNumber, + VersionNumber: uint8(m.pmtVersion.get()), + }, + } + section := PSISection{ + Header: &PSISectionHeader{ + SectionLength: calcPMTSectionLength(&m.pmt), + SectionSyntaxIndicator: true, + TableID: PSITableIDPMT, + }, + Syntax: syntax, + } + psiData := PSIData{ + Sections: []*PSISection{§ion}, + } + + m.buf.Reset() + w := astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: &m.buf}) + if _, err := writePSIData(w, &psiData); err != nil { + return err + } + + m.pmtBytes.Reset() + wPacket := astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: &m.pmtBytes}) + + pkt := Packet{ + Header: &PacketHeader{ + HasPayload: true, + PayloadUnitStartIndicator: true, + PID: pmtStartPID, // FIXME multiple programs support + }, + Payload: m.buf.Bytes(), + } + if _, err := writePacket(wPacket, &pkt, m.packetSize); err != nil { + // FIXME save old PMT and rollback to it here maybe? + return err + } + + return nil +} diff --git a/muxer_test.go b/muxer_test.go new file mode 100644 index 0000000..3af2df1 --- /dev/null +++ b/muxer_test.go @@ -0,0 +1,321 @@ +package astits + +import ( + "bytes" + "context" + "github.com/asticode/go-astikit" + "github.com/stretchr/testify/assert" + "testing" +) + +func patExpectedBytes(versionNumber uint8) []byte { + buf := bytes.Buffer{} + w := astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: &buf}) + w.Write(uint8(syncByte)) + w.Write("010") // no transport error, payload start, no priority + w.WriteN(PIDPAT, 13) + w.Write("0001") // no scrambling, no AF, payload present + w.Write("0000") // CC + + w.Write(uint16(0)) // Table ID + w.Write("1011") // Syntax section indicator, private bit, reserved + w.WriteN(uint16(13), 12) // Section length + + w.Write(uint16(PSITableIDPAT)) + w.Write("11") // Reserved bits + w.WriteN(versionNumber, 5) // Version number + w.Write("1") // Current/next indicator + w.Write(uint8(0)) // Section number + w.Write(uint8(0)) // Last section number + + w.Write(programNumberStart) + w.Write("111") // reserved + w.WriteN(pmtStartPID, 13) + + // CRC32 + if versionNumber == 0 { + w.Write([]byte{0x71, 0x10, 0xd8, 0x78}) + } else { + w.Write([]byte{0xef, 0xbe, 0x08, 0x5a}) + } + + w.Write(bytes.Repeat([]byte{0xff}, 167)) + + return buf.Bytes() +} + +func TestMuxer_generatePAT(t *testing.T) { + muxer := NewMuxer(context.Background(), nil) + + err := muxer.generatePAT() + assert.NoError(t, err) + assert.Equal(t, MpegTsPacketSize, muxer.patBytes.Len()) + assert.Equal(t, patExpectedBytes(0), muxer.patBytes.Bytes()) + + // to check version number increment + err = muxer.generatePAT() + assert.NoError(t, err) + assert.Equal(t, MpegTsPacketSize, muxer.patBytes.Len()) + assert.Equal(t, patExpectedBytes(1), muxer.patBytes.Bytes()) +} + +func pmtExpectedBytesVideoOnly(versionNumber uint8) []byte { + buf := bytes.Buffer{} + w := astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: &buf}) + w.Write(uint8(syncByte)) + w.Write("010") // no transport error, payload start, no priority + w.WriteN(pmtStartPID, 13) + w.Write("0001") // no scrambling, no AF, payload present + w.Write("0000") // CC + + w.Write(uint16(PSITableIDPMT)) // Table ID + w.Write("1011") // Syntax section indicator, private bit, reserved + w.WriteN(uint16(18), 12) // Section length + + w.Write(programNumberStart) + w.Write("11") // Reserved bits + w.WriteN(versionNumber, 5) // Version number + w.Write("1") // Current/next indicator + w.Write(uint8(0)) // Section number + w.Write(uint8(0)) // Last section number + + w.Write("111") // reserved + w.WriteN(uint16(0x1234), 13) // PCR PID + + w.Write("1111") // reserved + w.WriteN(uint16(0), 12) // program info length + + w.Write(uint8(StreamTypeH264Video)) + w.Write("111") // reserved + w.WriteN(uint16(0x1234), 13) + + w.Write("1111") // reserved + w.WriteN(uint16(0), 12) // es info length + + w.Write([]byte{0x31, 0x48, 0x5b, 0xa2}) // CRC32 + + w.Write(bytes.Repeat([]byte{0xff}, 162)) + + return buf.Bytes() +} + +func pmtExpectedBytesVideoAndAudio(versionNumber uint8) []byte { + buf := bytes.Buffer{} + w := astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: &buf}) + w.Write(uint8(syncByte)) + w.Write("010") // no transport error, payload start, no priority + w.WriteN(pmtStartPID, 13) + w.Write("0001") // no scrambling, no AF, payload present + w.Write("0000") // CC + + w.Write(uint16(PSITableIDPMT)) // Table ID + w.Write("1011") // Syntax section indicator, private bit, reserved + w.WriteN(uint16(23), 12) // Section length + + w.Write(programNumberStart) + w.Write("11") // Reserved bits + w.WriteN(versionNumber, 5) // Version number + w.Write("1") // Current/next indicator + w.Write(uint8(0)) // Section number + w.Write(uint8(0)) // Last section number + + w.Write("111") // reserved + w.WriteN(uint16(0x1234), 13) // PCR PID + + w.Write("1111") // reserved + w.WriteN(uint16(0), 12) // program info length + + w.Write(uint8(StreamTypeH264Video)) + w.Write("111") // reserved + w.WriteN(uint16(0x1234), 13) + w.Write("1111") // reserved + w.WriteN(uint16(0), 12) // es info length + + w.Write(uint8(StreamTypeADTS)) + w.Write("111") // reserved + w.WriteN(uint16(0x0234), 13) + w.Write("1111") // reserved + w.WriteN(uint16(0), 12) // es info length + + // CRC32 + if versionNumber == 0 { + w.Write([]byte{0x29, 0x52, 0xc4, 0x50}) + } else { + w.Write([]byte{0x06, 0xf4, 0xa6, 0xea}) + } + + w.Write(bytes.Repeat([]byte{0xff}, 157)) + + return buf.Bytes() +} + +func TestMuxer_generatePMT(t *testing.T) { + muxer := NewMuxer(context.Background(), nil) + err := muxer.AddElementaryStream(PMTElementaryStream{ + ElementaryPID: 0x1234, + StreamType: StreamTypeH264Video, + }) + muxer.SetPCRPID(0x1234) + assert.NoError(t, err) + + err = muxer.generatePMT() + assert.NoError(t, err) + assert.Equal(t, MpegTsPacketSize, muxer.pmtBytes.Len()) + assert.Equal(t, pmtExpectedBytesVideoOnly(0), muxer.pmtBytes.Bytes()) + + err = muxer.AddElementaryStream(PMTElementaryStream{ + ElementaryPID: 0x0234, + StreamType: StreamTypeAACAudio, + }) + assert.NoError(t, err) + + err = muxer.generatePMT() + assert.NoError(t, err) + assert.Equal(t, MpegTsPacketSize, muxer.pmtBytes.Len()) + assert.Equal(t, pmtExpectedBytesVideoAndAudio(1), muxer.pmtBytes.Bytes()) +} + +func TestMuxer_WriteTables(t *testing.T) { + buf := bytes.Buffer{} + muxer := NewMuxer(context.Background(), &buf) + err := muxer.AddElementaryStream(PMTElementaryStream{ + ElementaryPID: 0x1234, + StreamType: StreamTypeH264Video, + }) + muxer.SetPCRPID(0x1234) + assert.NoError(t, err) + + n, err := muxer.WriteTables() + assert.NoError(t, err) + assert.Equal(t, 2*MpegTsPacketSize, n) + assert.Equal(t, n, buf.Len()) + + expectedBytes := append(patExpectedBytes(0), pmtExpectedBytesVideoOnly(0)...) + assert.Equal(t, expectedBytes, buf.Bytes()) +} + +func TestMuxer_WriteTables_Error(t *testing.T) { + muxer := NewMuxer(context.Background(), nil) + err := muxer.AddElementaryStream(PMTElementaryStream{ + ElementaryPID: 0x1234, + StreamType: StreamTypeH264Video, + }) + assert.NoError(t, err) + + _, err = muxer.WriteTables() + assert.Equal(t, ErrPCRPIDInvalid, err) +} + +func TestMuxer_AddElementaryStream(t *testing.T) { + muxer := NewMuxer(context.Background(), nil) + err := muxer.AddElementaryStream(PMTElementaryStream{ + ElementaryPID: 0x1234, + StreamType: StreamTypeH264Video, + }) + assert.NoError(t, err) + + err = muxer.AddElementaryStream(PMTElementaryStream{ + ElementaryPID: 0x1234, + StreamType: StreamTypeH264Video, + }) + assert.Equal(t, ErrPIDAlreadyExists, err) +} + +func TestMuxer_RemoveElementaryStream(t *testing.T) { + muxer := NewMuxer(context.Background(), nil) + err := muxer.AddElementaryStream(PMTElementaryStream{ + ElementaryPID: 0x1234, + StreamType: StreamTypeH264Video, + }) + assert.NoError(t, err) + + err = muxer.RemoveElementaryStream(0x1234) + assert.NoError(t, err) + + err = muxer.RemoveElementaryStream(0x1234) + assert.Equal(t, ErrPIDNotFound, err) +} + +func testPayload() []byte { + ret := make([]byte, 0xff+1) + for i := 0; i <= 0xff; i++ { + ret[i] = byte(i) + } + return ret +} + +func TestMuxer_WritePayload(t *testing.T) { + buf := bytes.Buffer{} + muxer := NewMuxer(context.Background(), &buf) + + err := muxer.AddElementaryStream(PMTElementaryStream{ + ElementaryPID: 0x1234, + StreamType: StreamTypeH264Video, + }) + muxer.SetPCRPID(0x1234) + assert.NoError(t, err) + + err = muxer.AddElementaryStream(PMTElementaryStream{ + ElementaryPID: 0x0234, + StreamType: StreamTypeAACAudio, + }) + assert.NoError(t, err) + + payload := testPayload() + pcr := ClockReference{ + Base: 5726623061, + Extension: 341, + } + pts := ClockReference{Base: 5726623060} + + n, err := muxer.WriteData(&MuxerData{ + PID: 0x1234, + AdaptationField: &PacketAdaptationField{ + HasPCR: true, + PCR: &pcr, + RandomAccessIndicator: true, + }, + PES: &PESData{ + Data: payload, + Header: &PESHeader{ + OptionalHeader: &PESOptionalHeader{ + DTS: &pts, + PTS: &pts, + PTSDTSIndicator: PTSDTSIndicatorBothPresent, + }, + }, + }, + }) + + assert.NoError(t, err) + assert.Equal(t, buf.Len(), n) + + bytesTotal := n + + n, err = muxer.WriteData(&MuxerData{ + PID: 0x0234, + AdaptationField: &PacketAdaptationField{ + HasPCR: true, + PCR: &pcr, + RandomAccessIndicator: true, + }, + PES: &PESData{ + Data: payload, + Header: &PESHeader{ + OptionalHeader: &PESOptionalHeader{ + DTS: &pts, + PTS: &pts, + PTSDTSIndicator: PTSDTSIndicatorBothPresent, + }, + }, + }, + }) + + assert.NoError(t, err) + assert.Equal(t, buf.Len(), bytesTotal+n) + assert.Equal(t, 0, buf.Len()%MpegTsPacketSize) + + bs := buf.Bytes() + assert.Equal(t, patExpectedBytes(0), bs[:MpegTsPacketSize]) + assert.Equal(t, pmtExpectedBytesVideoAndAudio(0), bs[MpegTsPacketSize:MpegTsPacketSize*2]) +} diff --git a/packet.go b/packet.go index 375980e..1e699d1 100644 --- a/packet.go +++ b/packet.go @@ -2,7 +2,6 @@ package astits import ( "fmt" - "github.com/asticode/go-astikit" ) @@ -14,6 +13,12 @@ const ( ScramblingControlScrambledWithOddKey = 3 ) +const ( + MpegTsPacketSize = 188 + mpegTsPacketHeaderSize = 3 + pcrBytesSize = 6 +) + // Packet represents a packet // https://en.wikipedia.org/wiki/MPEG_transport_stream type Packet struct { @@ -45,6 +50,8 @@ type PacketAdaptationField struct { HasTransportPrivateData bool HasSplicingCountdown bool Length int + IsOneByteStuffing bool // Only used for one byte stuffing - if true, adaptation field will be written as one uint8(0). Not part of TS format + StuffingLength int // Only used in writePacketAdaptationField to request stuffing OPCR *ClockReference // Original Program clock reference. Helps when one TS is copied into another PCR *ClockReference // Program clock reference RandomAccessIndicator bool // Set when the stream may be decoded without errors from this point @@ -85,7 +92,7 @@ func parsePacket(i *astikit.BytesIterator) (p *Packet, err error) { p = &Packet{} // In case packet size is bigger than 188 bytes, we don't care for the first bytes - i.Seek(i.Len() - 188 + 1) + i.Seek(i.Len() - MpegTsPacketSize + 1) offsetStart := i.Offset() // Parse header @@ -157,6 +164,8 @@ func parsePacketAdaptationField(i *astikit.BytesIterator) (a *PacketAdaptationFi // Length a.Length = int(b) + afStartOffset := i.Offset() + // Valid length if a.Length > 0 { // Get next byte @@ -287,6 +296,9 @@ func parsePacketAdaptationField(i *astikit.BytesIterator) (a *PacketAdaptationFi } } } + + a.StuffingLength = a.Length - (i.Offset() - afStartOffset) + return } @@ -302,3 +314,230 @@ func parsePCR(i *astikit.BytesIterator) (cr *ClockReference, err error) { cr = newClockReference(int64(pcr>>15), int64(pcr&0x1ff)) return } + +func writePacket(w *astikit.BitsWriter, p *Packet, targetPacketSize int) (written int, retErr error) { + if retErr = w.Write(uint8(syncByte)); retErr != nil { + return + } + written += 1 + + n, retErr := writePacketHeader(w, p.Header) + if retErr != nil { + return + } + written += n + + if p.Header.HasAdaptationField { + n, retErr = writePacketAdaptationField(w, p.AdaptationField) + if retErr != nil { + return + } + written += n + } + + if targetPacketSize-written < len(p.Payload) { + return 0, fmt.Errorf( + "writePacket: can't write %d bytes of payload: only %d is available", + len(p.Payload), + targetPacketSize-written, + ) + } + + if p.Header.HasPayload { + retErr = w.Write(p.Payload) + if retErr != nil { + return + } + written += len(p.Payload) + } + + for written < targetPacketSize { + if retErr = w.Write(uint8(0xff)); retErr != nil { + return + } + written++ + } + + return written, nil +} + +func writePacketHeader(w *astikit.BitsWriter, h *PacketHeader) (written int, retErr error) { + b := astikit.NewBitsWriterBatch(w) + + b.Write(h.TransportErrorIndicator) + b.Write(h.PayloadUnitStartIndicator) + b.Write(h.TransportPriority) + b.WriteN(h.PID, 13) + b.WriteN(h.TransportScramblingControl, 2) + b.Write(h.HasAdaptationField) // adaptation_field_control higher bit + b.Write(h.HasPayload) // adaptation_field_control lower bit + b.WriteN(h.ContinuityCounter, 4) + + return mpegTsPacketHeaderSize, b.Err() +} + +func writePCR(w *astikit.BitsWriter, cr *ClockReference) (int, error) { + b := astikit.NewBitsWriterBatch(w) + + b.WriteN(uint64(cr.Base), 33) + b.WriteN(uint8(0xff), 6) + b.WriteN(uint64(cr.Extension), 9) + return pcrBytesSize, b.Err() +} + +func calcPacketAdaptationFieldLength(af *PacketAdaptationField) (length uint8) { + length++ + if af.HasPCR { + length += pcrBytesSize + } + if af.HasOPCR { + length += pcrBytesSize + } + if af.HasSplicingCountdown { + length++ + } + if af.HasTransportPrivateData { + length += 1 + uint8(len(af.TransportPrivateData)) + } + if af.HasAdaptationExtensionField { + length += 1 + calcPacketAdaptationFieldExtensionLength(af.AdaptationExtensionField) + } + length += uint8(af.StuffingLength) + return +} + +func writePacketAdaptationField(w *astikit.BitsWriter, af *PacketAdaptationField) (bytesWritten int, retErr error) { + b := astikit.NewBitsWriterBatch(w) + + if af.IsOneByteStuffing { + b.Write(uint8(0)) + return 1, nil + } + + length := calcPacketAdaptationFieldLength(af) + b.Write(length) + bytesWritten++ + + b.Write(af.DiscontinuityIndicator) + b.Write(af.RandomAccessIndicator) + b.Write(af.ElementaryStreamPriorityIndicator) + b.Write(af.HasPCR) + b.Write(af.HasOPCR) + b.Write(af.HasSplicingCountdown) + b.Write(af.HasTransportPrivateData) + b.Write(af.HasAdaptationExtensionField) + + bytesWritten++ + + if af.HasPCR { + n, err := writePCR(w, af.PCR) + if err != nil { + return 0, err + } + bytesWritten += n + } + + if af.HasOPCR { + n, err := writePCR(w, af.OPCR) + if err != nil { + return 0, err + } + bytesWritten += n + } + + if af.HasSplicingCountdown { + b.Write(uint8(af.SpliceCountdown)) + bytesWritten++ + } + + if af.HasTransportPrivateData { + // we can get length from TransportPrivateData itself, why do we need separate field? + b.Write(uint8(af.TransportPrivateDataLength)) + bytesWritten++ + if af.TransportPrivateDataLength > 0 { + b.Write(af.TransportPrivateData) + } + bytesWritten += len(af.TransportPrivateData) + } + + if af.HasAdaptationExtensionField { + n, err := writePacketAdaptationFieldExtension(w, af.AdaptationExtensionField) + if err != nil { + return 0, err + } + bytesWritten += n + } + + // stuffing + for i := 0; i < af.StuffingLength; i++ { + b.Write(uint8(0xff)) + bytesWritten++ + } + + retErr = b.Err() + return +} + +func calcPacketAdaptationFieldExtensionLength(afe *PacketAdaptationExtensionField) (length uint8) { + length++ + if afe.HasLegalTimeWindow { + length += 2 + } + if afe.HasPiecewiseRate { + length += 3 + } + if afe.HasSeamlessSplice { + length += ptsOrDTSByteLength + } + return length +} + +func writePacketAdaptationFieldExtension(w *astikit.BitsWriter, afe *PacketAdaptationExtensionField) (bytesWritten int, retErr error) { + b := astikit.NewBitsWriterBatch(w) + + length := calcPacketAdaptationFieldExtensionLength(afe) + b.Write(length) + bytesWritten++ + + b.Write(afe.HasLegalTimeWindow) + b.Write(afe.HasPiecewiseRate) + b.Write(afe.HasSeamlessSplice) + b.WriteN(uint8(0xff), 5) // reserved + bytesWritten++ + + if afe.HasLegalTimeWindow { + b.Write(afe.LegalTimeWindowIsValid) + b.WriteN(afe.LegalTimeWindowOffset, 15) + bytesWritten += 2 + } + + if afe.HasPiecewiseRate { + b.WriteN(uint8(0xff), 2) + b.WriteN(afe.PiecewiseRate, 22) + bytesWritten += 3 + } + + if afe.HasSeamlessSplice { + n, err := writePTSOrDTS(w, afe.SpliceType, afe.DTSNextAccessUnit) + if err != nil { + return 0, err + } + bytesWritten += n + } + + retErr = b.Err() + return +} + +func newStuffingAdaptationField(bytesToStuff int) *PacketAdaptationField { + if bytesToStuff == 1 { + return &PacketAdaptationField{ + IsOneByteStuffing: true, + } + } + + return &PacketAdaptationField{ + // one byte for length and one for flags + StuffingLength: bytesToStuff - 2, + } +} diff --git a/packet_buffer.go b/packet_buffer.go index 5273b2e..db3cf3a 100644 --- a/packet_buffer.go +++ b/packet_buffer.go @@ -1,6 +1,7 @@ package astits import ( + "bufio" "fmt" "io" @@ -9,8 +10,9 @@ import ( // packetBuffer represents a packet buffer type packetBuffer struct { - packetSize int - r io.Reader + packetSize int + r io.Reader + packetReadBuffer []byte } // newPacketBuffer creates a new packet buffer @@ -39,8 +41,9 @@ func autoDetectPacketSize(r io.Reader) (packetSize int, err error) { // Read first bytes const l = 193 var b = make([]byte, l) - if _, err = r.Read(b); err != nil { - err = fmt.Errorf("astits: reading first %d bytes failed: %w", l, err) + shouldRewind, rerr := peek(r, b) + if rerr != nil { + err = fmt.Errorf("astits: reading first %d bytes failed: %w", l, rerr) return } @@ -52,10 +55,14 @@ func autoDetectPacketSize(r io.Reader) (packetSize int, err error) { // Look for sync bytes for idx, b := range b { - if b == syncByte && idx >= 188 { + if b == syncByte && idx >= MpegTsPacketSize { // Update packet size packetSize = idx + if !shouldRewind { + return + } + // Rewind or sync reader var n int64 if n, err = rewind(r); err != nil { @@ -75,6 +82,25 @@ func autoDetectPacketSize(r io.Reader) (packetSize int, err error) { return } +// bufio.Reader can't be rewinded, which leads to packet loss on packet size autodetection +// but it has handy Peek() method +// so what we do here is peeking bytes for bufio.Reader and falling back to rewinding/syncing for all other readers +func peek(r io.Reader, b []byte) (shouldRewind bool, err error) { + if br, ok := r.(*bufio.Reader); ok { + var bs []byte + bs, err = br.Peek(len(b)) + if err != nil { + return + } + copy(b, bs) + return false, nil + } + + _, err = r.Read(b) + shouldRewind = true + return +} + // rewind rewinds the reader if possible, otherwise n = -1 func rewind(r io.Reader) (n int64, err error) { if s, ok := r.(io.Seeker); ok { @@ -91,8 +117,11 @@ func rewind(r io.Reader) (n int64, err error) { // next fetches the next packet from the buffer func (pb *packetBuffer) next() (p *Packet, err error) { // Read - var b = make([]byte, pb.packetSize) - if _, err = io.ReadFull(pb.r, b); err != nil { + if pb.packetReadBuffer == nil || len(pb.packetReadBuffer) != pb.packetSize { + pb.packetReadBuffer = make([]byte, pb.packetSize) + } + + if _, err = io.ReadFull(pb.r, pb.packetReadBuffer); err != nil { if err == io.EOF || err == io.ErrUnexpectedEOF { err = ErrNoMorePackets } else { @@ -102,7 +131,7 @@ func (pb *packetBuffer) next() (p *Packet, err error) { } // Parse packet - if p, err = parsePacket(astikit.NewBytesIterator(b)); err != nil { + if p, err = parsePacket(astikit.NewBytesIterator(pb.packetReadBuffer)); err != nil { err = fmt.Errorf("astits: building packet failed: %w", err) return } diff --git a/packet_buffer_test.go b/packet_buffer_test.go index b0c3d6e..a4b9e35 100644 --- a/packet_buffer_test.go +++ b/packet_buffer_test.go @@ -29,6 +29,6 @@ func TestAutoDetectPacketSize(t *testing.T) { r := bytes.NewReader(buf.Bytes()) p, err := autoDetectPacketSize(r) assert.NoError(t, err) - assert.Equal(t, 188, p) + assert.Equal(t, MpegTsPacketSize, p) assert.Equal(t, 380, r.Len()) } diff --git a/packet_test.go b/packet_test.go index 873582d..42c281d 100644 --- a/packet_test.go +++ b/packet_test.go @@ -9,14 +9,16 @@ import ( "github.com/stretchr/testify/assert" ) -func packet(h PacketHeader, a PacketAdaptationField, i []byte) ([]byte, *Packet) { +func packet(h PacketHeader, a PacketAdaptationField, i []byte, packet192bytes bool) ([]byte, *Packet) { buf := &bytes.Buffer{} w := astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: buf}) - w.Write(uint8(syncByte)) // Sync byte - w.Write([]byte("test")) // Sometimes packets are 192 bytes - w.Write(packetHeaderBytes(h)) // Header - w.Write(packetAdaptationFieldBytes(a)) // Adaptation field - var payload = append(i, make([]byte, 147-len(i))...) // Payload + w.Write(uint8(syncByte)) // Sync byte + if packet192bytes { + w.Write([]byte("test")) // Sometimes packets are 192 bytes + } + w.Write(packetHeaderBytes(h, "11")) // Header + w.Write(packetAdaptationFieldBytes(a)) // Adaptation field + var payload = append(i, bytes.Repeat([]byte{0}, 147-len(i))...) // Payload w.Write(payload) return buf.Bytes(), &Packet{ AdaptationField: packetAdaptationField, @@ -25,6 +27,19 @@ func packet(h PacketHeader, a PacketAdaptationField, i []byte) ([]byte, *Packet) } } +func packetShort(h PacketHeader, payload []byte) ([]byte, *Packet) { + buf := &bytes.Buffer{} + w := astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: buf}) + w.Write(uint8(syncByte)) // Sync byte + w.Write(packetHeaderBytes(h, "01")) // Header + p := append(payload, bytes.Repeat([]byte{0}, MpegTsPacketSize-buf.Len())...) + w.Write(p) + return buf.Bytes(), &Packet{ + Header: &h, + Payload: payload, + } +} + func TestParsePacket(t *testing.T) { // Packet not starting with a sync buf := &bytes.Buffer{} @@ -34,7 +49,7 @@ func TestParsePacket(t *testing.T) { assert.EqualError(t, err, ErrPacketMustStartWithASyncByte.Error()) // Valid - b, ep := packet(*packetHeader, *packetAdaptationField, []byte("payload")) + b, ep := packet(*packetHeader, *packetAdaptationField, []byte("payload"), true) p, err := parsePacket(astikit.NewBytesIterator(b)) assert.NoError(t, err) assert.Equal(t, p, ep) @@ -45,6 +60,40 @@ func TestPayloadOffset(t *testing.T) { assert.Equal(t, 7, payloadOffset(1, &PacketHeader{HasAdaptationField: true}, &PacketAdaptationField{Length: 2})) } +func TestWritePacket(t *testing.T) { + eb, ep := packet(*packetHeader, *packetAdaptationField, []byte("payload"), false) + buf := &bytes.Buffer{} + w := astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: buf}) + n, err := writePacket(w, ep, MpegTsPacketSize) + assert.NoError(t, err) + assert.Equal(t, MpegTsPacketSize, n) + assert.Equal(t, n, buf.Len()) + assert.Equal(t, len(eb), buf.Len()) + assert.Equal(t, eb, buf.Bytes()) +} + +func TestWritePacket_HeaderOnly(t *testing.T) { + shortPacketHeader := *packetHeader + shortPacketHeader.HasPayload = false + shortPacketHeader.HasAdaptationField = false + _, ep := packetShort(shortPacketHeader, nil) + + buf := &bytes.Buffer{} + w := astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: buf}) + + n, err := writePacket(w, ep, MpegTsPacketSize) + assert.NoError(t, err) + assert.Equal(t, MpegTsPacketSize, n) + assert.Equal(t, n, buf.Len()) + + // we can't just compare bytes returned by packetShort since they're not completely correct, + // so we just cross-check writePacket with parsePacket + i := astikit.NewBytesIterator(buf.Bytes()) + p, err := parsePacket(i) + assert.NoError(t, err) + assert.Equal(t, ep, p) +} + var packetHeader = &PacketHeader{ ContinuityCounter: 10, HasAdaptationField: true, @@ -56,7 +105,7 @@ var packetHeader = &PacketHeader{ TransportScramblingControl: ScramblingControlScrambledWithEvenKey, } -func packetHeaderBytes(h PacketHeader) []byte { +func packetHeaderBytes(h PacketHeader, afControl string) []byte { buf := &bytes.Buffer{} w := astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: buf}) w.Write(h.TransportErrorIndicator) // Transport error indicator @@ -64,17 +113,27 @@ func packetHeaderBytes(h PacketHeader) []byte { w.Write("1") // Transport priority w.Write(fmt.Sprintf("%.13b", h.PID)) // PID w.Write("10") // Scrambling control - w.Write("11") // Adaptation field control + w.Write(afControl) // Adaptation field control w.Write(fmt.Sprintf("%.4b", h.ContinuityCounter)) // Continuity counter return buf.Bytes() } func TestParsePacketHeader(t *testing.T) { - v, err := parsePacketHeader(astikit.NewBytesIterator(packetHeaderBytes(*packetHeader))) + v, err := parsePacketHeader(astikit.NewBytesIterator(packetHeaderBytes(*packetHeader, "11"))) assert.Equal(t, packetHeader, v) assert.NoError(t, err) } +func TestWritePacketHeader(t *testing.T) { + buf := &bytes.Buffer{} + w := astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: buf}) + bytesWritten, err := writePacketHeader(w, packetHeader) + assert.NoError(t, err) + assert.Equal(t, bytesWritten, 3) + assert.Equal(t, bytesWritten, buf.Len()) + assert.Equal(t, packetHeaderBytes(*packetHeader, "11"), buf.Bytes()) +} + var packetAdaptationField = &PacketAdaptationField{ AdaptationExtensionField: &PacketAdaptationExtensionField{ DTSNextAccessUnit: dtsClockReference, @@ -101,6 +160,7 @@ var packetAdaptationField = &PacketAdaptationField{ SpliceCountdown: 2, TransportPrivateDataLength: 4, TransportPrivateData: []byte("test"), + StuffingLength: 5, } func packetAdaptationFieldBytes(a PacketAdaptationField) []byte { @@ -129,8 +189,8 @@ func packetAdaptationFieldBytes(a PacketAdaptationField) []byte { w.Write("010101010101010") // LTW offset w.Write("11") // Piecewise rate reserved w.Write("1010101010101010101010") // Piecewise rate - w.Write(dtsBytes()) // Splice type + DTS next access unit - w.Write([]byte("stuff")) // Stuffing bytes + w.Write(dtsBytes("0010")) // Splice type + DTS next access unit + w.WriteN(^uint64(0), 40) // Stuffing bytes return buf.Bytes() } @@ -140,6 +200,17 @@ func TestParsePacketAdaptationField(t *testing.T) { assert.NoError(t, err) } +func TestWritePacketAdaptationField(t *testing.T) { + buf := &bytes.Buffer{} + w := astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: buf}) + eb := packetAdaptationFieldBytes(*packetAdaptationField) + bytesWritten, err := writePacketAdaptationField(w, packetAdaptationField) + assert.NoError(t, err) + assert.Equal(t, bytesWritten, buf.Len()) + assert.Equal(t, len(eb), buf.Len()) + assert.Equal(t, eb, buf.Bytes()) +} + var pcr = &ClockReference{ Base: 5726623061, Extension: 341, @@ -159,3 +230,25 @@ func TestParsePCR(t *testing.T) { assert.Equal(t, pcr, v) assert.NoError(t, err) } + +func TestWritePCR(t *testing.T) { + buf := &bytes.Buffer{} + w := astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: buf}) + bytesWritten, err := writePCR(w, pcr) + assert.NoError(t, err) + assert.Equal(t, bytesWritten, 6) + assert.Equal(t, bytesWritten, buf.Len()) + assert.Equal(t, pcrBytes(), buf.Bytes()) +} + +func BenchmarkWritePCR(b *testing.B) { + buf := &bytes.Buffer{} + buf.Grow(6) + w := astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: buf}) + + b.ReportAllocs() + for i := 0; i < b.N; i++ { + buf.Reset() + writePCR(w, pcr) + } +} diff --git a/program_map.go b/program_map.go index b4f77b5..855f2d7 100644 --- a/program_map.go +++ b/program_map.go @@ -30,3 +30,28 @@ func (m programMap) set(pid, number uint16) { defer m.m.Unlock() m.p[pid] = number } + +func (m programMap) unset(pid uint16) { + m.m.Lock() + defer m.m.Unlock() + delete(m.p, pid) +} + +func (m programMap) toPATData() *PATData { + m.m.Lock() + defer m.m.Unlock() + + d := &PATData{ + Programs: []*PATProgram{}, + TransportStreamID: uint16(PSITableIDPAT), + } + + for pid, pnr := range m.p { + d.Programs = append(d.Programs, &PATProgram{ + ProgramMapID: pid, + ProgramNumber: pnr, + }) + } + + return d +} diff --git a/program_map_test.go b/program_map_test.go index cc22f76..de9b4a0 100644 --- a/program_map_test.go +++ b/program_map_test.go @@ -11,4 +11,6 @@ func TestProgramMap(t *testing.T) { assert.False(t, pm.exists(1)) pm.set(1, 1) assert.True(t, pm.exists(1)) + pm.unset(1) + assert.False(t, pm.exists(1)) } diff --git a/wrapping_counter.go b/wrapping_counter.go new file mode 100644 index 0000000..025bc4f --- /dev/null +++ b/wrapping_counter.go @@ -0,0 +1,22 @@ +package astits + +type wrappingCounter struct { + wrapAt int + value int +} + +func newWrappingCounter(wrapAt int) wrappingCounter { + return wrappingCounter{ + wrapAt: wrapAt, + } +} + +// returns current counter state and increments internal value +func (c *wrappingCounter) get() int { + ret := c.value + c.value++ + if c.value > c.wrapAt { + c.value = 0 + } + return ret +}