#!/usr/bin/env python3

import struct
import argparse
from collections import namedtuple

ImageHeaderStruct = struct.Struct("<LLHHLLBBHLL")
ImageHeader = namedtuple(
    "ImageHeader",
    [
        "ih_magic",
        "ih_load_addr",
        "ih_hdr_size",
        "ih_protect_tlv_size",
        "ih_img_size",
        "ih_flags",
        # ih_ver
        "iv_major",
        "iv_minor",
        "iv_revision",
        "iv_build_num",
        # end ih_ver
        "ih_pad1",
    ],
)

TLVInfoHeaderStruct = struct.Struct("<HH")
TLVInfoHeader = namedtuple(
    "TLVInfoHeader",
    [
        "it_magic",
        "it_tlv_tot",
    ],
)

TLVHeaderStruct = struct.Struct("<HH")
TLVHeader = namedtuple(
    "TLVHeader",
    [
        "it_type",
        "it_len",
    ],
)

TLV = namedtuple("TLV", ["header", "data"])


class MCUBootTLVIterator(object):
    IMAGE_TLV_INFO_MAGIC = 0x6907
    IMAGE_TLV_PROT_INFO_MAGIC = 0x6908

    def __init__(self, data, offset, end):
        self._data = data
        self._offset = offset
        self._end = end

    def __iter__(self):
        return self

    def __next__(self):
        offset = self._offset
        if offset >= self._end:
            raise StopIteration

        header = TLVHeader(
            *TLVHeaderStruct.unpack(
                self._data[offset : offset + TLVHeaderStruct.size]
            )
        )
        # Skip over TLV info headers
        while header.it_type in (
            self.IMAGE_TLV_PROT_INFO_MAGIC,
            self.IMAGE_TLV_INFO_MAGIC,
        ):
            offset += TLVHeaderStruct.size
            header = TLVHeader(
                *TLVHeaderStruct.unpack(
                    self._data[offset : offset + TLVHeaderStruct.size]
                )
            )

        offset += TLVHeaderStruct.size
        data = self._data[offset : offset + header.it_len]
        self._offset = offset + header.it_len
        return TLV(header, data)


class MCUBootImage(object):
    IMAGE_MAGIC = 0x96F3B83D
    IMAGE_MAGIC_V1 = 0x96F3B83C
    IMAGE_MAGIC_NONE = 0xFFFFFFFF

    def __init__(self, data):
        self._data = data
        self._header = None
        self._img_end = None
        self._ptlv_start = None
        self._ptlv_header = None
        self._ptlv_end = None
        self._utlv_start = None
        self._utlv_header = None
        self._utlv_end = None

        # Start at the head
        offset = 0

        self._header = ImageHeader(
            *ImageHeaderStruct.unpack(
                self._data[offset : ImageHeaderStruct.size]
            )
        )
        offset += ImageHeaderStruct.size

        if self._header.ih_magic not in (
            self.IMAGE_MAGIC,
            self.IMAGE_MAGIC_V1,
        ):
            raise ValueError(
                "Invalid image magic: 0x%08x" % self._header.ih_magic
            )

        # Compute offsets
        # Executable Image:
        offset = self._header.ih_hdr_size + self._header.ih_img_size
        self._img_end = offset

        tlv_header = TLVInfoHeader(
            *TLVInfoHeaderStruct.unpack(
                self._data[offset : offset + TLVInfoHeaderStruct.size]
            )
        )

        if (
            tlv_header.it_magic
            == MCUBootTLVIterator.IMAGE_TLV_PROT_INFO_MAGIC
        ):
            # Protected TLVs
            if tlv_header.it_tlv_tot != self._header.ih_protect_tlv_size:
                raise ValueError(
                    "Header ih_protect_tlv_size (%d) "
                    "does not match TLV info it_tlv_tot (%d)"
                    % (
                        self._header.ih_protect_tlv_size,
                        tlv_header.it_tlv_tot,
                    )
                )

            self._ptlv_start = offset
            self._ptlv_header = tlv_header
            offset += self._header.ih_protect_tlv_size
            self._ptlv_end = offset

            # Unprotected TLVs will follow
            tlv_header = TLVInfoHeader(
                *TLVInfoHeaderStruct.unpack(
                    self._data[offset : offset + TLVInfoHeaderStruct.size]
                )
            )
        elif tlv_header.it_magic == MCUBootTLVIterator.IMAGE_TLV_INFO_MAGIC:
            pass
        else:
            raise ValueError(
                "Garbage at end of image: %s" % self._data[offset:].hex()
            )

        # Unprotected TLVs should be next
        if tlv_header.it_magic != MCUBootTLVIterator.IMAGE_TLV_INFO_MAGIC:
            raise ValueError(
                "Garbage at end of image: %s" % (self._data[offset:].hex())
            )

        self._utlv_start = offset
        self._utlv_header = tlv_header
        offset += tlv_header.it_tlv_tot
        self._utlv_end = offset

        assert len(data) == offset

    @property
    def header(self):
        return self._header

    @property
    def ptlvs(self):
        return MCUBootTLVIterator(
            self._data, self._ptlv_start, self._ptlv_end
        )

    @property
    def utlvs(self):
        return MCUBootTLVIterator(
            self._data, self._utlv_start, self._utlv_end
        )

    @property
    def exec_image(self):
        return self._data[self.header.ih_hdr_size : self._img_end]


if __name__ == "__main__":
    ap = argparse.ArgumentParser()
    ap.add_argument("filename", help="Binary image to analyse")

    args = ap.parse_args()

    data = bytes(open(args.filename, "rb").read())
    image = MCUBootImage(data)

    print(image.header)
    for idx, tlv in enumerate(image.ptlvs):
        print(
            "PTLV %2d: 0x%04x %s" % (idx, tlv.header.it_type, tlv.data.hex())
        )

    for idx, tlv in enumerate(image.utlvs):
        print(
            "UTLV %2d: 0x%04x %s" % (idx, tlv.header.it_type, tlv.data.hex())
        )
