Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tests for scripts/imgtool #1983

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 26 additions & 17 deletions scripts/imgtool/dumpinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@

import click
import yaml
from intelhex import IntelHex

from imgtool import image
from imgtool.image import INTEL_HEX_EXT, IMAGE_MAGIC_LE

HEADER_ITEMS = ("magic", "load_addr", "hdr_size", "protected_tlv_size",
"img_size", "flags", "version")
Expand Down Expand Up @@ -129,14 +131,23 @@ def dump_imginfo(imgfile, outfile=None, silent=False):
trailer = {}
key_field_len = None

ext = os.path.splitext(imgfile)[1][1:].lower()
try:
with open(imgfile, "rb") as f:
b = f.read()
if ext == INTEL_HEX_EXT:
ih = IntelHex(imgfile)
b = ih.tobinstr()
else:
with open(imgfile, "rb") as f:
b = f.read()
except FileNotFoundError:
raise click.UsageError("Image file not found ({})".format(imgfile))
raise click.UsageError(f"Image file not found: {imgfile}")

# Detect image byteorder by image magic
magic = int.from_bytes(b[:4], "big")
order = '<' if magic == IMAGE_MAGIC_LE else '>'

# Parsing the image header
_header = struct.unpack('IIHHIIBBHI', b[:28])
_header = struct.unpack(order + 'IIHHIIBBHI', b[:28])
# Image version consists of the last 4 item ('BBHI')
_version = _header[-4:]
header = {}
Expand All @@ -155,26 +166,24 @@ def dump_imginfo(imgfile, outfile=None, silent=False):
protected_tlv_size = header["protected_tlv_size"]

if protected_tlv_size != 0:
_tlv_prot_head = struct.unpack(
'HH',
b[tlv_off:(tlv_off + image.TLV_INFO_SIZE)])
_tlv_prot_head = struct.unpack(order + 'HH',
b[tlv_off:(tlv_off + image.TLV_INFO_SIZE)])
tlv_area["tlv_hdr_prot"]["magic"] = _tlv_prot_head[0]
tlv_area["tlv_hdr_prot"]["tlv_tot"] = _tlv_prot_head[1]
tlv_end = tlv_off + tlv_area["tlv_hdr_prot"]["tlv_tot"]
tlv_off += image.TLV_INFO_SIZE

# Iterating through the protected TLV area
while tlv_off < tlv_end:
tlv_type, tlv_len = struct.unpack(
'HH',
b[tlv_off:(tlv_off + image.TLV_INFO_SIZE)])
tlv_type, tlv_len = struct.unpack(order + 'HH',
b[tlv_off:(tlv_off + image.TLV_INFO_SIZE)])
tlv_off += image.TLV_INFO_SIZE
tlv_data = b[tlv_off:(tlv_off + tlv_len)]
tlv_area["tlvs_prot"].append(
{"type": tlv_type, "len": tlv_len, "data": tlv_data})
tlv_off += tlv_len

_tlv_head = struct.unpack('HH', b[tlv_off:(tlv_off + image.TLV_INFO_SIZE)])
_tlv_head = struct.unpack(order + 'HH', b[tlv_off:(tlv_off + image.TLV_INFO_SIZE)])
tlv_area["tlv_hdr"]["magic"] = _tlv_head[0]
tlv_area["tlv_hdr"]["tlv_tot"] = _tlv_head[1]

Expand All @@ -183,9 +192,8 @@ def dump_imginfo(imgfile, outfile=None, silent=False):

# Iterating through the TLV area
while tlv_off < tlv_end:
tlv_type, tlv_len = struct.unpack(
'HH',
b[tlv_off:(tlv_off + image.TLV_INFO_SIZE)])
tlv_type, tlv_len = struct.unpack(order + 'HH',
b[tlv_off:(tlv_off + image.TLV_INFO_SIZE)])
tlv_off += image.TLV_INFO_SIZE
tlv_data = b[tlv_off:(tlv_off + tlv_len)]
tlv_area["tlvs"].append(
Expand All @@ -205,7 +213,7 @@ def dump_imginfo(imgfile, outfile=None, silent=False):
max_align = 8
elif trailer_magic[-len(BOOT_MAGIC_2):] == BOOT_MAGIC_2:
# The alignment value is encoded in the magic field
max_align = int.from_bytes(trailer_magic[:2], "little")
max_align = int.from_bytes(trailer_magic[:2], "big" if order == '>' else "little")
else:
# Invalid magic: the rest of the image trailer cannot be processed.
print("Warning: the trailer magic value is invalid!")
Expand All @@ -228,7 +236,7 @@ def dump_imginfo(imgfile, outfile=None, silent=False):

trailer_off -= max_align
swap_size = int.from_bytes(b[trailer_off:(trailer_off + 4)],
"little")
"big" if order == '>' else "little")
trailer["swap_size"] = swap_size

# Encryption key 0/1
Expand All @@ -254,7 +262,8 @@ def dump_imginfo(imgfile, outfile=None, silent=False):
sys.exit(0)

print("Printing content of signed image:", os.path.basename(imgfile), "\n")

byteorder_text = "Byte order: " + "little" if order == "<" else "big"
print_in_row(byteorder_text)
# Image header
section_name = "Image header (offset: 0x0)"
print_in_row(section_name)
Expand Down
32 changes: 19 additions & 13 deletions scripts/imgtool/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from .keys import rsa, ecdsa, x25519

IMAGE_MAGIC = 0x96f3b83d
IMAGE_MAGIC_LE = 0x3db8f396
IMAGE_HEADER_SIZE = 32
BIN_EXT = "bin"
INTEL_HEX_EXT = "hex"
Expand Down Expand Up @@ -101,7 +102,7 @@ def align_up(num, align):
return (num + (align - 1)) & ~(align - 1)


class TLV():
class TLV:
def __init__(self, endian, magic=TLV_INFO_MAGIC):
self.magic = magic
self.buf = bytearray()
Expand All @@ -123,7 +124,7 @@ def add(self, kind, payload):
raise click.UsageError(msg)
buf = struct.pack(e + 'HH', kind, len(payload))
else:
buf = struct.pack(e + 'BBH', TLV_VALUES[kind], 0, len(payload))
buf = struct.pack(e + 'HH', TLV_VALUES[kind], len(payload))
self.buf += buf
self.buf += payload

Expand Down Expand Up @@ -194,10 +195,11 @@ def __init__(self, version=None, header_size=IMAGE_HEADER_SIZE,
lsb = self.max_align & 0x00ff
msb = (self.max_align & 0xff00) >> 8
align = bytes([msb, lsb]) if self.endian == "big" else bytes([lsb, msb])
self.boot_magic = align + bytes([0x2d, 0xe1,
0x5d, 0x29, 0x41, 0x0b,
0x8d, 0x77, 0x67, 0x9c,
0x11, 0x0f, 0x1f, 0x8a, ])
self.boot_magic = align + bytes([
0x2d, 0xe1,
0x5d, 0x29, 0x41, 0x0b,
0x8d, 0x77, 0x67, 0x9c,
0x11, 0x0f, 0x1f, 0x8a, ])

if security_counter == 'auto':
# Security counter has not been explicitly provided,
Expand Down Expand Up @@ -238,7 +240,7 @@ def load(self, path):
with open(path, 'rb') as f:
self.payload = f.read()
except FileNotFoundError:
raise click.UsageError("Input file not found")
raise click.UsageError(f"Image file not found: {path}")

# Add the image header if needed.
if self.pad_header and self.header_size > 0:
Expand Down Expand Up @@ -647,21 +649,25 @@ def verify(imgfile, key):
with open(imgfile, 'rb') as f:
b = f.read()
except FileNotFoundError:
raise click.UsageError(f"Image file {imgfile} not found")
raise click.UsageError(f"Image file not found: {imgfile}")

magic, _, header_size, _, img_size = struct.unpack('IIHHI', b[:16])
version = struct.unpack('BBHI', b[20:28])
# Detect image byteorder by image magic
magic = int.from_bytes(b[:4], "big")
e = '<' if magic == IMAGE_MAGIC_LE else '>'

magic, _, header_size, _, img_size = struct.unpack(e + 'IIHHI', b[:16])
version = struct.unpack(e + 'BBHI', b[20:28])

if magic != IMAGE_MAGIC:
return VerifyResult.INVALID_MAGIC, None, None

tlv_off = header_size + img_size
tlv_info = b[tlv_off:tlv_off + TLV_INFO_SIZE]
magic, tlv_tot = struct.unpack('HH', tlv_info)
magic, tlv_tot = struct.unpack(e + 'HH', tlv_info)
if magic == TLV_PROT_INFO_MAGIC:
tlv_off += tlv_tot
tlv_info = b[tlv_off:tlv_off + TLV_INFO_SIZE]
magic, tlv_tot = struct.unpack('HH', tlv_info)
magic, tlv_tot = struct.unpack(e + 'HH', tlv_info)

if magic != TLV_INFO_MAGIC:
return VerifyResult.INVALID_TLV_INFO_MAGIC, None, None
Expand All @@ -673,7 +679,7 @@ def verify(imgfile, key):
tlv_off += TLV_INFO_SIZE # skip tlv info
while tlv_off < tlv_end:
tlv = b[tlv_off:tlv_off + TLV_SIZE]
tlv_type, _, tlv_len = struct.unpack('BBH', tlv)
tlv_type, tlv_len = struct.unpack(e + 'HH', tlv)
if tlv_type == TLV_VALUES["SHA256"] or tlv_type == TLV_VALUES["SHA384"]:
if not tlv_matches_key_type(tlv_type, key):
return VerifyResult.KEY_MISMATCH, None, None
Expand Down
16 changes: 13 additions & 3 deletions scripts/imgtool/keys/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright 2017 Linaro Limited
# Copyright 2023 Arm Limited
# Copyright 2023-2024 Arm Limited
#
# SPDX-License-Identifier: Apache-2.0
#
Expand Down Expand Up @@ -58,14 +58,24 @@ def load(path, passwd=None):
except TypeError as e:
msg = str(e)
if "private key is encrypted" in msg:
print(msg)
return None
raise e
except ValueError:
except ValueError as e:
msg1 = str(e)
# This seems to happen if the key is a public key, let's try
# loading it as a public key.
pk = serialization.load_pem_public_key(
try:
pk = serialization.load_pem_public_key(
raw_pem,
backend=default_backend())
except ValueError as e:
# If loading as public key also fails, that indicates wrong
# passphrase input
msg2 = str(e)
if ("password may be incorrect" in msg1 and
"Are you sure this is a public key" in msg2):
raise Exception("Invalid passphrase")

if isinstance(pk, RSAPrivateKey):
if pk.key_size not in RSA_KEY_SIZES:
Expand Down
72 changes: 41 additions & 31 deletions scripts/imgtool/keys/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,25 @@

# SPDX-License-Identifier: Apache-2.0

import binascii
import io
import os
import sys

from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes, PublicKeyTypes
from cryptography.hazmat.primitives.hashes import Hash, SHA256

from imgtool import keys

AUTOGEN_MESSAGE = "/* Autogenerated by imgtool.py, do not edit. */"


def key_types_matching(key: PrivateKeyTypes, enckey: PublicKeyTypes):
type_dict = {keys.ECDSA256P1: keys.ECDSA256P1Public,
keys.ECDSA384P1: keys.ECDSA384P1Public,
keys.Ed25519: keys.X25519Public,
keys.RSA: keys.RSAPublic}
return type_dict[type(key)] == type(enckey)


class FileHandler(object):
def __init__(self, file, *args, **kwargs):
self.file_in = file
Expand All @@ -34,7 +44,7 @@ def _emit(self, header, trailer, encoded_bytes, indent, file=sys.stdout,
len_format=None):
with FileHandler(file, 'w') as file:
self._emit_to_output(header, trailer, encoded_bytes, indent,
file, len_format)
file, len_format)

def _emit_to_output(self, header, trailer, encoded_bytes, indent, file,
len_format):
Expand Down Expand Up @@ -62,27 +72,27 @@ def _emit_raw(self, encoded_bytes, file):

def emit_c_public(self, file=sys.stdout):
self._emit(
header="const unsigned char {}_pub_key[] = {{"
.format(self.shortname()),
trailer="};",
encoded_bytes=self.get_public_bytes(),
indent=" ",
len_format="const unsigned int {}_pub_key_len = {{}};"
.format(self.shortname()),
file=file)
header="const unsigned char {}_pub_key[] = {{"
.format(self.shortname()),
trailer="};",
encoded_bytes=self.get_public_bytes(),
indent=" ",
len_format="const unsigned int {}_pub_key_len = {{}};"
.format(self.shortname()),
file=file)

def emit_c_public_hash(self, file=sys.stdout):
digest = Hash(SHA256())
digest.update(self.get_public_bytes())
self._emit(
header="const unsigned char {}_pub_key_hash[] = {{"
.format(self.shortname()),
trailer="};",
encoded_bytes=digest.finalize(),
indent=" ",
len_format="const unsigned int {}_pub_key_hash_len = {{}};"
.format(self.shortname()),
file=file)
header="const unsigned char {}_pub_key_hash[] = {{"
.format(self.shortname()),
trailer="};",
encoded_bytes=digest.finalize(),
indent=" ",
len_format="const unsigned int {}_pub_key_hash_len = {{}};"
.format(self.shortname()),
file=file)

def emit_raw_public(self, file=sys.stdout):
self._emit_raw(self.get_public_bytes(), file=file)
Expand All @@ -94,22 +104,22 @@ def emit_raw_public_hash(self, file=sys.stdout):

def emit_rust_public(self, file=sys.stdout):
self._emit(
header="static {}_PUB_KEY: &[u8] = &["
.format(self.shortname().upper()),
trailer="];",
encoded_bytes=self.get_public_bytes(),
indent=" ",
file=file)
header="static {}_PUB_KEY: &[u8] = &["
.format(self.shortname().upper()),
trailer="];",
encoded_bytes=self.get_public_bytes(),
indent=" ",
file=file)

def emit_public_pem(self, file=sys.stdout):
with FileHandler(file, 'w') as file:
print(str(self.get_public_pem(), 'utf-8'), file=file, end='')

def emit_private(self, minimal, format, file=sys.stdout):
self._emit(
header="const unsigned char enc_priv_key[] = {",
trailer="};",
encoded_bytes=self.get_private_bytes(minimal, format),
indent=" ",
len_format="const unsigned int enc_priv_key_len = {};",
file=file)
header="const unsigned char enc_priv_key[] = {",
trailer="};",
encoded_bytes=self.get_private_bytes(minimal, format),
indent=" ",
len_format="const unsigned int enc_priv_key_len = {};",
file=file)
Loading
Loading