Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Core: change Region caching to on_change from on-miss-strategy #2366

Merged
merged 12 commits into from
Oct 29, 2023
Merged
185 changes: 123 additions & 62 deletions BaseClasses.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from __future__ import annotations

import copy
import itertools
import functools
import logging
import random
import secrets
import typing # this can go away when Python 3.8 support is dropped
from argparse import Namespace
from collections import ChainMap, Counter, deque
from collections.abc import Collection
from collections import Counter, deque
from collections.abc import Collection, MutableSequence
from enum import IntEnum, IntFlag
from typing import Any, Callable, Dict, Iterable, Iterator, List, NamedTuple, Optional, Set, Tuple, TypedDict, Union, \
Type, ClassVar
Expand Down Expand Up @@ -47,7 +48,6 @@ def __getattr__(self, name: str) -> Any:
class MultiWorld():
debug_types = False
player_name: Dict[int, str]
_region_cache: Dict[int, Dict[str, Region]]
difficulty_requirements: dict
required_medallions: dict
dark_room_logic: Dict[int, str]
Expand All @@ -57,7 +57,7 @@ class MultiWorld():
plando_connections: List
worlds: Dict[int, auto_world]
groups: Dict[int, Group]
regions: List[Region]
regions: RegionManager
itempool: List[Item]
is_race: bool = False
precollected_items: Dict[int, List[Item]]
Expand Down Expand Up @@ -92,6 +92,34 @@ def __init__(self, rule):
def __getitem__(self, player) -> bool:
return self.rule(player)

class RegionManager:
region_cache: Dict[int, Dict[str, Region]]
entrance_cache: Dict[int, Dict[str, Entrance]]
location_cache: Dict[int, Dict[str, Location]]

def __init__(self, players: int):
self.region_cache = {player: {} for player in range(1, players+1)}
self.entrance_cache = {player: {} for player in range(1, players+1)}
self.location_cache = {player: {} for player in range(1, players+1)}

def __iadd__(self, other: Iterable[Region]):
self.extend(other)
return self

def append(self, region: Region):
self.region_cache[region.player][region.name] = region

def extend(self, regions: Iterable[Region]):
for region in regions:
self.region_cache[region.player][region.name] = region

def __iter__(self) -> Iterator[Region]:
for regions in self.region_cache.values():
yield from regions.values()

def __len__(self):
return sum(len(regions) for regions in self.region_cache.values())

def __init__(self, players: int):
# world-local random state is saved for multiple generations running concurrently
self.random = ThreadBarrierProxy(random.Random())
Expand All @@ -100,16 +128,12 @@ def __init__(self, players: int):
self.glitch_triforce = False
self.algorithm = 'balanced'
self.groups = {}
self.regions = []
self.regions = self.RegionManager(players)
self.shops = []
self.itempool = []
self.seed = None
self.seed_name: str = "Unavailable"
self.precollected_items = {player: [] for player in self.player_ids}
self._cached_entrances = None
self._cached_locations = None
self._entrance_cache = {}
self._location_cache: Dict[Tuple[str, int], Location] = {}
self.required_locations = []
self.light_world_light_cone = False
self.dark_world_light_cone = False
Expand Down Expand Up @@ -137,7 +161,6 @@ def __init__(self, players: int):
def set_player_attr(attr, val):
self.__dict__.setdefault(attr, {})[player] = val

set_player_attr('_region_cache', {})
set_player_attr('shuffle', "vanilla")
set_player_attr('logic', "noglitches")
set_player_attr('mode', 'open')
Expand Down Expand Up @@ -199,7 +222,6 @@ def add_group(self, name: str, game: str, players: Set[int] = frozenset()) -> Tu

self.game[new_id] = game
self.player_types[new_id] = NetUtils.SlotType.group
self._region_cache[new_id] = {}
world_type = AutoWorld.AutoWorldRegister.world_types[game]
self.worlds[new_id] = world_type.create_group(self, new_id, players)
self.worlds[new_id].collect_item = classmethod(AutoWorld.World.collect_item).__get__(self.worlds[new_id])
Expand Down Expand Up @@ -333,41 +355,17 @@ def get_out_file_name_base(self, player: int) -> str:
def world_name_lookup(self):
return {self.player_name[player_id]: player_id for player_id in self.player_ids}

def _recache(self):
"""Rebuild world cache"""
self._cached_locations = None
for region in self.regions:
player = region.player
self._region_cache[player][region.name] = region
for exit in region.exits:
self._entrance_cache[exit.name, player] = exit

for r_location in region.locations:
self._location_cache[r_location.name, player] = r_location

def get_regions(self, player: Optional[int] = None) -> Collection[Region]:
return self.regions if player is None else self._region_cache[player].values()
return self.regions if player is None else self.regions.region_cache[player].values()

def get_region(self, regionname: str, player: int) -> Region:
try:
return self._region_cache[player][regionname]
except KeyError:
self._recache()
return self._region_cache[player][regionname]
def get_region(self, region_name: str, player: int) -> Region:
return self.regions.region_cache[player][region_name]

def get_entrance(self, entrance: str, player: int) -> Entrance:
try:
return self._entrance_cache[entrance, player]
except KeyError:
self._recache()
return self._entrance_cache[entrance, player]
def get_entrance(self, entrance_name: str, player: int) -> Entrance:
return self.regions.entrance_cache[player][entrance_name]

def get_location(self, location: str, player: int) -> Location:
try:
return self._location_cache[location, player]
except KeyError:
self._recache()
return self._location_cache[location, player]
def get_location(self, location_name: str, player: int) -> Location:
return self.regions.location_cache[player][location_name]

def get_all_state(self, use_cache: bool) -> CollectionState:
cached = getattr(self, "_all_state", None)
Expand Down Expand Up @@ -428,28 +426,22 @@ def push_item(self, location: Location, item: Item, collect: bool = True):

logging.debug('Placed %s at %s', item, location)

def get_entrances(self) -> List[Entrance]:
if self._cached_entrances is None:
self._cached_entrances = [entrance for region in self.regions for entrance in region.entrances]
return self._cached_entrances

def clear_entrance_cache(self):
self._cached_entrances = None
def get_entrances(self, player: Optional[int] = None) -> Iterable[Entrance]:
if player is not None:
return self.regions.entrance_cache[player].values()
return Utils.RepeatableChain(tuple(self.regions.entrance_cache[player].values()
for player in self.regions.entrance_cache))

def register_indirect_condition(self, region: Region, entrance: Entrance):
"""Report that access to this Region can result in unlocking this Entrance,
state.can_reach(Region) in the Entrance's traversal condition, as opposed to pure transition logic."""
self.indirect_connections.setdefault(region, set()).add(entrance)

def get_locations(self, player: Optional[int] = None) -> List[Location]:
if self._cached_locations is None:
self._cached_locations = [location for region in self.regions for location in region.locations]
def get_locations(self, player: Optional[int] = None) -> Iterable[Location]:
if player is not None:
return [location for location in self._cached_locations if location.player == player]
return self._cached_locations

def clear_location_cache(self):
self._cached_locations = None
return self.regions.location_cache[player].values()
return Utils.RepeatableChain(tuple(self.regions.location_cache[player].values()
for player in self.regions.location_cache))

def get_unfilled_locations(self, player: Optional[int] = None) -> List[Location]:
return [location for location in self.get_locations(player) if location.item is None]
Expand All @@ -471,16 +463,17 @@ def get_unfilled_locations_for_players(self, location_names: List[str], players:
valid_locations = [location.name for location in self.get_unfilled_locations(player)]
else:
valid_locations = location_names
relevant_cache = self.regions.location_cache[player]
for location_name in valid_locations:
location = self._location_cache.get((location_name, player), None)
if location is not None and location.item is None:
location = relevant_cache.get(location_name, None)
if location and location.item is None:
yield location

def unlocks_new_location(self, item: Item) -> bool:
temp_state = self.state.copy()
temp_state.collect(item, True)

for location in self.get_unfilled_locations():
for location in self.get_unfilled_locations(item.player):
if temp_state.can_reach(location) and not self.state.can_reach(location):
return True

Expand Down Expand Up @@ -820,15 +813,83 @@ class Region:
locations: List[Location]
entrance_type: ClassVar[Type[Entrance]] = Entrance

class Register(MutableSequence):
region_manager: MultiWorld.RegionManager

def __init__(self, region_manager: MultiWorld.RegionManager):
self._list = []
self.region_manager = region_manager

def __getitem__(self, index: int) -> Location:
return self._list.__getitem__(index)

def __setitem__(self, index: int, value: Location) -> None:
raise NotImplementedError()

def __len__(self) -> int:
return self._list.__len__()

# This seems to not be needed, but that's a bit suspicious.
# def __del__(self):
# self.clear()

def copy(self):
return self._list.copy()

class LocationRegister(Register):
def __delitem__(self, index: int) -> None:
location: Location = self._list.__getitem__(index)
self._list.__delitem__(index)
del(self.region_manager.location_cache[location.player][location.name])

def insert(self, index: int, value: Location) -> None:
self._list.insert(index, value)
self.region_manager.location_cache[value.player][value.name] = value

class EntranceRegister(Register):
def __delitem__(self, index: int) -> None:
entrance: Entrance = self._list.__getitem__(index)
self._list.__delitem__(index)
del(self.region_manager.entrance_cache[entrance.player][entrance.name])

def insert(self, index: int, value: Entrance) -> None:
self._list.insert(index, value)
self.region_manager.entrance_cache[value.player][value.name] = value

_locations: LocationRegister[Location]
_exits: EntranceRegister[Entrance]

def __init__(self, name: str, player: int, multiworld: MultiWorld, hint: Optional[str] = None):
self.name = name
self.entrances = []
self.exits = []
self.locations = []
self._exits = self.EntranceRegister(multiworld.regions)
self._locations = self.LocationRegister(multiworld.regions)
self.multiworld = multiworld
self._hint_text = hint
self.player = player

def get_locations(self):
return self._locations

def set_locations(self, new):
if new is self._locations:
return
self._locations.clear()
self._locations.extend(new)

locations = property(get_locations, set_locations)

def get_exits(self):
return self._exits

def set_exits(self, new):
if new is self._exits:
return
self._exits.clear()
self._exits.extend(new)

exits = property(get_exits, set_exits)

def can_reach(self, state: CollectionState) -> bool:
if state.stale[self.player]:
state.update_reachable_regions(self.player)
Expand Down
7 changes: 1 addition & 6 deletions Main.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,6 @@ def main(args, seed=None, baked_server_options: Optional[Dict[str, object]] = No
logger.info('Creating Items.')
AutoWorld.call_all(world, "create_items")

# All worlds should have finished creating all regions, locations, and entrances.
# Recache to ensure that they are all visible for locality rules.
world._recache()

logger.info('Calculating Access Rules.')

for player in world.player_ids:
Expand Down Expand Up @@ -233,7 +229,7 @@ def find_common_pool(players: Set[int], shared_pool: Set[str]) -> Tuple[

region = Region("Menu", group_id, world, "ItemLink")
world.regions.append(region)
locations = region.locations = []
locations = region.locations
for item in world.itempool:
count = common_item_count.get(item.player, {}).get(item.name, 0)
if count:
Expand Down Expand Up @@ -267,7 +263,6 @@ def find_common_pool(players: Set[int], shared_pool: Set[str]) -> Tuple[
world.itempool.extend(items_to_add[:itemcount - len(world.itempool)])

if any(world.item_links.values()):
world._recache()
world._all_state = None

logger.info("Running Item Plando")
Expand Down
15 changes: 15 additions & 0 deletions Utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import typing
import builtins
import os
import itertools
import subprocess
import sys
import pickle
Expand Down Expand Up @@ -905,3 +906,17 @@ def visualize_other_regions() -> None:

with open(file_name, "wt", encoding="utf-8") as f:
f.write("\n".join(uml))


class RepeatableChain:
def __init__(self, iterable: typing.Iterable):
self.iterable = iterable

def __iter__(self):
return itertools.chain.from_iterable(self.iterable)

def __bool__(self):
return any(sub_iterable for sub_iterable in self.iterable)

def __len__(self):
return sum(len(iterable) for iterable in self.iterable)
2 changes: 1 addition & 1 deletion test/bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def test_fill(self):

# basically a shortened reimplementation of this method from core, in order to force the check is done
def fulfills_accessibility() -> bool:
locations = self.multiworld.get_locations(1).copy()
locations = list(self.multiworld.get_locations(1))
state = CollectionState(self.multiworld)
while locations:
sphere: typing.List[Location] = []
Expand Down
3 changes: 0 additions & 3 deletions test/general/test_locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def test_location_creation_steps(self):
for game_name, world_type in AutoWorldRegister.world_types.items():
with self.subTest("Game", game_name=game_name):
multiworld = setup_solo_multiworld(world_type, gen_steps)
multiworld._recache()
region_count = len(multiworld.get_regions())
location_count = len(multiworld.get_locations())

Expand All @@ -46,14 +45,12 @@ def test_location_creation_steps(self):
self.assertEqual(location_count, len(multiworld.get_locations()),
f"{game_name} modified locations count during rule creation")

multiworld._recache()
call_all(multiworld, "generate_basic")
self.assertEqual(region_count, len(multiworld.get_regions()),
f"{game_name} modified region count during generate_basic")
self.assertGreaterEqual(location_count, len(multiworld.get_locations()),
f"{game_name} modified locations count during generate_basic")

multiworld._recache()
call_all(multiworld, "pre_fill")
self.assertEqual(region_count, len(multiworld.get_regions()),
f"{game_name} modified region count during pre_fill")
Expand Down
1 change: 0 additions & 1 deletion worlds/alttp/ItemPool.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,6 @@ def generate_itempool(world):
loc.access_rule = lambda state: has_triforce_pieces(state, player)

region.locations.append(loc)
multiworld.clear_location_cache()

multiworld.push_item(loc, ItemFactory('Triforce', player), False)
loc.event = True
Expand Down
Loading