Skip to content

Commit

Permalink
Support default_variations in combined variations
Browse files Browse the repository at this point in the history
  • Loading branch information
arjunsuresh committed Jan 28, 2024
1 parent 2dbf70a commit 83dd52b
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 20 deletions.
75 changes: 57 additions & 18 deletions cm-mlops/automation/script/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1704,6 +1704,14 @@ def _update_variation_tags_from_variations(self, variation_tags, variations, var
v_static = self._get_name_for_dynamic_variation_tag(v)
tmp_variation_tags_static[v_i] = v_static

combined_variations = [ t for t in variations if ',' in t ]
# We support default_variations in the meta of cmbined_variations
combined_variations.sort(key=lambda x: x.count(','))
''' By sorting based on the number of variations users can safely override
env and state in a larger combined variation
'''
tmp_combined_variations = {k: False for k in combined_variations}

# Recursively add any base variations specified
if len(variation_tags) > 0:
tmp_variations = {k: False for k in variation_tags}
Expand Down Expand Up @@ -1748,28 +1756,34 @@ def _update_variation_tags_from_variations(self, variation_tags, variations, var
tag_to_append = None

# default_variations dictionary specifies the default_variation for each variation group. A default variation in a group is turned on if no other variation from that group is turned on and it is not excluded using the '-' prefix
if "default_variations" in variations[variation_name]:
default_base_variations = variations[variation_name]["default_variations"]
for default_base_variation in default_base_variations:
tag_to_append = None
r = self._get_variation_tags_from_default_variations(variations[variation_name], variations, variation_groups, tmp_variation_tags_static, excluded_variation_tags)
if r['return'] > 0:
return r

if default_base_variation not in variation_groups:
return {'return': 1, 'error': 'Default variation "{}" is not a valid group. Valid groups are "{}" '.format(default_base_variation, variation_groups)}
variations_to_add = r['variations_to_add']
for t in variations_to_add:
tmp_variations[t] = False
variation_tags.append(t)

unique_allowed_variations = variation_groups[default_base_variation]['variations']
# add the default only if none of the variations from the current group is selected and it is not being excluded with - prefix
if len(set(unique_allowed_variations) & set(tmp_variation_tags_static)) == 0 and default_base_variations[default_base_variation] not in excluded_variation_tags and default_base_variations[default_base_variation] not in tmp_variation_tags_static:
tag_to_append = default_base_variations[default_base_variation]
tmp_variations[variation_name] = True

if tag_to_append:
if tag_to_append not in variations:
variation_tag_static = self._get_name_for_dynamic_variation_tag(tag_to_append)
if not variation_tag_static or variation_tag_static not in variations:
return {'return': 1, 'error': 'Invalid variation "{}" specified in default variations for the variation "{}" '.format(tag_to_append, variation_name)}
variation_tags.append(tag_to_append)
tmp_variations[tag_to_append] = False
for combined_variation in combined_variations:
if tmp_combined_variations[combined_variation]:
continue
v = combined_variation.split(",")
all_present = set(v).issubset(set(variation_tags))
if all_present:
combined_variation_meta = variations[combined_variation]
tmp_combined_variations[combined_variation] = True

tmp_variations[variation_name] = True
r = self._get_variation_tags_from_default_variations(combined_variation_meta, variations, variation_groups, tmp_variation_tags_static, excluded_variation_tags)
if r['return'] > 0:
return r

variations_to_add = r['variations_to_add']
for t in variations_to_add:
tmp_variations[t] = False
variation_tags.append(t)

all_base_processed = True
for variation_name in variation_tags:
Expand All @@ -1785,6 +1799,31 @@ def _update_variation_tags_from_variations(self, variation_tags, variations, var
return {'return': 0}


def _get_variation_tags_from_default_variations(self, variation_meta, variations, variation_groups, tmp_variation_tags_static, excluded_variation_tags):
# default_variations dictionary specifies the default_variation for each variation group. A default variation in a group is turned on if no other variation from that group is turned on and it is not excluded using the '-' prefix

tmp_variation_tags = []
if "default_variations" in variation_meta:
default_base_variations = variation_meta["default_variations"]
for default_base_variation in default_base_variations:
tag_to_append = None

if default_base_variation not in variation_groups:
return {'return': 1, 'error': 'Default variation "{}" is not a valid group. Valid groups are "{}" '.format(default_base_variation, variation_groups)}

unique_allowed_variations = variation_groups[default_base_variation]['variations']
# add the default only if none of the variations from the current group is selected and it is not being excluded with - prefix
if len(set(unique_allowed_variations) & set(tmp_variation_tags_static)) == 0 and default_base_variations[default_base_variation] not in excluded_variation_tags and default_base_variations[default_base_variation] not in tmp_variation_tags_static:
tag_to_append = default_base_variations[default_base_variation]

if tag_to_append:
if tag_to_append not in variations:
variation_tag_static = self._get_name_for_dynamic_variation_tag(tag_to_append)
if not variation_tag_static or variation_tag_static not in variations:
return {'return': 1, 'error': 'Invalid variation "{}" specified in default variations for the variation "{}" '.format(tag_to_append, variation_name)}
tmp_variation_tags.append(tag_to_append)

return {'return': 0, 'variations_to_add': tmp_variation_tags}

############################################################
def version(self, i):
Expand Down
12 changes: 10 additions & 2 deletions cm-mlops/script/app-mlperf-inference-cpp/_cm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -245,5 +245,13 @@ variations:
CM_MLPERF_LOADGEN_SCENARIO: Server

multistream,resnet50:
env:
CM_MLPERF_LOADGEN_MAX_BATCHSIZE: 8
default_variations:
batch-size: batch-size.8

offline,resnet50:
default_variations:
batch-size: batch-size.8

multistream,retinanet:
default_variations:
batch-size: batch-size.1

0 comments on commit 83dd52b

Please sign in to comment.