Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix agent-mission combination #1868

Merged
merged 10 commits into from
Feb 21, 2023
53 changes: 53 additions & 0 deletions smarts/core/utils/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# THE SOFTWARE.
import math
from dataclasses import dataclass
from itertools import chain, permutations, product, repeat
from math import factorial
from typing import Callable, List, Sequence, Tuple, Union

Expand Down Expand Up @@ -584,3 +585,55 @@ def running_mean(prev_mean: float, prev_step: int, new_val: float) -> Tuple[floa
new_step = prev_step + 1
new_mean = prev_mean + (new_val - prev_mean) / new_step
return new_mean, new_step


def _unique_element_combination(first, second):
"""Generates a combination of set of values result groupings only contain values from unique
indices. Only works if len(first_group) <= len(second_group).
_unique_perm('ab', '123') -> a1 b2 a1 b3 a2 b1 a2 b3 a3 b1 a3 b2
_unique_perm('abc', '1') -> []

Args:
first (Sequence): A sequence of values.
second (Sequence): Another sequence of values.

Yields:
Tuple[Tuple[*Any]]: A set of values similar to permutation.
"""
result = [[a] for a in first]
sl = len(second)
rl = len(first)
perms = list(permutations(range(sl), r=rl))
for i, p in enumerate(perms):
yield tuple(
tuple(result[(i * rl + j) % rl] + [second[idx]]) for j, idx in enumerate(p)
)


def ordered_combinations(first_group, second_group, default=None):
"""Generate a product that generates that sets the values. If
len(first_group) <= len(second_group) the value is padded.

padded_product('ab', '123') -> a1 b2 a1 b3 a2 b1 a2 b3 a3 b1 a3 b2
padded_product('ab', '1', default="k") -> a1 bk ak b1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment might be a bit misleading

  • Should it be called ordered_combinations instead of padded_product?
  • From the comment it looks like it returns a linear sequence rather than a sequence of products

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have clarified the utility and return values.

It should return groups of pairs, where the pairs use only unique indices from the contributor sequences, within each resulting group.


Args:
first_group (Sequence): A sequence of values.
second_group (Sequence): Another sequence of values.
default (Any, optional): The default values. Defaults to None.

Returns:
Generator[Tuple[Tuple[Any, ...]], None, None]: Some permutation values.
"""
len_first = len(first_group)
len_second = len(second_group)
if len_second == 0:
return product(first_group, [default])
if len_first > len_second:
return _unique_element_combination(
first_group,
list(chain(second_group, repeat(default, len_first - len_second))),
)
if len_first <= len_second:
return _unique_element_combination(first_group, second_group)
return []
28 changes: 27 additions & 1 deletion smarts/core/utils/tests/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@
# THE SOFTWARE.
import numpy as np

from smarts.core.utils.math import position_to_ego_frame, world_position_from_ego_frame
from smarts.core.utils.math import (
ordered_combinations,
position_to_ego_frame,
world_position_from_ego_frame,
)


def test_egocentric_conversion():
Expand All @@ -36,3 +40,25 @@ def test_egocentric_conversion():
p_end = world_position_from_ego_frame(pec, pe, he)

assert np.allclose(p_end, p_start)


def test_ordered_combinations():

assert not tuple(ordered_combinations("", ""))
assert tuple(ordered_combinations("a", "")) == (("a", None),)
assert tuple(ordered_combinations("abc", "12", default=1)) == (
(("a", "1"), ("b", "2"), ("c", 1)),
(("a", "1"), ("b", 1), ("c", "2")),
(("a", "2"), ("b", "1"), ("c", 1)),
(("a", "2"), ("b", 1), ("c", "1")),
(("a", 1), ("b", "1"), ("c", "2")),
(("a", 1), ("b", "2"), ("c", "1")),
)
assert tuple(ordered_combinations("ab", "123")) == (
(("a", "1"), ("b", "2")),
(("a", "1"), ("b", "3")),
(("a", "2"), ("b", "1")),
(("a", "2"), ("b", "3")),
(("a", "3"), ("b", "1")),
(("a", "3"), ("b", "2")),
)