diff --git a/Starcraft2Client.py b/Starcraft2Client.py index cf16405766b4..cdcdb39a0b44 100644 --- a/Starcraft2Client.py +++ b/Starcraft2Client.py @@ -25,11 +25,10 @@ sc2_logger = logging.getLogger("Starcraft2") import nest_asyncio -import sc2 -from sc2.bot_ai import BotAI -from sc2.data import Race -from sc2.main import run_game -from sc2.player import Bot +from worlds._sc2common import bot +from worlds._sc2common.bot.data import Race +from worlds._sc2common.bot.main import run_game +from worlds._sc2common.bot.player import Bot from worlds.sc2wol import SC2WoLWorld from worlds.sc2wol.Items import lookup_id_to_name, item_table, ItemData, type_flaggroups from worlds.sc2wol.Locations import SC2WOL_LOC_ID_OFFSET @@ -240,8 +239,6 @@ def run_gui(self): from kivy.uix.floatlayout import FloatLayout from kivy.properties import StringProperty - import Utils - class HoverableButton(HoverBehavior, Button): pass @@ -544,11 +541,11 @@ async def starcraft_launch(ctx: SC2Context, mission_id: int): sc2_logger.info(f"Launching {lookup_id_to_mission[mission_id]}. If game does not launch check log file for errors.") with DllDirectory(None): - run_game(sc2.maps.get(maps_table[mission_id - 1]), [Bot(Race.Terran, ArchipelagoBot(ctx, mission_id), + run_game(bot.maps.get(maps_table[mission_id - 1]), [Bot(Race.Terran, ArchipelagoBot(ctx, mission_id), name="Archipelago", fullscreen=True)], realtime=True) -class ArchipelagoBot(sc2.bot_ai.BotAI): +class ArchipelagoBot(bot.bot_ai.BotAI): game_running: bool = False mission_completed: bool = False boni: typing.List[bool] @@ -867,7 +864,7 @@ def check_game_install_path() -> bool: documentspath = buf.value einfo = str(documentspath / Path("StarCraft II\\ExecuteInfo.txt")) else: - einfo = str(sc2.paths.get_home() / Path(sc2.paths.USERPATH[sc2.paths.PF])) + einfo = str(bot.paths.get_home() / Path(bot.paths.USERPATH[bot.paths.PF])) # Check if the file exists. if os.path.isfile(einfo): @@ -883,7 +880,7 @@ def check_game_install_path() -> bool: f"try again.") return False if os.path.exists(base): - executable = sc2.paths.latest_executeble(Path(base).expanduser() / "Versions") + executable = bot.paths.latest_executeble(Path(base).expanduser() / "Versions") # Finally, check the path for an actual executable. # If we find one, great. Set up the SC2PATH. diff --git a/worlds/_sc2common/__init__.py b/worlds/_sc2common/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/worlds/_sc2common/bot/LICENSE b/worlds/_sc2common/bot/LICENSE new file mode 100644 index 000000000000..76b89f323537 --- /dev/null +++ b/worlds/_sc2common/bot/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2017 Hannes Karppila + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/worlds/_sc2common/bot/README.md b/worlds/_sc2common/bot/README.md new file mode 100644 index 000000000000..d2b74a7ebbf0 --- /dev/null +++ b/worlds/_sc2common/bot/README.md @@ -0,0 +1,6 @@ +# SC2 Bot +This is client library to communicate with Starcraft 2 game +It's based on `burnysc2` python package, see https://github.com/BurnySc2/python-sc2 + +The base package is stripped down to clean up unneeded features and those not working outside a +melee game. diff --git a/worlds/_sc2common/bot/__init__.py b/worlds/_sc2common/bot/__init__.py new file mode 100644 index 000000000000..be5c5e104561 --- /dev/null +++ b/worlds/_sc2common/bot/__init__.py @@ -0,0 +1,16 @@ +from pathlib import Path +from loguru import logger + + +def is_submodule(path): + if path.is_file(): + return path.suffix == ".py" and path.stem != "__init__" + if path.is_dir(): + return (path / "__init__.py").exists() + return False + + +__all__ = [p.stem for p in Path(__file__).parent.iterdir() if is_submodule(p)] + + +logger = logger diff --git a/worlds/_sc2common/bot/bot_ai.py b/worlds/_sc2common/bot/bot_ai.py new file mode 100644 index 000000000000..79c11a5ad4a3 --- /dev/null +++ b/worlds/_sc2common/bot/bot_ai.py @@ -0,0 +1,476 @@ +# pylint: disable=W0212,R0916,R0904 +from __future__ import annotations + +import math +from functools import cached_property +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union + +from .bot_ai_internal import BotAIInternal +from .cache import property_cache_once_per_frame +from .data import Alert, Result +from .position import Point2 +from .unit import Unit +from .units import Units + +if TYPE_CHECKING: + from .game_info import Ramp + + +class BotAI(BotAIInternal): + """Base class for bots.""" + + EXPANSION_GAP_THRESHOLD = 15 + + @property + def time(self) -> float: + """ Returns time in seconds, assumes the game is played on 'faster' """ + return self.state.game_loop / 22.4 # / (1/1.4) * (1/16) + + @property + def time_formatted(self) -> str: + """ Returns time as string in min:sec format """ + t = self.time + return f"{int(t // 60):02}:{int(t % 60):02}" + + @property + def step_time(self) -> Tuple[float, float, float, float]: + """Returns a tuple of step duration in milliseconds. + First value is the minimum step duration - the shortest the bot ever took + Second value is the average step duration + Third value is the maximum step duration - the longest the bot ever took (including on_start()) + Fourth value is the step duration the bot took last iteration + If called in the first iteration, it returns (inf, 0, 0, 0)""" + avg_step_duration = ( + (self._total_time_in_on_step / self._total_steps_iterations) if self._total_steps_iterations else 0 + ) + return ( + self._min_step_time * 1000, + avg_step_duration * 1000, + self._max_step_time * 1000, + self._last_step_step_time * 1000, + ) + + def alert(self, alert_code: Alert) -> bool: + """ + Check if alert is triggered in the current step. + Possible alerts are listed here https://github.com/Blizzard/s2client-proto/blob/e38efed74c03bec90f74b330ea1adda9215e655f/s2clientprotocol/sc2api.proto#L679-L702 + + Example use:: + + from sc2.data import Alert + if self.alert(Alert.AddOnComplete): + print("Addon Complete") + + Alert codes:: + + AlertError + AddOnComplete + BuildingComplete + BuildingUnderAttack + LarvaHatched + MergeComplete + MineralsExhausted + MorphComplete + MothershipComplete + MULEExpired + NuclearLaunchDetected + NukeComplete + NydusWormDetected + ResearchComplete + TrainError + TrainUnitComplete + TrainWorkerComplete + TransformationComplete + UnitUnderAttack + UpgradeComplete + VespeneExhausted + WarpInComplete + + :param alert_code: + """ + assert isinstance(alert_code, Alert), f"alert_code {alert_code} is no Alert" + return alert_code.value in self.state.alerts + + @property + def start_location(self) -> Point2: + """ + Returns the spawn location of the bot, using the position of the first created townhall. + This will be None if the bot is run on an arcade or custom map that does not feature townhalls at game start. + """ + return self.game_info.player_start_location + + @property + def enemy_start_locations(self) -> List[Point2]: + """Possible start locations for enemies.""" + return self.game_info.start_locations + + @cached_property + def main_base_ramp(self) -> Ramp: + """Returns the Ramp instance of the closest main-ramp to start location. + Look in game_info.py for more information about the Ramp class + + Example: See terran ramp wall bot + """ + # The reason for len(ramp.upper) in {2, 5} is: + # ParaSite map has 5 upper points, and most other maps have 2 upper points at the main ramp. + # The map Acolyte has 4 upper points at the wrong ramp (which is closest to the start position). + try: + found_main_base_ramp = min( + (ramp for ramp in self.game_info.map_ramps if len(ramp.upper) in {2, 5}), + key=lambda r: self.start_location.distance_to(r.top_center), + ) + except ValueError: + # Hardcoded hotfix for Honorgrounds LE map, as that map has a large main base ramp with inbase natural + found_main_base_ramp = min( + (ramp for ramp in self.game_info.map_ramps if len(ramp.upper) in {4, 9}), + key=lambda r: self.start_location.distance_to(r.top_center), + ) + return found_main_base_ramp + + @property_cache_once_per_frame + def expansion_locations_list(self) -> List[Point2]: + """ Returns a list of expansion positions, not sorted in any way. """ + assert ( + self._expansion_positions_list + ), "self._find_expansion_locations() has not been run yet, so accessing the list of expansion locations is pointless." + return self._expansion_positions_list + + @property_cache_once_per_frame + def expansion_locations_dict(self) -> Dict[Point2, Units]: + """ + Returns dict with the correct expansion position Point2 object as key, + resources as Units (mineral fields and vespene geysers) as value. + + Caution: This function is slow. If you only need the expansion locations, use the property above. + """ + assert ( + self._expansion_positions_list + ), "self._find_expansion_locations() has not been run yet, so accessing the list of expansion locations is pointless." + expansion_locations: Dict[Point2, Units] = {pos: Units([], self) for pos in self._expansion_positions_list} + for resource in self.resources: + # It may be that some resources are not mapped to an expansion location + exp_position: Point2 = self._resource_location_to_expansion_position_dict.get(resource.position, None) + if exp_position: + assert exp_position in expansion_locations + expansion_locations[exp_position].append(resource) + return expansion_locations + + async def get_next_expansion(self) -> Optional[Point2]: + """Find next expansion location.""" + + closest = None + distance = math.inf + for el in self.expansion_locations_list: + + def is_near_to_expansion(t): + return t.distance_to(el) < self.EXPANSION_GAP_THRESHOLD + + if any(map(is_near_to_expansion, self.townhalls)): + # already taken + continue + + startp = self.game_info.player_start_location + d = await self.client.query_pathing(startp, el) + if d is None: + continue + + if d < distance: + distance = d + closest = el + + return closest + + # pylint: disable=R0912 + async def distribute_workers(self, resource_ratio: float = 2): + """ + Distributes workers across all the bases taken. + Keyword `resource_ratio` takes a float. If the current minerals to gas + ratio is bigger than `resource_ratio`, this function prefer filling gas_buildings + first, if it is lower, it will prefer sending workers to minerals first. + + NOTE: This function is far from optimal, if you really want to have + refined worker control, you should write your own distribution function. + For example long distance mining control and moving workers if a base was killed + are not being handled. + + WARNING: This is quite slow when there are lots of workers or multiple bases. + + :param resource_ratio:""" + if not self.mineral_field or not self.workers or not self.townhalls.ready: + return + worker_pool = self.workers.idle + bases = self.townhalls.ready + gas_buildings = self.gas_buildings.ready + + # list of places that need more workers + deficit_mining_places = [] + + for mining_place in bases | gas_buildings: + difference = mining_place.surplus_harvesters + # perfect amount of workers, skip mining place + if not difference: + continue + if mining_place.has_vespene: + # get all workers that target the gas extraction site + # or are on their way back from it + local_workers = self.workers.filter( + lambda unit: unit.order_target == mining_place.tag or + (unit.is_carrying_vespene and unit.order_target == bases.closest_to(mining_place).tag) + ) + else: + # get tags of minerals around expansion + local_minerals_tags = { + mineral.tag + for mineral in self.mineral_field if mineral.distance_to(mining_place) <= 8 + } + # get all target tags a worker can have + # tags of the minerals he could mine at that base + # get workers that work at that gather site + local_workers = self.workers.filter( + lambda unit: unit.order_target in local_minerals_tags or + (unit.is_carrying_minerals and unit.order_target == mining_place.tag) + ) + # too many workers + if difference > 0: + for worker in local_workers[:difference]: + worker_pool.append(worker) + # too few workers + # add mining place to deficit bases for every missing worker + else: + deficit_mining_places += [mining_place for _ in range(-difference)] + + # prepare all minerals near a base if we have too many workers + # and need to send them to the closest patch + if len(worker_pool) > len(deficit_mining_places): + all_minerals_near_base = [ + mineral for mineral in self.mineral_field + if any(mineral.distance_to(base) <= 8 for base in self.townhalls.ready) + ] + # distribute every worker in the pool + for worker in worker_pool: + # as long as have workers and mining places + if deficit_mining_places: + # choose only mineral fields first if current mineral to gas ratio is less than target ratio + if self.vespene and self.minerals / self.vespene < resource_ratio: + possible_mining_places = [place for place in deficit_mining_places if not place.vespene_contents] + # else prefer gas + else: + possible_mining_places = [place for place in deficit_mining_places if place.vespene_contents] + # if preferred type is not available any more, get all other places + if not possible_mining_places: + possible_mining_places = deficit_mining_places + # find closest mining place + current_place = min(deficit_mining_places, key=lambda place: place.distance_to(worker)) + # remove it from the list + deficit_mining_places.remove(current_place) + # if current place is a gas extraction site, go there + if current_place.vespene_contents: + worker.gather(current_place) + # if current place is a gas extraction site, + # go to the mineral field that is near and has the most minerals left + else: + local_minerals = ( + mineral for mineral in self.mineral_field if mineral.distance_to(current_place) <= 8 + ) + # local_minerals can be empty if townhall is misplaced + target_mineral = max(local_minerals, key=lambda mineral: mineral.mineral_contents, default=None) + if target_mineral: + worker.gather(target_mineral) + # more workers to distribute than free mining spots + # send to closest if worker is doing nothing + elif worker.is_idle and all_minerals_near_base: + target_mineral = min(all_minerals_near_base, key=lambda mineral: mineral.distance_to(worker)) + worker.gather(target_mineral) + else: + # there are no deficit mining places and worker is not idle + # so dont move him + pass + + @property_cache_once_per_frame + def owned_expansions(self) -> Dict[Point2, Unit]: + """Dict of expansions owned by the player with mapping {expansion_location: townhall_structure}.""" + owned = {} + for el in self.expansion_locations_list: + + def is_near_to_expansion(t): + return t.distance_to(el) < self.EXPANSION_GAP_THRESHOLD + + th = next((x for x in self.townhalls if is_near_to_expansion(x)), None) + if th: + owned[el] = th + return owned + + async def chat_send(self, message: str, team_only: bool = False): + """Send a chat message to the SC2 Client. + + Example:: + + await self.chat_send("Hello, this is a message from my bot!") + + :param message: + :param team_only:""" + assert isinstance(message, str), f"{message} is not a string" + await self.client.chat_send(message, team_only) + + def in_map_bounds(self, pos: Union[Point2, tuple, list]) -> bool: + """Tests if a 2 dimensional point is within the map boundaries of the pixelmaps. + + :param pos:""" + return ( + self.game_info.playable_area.x <= pos[0] < + self.game_info.playable_area.x + self.game_info.playable_area.width and self.game_info.playable_area.y <= + pos[1] < self.game_info.playable_area.y + self.game_info.playable_area.height + ) + + # For the functions below, make sure you are inside the boundaries of the map size. + def get_terrain_height(self, pos: Union[Point2, Unit]) -> int: + """Returns terrain height at a position. + Caution: terrain height is different from a unit's z-coordinate. + + :param pos:""" + assert isinstance(pos, (Point2, Unit)), "pos is not of type Point2 or Unit" + pos = pos.position.rounded + return self.game_info.terrain_height[pos] + + def get_terrain_z_height(self, pos: Union[Point2, Unit]) -> float: + """Returns terrain z-height at a position. + + :param pos:""" + assert isinstance(pos, (Point2, Unit)), "pos is not of type Point2 or Unit" + pos = pos.position.rounded + return -16 + 32 * self.game_info.terrain_height[pos] / 255 + + def in_placement_grid(self, pos: Union[Point2, Unit]) -> bool: + """Returns True if you can place something at a position. + Remember, buildings usually use 2x2, 3x3 or 5x5 of these grid points. + Caution: some x and y offset might be required, see ramp code in game_info.py + + :param pos:""" + assert isinstance(pos, (Point2, Unit)), "pos is not of type Point2 or Unit" + pos = pos.position.rounded + return self.game_info.placement_grid[pos] == 1 + + def in_pathing_grid(self, pos: Union[Point2, Unit]) -> bool: + """Returns True if a ground unit can pass through a grid point. + + :param pos:""" + assert isinstance(pos, (Point2, Unit)), "pos is not of type Point2 or Unit" + pos = pos.position.rounded + return self.game_info.pathing_grid[pos] == 1 + + def is_visible(self, pos: Union[Point2, Unit]) -> bool: + """Returns True if you have vision on a grid point. + + :param pos:""" + # more info: https://github.com/Blizzard/s2client-proto/blob/9906df71d6909511907d8419b33acc1a3bd51ec0/s2clientprotocol/spatial.proto#L19 + assert isinstance(pos, (Point2, Unit)), "pos is not of type Point2 or Unit" + pos = pos.position.rounded + return self.state.visibility[pos] == 2 + + def has_creep(self, pos: Union[Point2, Unit]) -> bool: + """Returns True if there is creep on the grid point. + + :param pos:""" + assert isinstance(pos, (Point2, Unit)), "pos is not of type Point2 or Unit" + pos = pos.position.rounded + return self.state.creep[pos] == 1 + + async def on_unit_destroyed(self, unit_tag: int): + """ + Override this in your bot class. + Note that this function uses unit tags and not the unit objects + because the unit does not exist any more. + This will event will be called when a unit (or structure, friendly or enemy) dies. + For enemy units, this only works if the enemy unit was in vision on death. + + :param unit_tag: + """ + + async def on_unit_created(self, unit: Unit): + """Override this in your bot class. This function is called when a unit is created. + + :param unit:""" + + async def on_building_construction_started(self, unit: Unit): + """ + Override this in your bot class. + This function is called when a building construction has started. + + :param unit: + """ + + async def on_building_construction_complete(self, unit: Unit): + """ + Override this in your bot class. This function is called when a building + construction is completed. + + :param unit: + """ + + async def on_unit_took_damage(self, unit: Unit, amount_damage_taken: float): + """ + Override this in your bot class. This function is called when your own unit (unit or structure) took damage. + It will not be called if the unit died this frame. + + This may be called frequently for terran structures that are burning down, or zerg buildings that are off creep, + or terran bio units that just used stimpack ability. + TODO: If there is a demand for it, then I can add a similar event for when enemy units took damage + + Examples:: + + print(f"My unit took damage: {unit} took {amount_damage_taken} damage") + + :param unit: + :param amount_damage_taken: + """ + + async def on_enemy_unit_entered_vision(self, unit: Unit): + """ + Override this in your bot class. This function is called when an enemy unit (unit or structure) entered vision (which was not visible last frame). + + :param unit: + """ + + async def on_enemy_unit_left_vision(self, unit_tag: int): + """ + Override this in your bot class. This function is called when an enemy unit (unit or structure) left vision (which was visible last frame). + Same as the self.on_unit_destroyed event, this function is called with the unit's tag because the unit is no longer visible anymore. + If you want to store a snapshot of the unit, use self._enemy_units_previous_map[unit_tag] for units or self._enemy_structures_previous_map[unit_tag] for structures. + + Examples:: + + last_known_unit = self._enemy_units_previous_map.get(unit_tag, None) or self._enemy_structures_previous_map[unit_tag] + print(f"Enemy unit left vision, last known location: {last_known_unit.position}") + + :param unit_tag: + """ + + async def on_before_start(self): + """ + Override this in your bot class. This function is called before "on_start" + and before "prepare_first_step" that calculates expansion locations. + Not all data is available yet. + This function is useful in realtime=True mode to split your workers or start producing the first worker. + """ + + async def on_start(self): + """ + Override this in your bot class. + At this point, game_data, game_info and the first iteration of game_state (self.state) are available. + """ + + async def on_step(self, iteration: int): + """ + You need to implement this function! + Override this in your bot class. + This function is called on every game step (looped in realtime mode). + + :param iteration: + """ + raise NotImplementedError + + async def on_end(self, game_result: Result): + """Override this in your bot class. This function is called at the end of a game. + Unsure if this function will be called on the laddermanager client as the bot process may forcefully be terminated. + + :param game_result:""" diff --git a/worlds/_sc2common/bot/bot_ai_internal.py b/worlds/_sc2common/bot/bot_ai_internal.py new file mode 100644 index 000000000000..583c491dc2f6 --- /dev/null +++ b/worlds/_sc2common/bot/bot_ai_internal.py @@ -0,0 +1,490 @@ +# pylint: disable=W0201,W0212,R0912 +from __future__ import annotations + +import math +import time +import warnings +from abc import ABC +from collections import Counter +from typing import TYPE_CHECKING, Any +from typing import Dict, Generator, Iterable, List, Set, Tuple, Union, final + +from s2clientprotocol import sc2api_pb2 as sc_pb + +from .constants import ( + IS_PLACEHOLDER, +) +from .data import Race +from .game_data import GameData +from .game_state import Blip, GameState +from .pixel_map import PixelMap +from .position import Point2 +from .unit import Unit +from .units import Units + +# with warnings.catch_warnings(): +# warnings.simplefilter("ignore") +# from scipy.spatial.distance import cdist, pdist + +if TYPE_CHECKING: + from .client import Client + from .game_info import GameInfo + + +class BotAIInternal(ABC): + """Base class for bots.""" + + @final + def _initialize_variables(self): + """ Called from main.py internally """ + self.cache: Dict[str, Any] = {} + # Specific opponent bot ID used in sc2ai ladder games http://sc2ai.net/ and on ai arena https://aiarena.net + # The bot ID will stay the same each game so your bot can "adapt" to the opponent + if not hasattr(self, "opponent_id"): + # Prevent overwriting the opponent_id which is set here https://github.com/Hannessa/python-sc2-ladderbot/blob/master/__init__.py#L40 + # otherwise set it to None + self.opponent_id: str = None + # Select distance calculation method, see _distances_override_functions function + if not hasattr(self, "distance_calculation_method"): + self.distance_calculation_method: int = 2 + # Select if the Unit.command should return UnitCommand objects. Set this to True if your bot uses 'self.do(unit(ability, target))' + if not hasattr(self, "unit_command_uses_self_do"): + self.unit_command_uses_self_do: bool = False + # This value will be set to True by main.py in self._prepare_start if game is played in realtime (if true, the bot will have limited time per step) + self.realtime: bool = False + self.base_build: int = -1 + self.all_units: Units = Units([], self) + self.units: Units = Units([], self) + self.workers: Units = Units([], self) + self.larva: Units = Units([], self) + self.structures: Units = Units([], self) + self.townhalls: Units = Units([], self) + self.gas_buildings: Units = Units([], self) + self.all_own_units: Units = Units([], self) + self.enemy_units: Units = Units([], self) + self.enemy_structures: Units = Units([], self) + self.all_enemy_units: Units = Units([], self) + self.resources: Units = Units([], self) + self.destructables: Units = Units([], self) + self.watchtowers: Units = Units([], self) + self.mineral_field: Units = Units([], self) + self.vespene_geyser: Units = Units([], self) + self.placeholders: Units = Units([], self) + self.techlab_tags: Set[int] = set() + self.reactor_tags: Set[int] = set() + self.minerals: int = 50 + self.vespene: int = 0 + self.supply_army: float = 0 + self.supply_workers: float = 12 # Doesn't include workers in production + self.supply_cap: float = 15 + self.supply_used: float = 12 + self.supply_left: float = 3 + self.idle_worker_count: int = 0 + self.army_count: int = 0 + self.warp_gate_count: int = 0 + self.blips: Set[Blip] = set() + self.race: Race = None + self.enemy_race: Race = None + self._generated_frame = -100 + self._units_created: Counter = Counter() + self._unit_tags_seen_this_game: Set[int] = set() + self._units_previous_map: Dict[int, Unit] = {} + self._structures_previous_map: Dict[int, Unit] = {} + self._enemy_units_previous_map: Dict[int, Unit] = {} + self._enemy_structures_previous_map: Dict[int, Unit] = {} + self._all_units_previous_map: Dict[int, Unit] = {} + self._expansion_positions_list: List[Point2] = [] + self._resource_location_to_expansion_position_dict: Dict[Point2, Point2] = {} + self._time_before_step: float = None + self._time_after_step: float = None + self._min_step_time: float = math.inf + self._max_step_time: float = 0 + self._last_step_step_time: float = 0 + self._total_time_in_on_step: float = 0 + self._total_steps_iterations: int = 0 + # Internally used to keep track which units received an action in this frame, so that self.train() function does not give the same larva two orders - cleared every frame + self.unit_tags_received_action: Set[int] = set() + + @final + @property + def _game_info(self) -> GameInfo: + """ See game_info.py """ + warnings.warn( + "Using self._game_info is deprecated and may be removed soon. Please use self.game_info directly.", + DeprecationWarning, + stacklevel=2, + ) + return self.game_info + + @final + @property + def _game_data(self) -> GameData: + """ See game_data.py """ + warnings.warn( + "Using self._game_data is deprecated and may be removed soon. Please use self.game_data directly.", + DeprecationWarning, + stacklevel=2, + ) + return self.game_data + + @final + @property + def _client(self) -> Client: + """ See client.py """ + warnings.warn( + "Using self._client is deprecated and may be removed soon. Please use self.client directly.", + DeprecationWarning, + stacklevel=2, + ) + return self.client + + @final + def _prepare_start(self, client, player_id, game_info, game_data, realtime: bool = False, base_build: int = -1): + """ + Ran until game start to set game and player data. + + :param client: + :param player_id: + :param game_info: + :param game_data: + :param realtime: + """ + self.client: Client = client + self.player_id: int = player_id + self.game_info: GameInfo = game_info + self.game_data: GameData = game_data + self.realtime: bool = realtime + self.base_build: int = base_build + + self.race: Race = Race(self.game_info.player_races[self.player_id]) + + if len(self.game_info.player_races) == 2: + self.enemy_race: Race = Race(self.game_info.player_races[3 - self.player_id]) + + + @final + def _prepare_first_step(self): + """First step extra preparations. Must not be called before _prepare_step.""" + if self.townhalls: + self.game_info.player_start_location = self.townhalls.first.position + # Calculate and cache expansion locations forever inside 'self._cache_expansion_locations', this is done to prevent a bug when this is run and cached later in the game + self._time_before_step: float = time.perf_counter() + + @final + def _prepare_step(self, state, proto_game_info): + """ + :param state: + :param proto_game_info: + """ + # Set attributes from new state before on_step.""" + self.state: GameState = state # See game_state.py + # update pathing grid, which unfortunately is in GameInfo instead of GameState + self.game_info.pathing_grid = PixelMap(proto_game_info.game_info.start_raw.pathing_grid, in_bits=True) + # Required for events, needs to be before self.units are initialized so the old units are stored + self._units_previous_map: Dict[int, Unit] = {unit.tag: unit for unit in self.units} + self._structures_previous_map: Dict[int, Unit] = {structure.tag: structure for structure in self.structures} + self._enemy_units_previous_map: Dict[int, Unit] = {unit.tag: unit for unit in self.enemy_units} + self._enemy_structures_previous_map: Dict[int, Unit] = { + structure.tag: structure + for structure in self.enemy_structures + } + self._all_units_previous_map: Dict[int, Unit] = {unit.tag: unit for unit in self.all_units} + + self._prepare_units() + self.minerals: int = state.common.minerals + self.vespene: int = state.common.vespene + self.supply_army: int = state.common.food_army + self.supply_workers: int = state.common.food_workers # Doesn't include workers in production + self.supply_cap: int = state.common.food_cap + self.supply_used: int = state.common.food_used + self.supply_left: int = self.supply_cap - self.supply_used + + if self.race == Race.Zerg: + # Workaround Zerg supply rounding bug + pass + # self._correct_zerg_supply() + elif self.race == Race.Protoss: + self.warp_gate_count: int = state.common.warp_gate_count + + self.idle_worker_count: int = state.common.idle_worker_count + self.army_count: int = state.common.army_count + self._time_before_step: float = time.perf_counter() + + if self.enemy_race == Race.Random and self.all_enemy_units: + self.enemy_race = Race(self.all_enemy_units.first.race) + + @final + def _prepare_units(self): + # Set of enemy units detected by own sensor tower, as blips have less unit information than normal visible units + self.blips: Set[Blip] = set() + self.all_units: Units = Units([], self) + self.units: Units = Units([], self) + self.workers: Units = Units([], self) + self.larva: Units = Units([], self) + self.structures: Units = Units([], self) + self.townhalls: Units = Units([], self) + self.gas_buildings: Units = Units([], self) + self.all_own_units: Units = Units([], self) + self.enemy_units: Units = Units([], self) + self.enemy_structures: Units = Units([], self) + self.all_enemy_units: Units = Units([], self) + self.resources: Units = Units([], self) + self.destructables: Units = Units([], self) + self.watchtowers: Units = Units([], self) + self.mineral_field: Units = Units([], self) + self.vespene_geyser: Units = Units([], self) + self.placeholders: Units = Units([], self) + self.techlab_tags: Set[int] = set() + self.reactor_tags: Set[int] = set() + + index: int = 0 + for unit in self.state.observation_raw.units: + if unit.is_blip: + self.blips.add(Blip(unit)) + else: + unit_type: int = unit.unit_type + # Convert these units to effects: reaper grenade, parasitic bomb dummy, forcefield + unit_obj = Unit(unit, self, distance_calculation_index=index, base_build=self.base_build) + index += 1 + self.all_units.append(unit_obj) + if unit.display_type == IS_PLACEHOLDER: + self.placeholders.append(unit_obj) + continue + alliance = unit.alliance + # Alliance.Neutral.value = 3 + if alliance == 3: + # XELNAGATOWER = 149 + if unit_type == 149: + self.watchtowers.append(unit_obj) + # all destructable rocks + else: + self.destructables.append(unit_obj) + # Alliance.Self.value = 1 + elif alliance == 1: + self.all_own_units.append(unit_obj) + if unit_obj.is_structure: + self.structures.append(unit_obj) + # Alliance.Enemy.value = 4 + elif alliance == 4: + self.all_enemy_units.append(unit_obj) + if unit_obj.is_structure: + self.enemy_structures.append(unit_obj) + else: + self.enemy_units.append(unit_obj) + + @final + async def _after_step(self) -> int: + """ Executed by main.py after each on_step function. """ + # Keep track of the bot on_step duration + self._time_after_step: float = time.perf_counter() + step_duration = self._time_after_step - self._time_before_step + self._min_step_time = min(step_duration, self._min_step_time) + self._max_step_time = max(step_duration, self._max_step_time) + self._last_step_step_time = step_duration + self._total_time_in_on_step += step_duration + self._total_steps_iterations += 1 + # Clear set of unit tags that were given an order this frame by self.do() + self.unit_tags_received_action.clear() + # Commit debug queries + await self.client._send_debug() + + return self.state.game_loop + + @final + async def _advance_steps(self, steps: int): + """Advances the game loop by amount of 'steps'. This function is meant to be used as a debugging and testing tool only. + If you are using this, please be aware of the consequences, e.g. 'self.units' will be filled with completely new data.""" + await self._after_step() + # Advance simulation by exactly "steps" frames + await self.client.step(steps) + state = await self.client.observation() + gs = GameState(state.observation) + proto_game_info = await self.client._execute(game_info=sc_pb.RequestGameInfo()) + self._prepare_step(gs, proto_game_info) + await self.issue_events() + + @final + async def issue_events(self): + """This function will be automatically run from main.py and triggers the following functions: + - on_unit_created + - on_unit_destroyed + - on_building_construction_started + - on_building_construction_complete + - on_upgrade_complete + """ + await self._issue_unit_dead_events() + await self._issue_unit_added_events() + await self._issue_building_events() + await self._issue_upgrade_events() + await self._issue_vision_events() + + @final + async def _issue_unit_added_events(self): + pass + # for unit in self.units: + # if unit.tag not in self._units_previous_map and unit.tag not in self._unit_tags_seen_this_game: + # self._unit_tags_seen_this_game.add(unit.tag) + # self._units_created[unit.type_id] += 1 + # await self.on_unit_created(unit) + # elif unit.tag in self._units_previous_map: + # previous_frame_unit: Unit = self._units_previous_map[unit.tag] + # # Check if a unit took damage this frame and then trigger event + # if unit.health < previous_frame_unit.health or unit.shield < previous_frame_unit.shield: + # damage_amount = previous_frame_unit.health - unit.health + previous_frame_unit.shield - unit.shield + # await self.on_unit_took_damage(unit, damage_amount) + # # Check if a unit type has changed + # if previous_frame_unit.type_id != unit.type_id: + # await self.on_unit_type_changed(unit, previous_frame_unit.type_id) + + @final + async def _issue_upgrade_events(self): + pass + # difference = self.state.upgrades - self._previous_upgrades + # for upgrade_completed in difference: + # await self.on_upgrade_complete(upgrade_completed) + # self._previous_upgrades = self.state.upgrades + + @final + async def _issue_building_events(self): + pass + # for structure in self.structures: + # if structure.tag not in self._structures_previous_map: + # if structure.build_progress < 1: + # await self.on_building_construction_started(structure) + # else: + # # Include starting townhall + # self._units_created[structure.type_id] += 1 + # await self.on_building_construction_complete(structure) + # elif structure.tag in self._structures_previous_map: + # # Check if a structure took damage this frame and then trigger event + # previous_frame_structure: Unit = self._structures_previous_map[structure.tag] + # if ( + # structure.health < previous_frame_structure.health + # or structure.shield < previous_frame_structure.shield + # ): + # damage_amount = ( + # previous_frame_structure.health - structure.health + previous_frame_structure.shield - + # structure.shield + # ) + # await self.on_unit_took_damage(structure, damage_amount) + # # Check if a structure changed its type + # if previous_frame_structure.type_id != structure.type_id: + # await self.on_unit_type_changed(structure, previous_frame_structure.type_id) + # # Check if structure completed + # if structure.build_progress == 1 and previous_frame_structure.build_progress < 1: + # self._units_created[structure.type_id] += 1 + # await self.on_building_construction_complete(structure) + + @final + async def _issue_vision_events(self): + pass + # # Call events for enemy unit entered vision + # for enemy_unit in self.enemy_units: + # if enemy_unit.tag not in self._enemy_units_previous_map: + # await self.on_enemy_unit_entered_vision(enemy_unit) + # for enemy_structure in self.enemy_structures: + # if enemy_structure.tag not in self._enemy_structures_previous_map: + # await self.on_enemy_unit_entered_vision(enemy_structure) + + # # Call events for enemy unit left vision + # enemy_units_left_vision: Set[int] = set(self._enemy_units_previous_map) - self.enemy_units.tags + # for enemy_unit_tag in enemy_units_left_vision: + # await self.on_enemy_unit_left_vision(enemy_unit_tag) + # enemy_structures_left_vision: Set[int] = (set(self._enemy_structures_previous_map) - self.enemy_structures.tags) + # for enemy_structure_tag in enemy_structures_left_vision: + # await self.on_enemy_unit_left_vision(enemy_structure_tag) + + @final + async def _issue_unit_dead_events(self): + pass + # for unit_tag in self.state.dead_units & set(self._all_units_previous_map): + # await self.on_unit_destroyed(unit_tag) + + # DISTANCE CALCULATION + + @final + @property + def _units_count(self) -> int: + return len(self.all_units) + + # Helper functions + + @final + def square_to_condensed(self, i, j) -> int: + # Converts indices of a square matrix to condensed matrix + # https://stackoverflow.com/a/36867493/10882657 + assert i != j, "No diagonal elements in condensed matrix! Diagonal elements are zero" + if i < j: + i, j = j, i + return self._units_count * j - j * (j + 1) // 2 + i - 1 - j + + # Fast and simple calculation functions + + @final + @staticmethod + def distance_math_hypot( + p1: Union[Tuple[float, float], Point2], + p2: Union[Tuple[float, float], Point2], + ) -> float: + return math.hypot(p1[0] - p2[0], p1[1] - p2[1]) + + @final + @staticmethod + def distance_math_hypot_squared( + p1: Union[Tuple[float, float], Point2], + p2: Union[Tuple[float, float], Point2], + ) -> float: + return pow(p1[0] - p2[0], 2) + pow(p1[1] - p2[1], 2) + + @final + def _distance_squared_unit_to_unit_method0(self, unit1: Unit, unit2: Unit) -> float: + return self.distance_math_hypot_squared(unit1.position_tuple, unit2.position_tuple) + + # Distance calculation using the pre-calculated matrix above + + @final + def _distance_squared_unit_to_unit_method1(self, unit1: Unit, unit2: Unit) -> float: + # If checked on units if they have the same tag, return distance 0 as these are not in the 1 dimensional pdist array - would result in an error otherwise + if unit1.tag == unit2.tag: + return 0 + # Calculate index, needs to be after pdist has been calculated and cached + condensed_index = self.square_to_condensed(unit1.distance_calculation_index, unit2.distance_calculation_index) + assert condensed_index < len( + self._cached_pdist + ), f"Condensed index is larger than amount of calculated distances: {condensed_index} < {len(self._cached_pdist)}, units that caused the assert error: {unit1} and {unit2}" + distance = self._pdist[condensed_index] + return distance + + @final + def _distance_squared_unit_to_unit_method2(self, unit1: Unit, unit2: Unit) -> float: + # Calculate index, needs to be after cdist has been calculated and cached + return self._cdist[unit1.distance_calculation_index, unit2.distance_calculation_index] + + # Distance calculation using the fastest distance calculation functions + + @final + def _distance_pos_to_pos( + self, + pos1: Union[Tuple[float, float], Point2], + pos2: Union[Tuple[float, float], Point2], + ) -> float: + return self.distance_math_hypot(pos1, pos2) + + @final + def _distance_units_to_pos( + self, + units: Units, + pos: Union[Tuple[float, float], Point2], + ) -> Generator[float, None, None]: + """ This function does not scale well, if len(units) > 100 it gets fairly slow """ + return (self.distance_math_hypot(u.position_tuple, pos) for u in units) + + @final + def _distance_unit_to_points( + self, + unit: Unit, + points: Iterable[Tuple[float, float]], + ) -> Generator[float, None, None]: + """ This function does not scale well, if len(points) > 100 it gets fairly slow """ + pos = unit.position_tuple + return (self.distance_math_hypot(p, pos) for p in points) diff --git a/worlds/_sc2common/bot/cache.py b/worlds/_sc2common/bot/cache.py new file mode 100644 index 000000000000..0aa460ba8088 --- /dev/null +++ b/worlds/_sc2common/bot/cache.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable, Hashable, TypeVar + +if TYPE_CHECKING: + from .bot_ai import BotAI + +T = TypeVar("T") + + +class CacheDict(dict): + + def retrieve_and_set(self, key: Hashable, func: Callable[[], T]) -> T: + """ Either return the value at a certain key, + or set the return value of a function to that key, then return that value. """ + if key not in self: + self[key] = func() + return self[key] + + +class property_cache_once_per_frame(property): + """This decorator caches the return value for one game loop, + then clears it if it is accessed in a different game loop. + Only works on properties of the bot object, because it requires + access to self.state.game_loop + + This decorator compared to the above runs a little faster, however you should only use this decorator if you are sure that you do not modify the mutable once it is calculated and cached. + + Copied and modified from https://tedboy.github.io/flask/_modules/werkzeug/utils.html#cached_property + # """ + + def __init__(self, func: Callable[[BotAI], T], name=None): + # pylint: disable=W0231 + self.__name__ = name or func.__name__ + self.__frame__ = f"__frame__{self.__name__}" + self.func = func + + def __set__(self, obj: BotAI, value: T): + obj.cache[self.__name__] = value + obj.cache[self.__frame__] = obj.state.game_loop + + def __get__(self, obj: BotAI, _type=None) -> T: + value = obj.cache.get(self.__name__, None) + bot_frame = obj.state.game_loop + if value is None or obj.cache[self.__frame__] < bot_frame: + value = self.func(obj) + obj.cache[self.__name__] = value + obj.cache[self.__frame__] = bot_frame + return value diff --git a/worlds/_sc2common/bot/client.py b/worlds/_sc2common/bot/client.py new file mode 100644 index 000000000000..a902c99d5594 --- /dev/null +++ b/worlds/_sc2common/bot/client.py @@ -0,0 +1,720 @@ +from __future__ import annotations + +from typing import Dict, Iterable, List, Optional, Set, Tuple, Union + +from worlds._sc2common.bot import logger + +from s2clientprotocol import debug_pb2 as debug_pb +from s2clientprotocol import query_pb2 as query_pb +from s2clientprotocol import raw_pb2 as raw_pb +from s2clientprotocol import sc2api_pb2 as sc_pb +from s2clientprotocol import spatial_pb2 as spatial_pb + +from .data import ActionResult, ChatChannel, Race, Result, Status +from .game_data import AbilityData, GameData +from .game_info import GameInfo +from .position import Point2, Point3 +from .protocol import ConnectionAlreadyClosed, Protocol, ProtocolError +from .renderer import Renderer +from .unit import Unit +from .units import Units + + +# pylint: disable=R0904 +class Client(Protocol): + + def __init__(self, ws, save_replay_path: str = None): + """ + :param ws: + """ + super().__init__(ws) + # How many frames will be waited between iterations before the next one is called + self.game_step: int = 4 + self.save_replay_path: Optional[str] = save_replay_path + self._player_id = None + self._game_result = None + # Store a hash value of all the debug requests to prevent sending the same ones again if they haven't changed last frame + self._debug_hash_tuple_last_iteration: Tuple[int, int, int, int] = (0, 0, 0, 0) + self._debug_draw_last_frame = False + self._debug_texts = [] + self._debug_lines = [] + self._debug_boxes = [] + self._debug_spheres = [] + + self._renderer = None + self.raw_affects_selection = False + + @property + def in_game(self) -> bool: + return self._status in {Status.in_game, Status.in_replay} + + async def join_game(self, name=None, race=None, observed_player_id=None, portconfig=None, rgb_render_config=None): + ifopts = sc_pb.InterfaceOptions( + raw=True, + score=True, + show_cloaked=True, + show_burrowed_shadows=True, + raw_affects_selection=self.raw_affects_selection, + raw_crop_to_playable_area=False, + show_placeholders=True, + ) + + if rgb_render_config: + assert isinstance(rgb_render_config, dict) + assert "window_size" in rgb_render_config and "minimap_size" in rgb_render_config + window_size = rgb_render_config["window_size"] + minimap_size = rgb_render_config["minimap_size"] + self._renderer = Renderer(self, window_size, minimap_size) + map_width, map_height = window_size + minimap_width, minimap_height = minimap_size + + ifopts.render.resolution.x = map_width + ifopts.render.resolution.y = map_height + ifopts.render.minimap_resolution.x = minimap_width + ifopts.render.minimap_resolution.y = minimap_height + + if race is None: + assert isinstance(observed_player_id, int), f"observed_player_id is of type {type(observed_player_id)}" + # join as observer + req = sc_pb.RequestJoinGame(observed_player_id=observed_player_id, options=ifopts) + else: + assert isinstance(race, Race) + req = sc_pb.RequestJoinGame(race=race.value, options=ifopts) + + if portconfig: + req.server_ports.game_port = portconfig.server[0] + req.server_ports.base_port = portconfig.server[1] + + for ppc in portconfig.players: + p = req.client_ports.add() + p.game_port = ppc[0] + p.base_port = ppc[1] + + if name is not None: + assert isinstance(name, str), f"name is of type {type(name)}" + req.player_name = name + + result = await self._execute(join_game=req) + self._game_result = None + self._player_id = result.join_game.player_id + return result.join_game.player_id + + async def leave(self): + """ You can use 'await self.client.leave()' to surrender midst game. """ + is_resign = self._game_result is None + + if is_resign: + # For all clients that can leave, result of leaving the game either + # loss, or the client will ignore the result + self._game_result = {self._player_id: Result.Defeat} + + try: + if self.save_replay_path is not None: + await self.save_replay(self.save_replay_path) + self.save_replay_path = None + await self._execute(leave_game=sc_pb.RequestLeaveGame()) + except (ProtocolError, ConnectionAlreadyClosed): + if is_resign: + raise + + async def save_replay(self, path): + logger.debug("Requesting replay from server") + result = await self._execute(save_replay=sc_pb.RequestSaveReplay()) + with open(path, "wb") as f: + f.write(result.save_replay.data) + logger.info(f"Saved replay to {path}") + + async def observation(self, game_loop: int = None): + if game_loop is not None: + result = await self._execute(observation=sc_pb.RequestObservation(game_loop=game_loop)) + else: + result = await self._execute(observation=sc_pb.RequestObservation()) + assert result.HasField("observation") + + if not self.in_game or result.observation.player_result: + # Sometimes game ends one step before results are available + if not result.observation.player_result: + result = await self._execute(observation=sc_pb.RequestObservation()) + assert result.observation.player_result + + player_id_to_result = {} + for pr in result.observation.player_result: + player_id_to_result[pr.player_id] = Result(pr.result) + self._game_result = player_id_to_result + self._game_result = None + + # if render_data is available, then RGB rendering was requested + if self._renderer and result.observation.observation.HasField("render_data"): + await self._renderer.render(result.observation) + + return result + + async def step(self, step_size: int = None): + """ EXPERIMENTAL: Change self._client.game_step during the step function to increase or decrease steps per second """ + step_size = step_size or self.game_step + return await self._execute(step=sc_pb.RequestStep(count=step_size)) + + async def get_game_data(self) -> GameData: + result = await self._execute( + data=sc_pb.RequestData(ability_id=True, unit_type_id=True, upgrade_id=True, buff_id=True, effect_id=True) + ) + return GameData(result.data) + + async def dump_data(self, ability_id=True, unit_type_id=True, upgrade_id=True, buff_id=True, effect_id=True): + """ + Dump the game data files + choose what data to dump in the keywords + this function writes to a text file + call it one time in on_step with: + await self._client.dump_data() + """ + result = await self._execute( + data=sc_pb.RequestData( + ability_id=ability_id, + unit_type_id=unit_type_id, + upgrade_id=upgrade_id, + buff_id=buff_id, + effect_id=effect_id, + ) + ) + with open("data_dump.txt", "a") as file: + file.write(str(result.data)) + + async def get_game_info(self) -> GameInfo: + result = await self._execute(game_info=sc_pb.RequestGameInfo()) + return GameInfo(result.game_info) + + async def query_pathing(self, start: Union[Unit, Point2, Point3], + end: Union[Point2, Point3]) -> Optional[Union[int, float]]: + """Caution: returns "None" when path not found + Try to combine queries with the function below because the pathing query is generally slow. + + :param start: + :param end:""" + assert isinstance(start, (Point2, Unit)) + assert isinstance(end, Point2) + if isinstance(start, Point2): + path = [query_pb.RequestQueryPathing(start_pos=start.as_Point2D, end_pos=end.as_Point2D)] + else: + path = [query_pb.RequestQueryPathing(unit_tag=start.tag, end_pos=end.as_Point2D)] + result = await self._execute(query=query_pb.RequestQuery(pathing=path)) + distance = float(result.query.pathing[0].distance) + if distance <= 0.0: + return None + return distance + + async def query_pathings(self, zipped_list: List[List[Union[Unit, Point2, Point3]]]) -> List[float]: + """Usage: await self.query_pathings([[unit1, target2], [unit2, target2]]) + -> returns [distance1, distance2] + Caution: returns 0 when path not found + + :param zipped_list: + """ + assert zipped_list, "No zipped_list" + assert isinstance(zipped_list, list), f"{type(zipped_list)}" + assert isinstance(zipped_list[0], list), f"{type(zipped_list[0])}" + assert len(zipped_list[0]) == 2, f"{len(zipped_list[0])}" + assert isinstance(zipped_list[0][0], (Point2, Unit)), f"{type(zipped_list[0][0])}" + assert isinstance(zipped_list[0][1], Point2), f"{type(zipped_list[0][1])}" + if isinstance(zipped_list[0][0], Point2): + path = ( + query_pb.RequestQueryPathing(start_pos=p1.as_Point2D, end_pos=p2.as_Point2D) for p1, p2 in zipped_list + ) + else: + path = (query_pb.RequestQueryPathing(unit_tag=p1.tag, end_pos=p2.as_Point2D) for p1, p2 in zipped_list) + results = await self._execute(query=query_pb.RequestQuery(pathing=path)) + return [float(d.distance) for d in results.query.pathing] + + async def query_building_placement( + self, + ability: AbilityData, + positions: List[Union[Point2, Point3]], + ignore_resources: bool = True + ) -> List[ActionResult]: + """This function might be deleted in favor of the function above (_query_building_placement_fast). + + :param ability: + :param positions: + :param ignore_resources:""" + assert isinstance(ability, AbilityData) + result = await self._execute( + query=query_pb.RequestQuery( + placements=( + query_pb.RequestQueryBuildingPlacement(ability_id=ability.id.value, target_pos=position.as_Point2D) + for position in positions + ), + ignore_resource_requirements=ignore_resources, + ) + ) + # Unnecessary converting to ActionResult? + return [ActionResult(p.result) for p in result.query.placements] + + async def chat_send(self, message: str, team_only: bool): + """ Writes a message to the chat """ + ch = ChatChannel.Team if team_only else ChatChannel.Broadcast + await self._execute( + action=sc_pb.RequestAction( + actions=[sc_pb.Action(action_chat=sc_pb.ActionChat(channel=ch.value, message=message))] + ) + ) + + async def debug_kill_unit(self, unit_tags: Union[Unit, Units, List[int], Set[int]]): + """ + :param unit_tags: + """ + if isinstance(unit_tags, Units): + unit_tags = unit_tags.tags + if isinstance(unit_tags, Unit): + unit_tags = [unit_tags.tag] + assert unit_tags + + await self._execute( + debug=sc_pb.RequestDebug(debug=[debug_pb.DebugCommand(kill_unit=debug_pb.DebugKillUnit(tag=unit_tags))]) + ) + + async def move_camera(self, position: Union[Unit, Units, Point2, Point3]): + """Moves camera to the target position + + :param position:""" + assert isinstance(position, (Unit, Units, Point2, Point3)) + if isinstance(position, Units): + position = position.center + if isinstance(position, Unit): + position = position.position + await self._execute( + action=sc_pb.RequestAction( + actions=[ + sc_pb.Action( + action_raw=raw_pb.ActionRaw( + camera_move=raw_pb.ActionRawCameraMove(center_world_space=position.to3.as_Point) + ) + ) + ] + ) + ) + + async def obs_move_camera(self, position: Union[Unit, Units, Point2, Point3]): + """Moves observer camera to the target position. Only works when observing (e.g. watching the replay). + + :param position:""" + assert isinstance(position, (Unit, Units, Point2, Point3)) + if isinstance(position, Units): + position = position.center + if isinstance(position, Unit): + position = position.position + await self._execute( + obs_action=sc_pb.RequestObserverAction( + actions=[ + sc_pb.ObserverAction(camera_move=sc_pb.ActionObserverCameraMove(world_pos=position.as_Point2D)) + ] + ) + ) + + async def move_camera_spatial(self, position: Union[Point2, Point3]): + """Moves camera to the target position using the spatial aciton interface + + :param position:""" + assert isinstance(position, (Point2, Point3)) + action = sc_pb.Action( + action_render=spatial_pb.ActionSpatial( + camera_move=spatial_pb.ActionSpatialCameraMove(center_minimap=position.as_PointI) + ) + ) + await self._execute(action=sc_pb.RequestAction(actions=[action])) + + def debug_text_simple(self, text: str): + """ Draws a text in the top left corner of the screen (up to a max of 6 messages fit there). """ + self._debug_texts.append(DrawItemScreenText(text=text, color=None, start_point=Point2((0, 0)), font_size=8)) + + def debug_text_screen( + self, + text: str, + pos: Union[Point2, Point3, tuple, list], + color: Union[tuple, list, Point3] = None, + size: int = 8, + ): + """ + Draws a text on the screen (monitor / game window) with coordinates 0 <= x, y <= 1. + + :param text: + :param pos: + :param color: + :param size: + """ + assert len(pos) >= 2 + assert 0 <= pos[0] <= 1 + assert 0 <= pos[1] <= 1 + pos = Point2((pos[0], pos[1])) + self._debug_texts.append(DrawItemScreenText(text=text, color=color, start_point=pos, font_size=size)) + + def debug_text_2d( + self, + text: str, + pos: Union[Point2, Point3, tuple, list], + color: Union[tuple, list, Point3] = None, + size: int = 8, + ): + return self.debug_text_screen(text, pos, color, size) + + def debug_text_world( + self, text: str, pos: Union[Unit, Point3], color: Union[tuple, list, Point3] = None, size: int = 8 + ): + """ + Draws a text at Point3 position in the game world. + To grab a unit's 3d position, use unit.position3d + Usually the Z value of a Point3 is between 8 and 14 (except for flying units). Use self.get_terrain_z_height() from bot_ai.py to get the Z value (height) of the terrain at a 2D position. + + :param text: + :param color: + :param size: + """ + if isinstance(pos, Unit): + pos = pos.position3d + assert isinstance(pos, Point3) + self._debug_texts.append(DrawItemWorldText(text=text, color=color, start_point=pos, font_size=size)) + + def debug_text_3d( + self, text: str, pos: Union[Unit, Point3], color: Union[tuple, list, Point3] = None, size: int = 8 + ): + return self.debug_text_world(text, pos, color, size) + + def debug_line_out( + self, p0: Union[Unit, Point3], p1: Union[Unit, Point3], color: Union[tuple, list, Point3] = None + ): + """ + Draws a line from p0 to p1. + + :param p0: + :param p1: + :param color: + """ + if isinstance(p0, Unit): + p0 = p0.position3d + assert isinstance(p0, Point3) + if isinstance(p1, Unit): + p1 = p1.position3d + assert isinstance(p1, Point3) + self._debug_lines.append(DrawItemLine(color=color, start_point=p0, end_point=p1)) + + def debug_box_out( + self, + p_min: Union[Unit, Point3], + p_max: Union[Unit, Point3], + color: Union[tuple, list, Point3] = None, + ): + """ + Draws a box with p_min and p_max as corners of the box. + + :param p_min: + :param p_max: + :param color: + """ + if isinstance(p_min, Unit): + p_min = p_min.position3d + assert isinstance(p_min, Point3) + if isinstance(p_max, Unit): + p_max = p_max.position3d + assert isinstance(p_max, Point3) + self._debug_boxes.append(DrawItemBox(start_point=p_min, end_point=p_max, color=color)) + + def debug_box2_out( + self, + pos: Union[Unit, Point3], + half_vertex_length: float = 0.25, + color: Union[tuple, list, Point3] = None, + ): + """ + Draws a box center at a position 'pos', with box side lengths (vertices) of two times 'half_vertex_length'. + + :param pos: + :param half_vertex_length: + :param color: + """ + if isinstance(pos, Unit): + pos = pos.position3d + assert isinstance(pos, Point3) + p0 = pos + Point3((-half_vertex_length, -half_vertex_length, -half_vertex_length)) + p1 = pos + Point3((half_vertex_length, half_vertex_length, half_vertex_length)) + self._debug_boxes.append(DrawItemBox(start_point=p0, end_point=p1, color=color)) + + def debug_sphere_out(self, p: Union[Unit, Point3], r: float, color: Union[tuple, list, Point3] = None): + """ + Draws a sphere at point p with radius r. + + :param p: + :param r: + :param color: + """ + if isinstance(p, Unit): + p = p.position3d + assert isinstance(p, Point3) + self._debug_spheres.append(DrawItemSphere(start_point=p, radius=r, color=color)) + + async def _send_debug(self): + """Sends the debug draw execution. This is run by main.py now automatically, if there is any items in the list. You do not need to run this manually any longer. + Check examples/terran/ramp_wall.py for example drawing. Each draw request needs to be sent again in every single on_step iteration. + """ + debug_hash = ( + sum(hash(item) for item in self._debug_texts), + sum(hash(item) for item in self._debug_lines), + sum(hash(item) for item in self._debug_boxes), + sum(hash(item) for item in self._debug_spheres), + ) + if debug_hash != (0, 0, 0, 0): + if debug_hash != self._debug_hash_tuple_last_iteration: + # Something has changed, either more or less is to be drawn, or a position of a drawing changed (e.g. when drawing on a moving unit) + self._debug_hash_tuple_last_iteration = debug_hash + try: + await self._execute( + debug=sc_pb.RequestDebug( + debug=[ + debug_pb.DebugCommand( + draw=debug_pb.DebugDraw( + text=[text.to_proto() + for text in self._debug_texts] if self._debug_texts else None, + lines=[line.to_proto() + for line in self._debug_lines] if self._debug_lines else None, + boxes=[box.to_proto() + for box in self._debug_boxes] if self._debug_boxes else None, + spheres=[sphere.to_proto() + for sphere in self._debug_spheres] if self._debug_spheres else None, + ) + ) + ] + ) + ) + except ProtocolError: + return + self._debug_draw_last_frame = True + self._debug_texts.clear() + self._debug_lines.clear() + self._debug_boxes.clear() + self._debug_spheres.clear() + elif self._debug_draw_last_frame: + # Clear drawing if we drew last frame but nothing to draw this frame + self._debug_hash_tuple_last_iteration = (0, 0, 0, 0) + await self._execute( + debug=sc_pb.RequestDebug( + debug=[ + debug_pb.DebugCommand(draw=debug_pb.DebugDraw(text=None, lines=None, boxes=None, spheres=None)) + ] + ) + ) + self._debug_draw_last_frame = False + + async def debug_leave(self): + await self._execute(debug=sc_pb.RequestDebug(debug=[debug_pb.DebugCommand(end_game=debug_pb.DebugEndGame())])) + + async def debug_set_unit_value(self, unit_tags: Union[Iterable[int], Units, Unit], unit_value: int, value: float): + """Sets a "unit value" (Energy, Life or Shields) of the given units to the given value. + Can't set the life of a unit to 0, use "debug_kill_unit" for that. Also can't set the life above the unit's maximum. + The following example sets the health of all your workers to 1: + await self.debug_set_unit_value(self.workers, 2, value=1)""" + if isinstance(unit_tags, Units): + unit_tags = unit_tags.tags + if isinstance(unit_tags, Unit): + unit_tags = [unit_tags.tag] + assert hasattr( + unit_tags, "__iter__" + ), f"unit_tags argument needs to be an iterable (list, dict, set, Units), given argument is {type(unit_tags).__name__}" + assert ( + 1 <= unit_value <= 3 + ), f"unit_value needs to be between 1 and 3 (1 for energy, 2 for life, 3 for shields), given argument is {unit_value}" + assert all(tag > 0 for tag in unit_tags), f"Unit tags have invalid value: {unit_tags}" + assert isinstance(value, (int, float)), "Value needs to be of type int or float" + assert value >= 0, "Value can't be negative" + await self._execute( + debug=sc_pb.RequestDebug( + debug=( + debug_pb.DebugCommand( + unit_value=debug_pb. + DebugSetUnitValue(unit_value=unit_value, value=float(value), unit_tag=unit_tag) + ) for unit_tag in unit_tags + ) + ) + ) + + async def debug_hang(self, delay_in_seconds: float): + """ Freezes the SC2 client. Not recommended to be used. """ + delay_in_ms = int(round(delay_in_seconds * 1000)) + await self._execute( + debug=sc_pb.RequestDebug( + debug=[debug_pb.DebugCommand(test_process=debug_pb.DebugTestProcess(test=1, delay_ms=delay_in_ms))] + ) + ) + + async def debug_show_map(self): + """ Reveals the whole map for the bot. Using it a second time disables it again. """ + await self._execute(debug=sc_pb.RequestDebug(debug=[debug_pb.DebugCommand(game_state=1)])) + + async def debug_control_enemy(self): + """ Allows control over enemy units and structures similar to team games control - does not allow the bot to spend the opponent's ressources. Using it a second time disables it again. """ + await self._execute(debug=sc_pb.RequestDebug(debug=[debug_pb.DebugCommand(game_state=2)])) + + async def debug_food(self): + """ Should disable food usage (does not seem to work?). Using it a second time disables it again. """ + await self._execute(debug=sc_pb.RequestDebug(debug=[debug_pb.DebugCommand(game_state=3)])) + + async def debug_free(self): + """ Units, structures and upgrades are free of mineral and gas cost. Using it a second time disables it again. """ + await self._execute(debug=sc_pb.RequestDebug(debug=[debug_pb.DebugCommand(game_state=4)])) + + async def debug_all_resources(self): + """ Gives 5000 minerals and 5000 vespene to the bot. """ + await self._execute(debug=sc_pb.RequestDebug(debug=[debug_pb.DebugCommand(game_state=5)])) + + async def debug_god(self): + """ Your units and structures no longer take any damage. Using it a second time disables it again. """ + await self._execute(debug=sc_pb.RequestDebug(debug=[debug_pb.DebugCommand(game_state=6)])) + + async def debug_minerals(self): + """ Gives 5000 minerals to the bot. """ + await self._execute(debug=sc_pb.RequestDebug(debug=[debug_pb.DebugCommand(game_state=7)])) + + async def debug_gas(self): + """ Gives 5000 vespene to the bot. This does not seem to be working. """ + await self._execute(debug=sc_pb.RequestDebug(debug=[debug_pb.DebugCommand(game_state=8)])) + + async def debug_cooldown(self): + """ Disables cooldowns of unit abilities for the bot. Using it a second time disables it again. """ + await self._execute(debug=sc_pb.RequestDebug(debug=[debug_pb.DebugCommand(game_state=9)])) + + async def debug_tech_tree(self): + """ Removes all tech requirements (e.g. can build a factory without having a barracks). Using it a second time disables it again. """ + await self._execute(debug=sc_pb.RequestDebug(debug=[debug_pb.DebugCommand(game_state=10)])) + + async def debug_upgrade(self): + """ Researches all currently available upgrades. E.g. using it once unlocks combat shield, stimpack and 1-1. Using it a second time unlocks 2-2 and all other upgrades stay researched. """ + await self._execute(debug=sc_pb.RequestDebug(debug=[debug_pb.DebugCommand(game_state=11)])) + + async def debug_fast_build(self): + """ Sets the build time of units and structures and upgrades to zero. Using it a second time disables it again. """ + await self._execute(debug=sc_pb.RequestDebug(debug=[debug_pb.DebugCommand(game_state=12)])) + + async def quick_save(self): + """Saves the current game state to an in-memory bookmark. + See: https://github.com/Blizzard/s2client-proto/blob/eeaf5efaea2259d7b70247211dff98da0a2685a2/s2clientprotocol/sc2api.proto#L93""" + await self._execute(quick_save=sc_pb.RequestQuickSave()) + + async def quick_load(self): + """Loads the game state from the previously stored in-memory bookmark. + Caution: + - The SC2 Client will crash if the game wasn't quicksaved + - The bot step iteration counter will not reset + - self.state.game_loop will be set to zero after the quickload, and self.time is dependant on it""" + await self._execute(quick_load=sc_pb.RequestQuickLoad()) + + +class DrawItem: + + @staticmethod + def to_debug_color(color: Union[tuple, Point3]): + """ Helper function for color conversion """ + if color is None: + return debug_pb.Color(r=255, g=255, b=255) + # Need to check if not of type Point3 because Point3 inherits from tuple + if isinstance(color, (tuple, list)) and not isinstance(color, Point3) and len(color) == 3: + return debug_pb.Color(r=color[0], g=color[1], b=color[2]) + # In case color is of type Point3 + r = getattr(color, "r", getattr(color, "x", 255)) + g = getattr(color, "g", getattr(color, "y", 255)) + b = getattr(color, "b", getattr(color, "z", 255)) + if max(r, g, b) <= 1: + r *= 255 + g *= 255 + b *= 255 + + return debug_pb.Color(r=int(r), g=int(g), b=int(b)) + + +class DrawItemScreenText(DrawItem): + + def __init__(self, start_point: Point2 = None, color: Point3 = None, text: str = "", font_size: int = 8): + self._start_point: Point2 = start_point + self._color: Point3 = color + self._text: str = text + self._font_size: int = font_size + + def to_proto(self): + return debug_pb.DebugText( + color=self.to_debug_color(self._color), + text=self._text, + virtual_pos=self._start_point.to3.as_Point, + world_pos=None, + size=self._font_size, + ) + + def __hash__(self): + return hash((self._start_point, self._color, self._text, self._font_size)) + + +class DrawItemWorldText(DrawItem): + + def __init__(self, start_point: Point3 = None, color: Point3 = None, text: str = "", font_size: int = 8): + self._start_point: Point3 = start_point + self._color: Point3 = color + self._text: str = text + self._font_size: int = font_size + + def to_proto(self): + return debug_pb.DebugText( + color=self.to_debug_color(self._color), + text=self._text, + virtual_pos=None, + world_pos=self._start_point.as_Point, + size=self._font_size, + ) + + def __hash__(self): + return hash((self._start_point, self._text, self._font_size, self._color)) + + +class DrawItemLine(DrawItem): + + def __init__(self, start_point: Point3 = None, end_point: Point3 = None, color: Point3 = None): + self._start_point: Point3 = start_point + self._end_point: Point3 = end_point + self._color: Point3 = color + + def to_proto(self): + return debug_pb.DebugLine( + line=debug_pb.Line(p0=self._start_point.as_Point, p1=self._end_point.as_Point), + color=self.to_debug_color(self._color), + ) + + def __hash__(self): + return hash((self._start_point, self._end_point, self._color)) + + +class DrawItemBox(DrawItem): + + def __init__(self, start_point: Point3 = None, end_point: Point3 = None, color: Point3 = None): + self._start_point: Point3 = start_point + self._end_point: Point3 = end_point + self._color: Point3 = color + + def to_proto(self): + return debug_pb.DebugBox( + min=self._start_point.as_Point, + max=self._end_point.as_Point, + color=self.to_debug_color(self._color), + ) + + def __hash__(self): + return hash((self._start_point, self._end_point, self._color)) + + +class DrawItemSphere(DrawItem): + + def __init__(self, start_point: Point3 = None, radius: float = None, color: Point3 = None): + self._start_point: Point3 = start_point + self._radius: float = radius + self._color: Point3 = color + + def to_proto(self): + return debug_pb.DebugSphere( + p=self._start_point.as_Point, r=self._radius, color=self.to_debug_color(self._color) + ) + + def __hash__(self): + return hash((self._start_point, self._radius, self._color)) diff --git a/worlds/_sc2common/bot/constants.py b/worlds/_sc2common/bot/constants.py new file mode 100644 index 000000000000..2ff14c271e8a --- /dev/null +++ b/worlds/_sc2common/bot/constants.py @@ -0,0 +1,30 @@ +from typing import Set + +from .data import Alliance, Attribute, CloakState, DisplayType, TargetType + +IS_STRUCTURE: int = Attribute.Structure.value +IS_LIGHT: int = Attribute.Light.value +IS_ARMORED: int = Attribute.Armored.value +IS_BIOLOGICAL: int = Attribute.Biological.value +IS_MECHANICAL: int = Attribute.Mechanical.value +IS_MASSIVE: int = Attribute.Massive.value +IS_PSIONIC: int = Attribute.Psionic.value +TARGET_GROUND: Set[int] = {TargetType.Ground.value, TargetType.Any.value} +TARGET_AIR: Set[int] = {TargetType.Air.value, TargetType.Any.value} +TARGET_BOTH = TARGET_GROUND | TARGET_AIR +IS_SNAPSHOT = DisplayType.Snapshot.value +IS_VISIBLE = DisplayType.Visible.value +IS_PLACEHOLDER = DisplayType.Placeholder.value +IS_MINE = Alliance.Self.value +IS_ENEMY = Alliance.Enemy.value +IS_CLOAKED: Set[int] = {CloakState.Cloaked.value, CloakState.CloakedDetected.value, CloakState.CloakedAllied.value} +IS_REVEALED: int = CloakState.CloakedDetected.value +CAN_BE_ATTACKED: Set[int] = {CloakState.NotCloaked.value, CloakState.CloakedDetected.value} + +TARGET_HELPER = { + 1: "no target", + 2: "Point2", + 3: "Unit", + 4: "Point2 or Unit", + 5: "Point2 or no target", +} diff --git a/worlds/_sc2common/bot/controller.py b/worlds/_sc2common/bot/controller.py new file mode 100644 index 000000000000..abb26ef8a9e5 --- /dev/null +++ b/worlds/_sc2common/bot/controller.py @@ -0,0 +1,80 @@ +import platform +from pathlib import Path + +from worlds._sc2common.bot import logger +from s2clientprotocol import sc2api_pb2 as sc_pb + +from .player import Computer +from .protocol import Protocol + + +class Controller(Protocol): + + def __init__(self, ws, process): + super().__init__(ws) + self._process = process + + @property + def running(self): + # pylint: disable=W0212 + return self._process._process is not None + + async def create_game(self, game_map, players, realtime: bool, random_seed=None, disable_fog=None): + req = sc_pb.RequestCreateGame( + local_map=sc_pb.LocalMap(map_path=str(game_map.relative_path)), realtime=realtime, disable_fog=disable_fog + ) + if random_seed is not None: + req.random_seed = random_seed + + for player in players: + p = req.player_setup.add() + p.type = player.type.value + if isinstance(player, Computer): + p.race = player.race.value + p.difficulty = player.difficulty.value + p.ai_build = player.ai_build.value + + logger.info("Creating new game") + logger.info(f"Map: {game_map.name}") + logger.info(f"Players: {', '.join(str(p) for p in players)}") + result = await self._execute(create_game=req) + return result + + async def request_available_maps(self): + req = sc_pb.RequestAvailableMaps() + result = await self._execute(available_maps=req) + return result + + async def request_save_map(self, download_path: str): + """ Not working on linux. """ + req = sc_pb.RequestSaveMap(map_path=download_path) + result = await self._execute(save_map=req) + return result + + async def request_replay_info(self, replay_path: str): + """ Not working on linux. """ + req = sc_pb.RequestReplayInfo(replay_path=replay_path, download_data=False) + result = await self._execute(replay_info=req) + return result + + async def start_replay(self, replay_path: str, realtime: bool, observed_id: int = 0): + ifopts = sc_pb.InterfaceOptions( + raw=True, score=True, show_cloaked=True, raw_affects_selection=True, raw_crop_to_playable_area=False + ) + if platform.system() == "Linux": + replay_name = Path(replay_path).name + home_replay_folder = Path.home() / "Documents" / "StarCraft II" / "Replays" + if str(home_replay_folder / replay_name) != replay_path: + logger.warning( + f"Linux detected, please put your replay in your home directory at {home_replay_folder}. It was detected at {replay_path}" + ) + raise FileNotFoundError + replay_path = replay_name + + req = sc_pb.RequestStartReplay( + replay_path=replay_path, observed_player_id=observed_id, realtime=realtime, options=ifopts + ) + + result = await self._execute(start_replay=req) + assert result.status == 4, f"{result.start_replay.error} - {result.start_replay.error_details}" + return result diff --git a/worlds/_sc2common/bot/data.py b/worlds/_sc2common/bot/data.py new file mode 100644 index 000000000000..4c9b3b94a046 --- /dev/null +++ b/worlds/_sc2common/bot/data.py @@ -0,0 +1,36 @@ +""" For the list of enums, see here + +https://github.com/Blizzard/s2client-api/blob/d9ba0a33d6ce9d233c2a4ee988360c188fbe9dbf/include/sc2api/sc2_gametypes.h +https://github.com/Blizzard/s2client-api/blob/d9ba0a33d6ce9d233c2a4ee988360c188fbe9dbf/include/sc2api/sc2_action.h +https://github.com/Blizzard/s2client-api/blob/d9ba0a33d6ce9d233c2a4ee988360c188fbe9dbf/include/sc2api/sc2_unit.h +https://github.com/Blizzard/s2client-api/blob/d9ba0a33d6ce9d233c2a4ee988360c188fbe9dbf/include/sc2api/sc2_data.h +""" +import enum + +from s2clientprotocol import common_pb2 as common_pb +from s2clientprotocol import data_pb2 as data_pb +from s2clientprotocol import error_pb2 as error_pb +from s2clientprotocol import raw_pb2 as raw_pb +from s2clientprotocol import sc2api_pb2 as sc_pb + +CreateGameError = enum.Enum("CreateGameError", sc_pb.ResponseCreateGame.Error.items()) + +PlayerType = enum.Enum("PlayerType", sc_pb.PlayerType.items()) +Difficulty = enum.Enum("Difficulty", sc_pb.Difficulty.items()) +AIBuild = enum.Enum("AIBuild", sc_pb.AIBuild.items()) +Status = enum.Enum("Status", sc_pb.Status.items()) +Result = enum.Enum("Result", sc_pb.Result.items()) +Alert = enum.Enum("Alert", sc_pb.Alert.items()) +ChatChannel = enum.Enum("ChatChannel", sc_pb.ActionChat.Channel.items()) + +Race = enum.Enum("Race", common_pb.Race.items()) + +DisplayType = enum.Enum("DisplayType", raw_pb.DisplayType.items()) +Alliance = enum.Enum("Alliance", raw_pb.Alliance.items()) +CloakState = enum.Enum("CloakState", raw_pb.CloakState.items()) + +Attribute = enum.Enum("Attribute", data_pb.Attribute.items()) +TargetType = enum.Enum("TargetType", data_pb.Weapon.TargetType.items()) +Target = enum.Enum("Target", data_pb.AbilityData.Target.items()) + +ActionResult = enum.Enum("ActionResult", error_pb.ActionResult.items()) diff --git a/worlds/_sc2common/bot/expiring_dict.py b/worlds/_sc2common/bot/expiring_dict.py new file mode 100644 index 000000000000..c60bb9e9640d --- /dev/null +++ b/worlds/_sc2common/bot/expiring_dict.py @@ -0,0 +1,158 @@ +from __future__ import annotations + +from collections import OrderedDict +from threading import RLock +from typing import TYPE_CHECKING, Any, Iterable, Union + +if TYPE_CHECKING: + from .bot_ai import BotAI + + +class ExpiringDict(OrderedDict): + """ + An expiring dict that uses the bot.state.game_loop to only return items that are valid for a specific amount of time. + + Example usages:: + + async def on_step(iteration: int): + # This dict will hold up to 10 items and only return values that have been added up to 20 frames ago + my_dict = ExpiringDict(self, max_age_frames=20) + if iteration == 0: + # Add item + my_dict["test"] = "something" + if iteration == 2: + # On default, one iteration is called every 8 frames + if "test" in my_dict: + print("test is in dict") + if iteration == 20: + if "test" not in my_dict: + print("test is not anymore in dict") + """ + + def __init__(self, bot: BotAI, max_age_frames: int = 1): + assert max_age_frames >= -1 + assert bot + + OrderedDict.__init__(self) + self.bot: BotAI = bot + self.max_age: Union[int, float] = max_age_frames + self.lock: RLock = RLock() + + @property + def frame(self) -> int: + return self.bot.state.game_loop + + def __contains__(self, key) -> bool: + """ Return True if dict has key, else False, e.g. 'key in dict' """ + with self.lock: + if OrderedDict.__contains__(self, key): + # Each item is a list of [value, frame time] + item = OrderedDict.__getitem__(self, key) + if self.frame - item[1] < self.max_age: + return True + del self[key] + return False + + def __getitem__(self, key, with_age=False) -> Any: + """ Return the item of the dict using d[key] """ + with self.lock: + # Each item is a list of [value, frame time] + item = OrderedDict.__getitem__(self, key) + if self.frame - item[1] < self.max_age: + if with_age: + return item[0], item[1] + return item[0] + OrderedDict.__delitem__(self, key) + raise KeyError(key) + + def __setitem__(self, key, value): + """ Set d[key] = value """ + with self.lock: + OrderedDict.__setitem__(self, key, (value, self.frame)) + + def __repr__(self): + """ Printable version of the dict instead of getting memory adress """ + print_list = [] + with self.lock: + for key, value in OrderedDict.items(self): + if self.frame - value[1] < self.max_age: + print_list.append(f"{repr(key)}: {repr(value)}") + print_str = ", ".join(print_list) + return f"ExpiringDict({print_str})" + + def __str__(self): + return self.__repr__() + + def __iter__(self): + """ Override 'for key in dict:' """ + with self.lock: + return self.keys() + + # TODO find a way to improve len + def __len__(self): + """Override len method as key value pairs aren't instantly being deleted, but only on __get__(item). + This function is slow because it has to check if each element is not expired yet.""" + with self.lock: + count = 0 + for _ in self.values(): + count += 1 + return count + + def pop(self, key, default=None, with_age=False): + """ Return the item and remove it """ + with self.lock: + if OrderedDict.__contains__(self, key): + item = OrderedDict.__getitem__(self, key) + if self.frame - item[1] < self.max_age: + del self[key] + if with_age: + return item[0], item[1] + return item[0] + del self[key] + if default is None: + raise KeyError(key) + if with_age: + return default, self.frame + return default + + def get(self, key, default=None, with_age=False): + """ Return the value for key if key is in dict, else default """ + with self.lock: + if OrderedDict.__contains__(self, key): + item = OrderedDict.__getitem__(self, key) + if self.frame - item[1] < self.max_age: + if with_age: + return item[0], item[1] + return item[0] + if default is None: + raise KeyError(key) + if with_age: + return default, self.frame + return None + return None + + def update(self, other_dict: dict): + with self.lock: + for key, value in other_dict.items(): + self[key] = value + + def items(self) -> Iterable: + """ Return iterator of zipped list [keys, values] """ + with self.lock: + for key, value in OrderedDict.items(self): + if self.frame - value[1] < self.max_age: + yield key, value[0] + + def keys(self) -> Iterable: + """ Return iterator of keys """ + with self.lock: + for key, value in OrderedDict.items(self): + if self.frame - value[1] < self.max_age: + yield key + + def values(self) -> Iterable: + """ Return iterator of values """ + with self.lock: + for value in OrderedDict.values(self): + if self.frame - value[1] < self.max_age: + yield value[0] diff --git a/worlds/_sc2common/bot/game_data.py b/worlds/_sc2common/bot/game_data.py new file mode 100644 index 000000000000..50f10bd6692e --- /dev/null +++ b/worlds/_sc2common/bot/game_data.py @@ -0,0 +1,209 @@ +# pylint: disable=W0212 +from __future__ import annotations + +from bisect import bisect_left +from dataclasses import dataclass +from functools import lru_cache +from typing import Dict, List, Optional, Union + +from .data import Attribute, Race + +# Set of parts of names of abilities that have no cost +# E.g every ability that has 'Hold' in its name is free +FREE_ABILITIES = {"Lower", "Raise", "Land", "Lift", "Hold", "Harvest"} + + +class GameData: + + def __init__(self, data): + """ + :param data: + """ + self.abilities: Dict[int, AbilityData] = {} + self.units: Dict[int, UnitTypeData] = {u.unit_id: UnitTypeData(self, u) for u in data.units if u.available} + self.upgrades: Dict[int, UpgradeData] = {u.upgrade_id: UpgradeData(self, u) for u in data.upgrades} + # Cached UnitTypeIds so that conversion does not take long. This needs to be moved elsewhere if a new GameData object is created multiple times per game + + +class AbilityData: + + @classmethod + def id_exists(cls, ability_id): + assert isinstance(ability_id, int), f"Wrong type: {ability_id} is not int" + if ability_id == 0: + return False + i = bisect_left(cls.ability_ids, ability_id) # quick binary search + return i != len(cls.ability_ids) and cls.ability_ids[i] == ability_id + + def __init__(self, game_data, proto): + self._game_data = game_data + self._proto = proto + + # What happens if we comment this out? Should this not be commented out? What is its purpose? + assert self.id != 0 + + def __repr__(self) -> str: + return f"AbilityData(name={self._proto.button_name})" + + @property + def link_name(self) -> str: + """ For Stimpack this returns 'BarracksTechLabResearch' """ + return self._proto.link_name + + @property + def button_name(self) -> str: + """ For Stimpack this returns 'Stimpack' """ + return self._proto.button_name + + @property + def friendly_name(self) -> str: + """ For Stimpack this returns 'Research Stimpack' """ + return self._proto.friendly_name + + @property + def is_free_morph(self) -> bool: + return any(free in self._proto.link_name for free in FREE_ABILITIES) + + @property + def cost(self) -> Cost: + return self._game_data.calculate_ability_cost(self.id) + + +class UnitTypeData: + + def __init__(self, game_data: GameData, proto): + """ + :param game_data: + :param proto: + """ + self._game_data = game_data + self._proto = proto + + def __repr__(self) -> str: + return f"UnitTypeData(name={self.name})" + + @property + def name(self) -> str: + return self._proto.name + + @property + def creation_ability(self) -> Optional[AbilityData]: + if self._proto.ability_id == 0: + return None + if self._proto.ability_id not in self._game_data.abilities: + return None + return self._game_data.abilities[self._proto.ability_id] + + @property + def footprint_radius(self) -> Optional[float]: + """ See unit.py footprint_radius """ + if self.creation_ability is None: + return None + return self.creation_ability._proto.footprint_radius + + @property + def attributes(self) -> List[Attribute]: + return self._proto.attributes + + def has_attribute(self, attr) -> bool: + assert isinstance(attr, Attribute) + return attr in self.attributes + + @property + def has_minerals(self) -> bool: + return self._proto.has_minerals + + @property + def has_vespene(self) -> bool: + return self._proto.has_vespene + + @property + def cargo_size(self) -> int: + """ How much cargo this unit uses up in cargo_space """ + return self._proto.cargo_size + + @property + def race(self) -> Race: + return Race(self._proto.race) + + @property + def cost(self) -> Cost: + return Cost(self._proto.mineral_cost, self._proto.vespene_cost, self._proto.build_time) + + @property + def cost_zerg_corrected(self) -> Cost: + """ This returns 25 for extractor and 200 for spawning pool instead of 75 and 250 respectively """ + if self.race == Race.Zerg and Attribute.Structure.value in self.attributes: + return Cost(self._proto.mineral_cost - 50, self._proto.vespene_cost, self._proto.build_time) + return self.cost + + +class UpgradeData: + + def __init__(self, game_data: GameData, proto): + """ + :param game_data: + :param proto: + """ + self._game_data = game_data + self._proto = proto + + def __repr__(self): + return f"UpgradeData({self.name} - research ability: {self.research_ability}, {self.cost})" + + @property + def name(self) -> str: + return self._proto.name + + @property + def research_ability(self) -> Optional[AbilityData]: + if self._proto.ability_id == 0: + return None + if self._proto.ability_id not in self._game_data.abilities: + return None + return self._game_data.abilities[self._proto.ability_id] + + @property + def cost(self) -> Cost: + return Cost(self._proto.mineral_cost, self._proto.vespene_cost, self._proto.research_time) + + +@dataclass +class Cost: + """ + The cost of an action, a structure, a unit or a research upgrade. + The time is given in frames (22.4 frames per game second). + """ + minerals: int + vespene: int + time: Optional[float] = None + + def __repr__(self) -> str: + return f"Cost({self.minerals}, {self.vespene})" + + def __eq__(self, other: Cost) -> bool: + return self.minerals == other.minerals and self.vespene == other.vespene + + def __ne__(self, other: Cost) -> bool: + return self.minerals != other.minerals or self.vespene != other.vespene + + def __bool__(self) -> bool: + return self.minerals != 0 or self.vespene != 0 + + def __add__(self, other) -> Cost: + if not other: + return self + if not self: + return other + time = (self.time or 0) + (other.time or 0) + return Cost(self.minerals + other.minerals, self.vespene + other.vespene, time=time) + + def __sub__(self, other: Cost) -> Cost: + time = (self.time or 0) + (other.time or 0) + return Cost(self.minerals - other.minerals, self.vespene - other.vespene, time=time) + + def __mul__(self, other: int) -> Cost: + return Cost(self.minerals * other, self.vespene * other, time=self.time) + + def __rmul__(self, other: int) -> Cost: + return Cost(self.minerals * other, self.vespene * other, time=self.time) diff --git a/worlds/_sc2common/bot/game_info.py b/worlds/_sc2common/bot/game_info.py new file mode 100644 index 000000000000..aef5d3d2482d --- /dev/null +++ b/worlds/_sc2common/bot/game_info.py @@ -0,0 +1,282 @@ +from __future__ import annotations + +import heapq +from collections import deque +from dataclasses import dataclass +from functools import cached_property +from typing import Deque, Dict, FrozenSet, Iterable, List, Optional, Set, Tuple + +from .pixel_map import PixelMap +from .player import Player, Race +from .position import Point2, Rect, Size + + +@dataclass +class Ramp: + points: FrozenSet[Point2] + game_info: GameInfo + + @property + def x_offset(self) -> float: + # Tested by printing actual building locations vs calculated depot positions + return 0.5 + + @property + def y_offset(self) -> float: + # Tested by printing actual building locations vs calculated depot positions + return 0.5 + + @cached_property + def _height_map(self): + return self.game_info.terrain_height + + @cached_property + def size(self) -> int: + return len(self.points) + + def height_at(self, p: Point2) -> int: + return self._height_map[p] + + @cached_property + def upper(self) -> FrozenSet[Point2]: + """ Returns the upper points of a ramp. """ + current_max = -10000 + result = set() + for p in self.points: + height = self.height_at(p) + if height > current_max: + current_max = height + result = {p} + elif height == current_max: + result.add(p) + return frozenset(result) + + @cached_property + def upper2_for_ramp_wall(self) -> FrozenSet[Point2]: + """ Returns the 2 upper ramp points of the main base ramp required for the supply depot and barracks placement properties used in this file. """ + # From bottom center, find 2 points that are furthest away (within the same ramp) + return frozenset(heapq.nlargest(2, self.upper, key=lambda x: x.distance_to_point2(self.bottom_center))) + + @cached_property + def top_center(self) -> Point2: + length = len(self.upper) + pos = Point2((sum(p.x for p in self.upper) / length, sum(p.y for p in self.upper) / length)) + return pos + + @cached_property + def lower(self) -> FrozenSet[Point2]: + current_min = 10000 + result = set() + for p in self.points: + height = self.height_at(p) + if height < current_min: + current_min = height + result = {p} + elif height == current_min: + result.add(p) + return frozenset(result) + + @cached_property + def bottom_center(self) -> Point2: + length = len(self.lower) + pos = Point2((sum(p.x for p in self.lower) / length, sum(p.y for p in self.lower) / length)) + return pos + + @cached_property + def barracks_in_middle(self) -> Optional[Point2]: + """ Barracks position in the middle of the 2 depots """ + if len(self.upper) not in {2, 5}: + return None + if len(self.upper2_for_ramp_wall) == 2: + points = set(self.upper2_for_ramp_wall) + p1 = points.pop().offset((self.x_offset, self.y_offset)) + p2 = points.pop().offset((self.x_offset, self.y_offset)) + # Offset from top point to barracks center is (2, 1) + intersects = p1.circle_intersection(p2, 5**0.5) + any_lower_point = next(iter(self.lower)) + return max(intersects, key=lambda p: p.distance_to_point2(any_lower_point)) + raise Exception("Not implemented. Trying to access a ramp that has a wrong amount of upper points.") + + @cached_property + def depot_in_middle(self) -> Optional[Point2]: + """ Depot in the middle of the 3 depots """ + if len(self.upper) not in {2, 5}: + return None + if len(self.upper2_for_ramp_wall) == 2: + points = set(self.upper2_for_ramp_wall) + p1 = points.pop().offset((self.x_offset, self.y_offset)) + p2 = points.pop().offset((self.x_offset, self.y_offset)) + # Offset from top point to depot center is (1.5, 0.5) + try: + intersects = p1.circle_intersection(p2, 2.5**0.5) + except AssertionError: + # Returns None when no placement was found, this is the case on the map Honorgrounds LE with an exceptionally large main base ramp + return None + any_lower_point = next(iter(self.lower)) + return max(intersects, key=lambda p: p.distance_to_point2(any_lower_point)) + raise Exception("Not implemented. Trying to access a ramp that has a wrong amount of upper points.") + + @cached_property + def corner_depots(self) -> FrozenSet[Point2]: + """ Finds the 2 depot positions on the outside """ + if not self.upper2_for_ramp_wall: + return frozenset() + if len(self.upper2_for_ramp_wall) == 2: + points = set(self.upper2_for_ramp_wall) + p1 = points.pop().offset((self.x_offset, self.y_offset)) + p2 = points.pop().offset((self.x_offset, self.y_offset)) + center = p1.towards(p2, p1.distance_to_point2(p2) / 2) + depot_position = self.depot_in_middle + if depot_position is None: + return frozenset() + # Offset from middle depot to corner depots is (2, 1) + intersects = center.circle_intersection(depot_position, 5**0.5) + return intersects + raise Exception("Not implemented. Trying to access a ramp that has a wrong amount of upper points.") + + @cached_property + def barracks_can_fit_addon(self) -> bool: + """ Test if a barracks can fit an addon at natural ramp """ + # https://i.imgur.com/4b2cXHZ.png + if len(self.upper2_for_ramp_wall) == 2: + return self.barracks_in_middle.x + 1 > max(self.corner_depots, key=lambda depot: depot.x).x + raise Exception("Not implemented. Trying to access a ramp that has a wrong amount of upper points.") + + @cached_property + def barracks_correct_placement(self) -> Optional[Point2]: + """ Corrected placement so that an addon can fit """ + if self.barracks_in_middle is None: + return None + if len(self.upper2_for_ramp_wall) == 2: + if self.barracks_can_fit_addon: + return self.barracks_in_middle + return self.barracks_in_middle.offset((-2, 0)) + raise Exception("Not implemented. Trying to access a ramp that has a wrong amount of upper points.") + + @cached_property + def protoss_wall_pylon(self) -> Optional[Point2]: + """ + Pylon position that powers the two wall buildings and the warpin position. + """ + if len(self.upper) not in {2, 5}: + return None + if len(self.upper2_for_ramp_wall) != 2: + raise Exception("Not implemented. Trying to access a ramp that has a wrong amount of upper points.") + middle = self.depot_in_middle + # direction up the ramp + direction = self.barracks_in_middle.negative_offset(middle) + return middle + 6 * direction + + @cached_property + def protoss_wall_buildings(self) -> FrozenSet[Point2]: + """ + List of two positions for 3x3 buildings that form a wall with a spot for a one unit block. + These buildings can be powered by a pylon on the protoss_wall_pylon position. + """ + if len(self.upper) not in {2, 5}: + return frozenset() + if len(self.upper2_for_ramp_wall) == 2: + middle = self.depot_in_middle + # direction up the ramp + direction = self.barracks_in_middle.negative_offset(middle) + # sort depots based on distance to start to get wallin orientation + sorted_depots = sorted( + self.corner_depots, key=lambda depot: depot.distance_to(self.game_info.player_start_location) + ) + wall1: Point2 = sorted_depots[1].offset(direction) + wall2 = middle + direction + (middle - wall1) / 1.5 + return frozenset([wall1, wall2]) + raise Exception("Not implemented. Trying to access a ramp that has a wrong amount of upper points.") + + @cached_property + def protoss_wall_warpin(self) -> Optional[Point2]: + """ + Position for a unit to block the wall created by protoss_wall_buildings. + Powered by protoss_wall_pylon. + """ + if len(self.upper) not in {2, 5}: + return None + if len(self.upper2_for_ramp_wall) != 2: + raise Exception("Not implemented. Trying to access a ramp that has a wrong amount of upper points.") + middle = self.depot_in_middle + # direction up the ramp + direction = self.barracks_in_middle.negative_offset(middle) + # sort depots based on distance to start to get wallin orientation + sorted_depots = sorted(self.corner_depots, key=lambda x: x.distance_to(self.game_info.player_start_location)) + return sorted_depots[0].negative_offset(direction) + + +class GameInfo: + + def __init__(self, proto): + self._proto = proto + self.players: List[Player] = [Player.from_proto(p) for p in self._proto.player_info] + self.map_name: str = self._proto.map_name + self.local_map_path: str = self._proto.local_map_path + self.map_size: Size = Size.from_proto(self._proto.start_raw.map_size) + + # self.pathing_grid[point]: if 0, point is not pathable, if 1, point is pathable + self.pathing_grid: PixelMap = PixelMap(self._proto.start_raw.pathing_grid, in_bits=True) + # self.terrain_height[point]: returns the height in range of 0 to 255 at that point + self.terrain_height: PixelMap = PixelMap(self._proto.start_raw.terrain_height) + # self.placement_grid[point]: if 0, point is not placeable, if 1, point is pathable + self.placement_grid: PixelMap = PixelMap(self._proto.start_raw.placement_grid, in_bits=True) + self.playable_area = Rect.from_proto(self._proto.start_raw.playable_area) + self.map_center = self.playable_area.center + self.map_ramps: List[Ramp] = None # Filled later by BotAI._prepare_first_step + self.vision_blockers: FrozenSet[Point2] = None # Filled later by BotAI._prepare_first_step + self.player_races: Dict[int, Race] = { + p.player_id: p.race_actual or p.race_requested + for p in self._proto.player_info + } + self.start_locations: List[Point2] = [ + Point2.from_proto(sl).round(decimals=1) for sl in self._proto.start_raw.start_locations + ] + self.player_start_location: Point2 = None # Filled later by BotAI._prepare_first_step + + def _find_groups(self, points: FrozenSet[Point2], minimum_points_per_group: int = 8) -> Iterable[FrozenSet[Point2]]: + """ + From a set of points, this function will try to group points together by + painting clusters of points in a rectangular map using flood fill algorithm. + Returns groups of points as list, like [{p1, p2, p3}, {p4, p5, p6, p7, p8}] + """ + # TODO do we actually need colors here? the ramps will never touch anyways. + NOT_COLORED_YET = -1 + map_width = self.pathing_grid.width + map_height = self.pathing_grid.height + current_color: int = NOT_COLORED_YET + picture: List[List[int]] = [[-2 for _ in range(map_width)] for _ in range(map_height)] + + def paint(pt: Point2) -> None: + picture[pt.y][pt.x] = current_color + + nearby: List[Tuple[int, int]] = [(a, b) for a in [-1, 0, 1] for b in [-1, 0, 1] if a != 0 or b != 0] + + remaining: Set[Point2] = set(points) + for point in remaining: + paint(point) + current_color = 1 + queue: Deque[Point2] = deque() + while remaining: + current_group: Set[Point2] = set() + if not queue: + start = remaining.pop() + paint(start) + queue.append(start) + current_group.add(start) + while queue: + base: Point2 = queue.popleft() + for offset in nearby: + px, py = base.x + offset[0], base.y + offset[1] + # Do we ever reach out of map bounds? + if not (0 <= px < map_width and 0 <= py < map_height): + continue + if picture[py][px] != NOT_COLORED_YET: + continue + point: Point2 = Point2((px, py)) + remaining.discard(point) + paint(point) + queue.append(point) + current_group.add(point) + if len(current_group) >= minimum_points_per_group: + yield frozenset(current_group) diff --git a/worlds/_sc2common/bot/game_state.py b/worlds/_sc2common/bot/game_state.py new file mode 100644 index 000000000000..61d8f3cacca3 --- /dev/null +++ b/worlds/_sc2common/bot/game_state.py @@ -0,0 +1,204 @@ +from __future__ import annotations + +from dataclasses import dataclass +from functools import cached_property +from itertools import chain +from typing import List, Set + +from .constants import IS_ENEMY, IS_MINE +from .data import Alliance, DisplayType +from .pixel_map import PixelMap +from .position import Point2, Point3 +from .power_source import PsionicMatrix +from .score import ScoreDetails + + +class Blip: + + def __init__(self, proto): + """ + :param proto: + """ + self._proto = proto + + @property + def is_blip(self) -> bool: + """Detected by sensor tower.""" + return self._proto.is_blip + + @property + def is_snapshot(self) -> bool: + return self._proto.display_type == DisplayType.Snapshot.value + + @property + def is_visible(self) -> bool: + return self._proto.display_type == DisplayType.Visible.value + + @property + def alliance(self) -> Alliance: + return self._proto.alliance + + @property + def is_mine(self) -> bool: + return self._proto.alliance == Alliance.Self.value + + @property + def is_enemy(self) -> bool: + return self._proto.alliance == Alliance.Enemy.value + + @property + def position(self) -> Point2: + """2d position of the blip.""" + return Point2.from_proto(self._proto.pos) + + @property + def position3d(self) -> Point3: + """3d position of the blip.""" + return Point3.from_proto(self._proto.pos) + + +class Common: + ATTRIBUTES = [ + "player_id", + "minerals", + "vespene", + "food_cap", + "food_used", + "food_army", + "food_workers", + "idle_worker_count", + "army_count", + "warp_gate_count", + "larva_count", + ] + + def __init__(self, proto): + self._proto = proto + + def __getattr__(self, attr): + assert attr in self.ATTRIBUTES, f"'{attr}' is not a valid attribute" + return int(getattr(self._proto, attr)) + + +class EffectData: + + def __init__(self, proto, fake=False): + """ + :param proto: + :param fake: + """ + self._proto = proto + self.fake = fake + + @property + def positions(self) -> Set[Point2]: + if self.fake: + return {Point2.from_proto(self._proto.pos)} + return {Point2.from_proto(p) for p in self._proto.pos} + + @property + def alliance(self) -> Alliance: + return self._proto.alliance + + @property + def is_mine(self) -> bool: + """ Checks if the effect is caused by me. """ + return self._proto.alliance == IS_MINE + + @property + def is_enemy(self) -> bool: + """ Checks if the effect is hostile. """ + return self._proto.alliance == IS_ENEMY + + @property + def owner(self) -> int: + return self._proto.owner + + @property + def radius(self) -> float: + return self._proto.radius + + def __repr__(self) -> str: + return f"{self.id} with radius {self.radius} at {self.positions}" + + +@dataclass +class ChatMessage: + player_id: int + message: str + + +@dataclass +class ActionRawCameraMove: + center_world_space: Point2 + + + +class GameState: + + def __init__(self, response_observation, previous_observation=None): + """ + :param response_observation: + :param previous_observation: + """ + # Only filled in realtime=True in case the bot skips frames + self.previous_observation = previous_observation + self.response_observation = response_observation + + # https://github.com/Blizzard/s2client-proto/blob/51662231c0965eba47d5183ed0a6336d5ae6b640/s2clientprotocol/sc2api.proto#L575 + self.observation = response_observation.observation + self.observation_raw = self.observation.raw_data + self.player_result = response_observation.player_result + self.common: Common = Common(self.observation.player_common) + + # Area covered by Pylons and Warpprisms + self.psionic_matrix: PsionicMatrix = PsionicMatrix.from_proto(self.observation_raw.player.power_sources) + # 22.4 per second on faster game speed + self.game_loop: int = self.observation.game_loop + + # https://github.com/Blizzard/s2client-proto/blob/33f0ecf615aa06ca845ffe4739ef3133f37265a9/s2clientprotocol/score.proto#L31 + self.score: ScoreDetails = ScoreDetails(self.observation.score) + self.abilities = self.observation.abilities # abilities of selected units + self.upgrades = set() + # self.upgrades: Set[UpgradeId] = {UpgradeId(upgrade) for upgrade in self.observation_raw.player.upgrade_ids} + + # self.visibility[point]: 0=Hidden, 1=Fogged, 2=Visible + self.visibility: PixelMap = PixelMap(self.observation_raw.map_state.visibility) + # self.creep[point]: 0=No creep, 1=creep + self.creep: PixelMap = PixelMap(self.observation_raw.map_state.creep, in_bits=True) + + # Effects like ravager bile shot, lurker attack, everything in effect_id.py + # self.effects: Set[EffectData] = {EffectData(effect) for effect in self.observation_raw.effects} + self.effects = set() + """ Usage: + for effect in self.state.effects: + if effect.id == EffectId.RAVAGERCORROSIVEBILECP: + positions = effect.positions + # dodge the ravager biles + """ + + @cached_property + def dead_units(self) -> Set[int]: + """ A set of unit tags that died this frame """ + _dead_units = set(self.observation_raw.event.dead_units) + if self.previous_observation: + return _dead_units | set(self.previous_observation.observation.raw_data.event.dead_units) + return _dead_units + + @cached_property + def chat(self) -> List[ChatMessage]: + """List of chat messages sent this frame (by either player).""" + previous_frame_chat = self.previous_observation.chat if self.previous_observation else [] + return [ + ChatMessage(message.player_id, message.message) + for message in chain(previous_frame_chat, self.response_observation.chat) + ] + + @cached_property + def alerts(self) -> List[int]: + """ + Game alerts, see https://github.com/Blizzard/s2client-proto/blob/01ab351e21c786648e4c6693d4aad023a176d45c/s2clientprotocol/sc2api.proto#L683-L706 + """ + if self.previous_observation: + return list(chain(self.previous_observation.observation.alerts, self.observation.alerts)) + return self.observation.alerts diff --git a/worlds/_sc2common/bot/main.py b/worlds/_sc2common/bot/main.py new file mode 100644 index 000000000000..f18c56836166 --- /dev/null +++ b/worlds/_sc2common/bot/main.py @@ -0,0 +1,646 @@ +# pylint: disable=W0212 +from __future__ import annotations + +import asyncio +import json +import platform +import signal +from contextlib import suppress +from dataclasses import dataclass +from io import BytesIO +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union + +import mpyq +import portpicker +from aiohttp import ClientSession, ClientWebSocketResponse +from worlds._sc2common.bot import logger +from s2clientprotocol import sc2api_pb2 as sc_pb + +from .bot_ai import BotAI +from .client import Client +from .controller import Controller +from .data import CreateGameError, Result, Status +from .game_state import GameState +from .maps import Map +from .player import AbstractPlayer, Bot, BotProcess, Human +from .portconfig import Portconfig +from .protocol import ConnectionAlreadyClosed, ProtocolError +from .proxy import Proxy +from .sc2process import SC2Process, kill_switch + + +@dataclass +class GameMatch: + """Dataclass for hosting a match of SC2. + This contains all of the needed information for RequestCreateGame. + :param sc2_config: dicts of arguments to unpack into sc2process's construction, one per player + second sc2_config will be ignored if only one sc2_instance is spawned + e.g. sc2_args=[{"fullscreen": True}, {}]: only player 1's sc2instance will be fullscreen + :param game_time_limit: The time (in seconds) until a match is artificially declared a Tie + """ + + map_sc2: Map + players: List[AbstractPlayer] + realtime: bool = False + random_seed: int = None + disable_fog: bool = None + sc2_config: List[Dict] = None + game_time_limit: int = None + + def __post_init__(self): + # avoid players sharing names + if len(self.players) > 1 and self.players[0].name is not None and self.players[0].name == self.players[1].name: + self.players[1].name += "2" + + if self.sc2_config is not None: + if isinstance(self.sc2_config, dict): + self.sc2_config = [self.sc2_config] + if len(self.sc2_config) == 0: + self.sc2_config = [{}] + while len(self.sc2_config) < len(self.players): + self.sc2_config += self.sc2_config + self.sc2_config = self.sc2_config[:len(self.players)] + + @property + def needed_sc2_count(self) -> int: + return sum(player.needs_sc2 for player in self.players) + + @property + def host_game_kwargs(self) -> Dict: + return { + "map_settings": self.map_sc2, + "players": self.players, + "realtime": self.realtime, + "random_seed": self.random_seed, + "disable_fog": self.disable_fog, + } + + def __repr__(self): + p1 = self.players[0] + p1 = p1.name if p1.name else p1 + p2 = self.players[1] + p2 = p2.name if p2.name else p2 + return f"Map: {self.map_sc2.name}, {p1} vs {p2}, realtime={self.realtime}, seed={self.random_seed}" + + +async def _play_game_human(client, player_id, realtime, game_time_limit): + while True: + state = await client.observation() + if client._game_result: + return client._game_result[player_id] + + if game_time_limit and state.observation.observation.game_loop / 22.4 > game_time_limit: + logger.info(state.observation.game_loop, state.observation.game_loop / 22.4) + return Result.Tie + + if not realtime: + await client.step() + + +# pylint: disable=R0912,R0911,R0914 +async def _play_game_ai( + client: Client, player_id: int, ai: BotAI, realtime: bool, game_time_limit: Optional[int] +) -> Result: + gs: GameState = None + + async def initialize_first_step() -> Optional[Result]: + nonlocal gs + ai._initialize_variables() + + game_data = await client.get_game_data() + game_info = await client.get_game_info() + ping_response = await client.ping() + + # This game_data will become self.game_data in botAI + ai._prepare_start( + client, player_id, game_info, game_data, realtime=realtime, base_build=ping_response.ping.base_build + ) + state = await client.observation() + # check game result every time we get the observation + if client._game_result: + await ai.on_end(client._game_result[player_id]) + return client._game_result[player_id] + gs = GameState(state.observation) + proto_game_info = await client._execute(game_info=sc_pb.RequestGameInfo()) + try: + ai._prepare_step(gs, proto_game_info) + await ai.on_before_start() + ai._prepare_first_step() + await ai.on_start() + # TODO Catching too general exception Exception (broad-except) + # pylint: disable=W0703 + except Exception as e: + logger.exception(f"Caught unknown exception in AI on_start: {e}") + logger.error("Resigning due to previous error") + await ai.on_end(Result.Defeat) + return Result.Defeat + + result = await initialize_first_step() + if result is not None: + return result + + async def run_bot_iteration(iteration: int): + nonlocal gs + logger.debug(f"Running AI step, it={iteration} {gs.game_loop / 22.4:.2f}s") + # Issue event like unit created or unit destroyed + await ai.issue_events() + # In on_step various errors can occur - log properly + try: + await ai.on_step(iteration) + except (AttributeError, ) as e: + logger.exception(f"Caught exception: {e}") + raise + except Exception as e: + logger.exception(f"Caught unknown exception: {e}") + raise + await ai._after_step() + logger.debug("Running AI step: done") + + # Only used in realtime=True + previous_state_observation = None + for iteration in range(10**10): + if realtime and gs: + # On realtime=True, might get an error here: sc2.protocol.ProtocolError: ['Not in a game'] + with suppress(ProtocolError): + requested_step = gs.game_loop + client.game_step + state = await client.observation(requested_step) + # If the bot took too long in the previous observation, request another observation one frame after + if state.observation.observation.game_loop > requested_step: + logger.debug("Skipped a step in realtime=True") + previous_state_observation = state.observation + state = await client.observation(state.observation.observation.game_loop + 1) + else: + state = await client.observation() + + # check game result every time we get the observation + if client._game_result: + await ai.on_end(client._game_result[player_id]) + return client._game_result[player_id] + gs = GameState(state.observation, previous_state_observation) + previous_state_observation = None + logger.debug(f"Score: {gs.score.score}") + + if game_time_limit and gs.game_loop / 22.4 > game_time_limit: + await ai.on_end(Result.Tie) + return Result.Tie + proto_game_info = await client._execute(game_info=sc_pb.RequestGameInfo()) + ai._prepare_step(gs, proto_game_info) + + await run_bot_iteration(iteration) # Main bot loop + + if not realtime: + if not client.in_game: # Client left (resigned) the game + await ai.on_end(client._game_result[player_id]) + return client._game_result[player_id] + + # TODO: In bot vs bot, if the other bot ends the game, this bot gets stuck in requesting an observation when using main.py:run_multiple_games + await client.step() + return Result.Undecided + + +async def _play_game( + player: AbstractPlayer, + client: Client, + realtime, + portconfig, + game_time_limit=None, + rgb_render_config=None +) -> Result: + assert isinstance(realtime, bool), repr(realtime) + + player_id = await client.join_game( + player.name, player.race, portconfig=portconfig, rgb_render_config=rgb_render_config + ) + logger.info(f"Player {player_id} - {player.name if player.name else str(player)}") + + if isinstance(player, Human): + result = await _play_game_human(client, player_id, realtime, game_time_limit) + else: + result = await _play_game_ai(client, player_id, player.ai, realtime, game_time_limit) + + logger.info( + f"Result for player {player_id} - {player.name if player.name else str(player)}: " + f"{result._name_ if isinstance(result, Result) else result}" + ) + + return result + +async def _setup_host_game( + server: Controller, map_settings, players, realtime, random_seed=None, disable_fog=None, save_replay_as=None +): + r = await server.create_game(map_settings, players, realtime, random_seed, disable_fog) + if r.create_game.HasField("error"): + err = f"Could not create game: {CreateGameError(r.create_game.error)}" + if r.create_game.HasField("error_details"): + err += f": {r.create_game.error_details}" + logger.critical(err) + raise RuntimeError(err) + + return Client(server._ws, save_replay_as) + + +async def _host_game( + map_settings, + players, + realtime=False, + portconfig=None, + save_replay_as=None, + game_time_limit=None, + rgb_render_config=None, + random_seed=None, + sc2_version=None, + disable_fog=None, +): + + assert players, "Can't create a game without players" + + assert any(isinstance(p, (Human, Bot)) for p in players) + + async with SC2Process( + fullscreen=players[0].fullscreen, render=rgb_render_config is not None, sc2_version=sc2_version + ) as server: + await server.ping() + + client = await _setup_host_game( + server, map_settings, players, realtime, random_seed, disable_fog, save_replay_as + ) + # Bot can decide if it wants to launch with 'raw_affects_selection=True' + if not isinstance(players[0], Human) and getattr(players[0].ai, "raw_affects_selection", None) is not None: + client.raw_affects_selection = players[0].ai.raw_affects_selection + + result = await _play_game(players[0], client, realtime, portconfig, game_time_limit, rgb_render_config) + if client.save_replay_path is not None: + await client.save_replay(client.save_replay_path) + try: + await client.leave() + except ConnectionAlreadyClosed: + logger.error("Connection was closed before the game ended") + await client.quit() + + return result + + +async def _host_game_aiter( + map_settings, + players, + realtime, + portconfig=None, + save_replay_as=None, + game_time_limit=None, +): + assert players, "Can't create a game without players" + + assert any(isinstance(p, (Human, Bot)) for p in players) + + async with SC2Process() as server: + while True: + await server.ping() + + client = await _setup_host_game(server, map_settings, players, realtime) + if not isinstance(players[0], Human) and getattr(players[0].ai, "raw_affects_selection", None) is not None: + client.raw_affects_selection = players[0].ai.raw_affects_selection + + try: + result = await _play_game(players[0], client, realtime, portconfig, game_time_limit) + + if save_replay_as is not None: + await client.save_replay(save_replay_as) + await client.leave() + except ConnectionAlreadyClosed: + logger.error("Connection was closed before the game ended") + return + + new_players = yield result + if new_players is not None: + players = new_players + + +def _host_game_iter(*args, **kwargs): + game = _host_game_aiter(*args, **kwargs) + new_playerconfig = None + while True: + new_playerconfig = yield asyncio.get_event_loop().run_until_complete(game.asend(new_playerconfig)) + + +async def _join_game( + players, + realtime, + portconfig, + save_replay_as=None, + game_time_limit=None, +): + async with SC2Process(fullscreen=players[1].fullscreen) as server: + await server.ping() + + client = Client(server._ws) + # Bot can decide if it wants to launch with 'raw_affects_selection=True' + if not isinstance(players[1], Human) and getattr(players[1].ai, "raw_affects_selection", None) is not None: + client.raw_affects_selection = players[1].ai.raw_affects_selection + + result = await _play_game(players[1], client, realtime, portconfig, game_time_limit) + if save_replay_as is not None: + await client.save_replay(save_replay_as) + try: + await client.leave() + except ConnectionAlreadyClosed: + logger.error("Connection was closed before the game ended") + await client.quit() + + return result + + +def get_replay_version(replay_path: Union[str, Path]) -> Tuple[str, str]: + with open(replay_path, 'rb') as f: + replay_data = f.read() + replay_io = BytesIO() + replay_io.write(replay_data) + replay_io.seek(0) + archive = mpyq.MPQArchive(replay_io).extract() + metadata = json.loads(archive[b"replay.gamemetadata.json"].decode("utf-8")) + return metadata["BaseBuild"], metadata["DataVersion"] + + +# TODO Deprecate run_game function in favor of run_multiple_games +def run_game(map_settings, players, **kwargs) -> Union[Result, List[Optional[Result]]]: + """ + Returns a single Result enum if the game was against the built-in computer. + Returns a list of two Result enums if the game was "Human vs Bot" or "Bot vs Bot". + """ + if sum(isinstance(p, (Human, Bot)) for p in players) > 1: + host_only_args = ["save_replay_as", "rgb_render_config", "random_seed", "sc2_version", "disable_fog"] + join_kwargs = {k: v for k, v in kwargs.items() if k not in host_only_args} + + portconfig = Portconfig() + + async def run_host_and_join(): + return await asyncio.gather( + _host_game(map_settings, players, **kwargs, portconfig=portconfig), + _join_game(players, **join_kwargs, portconfig=portconfig), + return_exceptions=True + ) + + result: List[Result] = asyncio.run(run_host_and_join()) + assert isinstance(result, list) + assert all(isinstance(r, Result) for r in result) + else: + result: Result = asyncio.run(_host_game(map_settings, players, **kwargs)) + assert isinstance(result, Result) + return result + + +async def play_from_websocket( + ws_connection: Union[str, ClientWebSocketResponse], + player: AbstractPlayer, + realtime: bool = False, + portconfig: Portconfig = None, + save_replay_as=None, + game_time_limit: int = None, + should_close=True, +): + """Use this to play when the match is handled externally e.g. for bot ladder games. + Portconfig MUST be specified if not playing vs Computer. + :param ws_connection: either a string("ws://{address}:{port}/sc2api") or a ClientWebSocketResponse object + :param should_close: closes the connection if True. Use False if something else will reuse the connection + + e.g. ladder usage: play_from_websocket("ws://127.0.0.1:5162/sc2api", MyBot, False, portconfig=my_PC) + """ + session = None + try: + if isinstance(ws_connection, str): + session = ClientSession() + ws_connection = await session.ws_connect(ws_connection, timeout=120) + should_close = True + client = Client(ws_connection) + result = await _play_game(player, client, realtime, portconfig, game_time_limit=game_time_limit) + if save_replay_as is not None: + await client.save_replay(save_replay_as) + except ConnectionAlreadyClosed: + logger.error("Connection was closed before the game ended") + return None + finally: + if should_close: + await ws_connection.close() + if session: + await session.close() + + return result + + +async def run_match(controllers: List[Controller], match: GameMatch, close_ws=True): + await _setup_host_game(controllers[0], **match.host_game_kwargs) + + # Setup portconfig beforehand, so all players use the same ports + startport = None + portconfig = None + if match.needed_sc2_count > 1: + if any(isinstance(player, BotProcess) for player in match.players): + portconfig = Portconfig.contiguous_ports() + # Most ladder bots generate their server and client ports as [s+2, s+3], [s+4, s+5] + startport = portconfig.server[0] - 2 + else: + portconfig = Portconfig() + + proxies = [] + coros = [] + players_that_need_sc2 = filter(lambda lambda_player: lambda_player.needs_sc2, match.players) + for i, player in enumerate(players_that_need_sc2): + if isinstance(player, BotProcess): + pport = portpicker.pick_unused_port() + p = Proxy(controllers[i], player, pport, match.game_time_limit, match.realtime) + proxies.append(p) + coros.append(p.play_with_proxy(startport)) + else: + coros.append( + play_from_websocket( + controllers[i]._ws, + player, + match.realtime, + portconfig, + should_close=close_ws, + game_time_limit=match.game_time_limit, + ) + ) + + async_results = await asyncio.gather(*coros, return_exceptions=True) + + if not isinstance(async_results, list): + async_results = [async_results] + for i, a in enumerate(async_results): + if isinstance(a, Exception): + logger.error(f"Exception[{a}] thrown by {[p for p in match.players if p.needs_sc2][i]}") + + return process_results(match.players, async_results) + + +def process_results(players: List[AbstractPlayer], async_results: List[Result]) -> Dict[AbstractPlayer, Result]: + opp_res = {Result.Victory: Result.Defeat, Result.Defeat: Result.Victory, Result.Tie: Result.Tie} + result: Dict[AbstractPlayer, Result] = {} + i = 0 + for player in players: + if player.needs_sc2: + if sum(r == Result.Victory for r in async_results) <= 1: + result[player] = async_results[i] + else: + result[player] = Result.Undecided + i += 1 + else: # computer + other_result = async_results[0] + result[player] = None + if other_result in opp_res: + result[player] = opp_res[other_result] + + return result + + +# pylint: disable=R0912 +async def maintain_SCII_count(count: int, controllers: List[Controller], proc_args: List[Dict] = None): + """Modifies the given list of controllers to reflect the desired amount of SCII processes""" + # kill unhealthy ones. + if controllers: + to_remove = [] + alive = await asyncio.wait_for( + asyncio.gather(*(c.ping() for c in controllers if not c._ws.closed), return_exceptions=True), timeout=20 + ) + i = 0 # for alive + for controller in controllers: + if controller._ws.closed: + if not controller._process._session.closed: + await controller._process._session.close() + to_remove.append(controller) + else: + if not isinstance(alive[i], sc_pb.Response): + try: + await controller._process._close_connection() + finally: + to_remove.append(controller) + i += 1 + for c in to_remove: + c._process._clean(verbose=False) + if c._process in kill_switch._to_kill: + kill_switch._to_kill.remove(c._process) + controllers.remove(c) + + # spawn more + if len(controllers) < count: + needed = count - len(controllers) + if proc_args: + index = len(controllers) % len(proc_args) + else: + proc_args = [{} for _ in range(needed)] + index = 0 + extra = [SC2Process(**proc_args[(index + _) % len(proc_args)]) for _ in range(needed)] + logger.info(f"Creating {needed} more SC2 Processes") + for _ in range(3): + if platform.system() == "Linux": + # Works on linux: start one client after the other + # pylint: disable=C2801 + new_controllers = [await asyncio.wait_for(sc.__aenter__(), timeout=50) for sc in extra] + else: + # Doesnt seem to work on linux: starting 2 clients nearly at the same time + new_controllers = await asyncio.wait_for( + # pylint: disable=C2801 + asyncio.gather(*[sc.__aenter__() for sc in extra], return_exceptions=True), + timeout=50 + ) + + controllers.extend(c for c in new_controllers if isinstance(c, Controller)) + if len(controllers) == count: + await asyncio.wait_for(asyncio.gather(*(c.ping() for c in controllers)), timeout=20) + break + extra = [ + extra[i] for i, result in enumerate(new_controllers) if not isinstance(new_controllers, Controller) + ] + else: + logger.critical("Could not launch sufficient SC2") + raise RuntimeError + + # kill excess + while len(controllers) > count: + proc = controllers.pop() + proc = proc._process + logger.info(f"Removing SCII listening to {proc._port}") + await proc._close_connection() + proc._clean(verbose=False) + if proc in kill_switch._to_kill: + kill_switch._to_kill.remove(proc) + + +def run_multiple_games(matches: List[GameMatch]): + return asyncio.get_event_loop().run_until_complete(a_run_multiple_games(matches)) + + +# TODO Catching too general exception Exception (broad-except) +# pylint: disable=W0703 +async def a_run_multiple_games(matches: List[GameMatch]) -> List[Dict[AbstractPlayer, Result]]: + """Run multiple matches. + Non-python bots are supported. + When playing bot vs bot, this is less likely to fatally crash than repeating run_game() + """ + if not matches: + return [] + + results = [] + controllers = [] + for m in matches: + result = None + dont_restart = m.needed_sc2_count == 2 + try: + await maintain_SCII_count(m.needed_sc2_count, controllers, m.sc2_config) + result = await run_match(controllers, m, close_ws=dont_restart) + except SystemExit as e: + logger.info(f"Game exit'ed as {e} during match {m}") + except Exception as e: + logger.exception(f"Caught unknown exception: {e}") + logger.info(f"Exception {e} thrown in match {m}") + finally: + if dont_restart: # Keeping them alive after a non-computer match can cause crashes + await maintain_SCII_count(0, controllers, m.sc2_config) + results.append(result) + kill_switch.kill_all() + return results + + +# TODO Catching too general exception Exception (broad-except) +# pylint: disable=W0703 +async def a_run_multiple_games_nokill(matches: List[GameMatch]) -> List[Dict[AbstractPlayer, Result]]: + """Run multiple matches while reusing SCII processes. + Prone to crashes and stalls + """ + # FIXME: check whether crashes between bot-vs-bot are avoidable or not + if not matches: + return [] + + # Start the matches + results = [] + controllers = [] + for m in matches: + logger.info(f"Starting match {1 + len(results)} / {len(matches)}: {m}") + result = None + try: + await maintain_SCII_count(m.needed_sc2_count, controllers, m.sc2_config) + result = await run_match(controllers, m, close_ws=False) + except SystemExit as e: + logger.critical(f"Game sys.exit'ed as {e} during match {m}") + except Exception as e: + logger.exception(f"Caught unknown exception: {e}") + logger.info(f"Exception {e} thrown in match {m}") + finally: + for c in controllers: + try: + await c.ping() + if c._status != Status.launched: + await c._execute(leave_game=sc_pb.RequestLeaveGame()) + except Exception as e: + logger.exception(f"Caught unknown exception: {e}") + if not (isinstance(e, ProtocolError) and e.is_game_over_error): + logger.info(f"controller {c.__dict__} threw {e}") + + results.append(result) + + # Fire the killswitch manually, instead of letting the winning player fire it. + await asyncio.wait_for(asyncio.gather(*(c._process._close_connection() for c in controllers)), timeout=50) + kill_switch.kill_all() + signal.signal(signal.SIGINT, signal.SIG_DFL) + + return results diff --git a/worlds/_sc2common/bot/maps.py b/worlds/_sc2common/bot/maps.py new file mode 100644 index 000000000000..f14b5af9009e --- /dev/null +++ b/worlds/_sc2common/bot/maps.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +from pathlib import Path + +from worlds._sc2common.bot import logger + +from .paths import Paths + + +def get(name: str) -> Map: + # Iterate through 2 folder depths + for map_dir in (p for p in Paths.MAPS.iterdir()): + if map_dir.is_dir(): + for map_file in (p for p in map_dir.iterdir()): + if Map.matches_target_map_name(map_file, name): + return Map(map_file) + elif Map.matches_target_map_name(map_dir, name): + return Map(map_dir) + + raise KeyError(f"Map '{name}' was not found. Please put the map file in \"/StarCraft II/Maps/\".") + + +class Map: + + def __init__(self, path: Path): + self.path = path + + if self.path.is_absolute(): + try: + self.relative_path = self.path.relative_to(Paths.MAPS) + except ValueError: # path not relative to basedir + logger.warning(f"Using absolute path: {self.path}") + self.relative_path = self.path + else: + self.relative_path = self.path + + @property + def name(self): + return self.path.stem + + @property + def data(self): + with open(self.path, "rb") as f: + return f.read() + + def __repr__(self): + return f"Map({self.path})" + + @classmethod + def is_map_file(cls, file: Path) -> bool: + return file.is_file() and file.suffix == ".SC2Map" + + @classmethod + def matches_target_map_name(cls, file: Path, name: str) -> bool: + return cls.is_map_file(file) and file.stem == name diff --git a/worlds/_sc2common/bot/observer_ai.py b/worlds/_sc2common/bot/observer_ai.py new file mode 100644 index 000000000000..362012b95995 --- /dev/null +++ b/worlds/_sc2common/bot/observer_ai.py @@ -0,0 +1,155 @@ +""" +This class is very experimental and probably not up to date and needs to be refurbished. +If it works, you can watch replays with it. +""" + +# pylint: disable=W0201,W0212 +from __future__ import annotations + +from typing import TYPE_CHECKING, List, Union + +from .bot_ai_internal import BotAIInternal +from .data import Alert, Result +from .game_data import GameData +from .position import Point2 +from .unit import Unit +from .units import Units + +if TYPE_CHECKING: + from .client import Client + from .game_info import GameInfo + + +class ObserverAI(BotAIInternal): + """Base class for bots.""" + + @property + def time(self) -> float: + """ Returns time in seconds, assumes the game is played on 'faster' """ + return self.state.game_loop / 22.4 # / (1/1.4) * (1/16) + + @property + def time_formatted(self) -> str: + """ Returns time as string in min:sec format """ + t = self.time + return f"{int(t // 60):02}:{int(t % 60):02}" + + @property + def game_info(self) -> GameInfo: + """ See game_info.py """ + return self._game_info + + @property + def game_data(self) -> GameData: + """ See game_data.py """ + return self._game_data + + @property + def client(self) -> Client: + """ See client.py """ + return self._client + + def alert(self, alert_code: Alert) -> bool: + """ + Check if alert is triggered in the current step. + Possible alerts are listed here https://github.com/Blizzard/s2client-proto/blob/e38efed74c03bec90f74b330ea1adda9215e655f/s2clientprotocol/sc2api.proto#L679-L702 + + Example use: + + from sc2.data import Alert + if self.alert(Alert.AddOnComplete): + print("Addon Complete") + + Alert codes:: + + AlertError + AddOnComplete + BuildingComplete + BuildingUnderAttack + LarvaHatched + MergeComplete + MineralsExhausted + MorphComplete + MothershipComplete + MULEExpired + NuclearLaunchDetected + NukeComplete + NydusWormDetected + ResearchComplete + TrainError + TrainUnitComplete + TrainWorkerComplete + TransformationComplete + UnitUnderAttack + UpgradeComplete + VespeneExhausted + WarpInComplete + + :param alert_code: + """ + assert isinstance(alert_code, Alert), f"alert_code {alert_code} is no Alert" + return alert_code.value in self.state.alerts + + @property + def start_location(self) -> Point2: + """ + Returns the spawn location of the bot, using the position of the first created townhall. + This will be None if the bot is run on an arcade or custom map that does not feature townhalls at game start. + """ + return self.game_info.player_start_location + + @property + def enemy_start_locations(self) -> List[Point2]: + """Possible start locations for enemies.""" + return self.game_info.start_locations + + async def on_unit_destroyed(self, unit_tag: int): + """ + Override this in your bot class. + This will event will be called when a unit (or structure, friendly or enemy) dies. + For enemy units, this only works if the enemy unit was in vision on death. + + :param unit_tag: + """ + + async def on_unit_created(self, unit: Unit): + """Override this in your bot class. This function is called when a unit is created. + + :param unit:""" + + async def on_building_construction_started(self, unit: Unit): + """ + Override this in your bot class. + This function is called when a building construction has started. + + :param unit: + """ + + async def on_building_construction_complete(self, unit: Unit): + """ + Override this in your bot class. This function is called when a building + construction is completed. + + :param unit: + """ + + async def on_start(self): + """ + Override this in your bot class. This function is called after "on_start". + At this point, game_data, game_info and the first iteration of game_state (self.state) are available. + """ + + async def on_step(self, iteration: int): + """ + You need to implement this function! + Override this in your bot class. + This function is called on every game step (looped in realtime mode). + + :param iteration: + """ + raise NotImplementedError + + async def on_end(self, game_result: Result): + """Override this in your bot class. This function is called at the end of a game. + + :param game_result:""" diff --git a/worlds/_sc2common/bot/paths.py b/worlds/_sc2common/bot/paths.py new file mode 100644 index 000000000000..5648f80bbaf0 --- /dev/null +++ b/worlds/_sc2common/bot/paths.py @@ -0,0 +1,157 @@ +import os +import platform +import re +import sys +from contextlib import suppress +from pathlib import Path + +from worlds._sc2common.bot import logger + +from . import wsl + +BASEDIR = { + "Windows": "C:/Program Files (x86)/StarCraft II", + "WSL1": "/mnt/c/Program Files (x86)/StarCraft II", + "WSL2": "/mnt/c/Program Files (x86)/StarCraft II", + "Darwin": "/Applications/StarCraft II", + "Linux": "~/StarCraftII", + "WineLinux": "~/.wine/drive_c/Program Files (x86)/StarCraft II", +} + +USERPATH = { + "Windows": "Documents\\StarCraft II\\ExecuteInfo.txt", + "WSL1": "Documents/StarCraft II/ExecuteInfo.txt", + "WSL2": "Documents/StarCraft II/ExecuteInfo.txt", + "Darwin": "Library/Application Support/Blizzard/StarCraft II/ExecuteInfo.txt", + "Linux": None, + "WineLinux": None, +} + +BINPATH = { + "Windows": "SC2_x64.exe", + "WSL1": "SC2_x64.exe", + "WSL2": "SC2_x64.exe", + "Darwin": "SC2.app/Contents/MacOS/SC2", + "Linux": "SC2_x64", + "WineLinux": "SC2_x64.exe", +} + +CWD = { + "Windows": "Support64", + "WSL1": "Support64", + "WSL2": "Support64", + "Darwin": None, + "Linux": None, + "WineLinux": "Support64", +} + + +def platform_detect(): + pf = os.environ.get("SC2PF", platform.system()) + if pf == "Linux": + return wsl.detect() or pf + return pf + + +PF = platform_detect() + + +def get_home(): + """Get home directory of user, using Windows home directory for WSL.""" + if PF in {"WSL1", "WSL2"}: + return wsl.get_wsl_home() or Path.home().expanduser() + return Path.home().expanduser() + + +def get_user_sc2_install(): + """Attempts to find a user's SC2 install if their OS has ExecuteInfo.txt""" + if USERPATH[PF]: + einfo = str(get_home() / Path(USERPATH[PF])) + if os.path.isfile(einfo): + with open(einfo) as f: + content = f.read() + if content: + base = re.search(r" = (.*)Versions", content).group(1) + if PF in {"WSL1", "WSL2"}: + base = str(wsl.win_path_to_wsl_path(base)) + + if os.path.exists(base): + return base + return None + + +def get_env(): + # TODO: Linux env conf from: https://github.com/deepmind/pysc2/blob/master/pysc2/run_configs/platforms.py + return None + + +def get_runner_args(cwd): + if "WINE" in os.environ: + runner_file = Path(os.environ.get("WINE")) + runner_file = runner_file if runner_file.is_file() else runner_file / "wine" + """ + TODO Is converting linux path really necessary? + That would convert + '/home/burny/Games/battlenet/drive_c/Program Files (x86)/StarCraft II/Support64' + to + 'Z:\\home\\burny\\Games\\battlenet\\drive_c\\Program Files (x86)\\StarCraft II\\Support64' + """ + return [runner_file, "start", "/d", cwd, "/unix"] + return [] + + +def latest_executeble(versions_dir, base_build=None): + latest = None + + if base_build is not None: + with suppress(ValueError): + latest = ( + int(base_build[4:]), + max(p for p in versions_dir.iterdir() if p.is_dir() and p.name.startswith(str(base_build))), + ) + + if base_build is None or latest is None: + latest = max((int(p.name[4:]), p) for p in versions_dir.iterdir() if p.is_dir() and p.name.startswith("Base")) + + version, path = latest + + if version < 55958: + logger.critical("Your SC2 binary is too old. Upgrade to 3.16.1 or newer.") + sys.exit(1) + return path / BINPATH[PF] + + +class _MetaPaths(type): + """"Lazily loads paths to allow importing the library even if SC2 isn't installed.""" + + # pylint: disable=C0203 + def __setup(self): + if PF not in BASEDIR: + logger.critical(f"Unsupported platform '{PF}'") + sys.exit(1) + + try: + base = os.environ.get("SC2PATH") or get_user_sc2_install() or BASEDIR[PF] + self.BASE = Path(base).expanduser() + self.EXECUTABLE = latest_executeble(self.BASE / "Versions") + self.CWD = self.BASE / CWD[PF] if CWD[PF] else None + + self.REPLAYS = self.BASE / "Replays" + + if (self.BASE / "maps").exists(): + self.MAPS = self.BASE / "maps" + else: + self.MAPS = self.BASE / "Maps" + except FileNotFoundError as e: + logger.critical(f"SC2 installation not found: File '{e.filename}' does not exist.") + sys.exit(1) + + # pylint: disable=C0203 + def __getattr__(self, attr): + # pylint: disable=E1120 + self.__setup() + return getattr(self, attr) + + +class Paths(metaclass=_MetaPaths): + """Paths for SC2 folders, lazily loaded using the above metaclass.""" diff --git a/worlds/_sc2common/bot/pixel_map.py b/worlds/_sc2common/bot/pixel_map.py new file mode 100644 index 000000000000..bc418c7dff9a --- /dev/null +++ b/worlds/_sc2common/bot/pixel_map.py @@ -0,0 +1,98 @@ +from typing import Callable, FrozenSet, List, Set, Tuple, Union + +from .position import Point2 + + +class PixelMap: + + def __init__(self, proto, in_bits: bool = False): + """ + :param proto: + :param in_bits: + """ + self._proto = proto + # Used for copying pixelmaps + self._in_bits: bool = in_bits + + assert self.width * self.height == (8 if in_bits else 1) * len( + self._proto.data + ), f"{self.width * self.height} {(8 if in_bits else 1)*len(self._proto.data)}" + + @property + def width(self) -> int: + return self._proto.size.x + + @property + def height(self) -> int: + return self._proto.size.y + + @property + def bits_per_pixel(self) -> int: + return self._proto.bits_per_pixel + + @property + def bytes_per_pixel(self) -> int: + return self._proto.bits_per_pixel // 8 + + def __getitem__(self, pos: Tuple[int, int]) -> int: + """ Example usage: is_pathable = self._game_info.pathing_grid[Point2((20, 20))] != 0 """ + assert 0 <= pos[0] < self.width, f"x is {pos[0]}, self.width is {self.width}" + assert 0 <= pos[1] < self.height, f"y is {pos[1]}, self.height is {self.height}" + return int(self.data_numpy[pos[1], pos[0]]) + + def __setitem__(self, pos: Tuple[int, int], value: int): + """ Example usage: self._game_info.pathing_grid[Point2((20, 20))] = 255 """ + assert 0 <= pos[0] < self.width, f"x is {pos[0]}, self.width is {self.width}" + assert 0 <= pos[1] < self.height, f"y is {pos[1]}, self.height is {self.height}" + assert ( + 0 <= value <= 254 * self._in_bits + 1 + ), f"value is {value}, it should be between 0 and {254 * self._in_bits + 1}" + assert isinstance(value, int), f"value is of type {type(value)}, it should be an integer" + self.data_numpy[pos[1], pos[0]] = value + + def is_set(self, p: Tuple[int, int]) -> bool: + return self[p] != 0 + + def is_empty(self, p: Tuple[int, int]) -> bool: + return not self.is_set(p) + + def copy(self) -> "PixelMap": + return PixelMap(self._proto, in_bits=self._in_bits) + + def flood_fill(self, start_point: Point2, pred: Callable[[int], bool]) -> Set[Point2]: + nodes: Set[Point2] = set() + queue: List[Point2] = [start_point] + + while queue: + x, y = queue.pop() + + if not (0 <= x < self.width and 0 <= y < self.height): + continue + + if Point2((x, y)) in nodes: + continue + + if pred(self[x, y]): + nodes.add(Point2((x, y))) + queue += [Point2((x + a, y + b)) for a in [-1, 0, 1] for b in [-1, 0, 1] if not (a == 0 and b == 0)] + return nodes + + def flood_fill_all(self, pred: Callable[[int], bool]) -> Set[FrozenSet[Point2]]: + groups: Set[FrozenSet[Point2]] = set() + + for x in range(self.width): + for y in range(self.height): + if any((x, y) in g for g in groups): + continue + + if pred(self[x, y]): + groups.add(frozenset(self.flood_fill(Point2((x, y)), pred))) + + return groups + + def print(self, wide: bool = False) -> None: + for y in range(self.height): + for x in range(self.width): + print("#" if self.is_set((x, y)) else " ", end=(" " if wide else "")) + print("") + diff --git a/worlds/_sc2common/bot/player.py b/worlds/_sc2common/bot/player.py new file mode 100644 index 000000000000..3af69de1d1e2 --- /dev/null +++ b/worlds/_sc2common/bot/player.py @@ -0,0 +1,193 @@ +from abc import ABC +from pathlib import Path +from typing import List, Union + +from .bot_ai import BotAI +from .data import AIBuild, Difficulty, PlayerType, Race + + +class AbstractPlayer(ABC): + + def __init__( + self, + p_type: PlayerType, + race: Race = None, + name: str = None, + difficulty=None, + ai_build=None, + fullscreen=False + ): + assert isinstance(p_type, PlayerType), f"p_type is of type {type(p_type)}" + assert name is None or isinstance(name, str), f"name is of type {type(name)}" + + self.name = name + self.type = p_type + self.fullscreen = fullscreen + if race is not None: + self.race = race + if p_type == PlayerType.Computer: + assert isinstance(difficulty, Difficulty), f"difficulty is of type {type(difficulty)}" + # Workaround, proto information does not carry ai_build info + # We cant set that in the Player classmethod + assert ai_build is None or isinstance(ai_build, AIBuild), f"ai_build is of type {type(ai_build)}" + self.difficulty = difficulty + self.ai_build = ai_build + + elif p_type == PlayerType.Observer: + assert race is None + assert difficulty is None + assert ai_build is None + + else: + assert isinstance(race, Race), f"race is of type {type(race)}" + assert difficulty is None + assert ai_build is None + + @property + def needs_sc2(self): + return not isinstance(self, Computer) + + +class Human(AbstractPlayer): + + def __init__(self, race, name=None, fullscreen=False): + super().__init__(PlayerType.Participant, race, name=name, fullscreen=fullscreen) + + def __str__(self): + if self.name is not None: + return f"Human({self.race._name_}, name={self.name !r})" + return f"Human({self.race._name_})" + + +class Bot(AbstractPlayer): + + def __init__(self, race, ai, name=None, fullscreen=False): + """ + AI can be None if this player object is just used to inform the + server about player types. + """ + assert isinstance(ai, BotAI) or ai is None, f"ai is of type {type(ai)}, inherit BotAI from bot_ai.py" + super().__init__(PlayerType.Participant, race, name=name, fullscreen=fullscreen) + self.ai = ai + + def __str__(self): + if self.name is not None: + return f"Bot {self.ai.__class__.__name__}({self.race._name_}), name={self.name !r})" + return f"Bot {self.ai.__class__.__name__}({self.race._name_})" + + +class Computer(AbstractPlayer): + + def __init__(self, race, difficulty=Difficulty.Easy, ai_build=AIBuild.RandomBuild): + super().__init__(PlayerType.Computer, race, difficulty=difficulty, ai_build=ai_build) + + def __str__(self): + return f"Computer {self.difficulty._name_}({self.race._name_}, {self.ai_build.name})" + + +class Observer(AbstractPlayer): + + def __init__(self): + super().__init__(PlayerType.Observer) + + def __str__(self): + return "Observer" + + +class Player(AbstractPlayer): + + def __init__(self, player_id, p_type, requested_race, difficulty=None, actual_race=None, name=None, ai_build=None): + super().__init__(p_type, requested_race, difficulty=difficulty, name=name, ai_build=ai_build) + self.id: int = player_id + self.actual_race: Race = actual_race + + @classmethod + def from_proto(cls, proto): + if PlayerType(proto.type) == PlayerType.Observer: + return cls(proto.player_id, PlayerType(proto.type), None, None, None) + return cls( + proto.player_id, + PlayerType(proto.type), + Race(proto.race_requested), + Difficulty(proto.difficulty) if proto.HasField("difficulty") else None, + Race(proto.race_actual) if proto.HasField("race_actual") else None, + proto.player_name if proto.HasField("player_name") else None, + ) + + +class BotProcess(AbstractPlayer): + """ + Class for handling bots launched externally, including non-python bots. + Default parameters comply with sc2ai and aiarena ladders. + + :param path: the executable file's path + :param launch_list: list of strings that launches the bot e.g. ["python", "run.py"] or ["run.exe"] + :param race: bot's race + :param name: bot's name + :param sc2port_arg: the accepted argument name for the port of the sc2 instance to listen to + :param hostaddress_arg: the accepted argument name for the address of the sc2 instance to listen to + :param match_arg: the accepted argument name for the starting port to generate a portconfig from + :param realtime_arg: the accepted argument name for specifying realtime + :param other_args: anything else that is needed + + e.g. to call a bot capable of running on the bot ladders: + BotProcess(os.getcwd(), "python run.py", Race.Terran, "INnoVation") + """ + + def __init__( + self, + path: Union[str, Path], + launch_list: List[str], + race: Race, + name=None, + sc2port_arg="--GamePort", + hostaddress_arg="--LadderServer", + match_arg="--StartPort", + realtime_arg="--RealTime", + other_args: str = None, + stdout: str = None, + ): + super().__init__(PlayerType.Participant, race, name=name) + assert Path(path).exists() + self.path = path + self.launch_list = launch_list + self.sc2port_arg = sc2port_arg + self.match_arg = match_arg + self.hostaddress_arg = hostaddress_arg + self.realtime_arg = realtime_arg + self.other_args = other_args + self.stdout = stdout + + def __repr__(self): + if self.name is not None: + return f"Bot {self.name}({self.race.name} from {self.launch_list})" + return f"Bot({self.race.name} from {self.launch_list})" + + def cmd_line(self, + sc2port: Union[int, str], + matchport: Union[int, str], + hostaddress: str, + realtime: bool = False) -> List[str]: + """ + + :param sc2port: the port that the launched sc2 instance listens to + :param matchport: some starting port that both bots use to generate identical portconfigs. + Note: This will not be sent if playing vs computer + :param hostaddress: the address the sc2 instances used + :param realtime: 1 or 0, indicating whether the match is played in realtime or not + :return: string that will be used to start the bot's process + """ + cmd_line = [ + *self.launch_list, + self.sc2port_arg, + str(sc2port), + self.hostaddress_arg, + hostaddress, + ] + if matchport is not None: + cmd_line.extend([self.match_arg, str(matchport)]) + if self.other_args is not None: + cmd_line.append(self.other_args) + if realtime: + cmd_line.extend([self.realtime_arg]) + return cmd_line diff --git a/worlds/_sc2common/bot/portconfig.py b/worlds/_sc2common/bot/portconfig.py new file mode 100644 index 000000000000..78011d89b3b4 --- /dev/null +++ b/worlds/_sc2common/bot/portconfig.py @@ -0,0 +1,69 @@ +import json + +import portpicker + + +class Portconfig: + """ + A data class for ports used by participants to join a match. + + EVERY participant joining the match must send the same sets of ports to join successfully. + SC2 needs 2 ports per connection (one for data, one as a 'header'), which is why the ports come in pairs. + + :param guests: number of non-hosting participants in a match (i.e. 1 less than the number of participants) + :param server_ports: [int portA, int portB] + :param player_ports: [[int port1A, int port1B], [int port2A, int port2B], ... ] + + .shared is deprecated, and should TODO be removed soon (once ladderbots' __init__.py doesnt specify them). + + .server contains the pair of ports used by the participant 'hosting' the match + + .players contains a pair of ports for every 'guest' (non-hosting participants) in the match + E.g. for 1v1, there will be only 1 guest. For 2v2 (coming soonTM), there would be 3 guests. + """ + + def __init__(self, guests=1, server_ports=None, player_ports=None): + self.shared = None + self._picked_ports = [] + if server_ports: + self.server = server_ports + else: + self.server = [portpicker.pick_unused_port() for _ in range(2)] + self._picked_ports.extend(self.server) + if player_ports: + self.players = player_ports + else: + self.players = [[portpicker.pick_unused_port() for _ in range(2)] for _ in range(guests)] + self._picked_ports.extend(port for player in self.players for port in player) + + def clean(self): + while self._picked_ports: + portpicker.return_port(self._picked_ports.pop()) + + def __str__(self): + return f"Portconfig(shared={self.shared}, server={self.server}, players={self.players})" + + @property + def as_json(self): + return json.dumps({"shared": self.shared, "server": self.server, "players": self.players}) + + @classmethod + def contiguous_ports(cls, guests=1, attempts=40): + """Returns a Portconfig with adjacent ports""" + for _ in range(attempts): + start = portpicker.pick_unused_port() + others = [start + j for j in range(1, 2 + guests * 2)] + if all(portpicker.is_port_free(p) for p in others): + server_ports = [start, others.pop(0)] + player_ports = [] + while others: + player_ports.append([others.pop(0), others.pop(0)]) + pc = cls(server_ports=server_ports, player_ports=player_ports) + pc._picked_ports.append(start) + return pc + raise portpicker.NoFreePortFoundError() + + @classmethod + def from_json(cls, json_data): + data = json.loads(json_data) + return cls(server_ports=data["server"], player_ports=data["players"]) diff --git a/worlds/_sc2common/bot/position.py b/worlds/_sc2common/bot/position.py new file mode 100644 index 000000000000..aca9a5105cbe --- /dev/null +++ b/worlds/_sc2common/bot/position.py @@ -0,0 +1,411 @@ +from __future__ import annotations + +import itertools +import math +import random +from typing import TYPE_CHECKING, Iterable, List, Set, Tuple, Union + +from s2clientprotocol import common_pb2 as common_pb + +if TYPE_CHECKING: + from .unit import Unit + from .units import Units + +EPSILON = 10**-8 + + +def _sign(num): + return math.copysign(1, num) + + +class Pointlike(tuple): + + @property + def position(self) -> Pointlike: + return self + + def distance_to(self, target: Union[Unit, Point2]) -> float: + """Calculate a single distance from a point or unit to another point or unit + + :param target:""" + p = target.position + return math.hypot(self[0] - p[0], self[1] - p[1]) + + def distance_to_point2(self, p: Union[Point2, Tuple[float, float]]) -> float: + """Same as the function above, but should be a bit faster because of the dropped asserts + and conversion. + + :param p:""" + return math.hypot(self[0] - p[0], self[1] - p[1]) + + def _distance_squared(self, p2: Point2) -> float: + """Function used to not take the square root as the distances will stay proportionally the same. + This is to speed up the sorting process. + + :param p2:""" + return (self[0] - p2[0])**2 + (self[1] - p2[1])**2 + + def sort_by_distance(self, ps: Union[Units, Iterable[Point2]]) -> List[Point2]: + """This returns the target points sorted as list. + You should not pass a set or dict since those are not sortable. + If you want to sort your units towards a point, use 'units.sorted_by_distance_to(point)' instead. + + :param ps:""" + return sorted(ps, key=lambda p: self.distance_to_point2(p.position)) + + def closest(self, ps: Union[Units, Iterable[Point2]]) -> Union[Unit, Point2]: + """This function assumes the 2d distance is meant + + :param ps:""" + assert ps, "ps is empty" + # pylint: disable=W0108 + return min(ps, key=lambda p: self.distance_to(p)) + + def distance_to_closest(self, ps: Union[Units, Iterable[Point2]]) -> float: + """This function assumes the 2d distance is meant + :param ps:""" + assert ps, "ps is empty" + closest_distance = math.inf + for p2 in ps: + p2 = p2.position + distance = self.distance_to(p2) + if distance <= closest_distance: + closest_distance = distance + return closest_distance + + def furthest(self, ps: Union[Units, Iterable[Point2]]) -> Union[Unit, Pointlike]: + """This function assumes the 2d distance is meant + + :param ps: Units object, or iterable of Unit or Point2""" + assert ps, "ps is empty" + # pylint: disable=W0108 + return max(ps, key=lambda p: self.distance_to(p)) + + def distance_to_furthest(self, ps: Union[Units, Iterable[Point2]]) -> float: + """This function assumes the 2d distance is meant + + :param ps:""" + assert ps, "ps is empty" + furthest_distance = -math.inf + for p2 in ps: + p2 = p2.position + distance = self.distance_to(p2) + if distance >= furthest_distance: + furthest_distance = distance + return furthest_distance + + def offset(self, p) -> Pointlike: + """ + + :param p: + """ + return self.__class__(a + b for a, b in itertools.zip_longest(self, p[:len(self)], fillvalue=0)) + + def unit_axes_towards(self, p): + """ + + :param p: + """ + return self.__class__(_sign(b - a) for a, b in itertools.zip_longest(self, p[:len(self)], fillvalue=0)) + + def towards(self, p: Union[Unit, Pointlike], distance: Union[int, float] = 1, limit: bool = False) -> Pointlike: + """ + + :param p: + :param distance: + :param limit: + """ + p = p.position + # assert self != p, f"self is {self}, p is {p}" + # TODO test and fix this if statement + if self == p: + return self + # end of test + d = self.distance_to(p) + if limit: + distance = min(d, distance) + return self.__class__( + a + (b - a) / d * distance for a, b in itertools.zip_longest(self, p[:len(self)], fillvalue=0) + ) + + def __eq__(self, other): + try: + return all(abs(a - b) <= EPSILON for a, b in itertools.zip_longest(self, other, fillvalue=0)) + except TypeError: + return False + + def __hash__(self): + return hash(tuple(self)) + + +# pylint: disable=R0904 +class Point2(Pointlike): + + @classmethod + def from_proto(cls, data) -> Point2: + """ + :param data: + """ + return cls((data.x, data.y)) + + @property + def as_Point2D(self) -> common_pb.Point2D: + return common_pb.Point2D(x=self.x, y=self.y) + + @property + def as_PointI(self) -> common_pb.PointI: + """Represents points on the minimap. Values must be between 0 and 64.""" + return common_pb.PointI(x=self.x, y=self.y) + + @property + def rounded(self) -> Point2: + return Point2((math.floor(self[0]), math.floor(self[1]))) + + @property + def length(self) -> float: + """ This property exists in case Point2 is used as a vector. """ + return math.hypot(self[0], self[1]) + + @property + def normalized(self) -> Point2: + """ This property exists in case Point2 is used as a vector. """ + length = self.length + # Cannot normalize if length is zero + assert length + return self.__class__((self[0] / length, self[1] / length)) + + @property + def x(self) -> float: + return self[0] + + @property + def y(self) -> float: + return self[1] + + @property + def to2(self) -> Point2: + return Point2(self[:2]) + + @property + def to3(self) -> Point3: + return Point3((*self, 0)) + + def round(self, decimals: int) -> Point2: + """Rounds each number in the tuple to the amount of given decimals.""" + return Point2((round(self[0], decimals), round(self[1], decimals))) + + def offset(self, p: Point2) -> Point2: + return Point2((self[0] + p[0], self[1] + p[1])) + + def random_on_distance(self, distance) -> Point2: + if isinstance(distance, (tuple, list)): # interval + distance = distance[0] + random.random() * (distance[1] - distance[0]) + + assert distance > 0, "Distance is not greater than 0" + angle = random.random() * 2 * math.pi + + dx, dy = math.cos(angle), math.sin(angle) + return Point2((self.x + dx * distance, self.y + dy * distance)) + + def towards_with_random_angle( + self, + p: Union[Point2, Point3], + distance: Union[int, float] = 1, + max_difference: Union[int, float] = (math.pi / 4), + ) -> Point2: + tx, ty = self.to2.towards(p.to2, 1) + angle = math.atan2(ty - self.y, tx - self.x) + angle = (angle - max_difference) + max_difference * 2 * random.random() + return Point2((self.x + math.cos(angle) * distance, self.y + math.sin(angle) * distance)) + + def circle_intersection(self, p: Point2, r: Union[int, float]) -> Set[Point2]: + """self is point1, p is point2, r is the radius for circles originating in both points + Used in ramp finding + + :param p: + :param r:""" + assert self != p, "self is equal to p" + distanceBetweenPoints = self.distance_to(p) + assert r >= distanceBetweenPoints / 2 + # remaining distance from center towards the intersection, using pythagoras + remainingDistanceFromCenter = (r**2 - (distanceBetweenPoints / 2)**2)**0.5 + # center of both points + offsetToCenter = Point2(((p.x - self.x) / 2, (p.y - self.y) / 2)) + center = self.offset(offsetToCenter) + + # stretch offset vector in the ratio of remaining distance from center to intersection + vectorStretchFactor = remainingDistanceFromCenter / (distanceBetweenPoints / 2) + v = offsetToCenter + offsetToCenterStretched = Point2((v.x * vectorStretchFactor, v.y * vectorStretchFactor)) + + # rotate vector by 90° and -90° + vectorRotated1 = Point2((offsetToCenterStretched.y, -offsetToCenterStretched.x)) + vectorRotated2 = Point2((-offsetToCenterStretched.y, offsetToCenterStretched.x)) + intersect1 = center.offset(vectorRotated1) + intersect2 = center.offset(vectorRotated2) + return {intersect1, intersect2} + + @property + def neighbors4(self) -> set: + return { + Point2((self.x - 1, self.y)), + Point2((self.x + 1, self.y)), + Point2((self.x, self.y - 1)), + Point2((self.x, self.y + 1)), + } + + @property + def neighbors8(self) -> set: + return self.neighbors4 | { + Point2((self.x - 1, self.y - 1)), + Point2((self.x - 1, self.y + 1)), + Point2((self.x + 1, self.y - 1)), + Point2((self.x + 1, self.y + 1)), + } + + def negative_offset(self, other: Point2) -> Point2: + return self.__class__((self[0] - other[0], self[1] - other[1])) + + def __add__(self, other: Point2) -> Point2: + return self.offset(other) + + def __sub__(self, other: Point2) -> Point2: + return self.negative_offset(other) + + def __neg__(self) -> Point2: + return self.__class__(-a for a in self) + + def __abs__(self) -> float: + return math.hypot(self.x, self.y) + + def __bool__(self) -> bool: + if self.x != 0 or self.y != 0: + return True + return False + + def __mul__(self, other: Union[int, float, Point2]) -> Point2: + try: + return self.__class__((self.x * other.x, self.y * other.y)) + except AttributeError: + return self.__class__((self.x * other, self.y * other)) + + def __rmul__(self, other: Union[int, float, Point2]) -> Point2: + return self.__mul__(other) + + def __truediv__(self, other: Union[int, float, Point2]) -> Point2: + if isinstance(other, self.__class__): + return self.__class__((self.x / other.x, self.y / other.y)) + return self.__class__((self.x / other, self.y / other)) + + def is_same_as(self, other: Point2, dist=0.001) -> bool: + return self.distance_to_point2(other) <= dist + + def direction_vector(self, other: Point2) -> Point2: + """ Converts a vector to a direction that can face vertically, horizontally or diagonal or be zero, e.g. (0, 0), (1, -1), (1, 0) """ + return self.__class__((_sign(other.x - self.x), _sign(other.y - self.y))) + + def manhattan_distance(self, other: Point2) -> float: + """ + :param other: + """ + return abs(other.x - self.x) + abs(other.y - self.y) + + @staticmethod + def center(points: List[Point2]) -> Point2: + """Returns the central point for points in list + + :param points:""" + s = Point2((0, 0)) + for p in points: + s += p + return s / len(points) + + +class Point3(Point2): + + @classmethod + def from_proto(cls, data) -> Point3: + """ + :param data: + """ + return cls((data.x, data.y, data.z)) + + @property + def as_Point(self) -> common_pb.Point: + return common_pb.Point(x=self.x, y=self.y, z=self.z) + + @property + def rounded(self) -> Point3: + return Point3((math.floor(self[0]), math.floor(self[1]), math.floor(self[2]))) + + @property + def z(self) -> float: + return self[2] + + @property + def to3(self) -> Point3: + return Point3(self) + + def __add__(self, other: Union[Point2, Point3]) -> Point3: + if not isinstance(other, Point3) and isinstance(other, Point2): + return Point3((self.x + other.x, self.y + other.y, self.z)) + return Point3((self.x + other.x, self.y + other.y, self.z + other.z)) + + +class Size(Point2): + + @property + def width(self) -> float: + return self[0] + + @property + def height(self) -> float: + return self[1] + + +class Rect(tuple): + + @classmethod + def from_proto(cls, data): + """ + :param data: + """ + assert data.p0.x < data.p1.x and data.p0.y < data.p1.y + return cls((data.p0.x, data.p0.y, data.p1.x - data.p0.x, data.p1.y - data.p0.y)) + + @property + def x(self) -> float: + return self[0] + + @property + def y(self) -> float: + return self[1] + + @property + def width(self) -> float: + return self[2] + + @property + def height(self) -> float: + return self[3] + + @property + def right(self) -> float: + """ Returns the x-coordinate of the rectangle of its right side. """ + return self.x + self.width + + @property + def top(self) -> float: + """ Returns the y-coordinate of the rectangle of its top side. """ + return self.y + self.height + + @property + def size(self) -> Size: + return Size((self[2], self[3])) + + @property + def center(self) -> Point2: + return Point2((self.x + self.width / 2, self.y + self.height / 2)) + + def offset(self, p): + return self.__class__((self[0] + p[0], self[1] + p[1], self[2], self[3])) diff --git a/worlds/_sc2common/bot/power_source.py b/worlds/_sc2common/bot/power_source.py new file mode 100644 index 000000000000..232eccf3cd41 --- /dev/null +++ b/worlds/_sc2common/bot/power_source.py @@ -0,0 +1,36 @@ +from dataclasses import dataclass +from typing import List + +from .position import Point2 + + +@dataclass +class PowerSource: + position: Point2 + radius: float + unit_tag: int + + def __post_init__(self): + assert self.radius > 0 + + @classmethod + def from_proto(cls, proto): + return PowerSource(Point2.from_proto(proto.pos), proto.radius, proto.tag) + + def covers(self, position: Point2) -> bool: + return self.position.distance_to(position) <= self.radius + + def __repr__(self): + return f"PowerSource({self.position}, {self.radius})" + + +@dataclass +class PsionicMatrix: + sources: List[PowerSource] + + @classmethod + def from_proto(cls, proto): + return PsionicMatrix([PowerSource.from_proto(p) for p in proto]) + + def covers(self, position: Point2) -> bool: + return any(source.covers(position) for source in self.sources) diff --git a/worlds/_sc2common/bot/protocol.py b/worlds/_sc2common/bot/protocol.py new file mode 100644 index 000000000000..d2c48facb57d --- /dev/null +++ b/worlds/_sc2common/bot/protocol.py @@ -0,0 +1,87 @@ +import asyncio +import sys +from contextlib import suppress + +from aiohttp import ClientWebSocketResponse +from worlds._sc2common.bot import logger +from s2clientprotocol import sc2api_pb2 as sc_pb + +from .data import Status + + +class ProtocolError(Exception): + + @property + def is_game_over_error(self) -> bool: + return self.args[0] in ["['Game has already ended']", "['Not supported if game has already ended']"] + + +class ConnectionAlreadyClosed(ProtocolError): + pass + + +class Protocol: + + def __init__(self, ws): + """ + A class for communicating with an SCII application. + :param ws: the websocket (type: aiohttp.ClientWebSocketResponse) used to communicate with a specific SCII app + """ + assert ws + self._ws: ClientWebSocketResponse = ws + self._status: Status = None + + async def __request(self, request): + logger.debug(f"Sending request: {request !r}") + try: + await self._ws.send_bytes(request.SerializeToString()) + except TypeError as exc: + logger.exception("Cannot send: Connection already closed.") + raise ConnectionAlreadyClosed("Connection already closed.") from exc + logger.debug("Request sent") + + response = sc_pb.Response() + try: + response_bytes = await self._ws.receive_bytes() + except TypeError as exc: + if self._status == Status.ended: + logger.info("Cannot receive: Game has already ended.") + raise ConnectionAlreadyClosed("Game has already ended") from exc + logger.error("Cannot receive: Connection already closed.") + raise ConnectionAlreadyClosed("Connection already closed.") from exc + except asyncio.CancelledError: + # If request is sent, the response must be received before reraising cancel + try: + await self._ws.receive_bytes() + except asyncio.CancelledError: + logger.critical("Requests must not be cancelled multiple times") + sys.exit(2) + raise + + response.ParseFromString(response_bytes) + logger.debug("Response received") + return response + + async def _execute(self, **kwargs): + assert len(kwargs) == 1, "Only one request allowed by the API" + + response = await self.__request(sc_pb.Request(**kwargs)) + + new_status = Status(response.status) + if new_status != self._status: + logger.info(f"Client status changed to {new_status} (was {self._status})") + self._status = new_status + + if response.error: + logger.debug(f"Response contained an error: {response.error}") + raise ProtocolError(f"{response.error}") + + return response + + async def ping(self): + result = await self._execute(ping=sc_pb.RequestPing()) + return result + + async def quit(self): + with suppress(ConnectionAlreadyClosed, ConnectionResetError): + await self._execute(quit=sc_pb.RequestQuit()) diff --git a/worlds/_sc2common/bot/proxy.py b/worlds/_sc2common/bot/proxy.py new file mode 100644 index 000000000000..fa9f8537af56 --- /dev/null +++ b/worlds/_sc2common/bot/proxy.py @@ -0,0 +1,233 @@ +# pylint: disable=W0212 +import asyncio +import os +import platform +import subprocess +import time +import traceback + +from aiohttp import WSMsgType, web +from worlds._sc2common.bot import logger +from s2clientprotocol import sc2api_pb2 as sc_pb + +from .controller import Controller +from .data import Result, Status +from .player import BotProcess + + +class Proxy: + """ + Class for handling communication between sc2 and an external bot. + This "middleman" is needed for enforcing time limits, collecting results, and closing things properly. + """ + + def __init__( + self, + controller: Controller, + player: BotProcess, + proxyport: int, + game_time_limit: int = None, + realtime: bool = False, + ): + self.controller = controller + self.player = player + self.port = proxyport + self.timeout_loop = game_time_limit * 22.4 if game_time_limit else None + self.realtime = realtime + logger.debug( + f"Proxy Inited with ctrl {controller}({controller._process._port}), player {player}, proxyport {proxyport}, lim {game_time_limit}" + ) + + self.result = None + self.player_id: int = None + self.done = False + + async def parse_request(self, msg): + request = sc_pb.Request() + request.ParseFromString(msg.data) + if request.HasField("quit"): + request = sc_pb.Request(leave_game=sc_pb.RequestLeaveGame()) + if request.HasField("leave_game"): + if self.controller._status == Status.in_game: + logger.info(f"Proxy: player {self.player.name}({self.player_id}) surrenders") + self.result = {self.player_id: Result.Defeat} + elif self.controller._status == Status.ended: + await self.get_response() + elif request.HasField("join_game") and not request.join_game.HasField("player_name"): + request.join_game.player_name = self.player.name + await self.controller._ws.send_bytes(request.SerializeToString()) + + # TODO Catching too general exception Exception (broad-except) + # pylint: disable=W0703 + async def get_response(self): + response_bytes = None + try: + response_bytes = await self.controller._ws.receive_bytes() + except TypeError as e: + logger.exception("Cannot receive: SC2 Connection already closed.") + tb = traceback.format_exc() + logger.error(f"Exception {e}: {tb}") + except asyncio.CancelledError: + logger.info(f"Proxy({self.player.name}), caught receive from sc2") + try: + x = await self.controller._ws.receive_bytes() + if response_bytes is None: + response_bytes = x + except (asyncio.CancelledError, asyncio.TimeoutError, Exception) as e: + logger.exception(f"Exception {e}") + except Exception as e: + logger.exception(f"Caught unknown exception: {e}") + return response_bytes + + async def parse_response(self, response_bytes): + response = sc_pb.Response() + response.ParseFromString(response_bytes) + + if not response.HasField("status"): + logger.critical("Proxy: RESPONSE HAS NO STATUS {response}") + else: + new_status = Status(response.status) + if new_status != self.controller._status: + logger.info(f"Controller({self.player.name}): {self.controller._status}->{new_status}") + self.controller._status = new_status + + if self.player_id is None: + if response.HasField("join_game"): + self.player_id = response.join_game.player_id + logger.info(f"Proxy({self.player.name}): got join_game for {self.player_id}") + + if self.result is None: + if response.HasField("observation"): + obs: sc_pb.ResponseObservation = response.observation + if obs.player_result: + self.result = {pr.player_id: Result(pr.result) for pr in obs.player_result} + elif ( + self.timeout_loop and obs.HasField("observation") and obs.observation.game_loop > self.timeout_loop + ): + self.result = {i: Result.Tie for i in range(1, 3)} + logger.info(f"Proxy({self.player.name}) timing out") + act = [sc_pb.Action(action_chat=sc_pb.ActionChat(message="Proxy: Timing out"))] + await self.controller._execute(action=sc_pb.RequestAction(actions=act)) + return response + + async def get_result(self): + try: + res = await self.controller.ping() + if res.status in {Status.in_game, Status.in_replay, Status.ended}: + res = await self.controller._execute(observation=sc_pb.RequestObservation()) + if res.HasField("observation") and res.observation.player_result: + self.result = {pr.player_id: Result(pr.result) for pr in res.observation.player_result} + # pylint: disable=W0703 + # TODO Catching too general exception Exception (broad-except) + except Exception as e: + logger.exception(f"Caught unknown exception: {e}") + + async def proxy_handler(self, request): + bot_ws = web.WebSocketResponse(receive_timeout=30) + await bot_ws.prepare(request) + try: + async for msg in bot_ws: + if msg.data is None: + raise TypeError(f"data is None, {msg}") + if msg.data and msg.type == WSMsgType.BINARY: + + await self.parse_request(msg) + + response_bytes = await self.get_response() + if response_bytes is None: + raise ConnectionError("Could not get response_bytes") + + new_response = await self.parse_response(response_bytes) + await bot_ws.send_bytes(new_response.SerializeToString()) + + elif msg.type == WSMsgType.CLOSED: + logger.error("Client shutdown") + else: + logger.error("Incorrect message type") + # pylint: disable=W0703 + # TODO Catching too general exception Exception (broad-except) + except Exception as e: + logger.exception(f"Caught unknown exception: {e}") + ignored_errors = {ConnectionError, asyncio.CancelledError} + if not any(isinstance(e, E) for E in ignored_errors): + tb = traceback.format_exc() + logger.info(f"Proxy({self.player.name}): Caught {e} traceback: {tb}") + finally: + try: + if self.controller._status in {Status.in_game, Status.in_replay}: + await self.controller._execute(leave_game=sc_pb.RequestLeaveGame()) + await bot_ws.close() + # pylint: disable=W0703 + # TODO Catching too general exception Exception (broad-except) + except Exception as e: + logger.exception(f"Caught unknown exception during surrender: {e}") + self.done = True + return bot_ws + + # pylint: disable=R0912 + async def play_with_proxy(self, startport): + logger.info(f"Proxy({self.port}): Starting app") + app = web.Application() + app.router.add_route("GET", "/sc2api", self.proxy_handler) + apprunner = web.AppRunner(app, access_log=None) + await apprunner.setup() + appsite = web.TCPSite(apprunner, self.controller._process._host, self.port) + await appsite.start() + + subproc_args = {"cwd": str(self.player.path), "stderr": subprocess.STDOUT} + if platform.system() == "Linux": + subproc_args["preexec_fn"] = os.setpgrp + elif platform.system() == "Windows": + subproc_args["creationflags"] = subprocess.CREATE_NEW_PROCESS_GROUP + + player_command_line = self.player.cmd_line(self.port, startport, self.controller._process._host, self.realtime) + logger.info(f"Starting bot with command: {' '.join(player_command_line)}") + if self.player.stdout is None: + bot_process = subprocess.Popen(player_command_line, stdout=subprocess.DEVNULL, **subproc_args) + else: + with open(self.player.stdout, "w+") as out: + bot_process = subprocess.Popen(player_command_line, stdout=out, **subproc_args) + + while self.result is None: + bot_alive = bot_process and bot_process.poll() is None + sc2_alive = self.controller.running + if self.done or not (bot_alive and sc2_alive): + logger.info( + f"Proxy({self.port}): {self.player.name} died, " + f"bot{(not bot_alive) * ' not'} alive, sc2{(not sc2_alive) * ' not'} alive" + ) + # Maybe its still possible to retrieve a result + if sc2_alive and not self.done: + await self.get_response() + logger.info(f"Proxy({self.port}): breaking, result {self.result}") + break + await asyncio.sleep(5) + + # cleanup + logger.info(f"({self.port}): cleaning up {self.player !r}") + for _i in range(3): + if isinstance(bot_process, subprocess.Popen): + if bot_process.stdout and not bot_process.stdout.closed: # should not run anymore + logger.info(f"==================output for player {self.player.name}") + for l in bot_process.stdout.readlines(): + logger.opt(raw=True).info(l.decode("utf-8")) + bot_process.stdout.close() + logger.info("==================") + bot_process.terminate() + bot_process.wait() + time.sleep(0.5) + if not bot_process or bot_process.poll() is not None: + break + else: + bot_process.terminate() + bot_process.wait() + try: + await apprunner.cleanup() + # pylint: disable=W0703 + # TODO Catching too general exception Exception (broad-except) + except Exception as e: + logger.exception(f"Caught unknown exception during cleaning: {e}") + if isinstance(self.result, dict): + self.result[None] = None + return self.result[self.player_id] + return self.result diff --git a/worlds/_sc2common/bot/renderer.py b/worlds/_sc2common/bot/renderer.py new file mode 100644 index 000000000000..2edc75d8441c --- /dev/null +++ b/worlds/_sc2common/bot/renderer.py @@ -0,0 +1,154 @@ +import datetime + +from s2clientprotocol import score_pb2 as score_pb + +from .position import Point2 + + +class Renderer: + + def __init__(self, client, map_size, minimap_size): + self._client = client + + self._window = None + self._map_size = map_size + self._map_image = None + self._minimap_size = minimap_size + self._minimap_image = None + self._mouse_x, self._mouse_y = None, None + self._text_supply = None + self._text_vespene = None + self._text_minerals = None + self._text_score = None + self._text_time = None + + async def render(self, observation): + render_data = observation.observation.render_data + + map_size = render_data.map.size + map_data = render_data.map.data + minimap_size = render_data.minimap.size + minimap_data = render_data.minimap.data + + map_width, map_height = map_size.x, map_size.y + map_pitch = -map_width * 3 + + minimap_width, minimap_height = minimap_size.x, minimap_size.y + minimap_pitch = -minimap_width * 3 + + if not self._window: + # pylint: disable=C0415 + from pyglet.image import ImageData + from pyglet.text import Label + from pyglet.window import Window + + self._window = Window(width=map_width, height=map_height) + self._window.on_mouse_press = self._on_mouse_press + self._window.on_mouse_release = self._on_mouse_release + self._window.on_mouse_drag = self._on_mouse_drag + self._map_image = ImageData(map_width, map_height, "RGB", map_data, map_pitch) + self._minimap_image = ImageData(minimap_width, minimap_height, "RGB", minimap_data, minimap_pitch) + self._text_supply = Label( + "", + font_name="Arial", + font_size=16, + anchor_x="right", + anchor_y="top", + x=self._map_size[0] - 10, + y=self._map_size[1] - 10, + color=(200, 200, 200, 255), + ) + self._text_vespene = Label( + "", + font_name="Arial", + font_size=16, + anchor_x="right", + anchor_y="top", + x=self._map_size[0] - 130, + y=self._map_size[1] - 10, + color=(28, 160, 16, 255), + ) + self._text_minerals = Label( + "", + font_name="Arial", + font_size=16, + anchor_x="right", + anchor_y="top", + x=self._map_size[0] - 200, + y=self._map_size[1] - 10, + color=(68, 140, 255, 255), + ) + self._text_score = Label( + "", + font_name="Arial", + font_size=16, + anchor_x="left", + anchor_y="top", + x=10, + y=self._map_size[1] - 10, + color=(219, 30, 30, 255), + ) + self._text_time = Label( + "", + font_name="Arial", + font_size=16, + anchor_x="right", + anchor_y="bottom", + x=self._minimap_size[0] - 10, + y=self._minimap_size[1] + 10, + color=(255, 255, 255, 255), + ) + else: + self._map_image.set_data("RGB", map_pitch, map_data) + self._minimap_image.set_data("RGB", minimap_pitch, minimap_data) + self._text_time.text = str(datetime.timedelta(seconds=(observation.observation.game_loop * 0.725) // 16)) + if observation.observation.HasField("player_common"): + self._text_supply.text = f"{observation.observation.player_common.food_used} / {observation.observation.player_common.food_cap}" + self._text_vespene.text = str(observation.observation.player_common.vespene) + self._text_minerals.text = str(observation.observation.player_common.minerals) + if observation.observation.HasField("score"): + # pylint: disable=W0212 + self._text_score.text = f"{score_pb._SCORE_SCORETYPE.values_by_number[observation.observation.score.score_type].name} score: {observation.observation.score.score}" + + await self._update_window() + + if self._client.in_game and (not observation.player_result) and self._mouse_x and self._mouse_y: + await self._client.move_camera_spatial(Point2((self._mouse_x, self._minimap_size[0] - self._mouse_y))) + self._mouse_x, self._mouse_y = None, None + + async def _update_window(self): + self._window.switch_to() + self._window.dispatch_events() + + self._window.clear() + + self._map_image.blit(0, 0) + self._minimap_image.blit(0, 0) + self._text_time.draw() + self._text_score.draw() + self._text_minerals.draw() + self._text_vespene.draw() + self._text_supply.draw() + + self._window.flip() + + def _on_mouse_press(self, x, y, button, _modifiers): + if button != 1: # 1: mouse.LEFT + return + if x > self._minimap_size[0] or y > self._minimap_size[1]: + return + self._mouse_x, self._mouse_y = x, y + + def _on_mouse_release(self, x, y, button, _modifiers): + if button != 1: # 1: mouse.LEFT + return + if x > self._minimap_size[0] or y > self._minimap_size[1]: + return + self._mouse_x, self._mouse_y = x, y + + def _on_mouse_drag(self, x, y, _dx, _dy, buttons, _modifiers): + if not buttons & 1: # 1: mouse.LEFT + return + if x > self._minimap_size[0] or y > self._minimap_size[1]: + return + self._mouse_x, self._mouse_y = x, y diff --git a/worlds/_sc2common/bot/sc2process.py b/worlds/_sc2common/bot/sc2process.py new file mode 100644 index 000000000000..e36632165979 --- /dev/null +++ b/worlds/_sc2common/bot/sc2process.py @@ -0,0 +1,275 @@ +import asyncio +import os +import os.path +import shutil +import signal +import subprocess +import sys +import tempfile +import time +from contextlib import suppress +from typing import Any, Dict, List, Optional, Tuple, Union + +import aiohttp +import portpicker +from worlds._sc2common.bot import logger + +from . import paths, wsl +from .controller import Controller +from .paths import Paths +from .versions import VERSIONS + + +class kill_switch: + _to_kill: List[Any] = [] + + @classmethod + def add(cls, value): + logger.debug("kill_switch: Add switch") + cls._to_kill.append(value) + + @classmethod + def kill_all(cls): + logger.info(f"kill_switch: Process cleanup for {len(cls._to_kill)} processes") + for p in cls._to_kill: + # pylint: disable=W0212 + p._clean(verbose=False) + + +class SC2Process: + """ + A class for handling SCII applications. + + :param host: hostname for the url the SCII application will listen to + :param port: the websocket port the SCII application will listen to + :param fullscreen: whether to launch the SCII application in fullscreen or not, defaults to False + :param resolution: (window width, window height) in pixels, defaults to (1024, 768) + :param placement: (x, y) the distances of the SCII app's top left corner from the top left corner of the screen + e.g. (20, 30) is 20 to the right of the screen's left border, and 30 below the top border + :param render: + :param sc2_version: + :param base_build: + :param data_hash: + """ + + def __init__( + self, + host: Optional[str] = None, + port: Optional[int] = None, + fullscreen: bool = False, + resolution: Optional[Union[List[int], Tuple[int, int]]] = None, + placement: Optional[Union[List[int], Tuple[int, int]]] = None, + render: bool = False, + sc2_version: str = None, + base_build: str = None, + data_hash: str = None, + ) -> None: + assert isinstance(host, str) or host is None + assert isinstance(port, int) or port is None + + self._render = render + self._arguments: Dict[str, str] = {"-displayMode": str(int(fullscreen))} + if not fullscreen: + if resolution and len(resolution) == 2: + self._arguments["-windowwidth"] = str(resolution[0]) + self._arguments["-windowheight"] = str(resolution[1]) + if placement and len(placement) == 2: + self._arguments["-windowx"] = str(placement[0]) + self._arguments["-windowy"] = str(placement[1]) + + self._host = host or os.environ.get("SC2CLIENTHOST", "127.0.0.1") + self._serverhost = os.environ.get("SC2SERVERHOST", self._host) + + if port is None: + self._port = portpicker.pick_unused_port() + else: + self._port = port + self._used_portpicker = bool(port is None) + self._tmp_dir = tempfile.mkdtemp(prefix="SC2_") + self._process: subprocess = None + self._session = None + self._ws = None + self._sc2_version = sc2_version + self._base_build = base_build + self._data_hash = data_hash + + async def __aenter__(self) -> Controller: + kill_switch.add(self) + + def signal_handler(*_args): + # unused arguments: signal handling library expects all signal + # callback handlers to accept two positional arguments + kill_switch.kill_all() + + signal.signal(signal.SIGINT, signal_handler) + + try: + self._process = self._launch() + self._ws = await self._connect() + except: + await self._close_connection() + self._clean() + raise + + return Controller(self._ws, self) + + async def __aexit__(self, *args): + logger.exception("async exit") + await self._close_connection() + kill_switch.kill_all() + signal.signal(signal.SIGINT, signal.SIG_DFL) + + @property + def ws_url(self): + return f"ws://{self._host}:{self._port}/sc2api" + + @property + def versions(self): + """Opens the versions.json file which origins from + https://github.com/Blizzard/s2client-proto/blob/master/buildinfo/versions.json""" + return VERSIONS + + def find_data_hash(self, target_sc2_version: str) -> Optional[str]: + """ Returns the data hash from the matching version string. """ + version: dict + for version in self.versions: + if version["label"] == target_sc2_version: + return version["data-hash"] + return None + + def _launch(self): + if self._base_build: + executable = str(paths.latest_executeble(Paths.BASE / "Versions", self._base_build)) + else: + executable = str(Paths.EXECUTABLE) + if self._port is None: + self._port = portpicker.pick_unused_port() + self._used_portpicker = True + args = paths.get_runner_args(Paths.CWD) + [ + executable, + "-listen", + self._serverhost, + "-port", + str(self._port), + "-dataDir", + str(Paths.BASE), + "-tempDir", + self._tmp_dir, + ] + for arg, value in self._arguments.items(): + args.append(arg) + args.append(value) + if self._sc2_version: + + def special_match(strg: str): + """ Tests if the specified version is in the versions.py dict. """ + for version in self.versions: + if version["label"] == strg: + return True + return False + + valid_version_string = special_match(self._sc2_version) + if valid_version_string: + self._data_hash = self.find_data_hash(self._sc2_version) + assert ( + self._data_hash is not None + ), f"StarCraft 2 Client version ({self._sc2_version}) was not found inside sc2/versions.py file. Please check your spelling or check the versions.py file." + + else: + logger.warning( + f'The submitted version string in sc2.rungame() function call (sc2_version="{self._sc2_version}") was not found in versions.py. Running latest version instead.' + ) + + if self._data_hash: + args.extend(["-dataVersion", self._data_hash]) + + if self._render: + args.extend(["-eglpath", "libEGL.so"]) + + # if logger.getEffectiveLevel() <= logging.DEBUG: + args.append("-verbose") + + sc2_cwd = str(Paths.CWD) if Paths.CWD else None + + if paths.PF in {"WSL1", "WSL2"}: + return wsl.run(args, sc2_cwd) + + return subprocess.Popen( + args, + cwd=sc2_cwd, + # Suppress Wine error messages + stderr=subprocess.DEVNULL + # , env=run_config.env + ) + + async def _connect(self): + # How long it waits for SC2 to start (in seconds) + for i in range(180): + if self._process is None: + # The ._clean() was called, clearing the process + logger.debug("Process cleanup complete, exit") + sys.exit() + + await asyncio.sleep(1) + try: + self._session = aiohttp.ClientSession() + ws = await self._session.ws_connect(self.ws_url, timeout=120) + # FIXME fix deprecation warning in for future aiohttp version + # ws = await self._session.ws_connect( + # self.ws_url, timeout=aiohttp.client_ws.ClientWSTimeout(ws_close=120) + # ) + logger.debug("Websocket connection ready") + return ws + except aiohttp.client_exceptions.ClientConnectorError: + await self._session.close() + if i > 15: + logger.debug("Connection refused (startup not complete (yet))") + + logger.debug("Websocket connection to SC2 process timed out") + raise TimeoutError("Websocket") + + async def _close_connection(self): + logger.info(f"Closing connection at {self._port}...") + + if self._ws is not None: + await self._ws.close() + + if self._session is not None: + await self._session.close() + + # pylint: disable=R0912 + def _clean(self, verbose=True): + if verbose: + logger.info("Cleaning up...") + + if self._process is not None: + if paths.PF in {"WSL1", "WSL2"}: + if wsl.kill(self._process): + logger.error("KILLED") + elif self._process.poll() is None: + for _ in range(3): + self._process.terminate() + time.sleep(0.5) + if not self._process or self._process.poll() is not None: + break + else: + self._process.kill() + self._process.wait() + logger.error("KILLED") + # Try to kill wineserver on linux + if paths.PF in {"Linux", "WineLinux"}: + # Command wineserver not detected + with suppress(FileNotFoundError): + with subprocess.Popen(["wineserver", "-k"]) as p: + p.wait() + + if os.path.exists(self._tmp_dir): + shutil.rmtree(self._tmp_dir) + + self._process = None + self._ws = None + if self._used_portpicker and self._port is not None: + portpicker.return_port(self._port) + self._port = None + if verbose: + logger.info("Cleanup complete") diff --git a/worlds/_sc2common/bot/score.py b/worlds/_sc2common/bot/score.py new file mode 100644 index 000000000000..808ee938e878 --- /dev/null +++ b/worlds/_sc2common/bot/score.py @@ -0,0 +1,424 @@ +# pylint: disable=R0904 +class ScoreDetails: + """Accessable in self.state.score during step function + For more information, see https://github.com/Blizzard/s2client-proto/blob/master/s2clientprotocol/score.proto + """ + + def __init__(self, proto): + self._data = proto + self._proto = proto.score_details + + @property + def summary(self): + """ + TODO this is super ugly, how can we improve this summary? + Print summary to file with: + In on_step: + + with open("stats.txt", "w+") as file: + for stat in self.state.score.summary: + file.write(f"{stat[0]:<35} {float(stat[1]):>35.3f}\n") + """ + values = [ + "score_type", + "score", + "idle_production_time", + "idle_worker_time", + "total_value_units", + "total_value_structures", + "killed_value_units", + "killed_value_structures", + "collected_minerals", + "collected_vespene", + "collection_rate_minerals", + "collection_rate_vespene", + "spent_minerals", + "spent_vespene", + "food_used_none", + "food_used_army", + "food_used_economy", + "food_used_technology", + "food_used_upgrade", + "killed_minerals_none", + "killed_minerals_army", + "killed_minerals_economy", + "killed_minerals_technology", + "killed_minerals_upgrade", + "killed_vespene_none", + "killed_vespene_army", + "killed_vespene_economy", + "killed_vespene_technology", + "killed_vespene_upgrade", + "lost_minerals_none", + "lost_minerals_army", + "lost_minerals_economy", + "lost_minerals_technology", + "lost_minerals_upgrade", + "lost_vespene_none", + "lost_vespene_army", + "lost_vespene_economy", + "lost_vespene_technology", + "lost_vespene_upgrade", + "friendly_fire_minerals_none", + "friendly_fire_minerals_army", + "friendly_fire_minerals_economy", + "friendly_fire_minerals_technology", + "friendly_fire_minerals_upgrade", + "friendly_fire_vespene_none", + "friendly_fire_vespene_army", + "friendly_fire_vespene_economy", + "friendly_fire_vespene_technology", + "friendly_fire_vespene_upgrade", + "used_minerals_none", + "used_minerals_army", + "used_minerals_economy", + "used_minerals_technology", + "used_minerals_upgrade", + "used_vespene_none", + "used_vespene_army", + "used_vespene_economy", + "used_vespene_technology", + "used_vespene_upgrade", + "total_used_minerals_none", + "total_used_minerals_army", + "total_used_minerals_economy", + "total_used_minerals_technology", + "total_used_minerals_upgrade", + "total_used_vespene_none", + "total_used_vespene_army", + "total_used_vespene_economy", + "total_used_vespene_technology", + "total_used_vespene_upgrade", + "total_damage_dealt_life", + "total_damage_dealt_shields", + "total_damage_dealt_energy", + "total_damage_taken_life", + "total_damage_taken_shields", + "total_damage_taken_energy", + "total_healed_life", + "total_healed_shields", + "total_healed_energy", + "current_apm", + "current_effective_apm", + ] + return [[value, getattr(self, value)] for value in values] + + @property + def score_type(self): + return self._data.score_type + + @property + def score(self): + return self._data.score + + @property + def idle_production_time(self): + return self._proto.idle_production_time + + @property + def idle_worker_time(self): + return self._proto.idle_worker_time + + @property + def total_value_units(self): + return self._proto.total_value_units + + @property + def total_value_structures(self): + return self._proto.total_value_structures + + @property + def killed_value_units(self): + return self._proto.killed_value_units + + @property + def killed_value_structures(self): + return self._proto.killed_value_structures + + @property + def collected_minerals(self): + return self._proto.collected_minerals + + @property + def collected_vespene(self): + return self._proto.collected_vespene + + @property + def collection_rate_minerals(self): + return self._proto.collection_rate_minerals + + @property + def collection_rate_vespene(self): + return self._proto.collection_rate_vespene + + @property + def spent_minerals(self): + return self._proto.spent_minerals + + @property + def spent_vespene(self): + return self._proto.spent_vespene + + @property + def food_used_none(self): + return self._proto.food_used.none + + @property + def food_used_army(self): + return self._proto.food_used.army + + @property + def food_used_economy(self): + return self._proto.food_used.economy + + @property + def food_used_technology(self): + return self._proto.food_used.technology + + @property + def food_used_upgrade(self): + return self._proto.food_used.upgrade + + @property + def killed_minerals_none(self): + return self._proto.killed_minerals.none + + @property + def killed_minerals_army(self): + return self._proto.killed_minerals.army + + @property + def killed_minerals_economy(self): + return self._proto.killed_minerals.economy + + @property + def killed_minerals_technology(self): + return self._proto.killed_minerals.technology + + @property + def killed_minerals_upgrade(self): + return self._proto.killed_minerals.upgrade + + @property + def killed_vespene_none(self): + return self._proto.killed_vespene.none + + @property + def killed_vespene_army(self): + return self._proto.killed_vespene.army + + @property + def killed_vespene_economy(self): + return self._proto.killed_vespene.economy + + @property + def killed_vespene_technology(self): + return self._proto.killed_vespene.technology + + @property + def killed_vespene_upgrade(self): + return self._proto.killed_vespene.upgrade + + @property + def lost_minerals_none(self): + return self._proto.lost_minerals.none + + @property + def lost_minerals_army(self): + return self._proto.lost_minerals.army + + @property + def lost_minerals_economy(self): + return self._proto.lost_minerals.economy + + @property + def lost_minerals_technology(self): + return self._proto.lost_minerals.technology + + @property + def lost_minerals_upgrade(self): + return self._proto.lost_minerals.upgrade + + @property + def lost_vespene_none(self): + return self._proto.lost_vespene.none + + @property + def lost_vespene_army(self): + return self._proto.lost_vespene.army + + @property + def lost_vespene_economy(self): + return self._proto.lost_vespene.economy + + @property + def lost_vespene_technology(self): + return self._proto.lost_vespene.technology + + @property + def lost_vespene_upgrade(self): + return self._proto.lost_vespene.upgrade + + @property + def friendly_fire_minerals_none(self): + return self._proto.friendly_fire_minerals.none + + @property + def friendly_fire_minerals_army(self): + return self._proto.friendly_fire_minerals.army + + @property + def friendly_fire_minerals_economy(self): + return self._proto.friendly_fire_minerals.economy + + @property + def friendly_fire_minerals_technology(self): + return self._proto.friendly_fire_minerals.technology + + @property + def friendly_fire_minerals_upgrade(self): + return self._proto.friendly_fire_minerals.upgrade + + @property + def friendly_fire_vespene_none(self): + return self._proto.friendly_fire_vespene.none + + @property + def friendly_fire_vespene_army(self): + return self._proto.friendly_fire_vespene.army + + @property + def friendly_fire_vespene_economy(self): + return self._proto.friendly_fire_vespene.economy + + @property + def friendly_fire_vespene_technology(self): + return self._proto.friendly_fire_vespene.technology + + @property + def friendly_fire_vespene_upgrade(self): + return self._proto.friendly_fire_vespene.upgrade + + @property + def used_minerals_none(self): + return self._proto.used_minerals.none + + @property + def used_minerals_army(self): + return self._proto.used_minerals.army + + @property + def used_minerals_economy(self): + return self._proto.used_minerals.economy + + @property + def used_minerals_technology(self): + return self._proto.used_minerals.technology + + @property + def used_minerals_upgrade(self): + return self._proto.used_minerals.upgrade + + @property + def used_vespene_none(self): + return self._proto.used_vespene.none + + @property + def used_vespene_army(self): + return self._proto.used_vespene.army + + @property + def used_vespene_economy(self): + return self._proto.used_vespene.economy + + @property + def used_vespene_technology(self): + return self._proto.used_vespene.technology + + @property + def used_vespene_upgrade(self): + return self._proto.used_vespene.upgrade + + @property + def total_used_minerals_none(self): + return self._proto.total_used_minerals.none + + @property + def total_used_minerals_army(self): + return self._proto.total_used_minerals.army + + @property + def total_used_minerals_economy(self): + return self._proto.total_used_minerals.economy + + @property + def total_used_minerals_technology(self): + return self._proto.total_used_minerals.technology + + @property + def total_used_minerals_upgrade(self): + return self._proto.total_used_minerals.upgrade + + @property + def total_used_vespene_none(self): + return self._proto.total_used_vespene.none + + @property + def total_used_vespene_army(self): + return self._proto.total_used_vespene.army + + @property + def total_used_vespene_economy(self): + return self._proto.total_used_vespene.economy + + @property + def total_used_vespene_technology(self): + return self._proto.total_used_vespene.technology + + @property + def total_used_vespene_upgrade(self): + return self._proto.total_used_vespene.upgrade + + @property + def total_damage_dealt_life(self): + return self._proto.total_damage_dealt.life + + @property + def total_damage_dealt_shields(self): + return self._proto.total_damage_dealt.shields + + @property + def total_damage_dealt_energy(self): + return self._proto.total_damage_dealt.energy + + @property + def total_damage_taken_life(self): + return self._proto.total_damage_taken.life + + @property + def total_damage_taken_shields(self): + return self._proto.total_damage_taken.shields + + @property + def total_damage_taken_energy(self): + return self._proto.total_damage_taken.energy + + @property + def total_healed_life(self): + return self._proto.total_healed.life + + @property + def total_healed_shields(self): + return self._proto.total_healed.shields + + @property + def total_healed_energy(self): + return self._proto.total_healed.energy + + @property + def current_apm(self): + return self._proto.current_apm + + @property + def current_effective_apm(self): + return self._proto.current_effective_apm diff --git a/worlds/_sc2common/bot/unit.py b/worlds/_sc2common/bot/unit.py new file mode 100644 index 000000000000..dd638a7b284e --- /dev/null +++ b/worlds/_sc2common/bot/unit.py @@ -0,0 +1,692 @@ +# pylint: disable=W0212 +from __future__ import annotations + +import math +from dataclasses import dataclass +from functools import cached_property +from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union + +from .cache import CacheDict +from .constants import ( + CAN_BE_ATTACKED, + IS_ARMORED, + IS_BIOLOGICAL, + IS_CLOAKED, + IS_ENEMY, + IS_LIGHT, + IS_MASSIVE, + IS_MECHANICAL, + IS_MINE, + IS_PLACEHOLDER, + IS_PSIONIC, + IS_REVEALED, + IS_SNAPSHOT, + IS_STRUCTURE, + IS_VISIBLE, +) +from .data import Alliance, Attribute, CloakState, Race +from .position import Point2, Point3 + +if TYPE_CHECKING: + from .bot_ai import BotAI + from .game_data import AbilityData, UnitTypeData + + +@dataclass +class RallyTarget: + point: Point2 + tag: Optional[int] = None + + @classmethod + def from_proto(cls, proto: Any) -> RallyTarget: + return cls( + Point2.from_proto(proto.point), + proto.tag if proto.HasField("tag") else None, + ) + + +@dataclass +class UnitOrder: + ability: AbilityData # TODO: Should this be AbilityId instead? + target: Optional[Union[int, Point2]] = None + progress: float = 0 + + @classmethod + def from_proto(cls, proto: Any, bot_object: BotAI) -> UnitOrder: + target: Optional[Union[int, Point2]] = proto.target_unit_tag + if proto.HasField("target_world_space_pos"): + target = Point2.from_proto(proto.target_world_space_pos) + elif proto.HasField("target_unit_tag"): + target = proto.target_unit_tag + return cls( + ability=bot_object.game_data.abilities[proto.ability_id], + target=target, + progress=proto.progress, + ) + + def __repr__(self) -> str: + return f"UnitOrder({self.ability}, {self.target}, {self.progress})" + + +# pylint: disable=R0904 +class Unit: + class_cache = CacheDict() + + def __init__( + self, + proto_data, + bot_object: BotAI, + distance_calculation_index: int = -1, + base_build: int = -1, + ): + """ + :param proto_data: + :param bot_object: + :param distance_calculation_index: + :param base_build: + """ + self._proto = proto_data + self._bot_object: BotAI = bot_object + self.game_loop: int = bot_object.state.game_loop + self.base_build = base_build + # Index used in the 2D numpy array to access the 2D distance between two units + self.distance_calculation_index: int = distance_calculation_index + + def __repr__(self) -> str: + """ Returns string of this form: Unit(name='SCV', tag=4396941328). """ + return f"Unit(name={self.name !r}, tag={self.tag})" + + @cached_property + def _type_data(self) -> UnitTypeData: + """ Provides the unit type data. """ + return self._bot_object.game_data.units[self._proto.unit_type] + + @cached_property + def _creation_ability(self) -> AbilityData: + """ Provides the AbilityData of the creation ability of this unit. """ + return self._type_data.creation_ability + + @property + def name(self) -> str: + """ Returns the name of the unit. """ + return self._type_data.name + + @cached_property + def race(self) -> Race: + """ Returns the race of the unit """ + return Race(self._type_data._proto.race) + + @property + def tag(self) -> int: + """ Returns the unique tag of the unit. """ + return self._proto.tag + + @property + def is_structure(self) -> bool: + """ Checks if the unit is a structure. """ + return IS_STRUCTURE in self._type_data.attributes + + @property + def is_light(self) -> bool: + """ Checks if the unit has the 'light' attribute. """ + return IS_LIGHT in self._type_data.attributes + + @property + def is_armored(self) -> bool: + """ Checks if the unit has the 'armored' attribute. """ + return IS_ARMORED in self._type_data.attributes + + @property + def is_biological(self) -> bool: + """ Checks if the unit has the 'biological' attribute. """ + return IS_BIOLOGICAL in self._type_data.attributes + + @property + def is_mechanical(self) -> bool: + """ Checks if the unit has the 'mechanical' attribute. """ + return IS_MECHANICAL in self._type_data.attributes + + @property + def is_massive(self) -> bool: + """ Checks if the unit has the 'massive' attribute. """ + return IS_MASSIVE in self._type_data.attributes + + @property + def is_psionic(self) -> bool: + """ Checks if the unit has the 'psionic' attribute. """ + return IS_PSIONIC in self._type_data.attributes + + @cached_property + def _weapons(self): + """ Returns the weapons of the unit. """ + return self._type_data._proto.weapons + + @cached_property + def bonus_damage(self) -> Optional[Tuple[int, str]]: + """Returns a tuple of form '(bonus damage, armor type)' if unit does 'bonus damage' against 'armor type'. + Possible armor typs are: 'Light', 'Armored', 'Biological', 'Mechanical', 'Psionic', 'Massive', 'Structure'.""" + # TODO: Consider units with ability attacks (Oracle, Baneling) or multiple attacks (Thor). + if self._weapons: + for weapon in self._weapons: + if weapon.damage_bonus: + b = weapon.damage_bonus[0] + return b.bonus, Attribute(b.attribute).name + return None + + @property + def armor(self) -> float: + """ Returns the armor of the unit. Does not include upgrades """ + return self._type_data._proto.armor + + @property + def sight_range(self) -> float: + """ Returns the sight range of the unit. """ + return self._type_data._proto.sight_range + + @property + def movement_speed(self) -> float: + """Returns the movement speed of the unit. + This is the unit movement speed on game speed 'normal'. To convert it to 'faster' movement speed, multiply it by a factor of '1.4'. E.g. reaper movement speed is listed here as 3.75, but should actually be 5.25. + Does not include upgrades or buffs.""" + return self._type_data._proto.movement_speed + + @property + def is_mineral_field(self) -> bool: + """ Checks if the unit is a mineral field. """ + return self._type_data.has_minerals + + @property + def is_vespene_geyser(self) -> bool: + """ Checks if the unit is a non-empty vespene geyser or gas extraction building. """ + return self._type_data.has_vespene + + @property + def health(self) -> float: + """ Returns the health of the unit. Does not include shields. """ + return self._proto.health + + @property + def health_max(self) -> float: + """ Returns the maximum health of the unit. Does not include shields. """ + return self._proto.health_max + + @cached_property + def health_percentage(self) -> float: + """ Returns the percentage of health the unit has. Does not include shields. """ + if not self._proto.health_max: + return 0 + return self._proto.health / self._proto.health_max + + @property + def shield(self) -> float: + """ Returns the shield points the unit has. Returns 0 for non-protoss units. """ + return self._proto.shield + + @property + def shield_max(self) -> float: + """ Returns the maximum shield points the unit can have. Returns 0 for non-protoss units. """ + return self._proto.shield_max + + @cached_property + def shield_percentage(self) -> float: + """ Returns the percentage of shield points the unit has. Returns 0 for non-protoss units. """ + if not self._proto.shield_max: + return 0 + return self._proto.shield / self._proto.shield_max + + @cached_property + def shield_health_percentage(self) -> float: + """Returns the percentage of combined shield + hp points the unit has. + Also takes build progress into account.""" + max_ = (self._proto.shield_max + self._proto.health_max) * self.build_progress + if max_ == 0: + return 0 + return (self._proto.shield + self._proto.health) / max_ + + @property + def energy(self) -> float: + """ Returns the amount of energy the unit has. Returns 0 for units without energy. """ + return self._proto.energy + + @property + def energy_max(self) -> float: + """ Returns the maximum amount of energy the unit can have. Returns 0 for units without energy. """ + return self._proto.energy_max + + @cached_property + def energy_percentage(self) -> float: + """ Returns the percentage of amount of energy the unit has. Returns 0 for units without energy. """ + if not self._proto.energy_max: + return 0 + return self._proto.energy / self._proto.energy_max + + @property + def age_in_frames(self) -> int: + """ Returns how old the unit object data is (in game frames). This age does not reflect the unit was created / trained / morphed! """ + return self._bot_object.state.game_loop - self.game_loop + + @property + def age(self) -> float: + """ Returns how old the unit object data is (in game seconds). This age does not reflect when the unit was created / trained / morphed! """ + return (self._bot_object.state.game_loop - self.game_loop) / 22.4 + + @property + def is_memory(self) -> bool: + """ Returns True if this Unit object is referenced from the future and is outdated. """ + return self.game_loop != self._bot_object.state.game_loop + + @cached_property + def is_snapshot(self) -> bool: + """Checks if the unit is only available as a snapshot for the bot. + Enemy buildings that have been scouted and are in the fog of war or + attacking enemy units on higher, not visible ground appear this way.""" + if self.base_build >= 82457: + return self._proto.display_type == IS_SNAPSHOT + # TODO: Fixed in version 5.0.4, remove if a new linux binary is released: https://github.com/Blizzard/s2client-proto/issues/167 + position = self.position.rounded + return self._bot_object.state.visibility.data_numpy[position[1], position[0]] != 2 + + @cached_property + def is_visible(self) -> bool: + """Checks if the unit is visible for the bot. + NOTE: This means the bot has vision of the position of the unit! + It does not give any information about the cloak status of the unit.""" + if self.base_build >= 82457: + return self._proto.display_type == IS_VISIBLE + # TODO: Remove when a new linux binary (5.0.4 or newer) is released + return self._proto.display_type == IS_VISIBLE and not self.is_snapshot + + @property + def is_placeholder(self) -> bool: + """Checks if the unit is a placerholder for the bot. + Raw information about placeholders: + display_type: Placeholder + alliance: Self + unit_type: 86 + owner: 1 + pos { + x: 29.5 + y: 53.5 + z: 7.98828125 + } + radius: 2.75 + is_on_screen: false + """ + return self._proto.display_type == IS_PLACEHOLDER + + @property + def alliance(self) -> Alliance: + """ Returns the team the unit belongs to. """ + return self._proto.alliance + + @property + def is_mine(self) -> bool: + """ Checks if the unit is controlled by the bot. """ + return self._proto.alliance == IS_MINE + + @property + def is_enemy(self) -> bool: + """ Checks if the unit is hostile. """ + return self._proto.alliance == IS_ENEMY + + @property + def owner_id(self) -> int: + """ Returns the owner of the unit. This is a value of 1 or 2 in a two player game. """ + return self._proto.owner + + @property + def position_tuple(self) -> Tuple[float, float]: + """ Returns the 2d position of the unit as tuple without conversion to Point2. """ + return self._proto.pos.x, self._proto.pos.y + + @cached_property + def position(self) -> Point2: + """ Returns the 2d position of the unit. """ + return Point2.from_proto(self._proto.pos) + + @cached_property + def position3d(self) -> Point3: + """ Returns the 3d position of the unit. """ + return Point3.from_proto(self._proto.pos) + + def distance_to(self, p: Union[Unit, Point2]) -> float: + """Using the 2d distance between self and p. + To calculate the 3d distance, use unit.position3d.distance_to(p) + + :param p: + """ + if isinstance(p, Unit): + return self._bot_object._distance_squared_unit_to_unit(self, p)**0.5 + return self._bot_object.distance_math_hypot(self.position_tuple, p) + + def distance_to_squared(self, p: Union[Unit, Point2]) -> float: + """Using the 2d distance squared between self and p. Slightly faster than distance_to, so when filtering a lot of units, this function is recommended to be used. + To calculate the 3d distance, use unit.position3d.distance_to(p) + + :param p: + """ + if isinstance(p, Unit): + return self._bot_object._distance_squared_unit_to_unit(self, p) + return self._bot_object.distance_math_hypot_squared(self.position_tuple, p) + + @property + def facing(self) -> float: + """Returns direction the unit is facing as a float in range [0,2π). 0 is in direction of x axis.""" + return self._proto.facing + + def is_facing(self, other_unit: Unit, angle_error: float = 0.05) -> bool: + """Check if this unit is facing the target unit. If you make angle_error too small, there might be rounding errors. If you make angle_error too big, this function might return false positives. + + :param other_unit: + :param angle_error: + """ + # TODO perhaps return default True for units that cannot 'face' another unit? e.g. structures (planetary fortress, bunker, missile turret, photon cannon, spine, spore) or sieged tanks + angle = math.atan2( + other_unit.position_tuple[1] - self.position_tuple[1], other_unit.position_tuple[0] - self.position_tuple[0] + ) + if angle < 0: + angle += math.pi * 2 + angle_difference = math.fabs(angle - self.facing) + return angle_difference < angle_error + + @property + def footprint_radius(self) -> Optional[float]: + """For structures only. + For townhalls this returns 2.5 + For barracks, spawning pool, gateway, this returns 1.5 + For supply depot, this returns 1 + For sensor tower, creep tumor, this return 0.5 + + NOTE: This can be None if a building doesn't have a creation ability. + For rich vespene buildings, flying terran buildings, this returns None""" + return self._type_data.footprint_radius + + @property + def radius(self) -> float: + """ Half of unit size. See https://liquipedia.net/starcraft2/Unit_Statistics_(Legacy_of_the_Void) """ + return self._proto.radius + + @property + def build_progress(self) -> float: + """ Returns completion in range [0,1].""" + return self._proto.build_progress + + @property + def is_ready(self) -> bool: + """ Checks if the unit is completed. """ + return self.build_progress == 1 + + @property + def cloak(self) -> CloakState: + """Returns cloak state. + See https://github.com/Blizzard/s2client-api/blob/d9ba0a33d6ce9d233c2a4ee988360c188fbe9dbf/include/sc2api/sc2_unit.h#L95 + """ + return CloakState(self._proto.cloak) + + @property + def is_cloaked(self) -> bool: + """ Checks if the unit is cloaked. """ + return self._proto.cloak in IS_CLOAKED + + @property + def is_revealed(self) -> bool: + """ Checks if the unit is revealed. """ + return self._proto.cloak == IS_REVEALED + + @property + def can_be_attacked(self) -> bool: + """ Checks if the unit is revealed or not cloaked and therefore can be attacked. """ + return self._proto.cloak in CAN_BE_ATTACKED + + @property + def detect_range(self) -> float: + """ Returns the detection distance of the unit. """ + return self._proto.detect_range + + @property + def radar_range(self) -> float: + return self._proto.radar_range + + @property + def is_selected(self) -> bool: + """ Checks if the unit is currently selected. """ + return self._proto.is_selected + + @property + def is_on_screen(self) -> bool: + """ Checks if the unit is on the screen. """ + return self._proto.is_on_screen + + @property + def is_blip(self) -> bool: + """ Checks if the unit is detected by a sensor tower. """ + return self._proto.is_blip + + @property + def is_powered(self) -> bool: + """ Checks if the unit is powered by a pylon or warppism. """ + return self._proto.is_powered + + @property + def is_active(self) -> bool: + """ Checks if the unit has an order (e.g. unit is currently moving or attacking, structure is currently training or researching). """ + return self._proto.is_active + + # PROPERTIES BELOW THIS COMMENT ARE NOT POPULATED FOR SNAPSHOTS + + @property + def mineral_contents(self) -> int: + """ Returns the amount of minerals remaining in a mineral field. """ + return self._proto.mineral_contents + + @property + def vespene_contents(self) -> int: + """ Returns the amount of gas remaining in a geyser. """ + return self._proto.vespene_contents + + @property + def has_vespene(self) -> bool: + """Checks if a geyser has any gas remaining. + You can't build extractors on empty geysers.""" + return bool(self._proto.vespene_contents) + + @property + def is_burrowed(self) -> bool: + """ Checks if the unit is burrowed. """ + return self._proto.is_burrowed + + @property + def is_hallucination(self) -> bool: + """ Returns True if the unit is your own hallucination or detected. """ + return self._proto.is_hallucination + + @property + def attack_upgrade_level(self) -> int: + """Returns the upgrade level of the units attack. + # NOTE: Returns 0 for units without a weapon.""" + return self._proto.attack_upgrade_level + + @property + def armor_upgrade_level(self) -> int: + """ Returns the upgrade level of the units armor. """ + return self._proto.armor_upgrade_level + + @property + def shield_upgrade_level(self) -> int: + """Returns the upgrade level of the units shield. + # NOTE: Returns 0 for units without a shield.""" + return self._proto.shield_upgrade_level + + @property + def buff_duration_remain(self) -> int: + """Returns the amount of remaining frames of the visible timer bar. + # NOTE: Returns 0 for units without a timer bar.""" + return self._proto.buff_duration_remain + + @property + def buff_duration_max(self) -> int: + """Returns the maximum amount of frames of the visible timer bar. + # NOTE: Returns 0 for units without a timer bar.""" + return self._proto.buff_duration_max + + # PROPERTIES BELOW THIS COMMENT ARE NOT POPULATED FOR ENEMIES + + @cached_property + def orders(self) -> List[UnitOrder]: + """ Returns the a list of the current orders. """ + # TODO: add examples on how to use unit orders + return [UnitOrder.from_proto(order, self._bot_object) for order in self._proto.orders] + + @cached_property + def order_target(self) -> Optional[Union[int, Point2]]: + """Returns the target tag (if it is a Unit) or Point2 (if it is a Position) + from the first order, returns None if the unit is idle""" + if self.orders: + target = self.orders[0].target + if isinstance(target, int): + return target + return Point2.from_proto(target) + return None + + @property + def is_idle(self) -> bool: + """ Checks if unit is idle. """ + return not self._proto.orders + + @property + def add_on_tag(self) -> int: + """Returns the tag of the addon of unit. If the unit has no addon, returns 0.""" + return self._proto.add_on_tag + + @property + def has_add_on(self) -> bool: + """ Checks if unit has an addon attached. """ + return bool(self._proto.add_on_tag) + + @cached_property + def has_techlab(self) -> bool: + """Check if a structure is connected to a techlab addon. This should only ever return True for BARRACKS, FACTORY, STARPORT. """ + return self.add_on_tag in self._bot_object.techlab_tags + + @cached_property + def has_reactor(self) -> bool: + """Check if a structure is connected to a reactor addon. This should only ever return True for BARRACKS, FACTORY, STARPORT. """ + return self.add_on_tag in self._bot_object.reactor_tags + + @cached_property + def add_on_land_position(self) -> Point2: + """If this unit is an addon (techlab, reactor), returns the position + where a terran building (BARRACKS, FACTORY, STARPORT) has to land to connect to this addon. + + Why offset (-2.5, 0.5)? See description in 'add_on_position' + """ + return self.position.offset(Point2((-2.5, 0.5))) + + @cached_property + def add_on_position(self) -> Point2: + """If this unit is a terran production building (BARRACKS, FACTORY, STARPORT), + this property returns the position of where the addon should be, if it should build one or has one attached. + + Why offset (2.5, -0.5)? + A barracks is of size 3x3. The distance from the center to the edge is 1.5. + An addon is 2x2 and the distance from the edge to center is 1. + The total distance from center to center on the x-axis is 2.5. + The distance from center to center on the y-axis is -0.5. + """ + return self.position.offset(Point2((2.5, -0.5))) + + @cached_property + def passengers(self) -> Set[Unit]: + """ Returns the units inside a Bunker, CommandCenter, PlanetaryFortress, Medivac, Nydus, Overlord or WarpPrism. """ + return {Unit(unit, self._bot_object) for unit in self._proto.passengers} + + @cached_property + def passengers_tags(self) -> Set[int]: + """ Returns the tags of the units inside a Bunker, CommandCenter, PlanetaryFortress, Medivac, Nydus, Overlord or WarpPrism. """ + return {unit.tag for unit in self._proto.passengers} + + @property + def cargo_used(self) -> int: + """Returns how much cargo space is currently used in the unit. + Note that some units take up more than one space.""" + return self._proto.cargo_space_taken + + @property + def has_cargo(self) -> bool: + """ Checks if this unit has any units loaded. """ + return bool(self._proto.cargo_space_taken) + + @property + def cargo_size(self) -> int: + """ Returns the amount of cargo space the unit needs. """ + return self._type_data.cargo_size + + @property + def cargo_max(self) -> int: + """ How much cargo space is available at maximum. """ + return self._proto.cargo_space_max + + @property + def cargo_left(self) -> int: + """ Returns how much cargo space is currently left in the unit. """ + return self._proto.cargo_space_max - self._proto.cargo_space_taken + + @property + def assigned_harvesters(self) -> int: + """ Returns the number of workers currently gathering resources at a geyser or mining base.""" + return self._proto.assigned_harvesters + + @property + def ideal_harvesters(self) -> int: + """Returns the ideal harverster count for unit. + 3 for gas buildings, 2*n for n mineral patches on that base.""" + return self._proto.ideal_harvesters + + @property + def surplus_harvesters(self) -> int: + """Returns a positive int if unit has too many harvesters mining, + a negative int if it has too few mining. + Will only works on townhalls, and gas buildings. + """ + return self._proto.assigned_harvesters - self._proto.ideal_harvesters + + @property + def weapon_cooldown(self) -> float: + """Returns the time until the unit can fire again, + returns -1 for units that can't attack. + Usage: + if unit.weapon_cooldown == 0: + unit.attack(target) + elif unit.weapon_cooldown < 0: + unit.move(closest_allied_unit_because_cant_attack) + else: + unit.move(retreatPosition)""" + if self.can_attack: + return self._proto.weapon_cooldown + return -1 + + @property + def weapon_ready(self) -> bool: + """Checks if the weapon is ready to be fired.""" + return self.weapon_cooldown == 0 + + @property + def engaged_target_tag(self) -> int: + # TODO What does this do? + return self._proto.engaged_target_tag + + @cached_property + def rally_targets(self) -> List[RallyTarget]: + """ Returns the queue of rallytargets of the structure. """ + return [RallyTarget.from_proto(rally_target) for rally_target in self._proto.rally_targets] + + # Unit functions + + def __hash__(self) -> int: + return self.tag + + def __eq__(self, other: Union[Unit, Any]) -> bool: + """ + :param other: + """ + return self.tag == getattr(other, "tag", -1) diff --git a/worlds/_sc2common/bot/units.py b/worlds/_sc2common/bot/units.py new file mode 100644 index 000000000000..3b3c18608bd3 --- /dev/null +++ b/worlds/_sc2common/bot/units.py @@ -0,0 +1,633 @@ +# pylint: disable=W0212 +from __future__ import annotations + +import random +from itertools import chain +from typing import TYPE_CHECKING, Any, Callable, Generator, Iterable, List, Optional, Set, Tuple, Union + +from .position import Point2 +from .unit import Unit + +if TYPE_CHECKING: + from .bot_ai import BotAI + + +# pylint: disable=R0904 +class Units(list): + """A collection of Unit objects. Makes it easy to select units by selectors.""" + + @classmethod + def from_proto(cls, units, bot_object: BotAI): + # pylint: disable=E1120 + return cls((Unit(raw_unit, bot_object=bot_object) for raw_unit in units)) + + def __init__(self, units: Iterable[Unit], bot_object: BotAI): + """ + :param units: + :param bot_object: + """ + super().__init__(units) + self._bot_object = bot_object + + def __call__(self) -> Units: + """Creates a new mutable Units object from Units or list object. + + :param unit_types: + """ + return self + + def __iter__(self) -> Generator[Unit, None, None]: + return (item for item in super().__iter__()) + + def copy(self) -> Units: + """Creates a new mutable Units object from Units or list object. + + :param units: + """ + return Units(self, self._bot_object) + + def __or__(self, other: Units) -> Units: + """ + :param other: + """ + return Units( + chain( + iter(self), + (other_unit for other_unit in other if other_unit.tag not in (self_unit.tag for self_unit in self)), + ), + self._bot_object, + ) + + def __add__(self, other: Units) -> Units: + """ + :param other: + """ + return Units( + chain( + iter(self), + (other_unit for other_unit in other if other_unit.tag not in (self_unit.tag for self_unit in self)), + ), + self._bot_object, + ) + + def __and__(self, other: Units) -> Units: + """ + :param other: + """ + return Units( + (other_unit for other_unit in other if other_unit.tag in (self_unit.tag for self_unit in self)), + self._bot_object, + ) + + def __sub__(self, other: Units) -> Units: + """ + :param other: + """ + return Units( + (self_unit for self_unit in self if self_unit.tag not in (other_unit.tag for other_unit in other)), + self._bot_object, + ) + + def __hash__(self) -> int: + return hash(unit.tag for unit in self) + + @property + def amount(self) -> int: + return len(self) + + @property + def empty(self) -> bool: + return not bool(self) + + @property + def exists(self) -> bool: + return bool(self) + + def find_by_tag(self, tag: int) -> Optional[Unit]: + """ + :param tag: + """ + for unit in self: + if unit.tag == tag: + return unit + return None + + def by_tag(self, tag: int) -> Unit: + """ + :param tag: + """ + unit = self.find_by_tag(tag) + if unit is None: + raise KeyError("Unit not found") + return unit + + @property + def first(self) -> Unit: + assert self, "Units object is empty" + return self[0] + + def take(self, n: int) -> Units: + """ + :param n: + """ + if n >= self.amount: + return self + return self.subgroup(self[:n]) + + @property + def random(self) -> Unit: + assert self, "Units object is empty" + return random.choice(self) + + def random_or(self, other: any) -> Unit: + return random.choice(self) if self else other + + def random_group_of(self, n: int) -> Units: + """ Returns self if n >= self.amount. """ + if n < 1: + return Units([], self._bot_object) + if n >= self.amount: + return self + return self.subgroup(random.sample(self, n)) + + def in_attack_range_of(self, unit: Unit, bonus_distance: float = 0) -> Units: + """Filters units that are in attack range of the given unit. + This uses the unit and target unit.radius when calculating the distance, so it should be accurate. + Caution: This may not work well for static structures (bunker, sieged tank, planetary fortress, photon cannon, spine and spore crawler) because it seems attack ranges differ for static / immovable units. + + Example:: + + enemy_zerglings = self.enemy_units(UnitTypeId.ZERGLING) + my_marine = next((unit for unit in self.units if unit.type_id == UnitTypeId.MARINE), None) + if my_marine: + all_zerglings_my_marine_can_attack = enemy_zerglings.in_attack_range_of(my_marine) + + Example:: + + enemy_mutalisks = self.enemy_units(UnitTypeId.MUTALISK) + my_marauder = next((unit for unit in self.units if unit.type_id == UnitTypeId.MARAUDER), None) + if my_marauder: + all_mutalisks_my_marauder_can_attack = enemy_mutaliskss.in_attack_range_of(my_marauder) + # Is empty because mutalisk are flying and marauder cannot attack air + + :param unit: + :param bonus_distance: + """ + return self.filter(lambda x: unit.target_in_range(x, bonus_distance=bonus_distance)) + + def closest_distance_to(self, position: Union[Unit, Point2]) -> float: + """Returns the distance between the closest unit from this group to the target unit. + + Example:: + + enemy_zerglings = self.enemy_units(UnitTypeId.ZERGLING) + my_marine = next((unit for unit in self.units if unit.type_id == UnitTypeId.MARINE), None) + if my_marine: + closest_zergling_distance = enemy_zerglings.closest_distance_to(my_marine) + # Contains the distance between the marine and the closest zergling + + :param position: + """ + assert self, "Units object is empty" + if isinstance(position, Unit): + return min(self._bot_object._distance_squared_unit_to_unit(unit, position) for unit in self)**0.5 + return min(self._bot_object._distance_units_to_pos(self, position)) + + def furthest_distance_to(self, position: Union[Unit, Point2]) -> float: + """Returns the distance between the furthest unit from this group to the target unit + + + Example:: + + enemy_zerglings = self.enemy_units(UnitTypeId.ZERGLING) + my_marine = next((unit for unit in self.units if unit.type_id == UnitTypeId.MARINE), None) + if my_marine: + furthest_zergling_distance = enemy_zerglings.furthest_distance_to(my_marine) + # Contains the distance between the marine and the furthest away zergling + + :param position: + """ + assert self, "Units object is empty" + if isinstance(position, Unit): + return max(self._bot_object._distance_squared_unit_to_unit(unit, position) for unit in self)**0.5 + return max(self._bot_object._distance_units_to_pos(self, position)) + + def closest_to(self, position: Union[Unit, Point2]) -> Unit: + """Returns the closest unit (from this Units object) to the target unit or position. + + Example:: + + enemy_zerglings = self.enemy_units(UnitTypeId.ZERGLING) + my_marine = next((unit for unit in self.units if unit.type_id == UnitTypeId.MARINE), None) + if my_marine: + closest_zergling = enemy_zerglings.closest_to(my_marine) + # Contains the zergling that is closest to the target marine + + :param position: + """ + assert self, "Units object is empty" + if isinstance(position, Unit): + return min( + (unit1 for unit1 in self), + key=lambda unit2: self._bot_object._distance_squared_unit_to_unit(unit2, position), + ) + + distances = self._bot_object._distance_units_to_pos(self, position) + return min(((unit, dist) for unit, dist in zip(self, distances)), key=lambda my_tuple: my_tuple[1])[0] + + def furthest_to(self, position: Union[Unit, Point2]) -> Unit: + """Returns the furhest unit (from this Units object) to the target unit or position. + + Example:: + + enemy_zerglings = self.enemy_units(UnitTypeId.ZERGLING) + my_marine = next((unit for unit in self.units if unit.type_id == UnitTypeId.MARINE), None) + if my_marine: + furthest_zergling = enemy_zerglings.furthest_to(my_marine) + # Contains the zergling that is furthest away to the target marine + + :param position: + """ + assert self, "Units object is empty" + if isinstance(position, Unit): + return max( + (unit1 for unit1 in self), + key=lambda unit2: self._bot_object._distance_squared_unit_to_unit(unit2, position), + ) + distances = self._bot_object._distance_units_to_pos(self, position) + return max(((unit, dist) for unit, dist in zip(self, distances)), key=lambda my_tuple: my_tuple[1])[0] + + def closer_than(self, distance: float, position: Union[Unit, Point2]) -> Units: + """Returns all units (from this Units object) that are closer than 'distance' away from target unit or position. + + Example:: + + enemy_zerglings = self.enemy_units(UnitTypeId.ZERGLING) + my_marine = next((unit for unit in self.units if unit.type_id == UnitTypeId.MARINE), None) + if my_marine: + close_zerglings = enemy_zerglings.closer_than(3, my_marine) + # Contains all zerglings that are distance 3 or less away from the marine (does not include unit radius in calculation) + + :param distance: + :param position: + """ + if not self: + return self + if isinstance(position, Unit): + distance_squared = distance**2 + return self.subgroup( + unit for unit in self + if self._bot_object._distance_squared_unit_to_unit(unit, position) < distance_squared + ) + distances = self._bot_object._distance_units_to_pos(self, position) + return self.subgroup(unit for unit, dist in zip(self, distances) if dist < distance) + + def further_than(self, distance: float, position: Union[Unit, Point2]) -> Units: + """Returns all units (from this Units object) that are further than 'distance' away from target unit or position. + + Example:: + + enemy_zerglings = self.enemy_units(UnitTypeId.ZERGLING) + my_marine = next((unit for unit in self.units if unit.type_id == UnitTypeId.MARINE), None) + if my_marine: + far_zerglings = enemy_zerglings.further_than(3, my_marine) + # Contains all zerglings that are distance 3 or more away from the marine (does not include unit radius in calculation) + + :param distance: + :param position: + """ + if not self: + return self + if isinstance(position, Unit): + distance_squared = distance**2 + return self.subgroup( + unit for unit in self + if distance_squared < self._bot_object._distance_squared_unit_to_unit(unit, position) + ) + distances = self._bot_object._distance_units_to_pos(self, position) + return self.subgroup(unit for unit, dist in zip(self, distances) if distance < dist) + + def in_distance_between( + self, position: Union[Unit, Point2, Tuple[float, float]], distance1: float, distance2: float + ) -> Units: + """Returns units that are further than distance1 and closer than distance2 to unit or position. + + Example:: + + enemy_zerglings = self.enemy_units(UnitTypeId.ZERGLING) + my_marine = next((unit for unit in self.units if unit.type_id == UnitTypeId.MARINE), None) + if my_marine: + zerglings_filtered = enemy_zerglings.in_distance_between(my_marine, 3, 5) + # Contains all zerglings that are between distance 3 and 5 away from the marine (does not include unit radius in calculation) + + :param position: + :param distance1: + :param distance2: + """ + if not self: + return self + if isinstance(position, Unit): + distance1_squared = distance1**2 + distance2_squared = distance2**2 + return self.subgroup( + unit for unit in self if + distance1_squared < self._bot_object._distance_squared_unit_to_unit(unit, position) < distance2_squared + ) + distances = self._bot_object._distance_units_to_pos(self, position) + return self.subgroup(unit for unit, dist in zip(self, distances) if distance1 < dist < distance2) + + def closest_n_units(self, position: Union[Unit, Point2], n: int) -> Units: + """Returns the n closest units in distance to position. + + Example:: + + enemy_zerglings = self.enemy_units(UnitTypeId.ZERGLING) + my_marine = next((unit for unit in self.units if unit.type_id == UnitTypeId.MARINE), None) + if my_marine: + zerglings_filtered = enemy_zerglings.closest_n_units(my_marine, 5) + # Contains 5 zerglings that are the closest to the marine + + :param position: + :param n: + """ + if not self: + return self + return self.subgroup(self._list_sorted_by_distance_to(position)[:n]) + + def furthest_n_units(self, position: Union[Unit, Point2], n: int) -> Units: + """Returns the n furhest units in distance to position. + + Example:: + + enemy_zerglings = self.enemy_units(UnitTypeId.ZERGLING) + my_marine = next((unit for unit in self.units if unit.type_id == UnitTypeId.MARINE), None) + if my_marine: + zerglings_filtered = enemy_zerglings.furthest_n_units(my_marine, 5) + # Contains 5 zerglings that are the furthest to the marine + + :param position: + :param n: + """ + if not self: + return self + return self.subgroup(self._list_sorted_by_distance_to(position)[-n:]) + + def in_distance_of_group(self, other_units: Units, distance: float) -> Units: + """Returns units that are closer than distance from any unit in the other units object. + + :param other_units: + :param distance: + """ + assert other_units, "Other units object is empty" + # Return self because there are no enemies + if not self: + return self + distance_squared = distance**2 + if len(self) == 1: + if any( + self._bot_object._distance_squared_unit_to_unit(self[0], target) < distance_squared + for target in other_units + ): + return self + return self.subgroup([]) + + return self.subgroup( + self_unit for self_unit in self if any( + self._bot_object._distance_squared_unit_to_unit(self_unit, other_unit) < distance_squared + for other_unit in other_units + ) + ) + + def in_closest_distance_to_group(self, other_units: Units) -> Unit: + """Returns unit in shortest distance from any unit in self to any unit in group. + + Loops over all units in self, then loops over all units in other_units and calculates the shortest distance. Returns the units that is closest to any unit of 'other_units'. + + :param other_units: + """ + assert self, "Units object is empty" + assert other_units, "Given units object is empty" + return min( + self, + key=lambda self_unit: + min(self._bot_object._distance_squared_unit_to_unit(self_unit, other_unit) for other_unit in other_units), + ) + + def _list_sorted_closest_to_distance(self, position: Union[Unit, Point2], distance: float) -> List[Unit]: + """This function should be a bit faster than using units.sorted(key=lambda u: u.distance_to(position)) + + :param position: + :param distance: + """ + if isinstance(position, Unit): + return sorted( + self, + key=lambda unit: abs(self._bot_object._distance_squared_unit_to_unit(unit, position) - distance), + reverse=True, + ) + distances = self._bot_object._distance_units_to_pos(self, position) + unit_dist_dict = {unit.tag: dist for unit, dist in zip(self, distances)} + return sorted(self, key=lambda unit2: abs(unit_dist_dict[unit2.tag] - distance), reverse=True) + + def n_closest_to_distance(self, position: Point2, distance: float, n: int) -> Units: + """Returns n units that are the closest to distance away. + For example if the distance is set to 5 and you want 3 units, from units with distance [3, 4, 5, 6, 7] to position, + the units with distance [4, 5, 6] will be returned + + :param position: + :param distance: + """ + return self.subgroup(self._list_sorted_closest_to_distance(position=position, distance=distance)[:n]) + + def n_furthest_to_distance(self, position: Point2, distance: float, n: int) -> Units: + """Inverse of the function 'n_closest_to_distance', returns the furthest units instead + + :param position: + :param distance: + """ + return self.subgroup(self._list_sorted_closest_to_distance(position=position, distance=distance)[-n:]) + + def subgroup(self, units: Iterable[Unit]) -> Units: + """Creates a new mutable Units object from Units or list object. + + :param units: + """ + return Units(units, self._bot_object) + + def filter(self, pred: Callable[[Unit], Any]) -> Units: + """Filters the current Units object and returns a new Units object. + + Example:: + + from sc2.ids.unit_typeid import UnitTypeId + my_marines = self.units.filter(lambda unit: unit.type_id == UnitTypeId.MARINE) + + completed_structures = self.structures.filter(lambda structure: structure.is_ready) + + queens_with_energy_to_inject = self.units.filter(lambda unit: unit.type_id == UnitTypeId.QUEEN and unit.energy >= 25) + + orbitals_with_energy_to_mule = self.structures.filter(lambda structure: structure.type_id == UnitTypeId.ORBITALCOMMAND and structure.energy >= 50) + + my_units_that_can_shoot_up = self.units.filter(lambda unit: unit.can_attack_air) + + See more unit properties in unit.py + + :param pred: + """ + assert callable(pred), "Function is not callable" + return self.subgroup(filter(pred, self)) + + def sorted(self, key: Callable[[Unit], Any], reverse: bool = False) -> Units: + return self.subgroup(sorted(self, key=key, reverse=reverse)) + + def _list_sorted_by_distance_to(self, position: Union[Unit, Point2], reverse: bool = False) -> List[Unit]: + """This function should be a bit faster than using units.sorted(key=lambda u: u.distance_to(position)) + + :param position: + :param reverse: + """ + if isinstance(position, Unit): + return sorted( + self, key=lambda unit: self._bot_object._distance_squared_unit_to_unit(unit, position), reverse=reverse + ) + distances = self._bot_object._distance_units_to_pos(self, position) + unit_dist_dict = {unit.tag: dist for unit, dist in zip(self, distances)} + return sorted(self, key=lambda unit2: unit_dist_dict[unit2.tag], reverse=reverse) + + def sorted_by_distance_to(self, position: Union[Unit, Point2], reverse: bool = False) -> Units: + """This function should be a bit faster than using units.sorted(key=lambda u: u.distance_to(position)) + + :param position: + :param reverse: + """ + return self.subgroup(self._list_sorted_by_distance_to(position, reverse=reverse)) + + def tags_in(self, other: Iterable[int]) -> Units: + """Filters all units that have their tags in the 'other' set/list/dict + + Example:: + + my_inject_queens = self.units.tags_in(self.queen_tags_assigned_to_do_injects) + + # Do not use the following as it is slower because it first loops over all units to filter out if they are queens and loops over those again to check if their tags are in the list/set + my_inject_queens_slow = self.units(QUEEN).tags_in(self.queen_tags_assigned_to_do_injects) + + :param other: + """ + return self.filter(lambda unit: unit.tag in other) + + def tags_not_in(self, other: Iterable[int]) -> Units: + """Filters all units that have their tags not in the 'other' set/list/dict + + Example:: + + my_non_inject_queens = self.units.tags_not_in(self.queen_tags_assigned_to_do_injects) + + # Do not use the following as it is slower because it first loops over all units to filter out if they are queens and loops over those again to check if their tags are in the list/set + my_non_inject_queens_slow = self.units(QUEEN).tags_not_in(self.queen_tags_assigned_to_do_injects) + + :param other: + """ + return self.filter(lambda unit: unit.tag not in other) + + @property + def center(self) -> Point2: + """ Returns the central position of all units. """ + assert self, "Units object is empty" + return Point2( + ( + sum(unit._proto.pos.x for unit in self) / self.amount, + sum(unit._proto.pos.y for unit in self) / self.amount, + ) + ) + + @property + def selected(self) -> Units: + """ Returns all units that are selected by the human player. """ + return self.filter(lambda unit: unit.is_selected) + + @property + def tags(self) -> Set[int]: + """ Returns all unit tags as a set. """ + return {unit.tag for unit in self} + + @property + def ready(self) -> Units: + """ Returns all structures that are ready (construction complete). """ + return self.filter(lambda unit: unit.is_ready) + + @property + def not_ready(self) -> Units: + """ Returns all structures that are not ready (construction not complete). """ + return self.filter(lambda unit: not unit.is_ready) + + @property + def idle(self) -> Units: + """ Returns all units or structures that are doing nothing (unit is standing still, structure is doing nothing). """ + return self.filter(lambda unit: unit.is_idle) + + @property + def owned(self) -> Units: + """ Deprecated: All your units. """ + return self.filter(lambda unit: unit.is_mine) + + @property + def enemy(self) -> Units: + """ Deprecated: All enemy units.""" + return self.filter(lambda unit: unit.is_enemy) + + @property + def flying(self) -> Units: + """ Returns all units that are flying. """ + return self.filter(lambda unit: unit.is_flying) + + @property + def not_flying(self) -> Units: + """ Returns all units that not are flying. """ + return self.filter(lambda unit: not unit.is_flying) + + @property + def structure(self) -> Units: + """ Deprecated: All structures. """ + return self.filter(lambda unit: unit.is_structure) + + @property + def not_structure(self) -> Units: + """ Deprecated: All units that are not structures. """ + return self.filter(lambda unit: not unit.is_structure) + + @property + def gathering(self) -> Units: + """ Returns all workers that are mining minerals or vespene (gather command). """ + return self.filter(lambda unit: unit.is_gathering) + + @property + def returning(self) -> Units: + """ Returns all workers that are carrying minerals or vespene and are returning to a townhall. """ + return self.filter(lambda unit: unit.is_returning) + + @property + def collecting(self) -> Units: + """ Returns all workers that are mining or returning resources. """ + return self.filter(lambda unit: unit.is_collecting) + + @property + def visible(self) -> Units: + """Returns all units or structures that are visible. + TODO: add proper description on which units are exactly visible (not snapshots?)""" + return self.filter(lambda unit: unit.is_visible) + + @property + def mineral_field(self) -> Units: + """ Returns all units that are mineral fields. """ + return self.filter(lambda unit: unit.is_mineral_field) + + @property + def vespene_geyser(self) -> Units: + """ Returns all units that are vespene geysers. """ + return self.filter(lambda unit: unit.is_vespene_geyser) + + @property + def prefer_idle(self) -> Units: + """ Sorts units based on if they are idle. Idle units come first. """ + return self.sorted(lambda unit: unit.is_idle, reverse=True) diff --git a/worlds/_sc2common/bot/versions.py b/worlds/_sc2common/bot/versions.py new file mode 100644 index 000000000000..0ce923295c54 --- /dev/null +++ b/worlds/_sc2common/bot/versions.py @@ -0,0 +1,472 @@ +VERSIONS = [ + { + "base-version": 52910, + "data-hash": "8D9FEF2E1CF7C6C9CBE4FBCA830DDE1C", + "fixed-hash": "009BC85EF547B51EBF461C83A9CBAB30", + "label": "3.13", + "replay-hash": "47BFE9D10F26B0A8B74C637D6327BF3C", + "version": 52910 + }, { + "base-version": 53644, + "data-hash": "CA275C4D6E213ED30F80BACCDFEDB1F5", + "fixed-hash": "29198786619C9011735BCFD378E49CB6", + "label": "3.14", + "replay-hash": "5AF236FC012ADB7289DB493E63F73FD5", + "version": 53644 + }, { + "base-version": 54518, + "data-hash": "BBF619CCDCC80905350F34C2AF0AB4F6", + "fixed-hash": "D5963F25A17D9E1EA406FF6BBAA9B736", + "label": "3.15", + "replay-hash": "43530321CF29FD11482AB9CBA3EB553D", + "version": 54518 + }, { + "base-version": 54518, + "data-hash": "6EB25E687F8637457538F4B005950A5E", + "fixed-hash": "D5963F25A17D9E1EA406FF6BBAA9B736", + "label": "3.15.1", + "replay-hash": "43530321CF29FD11482AB9CBA3EB553D", + "version": 54724 + }, { + "base-version": 55505, + "data-hash": "60718A7CA50D0DF42987A30CF87BCB80", + "fixed-hash": "0189B2804E2F6BA4C4591222089E63B2", + "label": "3.16", + "replay-hash": "B11811B13F0C85C29C5D4597BD4BA5A4", + "version": 55505 + }, { + "base-version": 55958, + "data-hash": "5BD7C31B44525DAB46E64C4602A81DC2", + "fixed-hash": "717B05ACD26C108D18A219B03710D06D", + "label": "3.16.1", + "replay-hash": "21C8FA403BB1194E2B6EB7520016B958", + "version": 55958 + }, { + "base-version": 56787, + "data-hash": "DFD1F6607F2CF19CB4E1C996B2563D9B", + "fixed-hash": "4E1C17AB6A79185A0D87F68D1C673CD9", + "label": "3.17", + "replay-hash": "D0296961C9EA1356F727A2468967A1E2", + "version": 56787 + }, { + "base-version": 56787, + "data-hash": "3F2FCED08798D83B873B5543BEFA6C4B", + "fixed-hash": "4474B6B7B0D1423DAA76B9623EF2E9A9", + "label": "3.17.1", + "replay-hash": "D0296961C9EA1356F727A2468967A1E2", + "version": 57218 + }, { + "base-version": 56787, + "data-hash": "C690FC543082D35EA0AAA876B8362BEA", + "fixed-hash": "4474B6B7B0D1423DAA76B9623EF2E9A9", + "label": "3.17.2", + "replay-hash": "D0296961C9EA1356F727A2468967A1E2", + "version": 57490 + }, { + "base-version": 57507, + "data-hash": "1659EF34997DA3470FF84A14431E3A86", + "fixed-hash": "95666060F129FD267C5A8135A8920AA2", + "label": "3.18", + "replay-hash": "06D650F850FDB2A09E4B01D2DF8C433A", + "version": 57507 + }, { + "base-version": 58400, + "data-hash": "2B06AEE58017A7DF2A3D452D733F1019", + "fixed-hash": "2CFE1B8757DA80086DD6FD6ECFF21AC6", + "label": "3.19", + "replay-hash": "227B6048D55535E0FF5607746EBCC45E", + "version": 58400 + }, { + "base-version": 58400, + "data-hash": "D9B568472880CC4719D1B698C0D86984", + "fixed-hash": "CE1005E9B145BDFC8E5E40CDEB5E33BB", + "label": "3.19.1", + "replay-hash": "227B6048D55535E0FF5607746EBCC45E", + "version": 58600 + }, { + "base-version": 59587, + "data-hash": "9B4FD995C61664831192B7DA46F8C1A1", + "fixed-hash": "D5D5798A9CCD099932C8F855C8129A7C", + "label": "4.0", + "replay-hash": "BB4DA41B57D490BD13C13A594E314BA4", + "version": 59587 + }, { + "base-version": 60196, + "data-hash": "1B8ACAB0C663D5510941A9871B3E9FBE", + "fixed-hash": "9327F9AF76CF11FC43D20E3E038B1B7A", + "label": "4.1", + "replay-hash": "AEA0C2A9D56E02C6B7D21E889D6B9B2F", + "version": 60196 + }, { + "base-version": 60321, + "data-hash": "5C021D8A549F4A776EE9E9C1748FFBBC", + "fixed-hash": "C53FA3A7336EDF320DCEB0BC078AEB0A", + "label": "4.1.1", + "replay-hash": "8EE054A8D98C7B0207E709190A6F3953", + "version": 60321 + }, { + "base-version": 60321, + "data-hash": "33D9FE28909573253B7FC352CE7AEA40", + "fixed-hash": "FEE6F86A211380DF509F3BBA58A76B87", + "label": "4.1.2", + "replay-hash": "8EE054A8D98C7B0207E709190A6F3953", + "version": 60604 + }, { + "base-version": 60321, + "data-hash": "F486693E00B2CD305B39E0AB254623EB", + "fixed-hash": "AF7F5499862F497C7154CB59167FEFB3", + "label": "4.1.3", + "replay-hash": "8EE054A8D98C7B0207E709190A6F3953", + "version": 61021 + }, { + "base-version": 60321, + "data-hash": "2E2A3F6E0BAFE5AC659C4D39F13A938C", + "fixed-hash": "F9A68CF1FBBF867216FFECD9EAB72F4A", + "label": "4.1.4", + "replay-hash": "8EE054A8D98C7B0207E709190A6F3953", + "version": 61545 + }, { + "base-version": 62347, + "data-hash": "C0C0E9D37FCDBC437CE386C6BE2D1F93", + "fixed-hash": "A5C4BE991F37F1565097AAD2A707FC4C", + "label": "4.2", + "replay-hash": "2167A7733637F3AFC49B210D165219A7", + "version": 62347 + }, { + "base-version": 62848, + "data-hash": "29BBAC5AFF364B6101B661DB468E3A37", + "fixed-hash": "ABAF9318FE79E84485BEC5D79C31262C", + "label": "4.2.1", + "replay-hash": "A7ACEC5759ADB459A5CEC30A575830EC", + "version": 62848 + }, { + "base-version": 63454, + "data-hash": "3CB54C86777E78557C984AB1CF3494A0", + "fixed-hash": "A9DCDAA97F7DA07F6EF29C0BF4DFC50D", + "label": "4.2.2", + "replay-hash": "A7ACEC5759ADB459A5CEC30A575830EC", + "version": 63454 + }, { + "base-version": 64469, + "data-hash": "C92B3E9683D5A59E08FC011F4BE167FF", + "fixed-hash": "DDF3E0A6C00DC667F59BF90F793C71B8", + "label": "4.3", + "replay-hash": "6E80072968515101AF08D3953FE3EEBA", + "version": 64469 + }, { + "base-version": 65094, + "data-hash": "E5A21037AA7A25C03AC441515F4E0644", + "fixed-hash": "09EF8E9B96F14C5126F1DB5378D15F3A", + "label": "4.3.1", + "replay-hash": "DD9B57C516023B58F5B588377880D93A", + "version": 65094 + }, { + "base-version": 65384, + "data-hash": "B6D73C85DFB70F5D01DEABB2517BF11C", + "fixed-hash": "615C1705E4C7A5FD8690B3FD376C1AFE", + "label": "4.3.2", + "replay-hash": "DD9B57C516023B58F5B588377880D93A", + "version": 65384 + }, { + "base-version": 65895, + "data-hash": "BF41339C22AE2EDEBEEADC8C75028F7D", + "fixed-hash": "C622989A4C0AF7ED5715D472C953830B", + "label": "4.4", + "replay-hash": "441BBF1A222D5C0117E85B118706037F", + "version": 65895 + }, { + "base-version": 66668, + "data-hash": "C094081D274A39219061182DBFD7840F", + "fixed-hash": "1C236A42171AAC6DD1D5E50D779C522D", + "label": "4.4.1", + "replay-hash": "21D5B4B4D5175C562CF4C4A803C995C6", + "version": 66668 + }, { + "base-version": 67188, + "data-hash": "2ACF84A7ECBB536F51FC3F734EC3019F", + "fixed-hash": "2F0094C990E0D4E505570195F96C2A0C", + "label": "4.5", + "replay-hash": "E9873B3A3846F5878CEE0D1E2ADD204A", + "version": 67188 + }, { + "base-version": 67188, + "data-hash": "6D239173B8712461E6A7C644A5539369", + "fixed-hash": "A1BC35751ACC34CF887321A357B40158", + "label": "4.5.1", + "replay-hash": "E9873B3A3846F5878CEE0D1E2ADD204A", + "version": 67344 + }, { + "base-version": 67926, + "data-hash": "7DE59231CBF06F1ECE9A25A27964D4AE", + "fixed-hash": "570BEB69151F40D010E89DE1825AE680", + "label": "4.6", + "replay-hash": "DA662F9091DF6590A5E323C21127BA5A", + "version": 67926 + }, { + "base-version": 67926, + "data-hash": "BEA99B4A8E7B41E62ADC06D194801BAB", + "fixed-hash": "309E45F53690F8D1108F073ABB4D4734", + "label": "4.6.1", + "replay-hash": "DA662F9091DF6590A5E323C21127BA5A", + "version": 68195 + }, { + "base-version": 69232, + "data-hash": "B3E14058F1083913B80C20993AC965DB", + "fixed-hash": "21935E776237EF12B6CC73E387E76D6E", + "label": "4.6.2", + "replay-hash": "A230717B315D83ACC3697B6EC28C3FF6", + "version": 69232 + }, { + "base-version": 70154, + "data-hash": "8E216E34BC61ABDE16A59A672ACB0F3B", + "fixed-hash": "09CD819C667C67399F5131185334243E", + "label": "4.7", + "replay-hash": "9692B04D6E695EF08A2FB920979E776C", + "version": 70154 + }, { + "base-version": 70154, + "data-hash": "94596A85191583AD2EBFAE28C5D532DB", + "fixed-hash": "0AE50F82AC1A7C0DCB6A290D7FBA45DB", + "label": "4.7.1", + "replay-hash": "D74FBB3CB0897A3EE8F44E78119C4658", + "version": 70326 + }, { + "base-version": 71061, + "data-hash": "760581629FC458A1937A05ED8388725B", + "fixed-hash": "815C099DF1A17577FDC186FDB1381B16", + "label": "4.8", + "replay-hash": "BD692311442926E1F0B7C17E9ABDA34B", + "version": 71061 + }, { + "base-version": 71523, + "data-hash": "FCAF3F050B7C0CC7ADCF551B61B9B91E", + "fixed-hash": "4593CC331691620509983E92180A309A", + "label": "4.8.1", + "replay-hash": "BD692311442926E1F0B7C17E9ABDA34B", + "version": 71523 + }, { + "base-version": 71663, + "data-hash": "FE90C92716FC6F8F04B74268EC369FA5", + "fixed-hash": "1DBF3819F3A7367592648632CC0D5BFD", + "label": "4.8.2", + "replay-hash": "E43A9885B3EFAE3D623091485ECCCB6C", + "version": 71663 + }, { + "base-version": 72282, + "data-hash": "0F14399BBD0BA528355FF4A8211F845B", + "fixed-hash": "E9958B2CB666DCFE101D23AF87DB8140", + "label": "4.8.3", + "replay-hash": "3AF3657F55AB961477CE268F5CA33361", + "version": 72282 + }, { + "base-version": 73286, + "data-hash": "CD040C0675FD986ED37A4CA3C88C8EB5", + "fixed-hash": "62A146F7A0D19A8DD05BF011631B31B8", + "label": "4.8.4", + "replay-hash": "EE3A89F443BE868EBDA33A17C002B609", + "version": 73286 + }, { + "base-version": 73559, + "data-hash": "B2465E73AED597C74D0844112D582595", + "fixed-hash": "EF0A43C33413613BC7343B86C0A7CC92", + "label": "4.8.5", + "replay-hash": "147388D35E76861BD4F590F8CC5B7B0B", + "version": 73559 + }, { + "base-version": 73620, + "data-hash": "AA18FEAD6573C79EF707DF44ABF1BE61", + "fixed-hash": "4D76491CCAE756F0498D1C5B2973FF9C", + "label": "4.8.6", + "replay-hash": "147388D35E76861BD4F590F8CC5B7B0B", + "version": 73620 + }, { + "base-version": 74071, + "data-hash": "70C74A2DCA8A0D8E7AE8647CAC68ACCA", + "fixed-hash": "C4A3F01B4753245296DC94BC1B5E9B36", + "label": "4.9", + "replay-hash": "19D15E5391FACB379BFCA262CA8FD208", + "version": 74071 + }, { + "base-version": 74456, + "data-hash": "218CB2271D4E2FA083470D30B1A05F02", + "fixed-hash": "E82051387C591CAB1212B64073759826", + "label": "4.9.1", + "replay-hash": "1586ADF060C26219FF3404673D70245B", + "version": 74456 + }, { + "base-version": 74741, + "data-hash": "614480EF79264B5BD084E57F912172FF", + "fixed-hash": "500CC375B7031C8272546B78E9BE439F", + "label": "4.9.2", + "replay-hash": "A7FAC56F940382E05157EAB19C932E3A", + "version": 74741 + }, { + "base-version": 75025, + "data-hash": "C305368C63621480462F8F516FB64374", + "fixed-hash": "DEE7842C8BCB6874EC254AA3D45365F7", + "label": "4.9.3", + "replay-hash": "A7FAC56F940382E05157EAB19C932E3A", + "version": 75025 + }, { + "base-version": 75689, + "data-hash": "B89B5D6FA7CBF6452E721311BFBC6CB2", + "fixed-hash": "2B2097DC4AD60A2D1E1F38691A1FF111", + "label": "4.10", + "replay-hash": "6A60E59031A7DB1B272EE87E51E4C7CD", + "version": 75689 + }, { + "base-version": 75800, + "data-hash": "DDFFF9EC4A171459A4F371C6CC189554", + "fixed-hash": "1FB8FAF4A87940621B34F0B8F6FDDEA6", + "label": "4.10.1", + "replay-hash": "6A60E59031A7DB1B272EE87E51E4C7CD", + "version": 75800 + }, { + "base-version": 76052, + "data-hash": "D0F1A68AA88BA90369A84CD1439AA1C3", + "fixed-hash": "", + "label": "4.10.2", + "replay-hash": "", + "version": 76052 + }, { + "base-version": 76114, + "data-hash": "CDB276D311F707C29BA664B7754A7293", + "fixed-hash": "", + "label": "4.10.3", + "replay-hash": "", + "version": 76114 + }, { + "base-version": 76811, + "data-hash": "FF9FA4EACEC5F06DEB27BD297D73ED67", + "fixed-hash": "", + "label": "4.10.4", + "replay-hash": "", + "version": 76811 + }, { + "base-version": 77379, + "data-hash": "70E774E722A58287EF37D487605CD384", + "fixed-hash": "", + "label": "4.11.0", + "replay-hash": "", + "version": 77379 + }, { + "base-version": 77379, + "data-hash": "F92D1127A291722120AC816F09B2E583", + "fixed-hash": "", + "label": "4.11.1", + "replay-hash": "", + "version": 77474 + }, { + "base-version": 77535, + "data-hash": "FC43E0897FCC93E4632AC57CBC5A2137", + "fixed-hash": "", + "label": "4.11.2", + "replay-hash": "", + "version": 77535 + }, { + "base-version": 77661, + "data-hash": "A15B8E4247434B020086354F39856C51", + "fixed-hash": "", + "label": "4.11.3", + "replay-hash": "", + "version": 77661 + }, { + "base-version": 78285, + "data-hash": "69493AFAB5C7B45DDB2F3442FD60F0CF", + "fixed-hash": "21D2EBD5C79DECB3642214BAD4A7EF56", + "label": "4.11.4", + "replay-hash": "CAB5C056EDBDA415C552074BF363CC85", + "version": 78285 + }, { + "base-version": 79998, + "data-hash": "B47567DEE5DC23373BFF57194538DFD3", + "fixed-hash": "0A698A1B072BC4B087F44DDEF0BE361E", + "label": "4.12.0", + "replay-hash": "9E15AA09E15FE3AF3655126CEEC7FF42", + "version": 79998 + }, { + "base-version": 80188, + "data-hash": "44DED5AED024D23177C742FC227C615A", + "fixed-hash": "0A698A1B072BC4B087F44DDEF0BE361E", + "label": "4.12.1", + "replay-hash": "9E15AA09E15FE3AF3655126CEEC7FF42", + "version": 80188 + }, { + "base-version": 80949, + "data-hash": "9AE39C332883B8BF6AA190286183ED72", + "fixed-hash": "DACEAFAB8B983C08ACD31ABC085A0052", + "label": "5.0.0", + "replay-hash": "28C41277C5837AABF9838B64ACC6BDCF", + "version": 80949 + }, { + "base-version": 81009, + "data-hash": "0D28678BC32E7F67A238F19CD3E0A2CE", + "fixed-hash": "DACEAFAB8B983C08ACD31ABC085A0052", + "label": "5.0.1", + "replay-hash": "28C41277C5837AABF9838B64ACC6BDCF", + "version": 81009 + }, { + "base-version": 81102, + "data-hash": "DC0A1182FB4ABBE8E29E3EC13CF46F68", + "fixed-hash": "0C193BD5F63BBAB79D798278F8B2548E", + "label": "5.0.2", + "replay-hash": "08BB9D4CAE25B57160A6E4AD7B8E1A5A", + "version": 81102 + }, { + "base-version": 81433, + "data-hash": "5FD8D4B6B52723B44862DF29F232CF31", + "fixed-hash": "4FC35CEA63509AB06AA80AACC1B3B700", + "label": "5.0.3", + "replay-hash": "0920F1BD722655B41DA096B98CC0912D", + "version": 81433 + }, { + "base-version": 82457, + "data-hash": "D2707E265785612D12B381AF6ED9DBF4", + "fixed-hash": "ED05F0DB335D003FBC3C7DEF69911114", + "label": "5.0.4", + "replay-hash": "7D9EE968AAD81761334BD9076BFD9EFF", + "version": 82457 + }, { + "base-version": 82893, + "data-hash": "D795328C01B8A711947CC62AA9750445", + "fixed-hash": "ED05F0DB335D003FBC3C7DEF69911114", + "label": "5.0.5", + "replay-hash": "7D9EE968AAD81761334BD9076BFD9EFF", + "version": 82893 + }, { + "base-version": 83830, + "data-hash": "B4745D6A4F982A3143C183D8ACB6C3E3", + "fixed-hash": "ed05f0db335d003fbc3c7def69911114", + "label": "5.0.6", + "replay-hash": "7D9EE968AAD81761334BD9076BFD9EFF", + "version": 83830 + }, { + "base-version": 84643, + "data-hash": "A389D1F7DF9DD792FBE980533B7119FF", + "fixed-hash": "368DE29820A74F5BE747543AC02DB3F8", + "label": "5.0.7", + "replay-hash": "7D9EE968AAD81761334BD9076BFD9EFF", + "version": 84643 + }, { + "base-version": 86383, + "data-hash": "22EAC562CD0C6A31FB2C2C21E3AA3680", + "fixed-hash": "B19F4D8B87A2835F9447CA17EDD40C1E", + "label": "5.0.8", + "replay-hash": "7D9EE968AAD81761334BD9076BFD9EFF", + "version": 86383 + }, { + "base-version": 87702, + "data-hash": "F799E093428D419FD634CCE9B925218C", + "fixed-hash": "B19F4D8B87A2835F9447CA17EDD40C1E", + "label": "5.0.9", + "replay-hash": "7D9EE968AAD81761334BD9076BFD9EFF", + "version": 87702 + }, { + "base-version": 88500, + "data-hash": "F38043A301B034A78AD13F558257DCF8", + "fixed-hash": "F3853B6E3B6013415CAC30EF3B27564B", + "label": "5.0.10", + "replay-hash": "A79CD3B6C6DADB0ECAEFA06E6D18E47B", + "version": 88500 + } +] diff --git a/worlds/_sc2common/bot/wsl.py b/worlds/_sc2common/bot/wsl.py new file mode 100644 index 000000000000..08c453a02b62 --- /dev/null +++ b/worlds/_sc2common/bot/wsl.py @@ -0,0 +1,117 @@ +# pylint: disable=R0911,W1510 +import os +import re +import subprocess +from pathlib import Path, PureWindowsPath + +from worlds._sc2common.bot import logger + +## This file is used for compatibility with WSL and shouldn't need to be +## accessed directly by any bot clients + + +def win_path_to_wsl_path(path): + """Convert a path like C:\\foo to /mnt/c/foo""" + return Path("/mnt") / PureWindowsPath(re.sub("^([A-Z]):", lambda m: m.group(1).lower(), path)) + + +def wsl_path_to_win_path(path): + """Convert a path like /mnt/c/foo to C:\\foo""" + return PureWindowsPath(re.sub("^/mnt/([a-z])", lambda m: m.group(1).upper() + ":", path)) + + +def get_wsl_home(): + """Get home directory of from Windows, even if run in WSL""" + proc = subprocess.run(["powershell.exe", "-Command", "Write-Host -NoNewLine $HOME"], capture_output=True) + + if proc.returncode != 0: + return None + + return win_path_to_wsl_path(proc.stdout.decode("utf-8")) + + +RUN_SCRIPT = """$proc = Start-Process -NoNewWindow -PassThru "%s" "%s" +if ($proc) { + Write-Host $proc.id + exit $proc.ExitCode +} else { + exit 1 +}""" + + +def run(popen_args, sc2_cwd): + """Run SC2 in Windows and get the pid so that it can be killed later.""" + path = wsl_path_to_win_path(popen_args[0]) + args = " ".join(popen_args[1:]) + + return subprocess.Popen( + ["powershell.exe", "-Command", RUN_SCRIPT % (path, args)], + cwd=sc2_cwd, + stdout=subprocess.PIPE, + universal_newlines=True, + bufsize=1, + ) + + +def kill(wsl_process): + """Needed to kill a process started with WSL. Returns true if killed successfully.""" + # HACK: subprocess and WSL1 appear to have a nasty interaction where + # any streams are never closed and the process is never considered killed, + # despite having an exit code (this works on WSL2 as well, but isn't + # necessary). As a result, + # 1: We need to read using readline (to make sure we block long enough to + # get the exit code in the rare case where the user immediately hits ^C) + out = wsl_process.stdout.readline().rstrip() + # 2: We need to use __exit__, since kill() calls send_signal(), which thinks + # the process has already exited! + wsl_process.__exit__(None, None, None) + proc = subprocess.run(["taskkill.exe", "-f", "-pid", out], capture_output=True) + return proc.returncode == 0 # Returns 128 on failure + + +def detect(): + """Detect the current running version of WSL, and bail out if it doesn't exist""" + # Allow disabling WSL detection with an environment variable + if os.getenv("SC2_WSL_DETECT", "1") == "0": + return None + + wsl_name = os.environ.get("WSL_DISTRO_NAME") + if not wsl_name: + return None + + try: + wsl_proc = subprocess.run(["wsl.exe", "--list", "--running", "--verbose"], capture_output=True) + except (OSError, ValueError): + return None + if wsl_proc.returncode != 0: + return None + + # WSL.exe returns a bunch of null characters for some reason, as well as + # windows-style linebreaks. It's inconsistent about how many \rs it uses + # and this could change in the future, so strip out all junk and split by + # Unix-style newlines for safety's sake. + lines = re.sub(r"\000|\r", "", wsl_proc.stdout.decode("utf-8")).split("\n") + + def line_has_proc(ln): + return re.search("^\\s*[*]?\\s+" + wsl_name, ln) + + def line_version(ln): + return re.sub("^.*\\s+(\\d+)\\s*$", "\\1", ln) + + versions = [line_version(ln) for ln in lines if line_has_proc(ln)] + + try: + version = versions[0] + if int(version) not in [1, 2]: + return None + except (ValueError, IndexError): + return None + + logger.info(f"WSL version {version} detected") + + if version == "2" and not (os.environ.get("SC2CLIENTHOST") and os.environ.get("SC2SERVERHOST")): + logger.warning("You appear to be running WSL2 without your hosts configured correctly.") + logger.warning("This may result in SC2 staying on a black screen and not connecting to your bot.") + logger.warning("Please see the python-sc2 README for WSL2 configuration instructions.") + + return "WSL" + version diff --git a/worlds/_sc2common/requirements.txt b/worlds/_sc2common/requirements.txt new file mode 100644 index 000000000000..2910b68c625c --- /dev/null +++ b/worlds/_sc2common/requirements.txt @@ -0,0 +1,6 @@ +s2clientprotocol>=5.0.11.90136.0 +mpyq>=0.2.5 +portpicker>=1.5.2 +aiohttp>=3.8.4 +loguru>=0.7.0 +protobuf==3.20.3 \ No newline at end of file diff --git a/worlds/sc2wol/requirements.txt b/worlds/sc2wol/requirements.txt index 11e836d302f4..9b84863c4590 100644 --- a/worlds/sc2wol/requirements.txt +++ b/worlds/sc2wol/requirements.txt @@ -1,3 +1,2 @@ nest-asyncio >= 1.5.5 -six >= 1.16.0 -apsc2 >= 5.6 \ No newline at end of file +six >= 1.16.0 \ No newline at end of file