Skip to content

Commit

Permalink
normal, uniform ops
Browse files Browse the repository at this point in the history
  • Loading branch information
ChaiBapchya committed Jul 15, 2019
1 parent 6acf7e6 commit f8d6f95
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
4 changes: 2 additions & 2 deletions benchmark/opperf/nd_operations/random_sampling_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
1. Operators are automatically fetched from MXNet operator registry.
2. Default Inputs are generated. See rules/default_params.py. You can override the default values.
Below 16 random sampling Operators are covered:
Below 18 random sampling Operators are covered:
['random_exponential', 'random_gamma', 'random_generalized_negative_binomial', 'random_negative_binomial',
'random_normal', 'random_poisson', 'random_randint', 'random_uniform', 'sample_exponential', 'sample_gamma',
'sample_generalized_negative_binomial', 'sample_multinomial', 'sample_negative_binomial', 'sample_normal',
'sample_poisson', 'sample_uniform']
'sample_poisson', 'sample_uniform', 'uniform', 'normal']
"""

Expand Down
5 changes: 4 additions & 1 deletion benchmark/opperf/utils/op_registry_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,10 +212,13 @@ def get_all_random_sampling_operators():
# Get all mxnet operators
mx_operators = _get_all_mxnet_operators()

# Deprecated random sampling operators
deprecate_ops = ['uniform', 'normal']

# Filter for Random Sampling operators
random_sampling_mx_operators = {}
for op_name, op_params in mx_operators.items():
if op_name.startswith(("random_", "sample_")) and op_name not in unique_ops:
if op_name.startswith(("random_", "sample_")) or op_name in deprecate_ops and op_name not in unique_ops:
random_sampling_mx_operators[op_name] = mx_operators[op_name]
return random_sampling_mx_operators

Expand Down

0 comments on commit f8d6f95

Please sign in to comment.