Skip to content

Commit

Permalink
APProcedurePatch: hotfix changing class variables to instance variabl…
Browse files Browse the repository at this point in the history
…es (ArchipelagoMW#2996)

* change class variables to instance variables

* Update worlds/Files.py

Co-authored-by: black-sliver <[email protected]>

* Update worlds/Files.py

Co-authored-by: black-sliver <[email protected]>

* move required_extensions to tuple

* fix missing tuple ellipsis

* fix classvar mixup

* rename tokens to _tokens. use hasattr

* type hint cleanup

* Update Files.py

* check using isinstance instead

---------

Co-authored-by: black-sliver <[email protected]>
  • Loading branch information
Silvris and black-sliver authored Mar 20, 2024
1 parent 12864f7 commit f4b7c28
Showing 1 changed file with 20 additions and 14 deletions.
34 changes: 20 additions & 14 deletions worlds/Files.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import os
import threading

from typing import ClassVar, Dict, List, Literal, Tuple, Any, Optional, Union, BinaryIO, overload
from typing import ClassVar, Dict, List, Literal, Tuple, Any, Optional, Union, BinaryIO, overload, Sequence

import bsdiff4

Expand Down Expand Up @@ -41,7 +41,7 @@ def get_handler(file: str) -> Optional[AutoPatchRegister]:

class AutoPatchExtensionRegister(abc.ABCMeta):
extension_types: ClassVar[Dict[str, AutoPatchExtensionRegister]] = {}
required_extensions: List[str] = []
required_extensions: Tuple[str, ...] = ()

def __new__(mcs, name: str, bases: Tuple[type, ...], dct: Dict[str, Any]) -> AutoPatchExtensionRegister:
# construct class
Expand All @@ -51,7 +51,9 @@ def __new__(mcs, name: str, bases: Tuple[type, ...], dct: Dict[str, Any]) -> Aut
return new_class

@staticmethod
def get_handler(game: str) -> Union[AutoPatchExtensionRegister, List[AutoPatchExtensionRegister]]:
def get_handler(game: Optional[str]) -> Union[AutoPatchExtensionRegister, List[AutoPatchExtensionRegister]]:
if not game:
return APPatchExtension
handler = AutoPatchExtensionRegister.extension_types.get(game, APPatchExtension)
if handler.required_extensions:
handlers = [handler]
Expand Down Expand Up @@ -191,7 +193,7 @@ class APProcedurePatch(APAutoPatchInterface):
hash: Optional[str] # base checksum of source file
source_data: bytes
patch_file_ending: str = ""
files: Dict[str, bytes] = {}
files: Dict[str, bytes]

@classmethod
def get_source_data(cls) -> bytes:
Expand All @@ -206,6 +208,7 @@ def get_source_data_with_cache(cls) -> bytes:

def __init__(self, *args: Any, **kwargs: Any):
super(APProcedurePatch, self).__init__(*args, **kwargs)
self.files = {}

def get_manifest(self) -> Dict[str, Any]:
manifest = super(APProcedurePatch, self).get_manifest()
Expand Down Expand Up @@ -277,7 +280,7 @@ def __init__(self, *args: Any, patched_path: str = "", **kwargs: Any) -> None:
super(APDeltaPatch, self).__init__(*args, **kwargs)
self.patched_path = patched_path

def write_contents(self, opened_zipfile: zipfile.ZipFile):
def write_contents(self, opened_zipfile: zipfile.ZipFile) -> None:
self.write_file("delta.bsdiff4",
bsdiff4.diff(self.get_source_data_with_cache(), open(self.patched_path, "rb").read()))
super(APDeltaPatch, self).write_contents(opened_zipfile)
Expand All @@ -296,21 +299,21 @@ class APTokenMixin:
"""
A class that defines functions for generating a token binary, for use in patches.
"""
tokens: List[
_tokens: Sequence[
Tuple[APTokenTypes, int, Union[
bytes, # WRITE
Tuple[int, int], # COPY, RLE
int # AND_8, OR_8, XOR_8
]]] = []
]]] = ()

def get_token_binary(self) -> bytes:
"""
Returns the token binary created from stored tokens.
:return: A bytes object representing the token data.
"""
data = bytearray()
data.extend(len(self.tokens).to_bytes(4, "little"))
for token_type, offset, args in self.tokens:
data.extend(len(self._tokens).to_bytes(4, "little"))
for token_type, offset, args in self._tokens:
data.append(token_type)
data.extend(offset.to_bytes(4, "little"))
if token_type in [APTokenTypes.AND_8, APTokenTypes.OR_8, APTokenTypes.XOR_8]:
Expand Down Expand Up @@ -351,11 +354,14 @@ def write_token(self,
data: bytes) -> None:
...

def write_token(self, token_type: APTokenTypes, offset: int, data: Union[bytes, Tuple[int, int], int]):
def write_token(self, token_type: APTokenTypes, offset: int, data: Union[bytes, Tuple[int, int], int]) -> None:
"""
Stores a token to be used by patching.
"""
self.tokens.append((token_type, offset, data))
if not isinstance(self._tokens, list):
assert len(self._tokens) == 0, f"{type(self)}._tokens was tampered with."
self._tokens = []
self._tokens.append((token_type, offset, data))


class APPatchExtension(metaclass=AutoPatchExtensionRegister):
Expand All @@ -371,10 +377,10 @@ class APPatchExtension(metaclass=AutoPatchExtensionRegister):
Patch extension functions must return the changed bytes.
"""
game: str
required_extensions: List[str] = []
required_extensions: ClassVar[Tuple[str, ...]] = ()

@staticmethod
def apply_bsdiff4(caller: APProcedurePatch, rom: bytes, patch: str):
def apply_bsdiff4(caller: APProcedurePatch, rom: bytes, patch: str) -> bytes:
"""Applies the given bsdiff4 from the patch onto the current file."""
return bsdiff4.patch(rom, caller.get_file(patch))

Expand Down Expand Up @@ -411,7 +417,7 @@ def apply_tokens(caller: APProcedurePatch, rom: bytes, token_file: str) -> bytes
return bytes(rom_data)

@staticmethod
def calc_snes_crc(caller: APProcedurePatch, rom: bytes):
def calc_snes_crc(caller: APProcedurePatch, rom: bytes) -> bytes:
"""Calculates and applies a valid CRC for the SNES rom header."""
rom_data = bytearray(rom)
if len(rom) < 0x8000:
Expand Down

0 comments on commit f4b7c28

Please sign in to comment.