diff --git a/worlds/Files.py b/worlds/Files.py index 6fee582c872..6e9bf6b31b5 100644 --- a/worlds/Files.py +++ b/worlds/Files.py @@ -7,7 +7,7 @@ import os import threading -from typing import ClassVar, Dict, List, Literal, Tuple, Any, Optional, Union, BinaryIO, overload +from typing import ClassVar, Dict, List, Literal, Tuple, Any, Optional, Union, BinaryIO, overload, Sequence import bsdiff4 @@ -41,7 +41,7 @@ def get_handler(file: str) -> Optional[AutoPatchRegister]: class AutoPatchExtensionRegister(abc.ABCMeta): extension_types: ClassVar[Dict[str, AutoPatchExtensionRegister]] = {} - required_extensions: List[str] = [] + required_extensions: Tuple[str, ...] = () def __new__(mcs, name: str, bases: Tuple[type, ...], dct: Dict[str, Any]) -> AutoPatchExtensionRegister: # construct class @@ -51,7 +51,9 @@ def __new__(mcs, name: str, bases: Tuple[type, ...], dct: Dict[str, Any]) -> Aut return new_class @staticmethod - def get_handler(game: str) -> Union[AutoPatchExtensionRegister, List[AutoPatchExtensionRegister]]: + def get_handler(game: Optional[str]) -> Union[AutoPatchExtensionRegister, List[AutoPatchExtensionRegister]]: + if not game: + return APPatchExtension handler = AutoPatchExtensionRegister.extension_types.get(game, APPatchExtension) if handler.required_extensions: handlers = [handler] @@ -191,7 +193,7 @@ class APProcedurePatch(APAutoPatchInterface): hash: Optional[str] # base checksum of source file source_data: bytes patch_file_ending: str = "" - files: Dict[str, bytes] = {} + files: Dict[str, bytes] @classmethod def get_source_data(cls) -> bytes: @@ -206,6 +208,7 @@ def get_source_data_with_cache(cls) -> bytes: def __init__(self, *args: Any, **kwargs: Any): super(APProcedurePatch, self).__init__(*args, **kwargs) + self.files = {} def get_manifest(self) -> Dict[str, Any]: manifest = super(APProcedurePatch, self).get_manifest() @@ -277,7 +280,7 @@ def __init__(self, *args: Any, patched_path: str = "", **kwargs: Any) -> None: super(APDeltaPatch, self).__init__(*args, **kwargs) self.patched_path = patched_path - def write_contents(self, opened_zipfile: zipfile.ZipFile): + def write_contents(self, opened_zipfile: zipfile.ZipFile) -> None: self.write_file("delta.bsdiff4", bsdiff4.diff(self.get_source_data_with_cache(), open(self.patched_path, "rb").read())) super(APDeltaPatch, self).write_contents(opened_zipfile) @@ -296,12 +299,12 @@ class APTokenMixin: """ A class that defines functions for generating a token binary, for use in patches. """ - tokens: List[ + _tokens: Sequence[ Tuple[APTokenTypes, int, Union[ bytes, # WRITE Tuple[int, int], # COPY, RLE int # AND_8, OR_8, XOR_8 - ]]] = [] + ]]] = () def get_token_binary(self) -> bytes: """ @@ -309,8 +312,8 @@ def get_token_binary(self) -> bytes: :return: A bytes object representing the token data. """ data = bytearray() - data.extend(len(self.tokens).to_bytes(4, "little")) - for token_type, offset, args in self.tokens: + data.extend(len(self._tokens).to_bytes(4, "little")) + for token_type, offset, args in self._tokens: data.append(token_type) data.extend(offset.to_bytes(4, "little")) if token_type in [APTokenTypes.AND_8, APTokenTypes.OR_8, APTokenTypes.XOR_8]: @@ -351,11 +354,14 @@ def write_token(self, data: bytes) -> None: ... - def write_token(self, token_type: APTokenTypes, offset: int, data: Union[bytes, Tuple[int, int], int]): + def write_token(self, token_type: APTokenTypes, offset: int, data: Union[bytes, Tuple[int, int], int]) -> None: """ Stores a token to be used by patching. """ - self.tokens.append((token_type, offset, data)) + if not isinstance(self._tokens, list): + assert len(self._tokens) == 0, f"{type(self)}._tokens was tampered with." + self._tokens = [] + self._tokens.append((token_type, offset, data)) class APPatchExtension(metaclass=AutoPatchExtensionRegister): @@ -371,10 +377,10 @@ class APPatchExtension(metaclass=AutoPatchExtensionRegister): Patch extension functions must return the changed bytes. """ game: str - required_extensions: List[str] = [] + required_extensions: ClassVar[Tuple[str, ...]] = () @staticmethod - def apply_bsdiff4(caller: APProcedurePatch, rom: bytes, patch: str): + def apply_bsdiff4(caller: APProcedurePatch, rom: bytes, patch: str) -> bytes: """Applies the given bsdiff4 from the patch onto the current file.""" return bsdiff4.patch(rom, caller.get_file(patch)) @@ -411,7 +417,7 @@ def apply_tokens(caller: APProcedurePatch, rom: bytes, token_file: str) -> bytes return bytes(rom_data) @staticmethod - def calc_snes_crc(caller: APProcedurePatch, rom: bytes): + def calc_snes_crc(caller: APProcedurePatch, rom: bytes) -> bytes: """Calculates and applies a valid CRC for the SNES rom header.""" rom_data = bytearray(rom) if len(rom) < 0x8000: