Skip to content

Commit

Permalink
Fix and add type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
BurnySc2 committed Jun 8, 2019
1 parent fa5aee5 commit 9caa515
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions sc2/distances.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import math
import numpy as np

from typing import List, Dict, Tuple, Iterable, Generator
from typing import Dict, Tuple, Iterable, Generator


class DistanceCalculation:
Expand All @@ -20,7 +20,7 @@ def __init__(self):
self._generated_frame = -100
self._generated_frame2 = -100
# A Dictionary with a dict positions: index of the pdist condensed matrix
self._cached_unit_index_dict: Dict[Tuple[float, float], int] = None
self._cached_unit_index_dict: Dict[int, int] = None
# Pdist condensed vector generated by scipy pdist, half the size of the cdist matrix as 1d array
self._cached_pdist: np.ndarray = None

Expand All @@ -29,30 +29,30 @@ def _units_count(self) -> int:
return len(self.all_units)

@property
def _unit_index_dict(self) -> Dict[Tuple[float, float], int]:
def _unit_index_dict(self) -> Dict[int, int]:
""" As property, so it will be recalculated each time it is called, or return from cache if it is called multiple times in teh same game_loop. """
if self._generated_frame != self.state.game_loop:
return self.generate_unit_indices()
return self._cached_unit_index_dict

@property
def _pdist(self):
def _pdist(self) -> np.ndarray:
""" As property, so it will be recalculated each time it is called, or return from cache if it is called multiple times in teh same game_loop. """
if self._generated_frame2 != self.state.game_loop:
return self.calculate_distances()
return self._cached_pdist

def generate_unit_indices(self):
def generate_unit_indices(self) -> Dict[int, int]:
if self._generated_frame != self.state.game_loop:
self._cached_unit_index_dict = {unit.tag: index for index, unit in enumerate(self.all_units)}
self._generated_frame = self.state.game_loop
return self._cached_unit_index_dict

def calculate_distances(self):
def calculate_distances(self) -> np.ndarray:
if self._generated_frame2 != self.state.game_loop:
# Converts tuple [(1, 2), (3, 4)] to flat list like [1, 2, 3, 4]
flat_positions = (coord for unit in self.all_units for coord in unit.position_tuple)
# Converts to numpy array, then converts the flat array back to [[1, 2], [3, 4]]
# Converts to numpy array, then converts the flat array back to shape (n, 2): [[1, 2], [3, 4]]
positions_array: np.ndarray = np.fromiter(
flat_positions, dtype=np.float, count=2 * self._units_count
).reshape((self._units_count, 2))
Expand Down Expand Up @@ -84,7 +84,7 @@ def calculate_distances(self):

return self._cached_pdist

def _get_index_of_two_units(self, unit1: Unit, unit2: Unit):
def _get_index_of_two_units(self, unit1: Unit, unit2: Unit) -> int:
assert unit1.tag in self._unit_index_dict, f"Unit1 {unit1} is not in index dict"
assert unit2.tag in self._unit_index_dict, f"Unit2 {unit2} is not in index dict"
index1 = self._unit_index_dict[unit1.tag]
Expand All @@ -94,15 +94,15 @@ def _get_index_of_two_units(self, unit1: Unit, unit2: Unit):

# Helper functions

def square_to_condensed(self, i, j):
def square_to_condensed(self, i, j) -> int:
# Converts indices of a square matrix to condensed matrix
# https://stackoverflow.com/a/36867493/10882657
assert i != j, "No diagonal elements in condensed matrix! Diagonal elements are zero"
if i < j:
i, j = j, i
return self._units_count * j - j * (j + 1) // 2 + i - 1 - j

def convert_tuple_to_numpy_array(self, pos: Tuple[float, float]):
def convert_tuple_to_numpy_array(self, pos: Tuple[float, float]) -> np.ndarray:
""" Converts a single position to a 2d numpy array with 1 row and 2 columns. """
return np.fromiter(pos, dtype=float, count=2).reshape((1, 2))

Expand All @@ -113,10 +113,10 @@ def distance_math_hypot(self, p1: Tuple[float, float], p2: Tuple[float, float]):

# Distance calculation using the pre-calculated matrix above

def _distance_squared_unit_to_unit(self, unit1: Unit, unit2: Unit):
def _distance_squared_unit_to_unit(self, unit1: Unit, unit2: Unit) -> float:
if unit1.tag == unit2.tag:
return 0
# Calculate dict and distances and cache them
# Calculate index dict and distances and cache them
self._unit_index_dict
self._pdist
# Calculate index, needs to be after pdist has been calculated and cached
Expand All @@ -129,7 +129,7 @@ def _distance_squared_unit_to_unit(self, unit1: Unit, unit2: Unit):

# Distance calculation using the fastest distance calculation functions

def _distance_pos_to_pos(self, pos1: Tuple[float, float], pos2: Tuple[float, float]):
def _distance_pos_to_pos(self, pos1: Tuple[float, float], pos2: Tuple[float, float]) -> float:
return self.distance_math_hypot(pos1, pos2)

def _distance_units_to_pos(self, units: Units, pos: Tuple[float, float]) -> Generator[float, None, None]:
Expand Down

0 comments on commit 9caa515

Please sign in to comment.