Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[OpPerf] Implement remaining GEMM ops #17501

Merged
merged 3 commits into from
Feb 7, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions benchmark/opperf/nd_operations/gemm_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

1. dot
2. batch_dot
3. khatri_rao

TODO
3. As part of default tests, following needs to be added:
Expand All @@ -36,7 +37,7 @@

def run_gemm_operators_benchmarks(ctx=mx.cpu(), dtype='float32', profiler='native', warmup=25, runs=100):
"""Runs benchmarks with the given context and precision (dtype)for all the GEMM
operators (dot, batch_dot) in MXNet.
operators (dot, batch_dot, khatri_rao) in MXNet.

Parameters
----------
Expand All @@ -54,7 +55,7 @@ def run_gemm_operators_benchmarks(ctx=mx.cpu(), dtype='float32', profiler='nativ
Dictionary of results. Key -> Name of the operator, Value -> Benchmark results.

"""
# Benchmark tests for dot and batch_dot operators
# Benchmark tests for dot operator
dot_benchmark_res = run_performance_test(
[getattr(MX_OP_MODULE, "dot")], run_backward=True,
dtype=dtype, ctx=ctx,
Expand All @@ -68,7 +69,7 @@ def run_gemm_operators_benchmarks(ctx=mx.cpu(), dtype='float32', profiler='nativ
"transpose_a": True,
"transpose_b": True}],
warmup=warmup, runs=runs, profiler=profiler)

# Benchmark tests for batch_dot operator
batch_dot_benchmark_res = run_performance_test(
[getattr(MX_OP_MODULE, "batch_dot")], run_backward=True,
dtype=dtype, ctx=ctx,
Expand All @@ -82,7 +83,17 @@ def run_gemm_operators_benchmarks(ctx=mx.cpu(), dtype='float32', profiler='nativ
"transpose_a": True,
"transpose_b": True}],
warmup=warmup, runs=runs, profiler=profiler)
# Operator khatri_rao is not yet implemented for GPU
khatri_rao_benchmark_res = []
if ctx != mx.gpu():
# Benchmark tests for khatri_rao operator
khatri_rao_benchmark_res = run_performance_test(
[getattr(MX_OP_MODULE, "khatri_rao")], run_backward=False,
dtype=dtype, ctx=ctx,
inputs=[{"args": [(32, 32), (32, 32)]},
{"args": [(64, 64), (64, 64)]}],
warmup=warmup, runs=runs, profiler=profiler)

# Prepare combined results for GEMM operators
mx_gemm_op_results = merge_map_list(dot_benchmark_res + batch_dot_benchmark_res)
mx_gemm_op_results = merge_map_list(dot_benchmark_res + batch_dot_benchmark_res + khatri_rao_benchmark_res)
return mx_gemm_op_results