Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-263] Support for dot(dns, csr) = dns and dot(dns, csr.T) = dns on GPU #10371

Merged
merged 7 commits into from
Apr 26, 2018

Conversation

haojin2
Copy link
Contributor

@haojin2 haojin2 commented Apr 2, 2018

Description

Support dot(dns, csr) = dns and dot(dns, csr.T) = dns.

Checklist

Essentials

  • The PR title starts with [MXNET-263]
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • Check the API doc at http://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-$PR_ID/$BUILD_ID/index.html
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • Support dot(dns, csr.T) = dns
  • Support dot(dns, csr) = dns
  • Add forward_stype_hint to dot interface
  • Tests for dot(dns, csr.T)
  • Tests for dot(dns, csr)
  • Test for new forward storage type inference with forward_stype_hint
  • Fix overly strict atol in test_gluon.test_lambda mentioned in Flaky test_gluon.test_lambda #10376

Comments

The storage type hint is to provide the user a way to specify what he/she wants so that we can avoid surprises when they hit any un-supported cases. A test with full coverage of all possible input cases is added.
The algorithm for dot(dense, csr) = dense will be changed to a deterministic version. On the other hand a determinism test will be added soon.
Profling results for dot(dense, csr.T):
([CSR density%] [Initialization Phase] [Transpose Phase] [Multiplication-Addition Phase]) (time in ns)
(1.00% 2877416 93137180 4862214884)
(0.10% 24891 7864861 560496940)
(0.01% 15276 652855 58704271)
Most of time is spent on the Multiplication-Addition Phase, we'll be working on improvement of the computation kernel soon. @eric-haibin-lin
Tested with the new warp kernel but saw no fundamental improvement, the kernel code:

struct DotDnsCsrTransDnsWarpKernel {
  /*!
   * \brief
   * \param tid          global thread id
   * \param lhs_data     lhs dense matrix data
   * \param rhs_data     csr matrix data
   * \param rhs_indices  csr matrix column indices
   * \param rhs_indptr   csr matrix row pointer
   * \param out          output matrix data
   * \param lhs_num_cols lhs dns matrix number of columns
   * \param out_num_rows output dns matrix number of rows
   * \param out_num_cols output dns matrix number of columns
   */
  template<typename DType, typename IType, typename CType>
__device__ __forceinline__ static void Map(int tid,
                                             const DType* lhs_data,
                                             const DType* rhs_data,
                                             const IType* rhs_indices,
                                             const CType* rhs_indptr,
                                             DType* out,
                                             const nnvm::dim_t lhs_num_cols,
                                             const nnvm::dim_t out_num_rows,
                                             const nnvm::dim_t out_num_cols) {
    using nnvm::dim_t;
    __shared__ volatile DType vals[mshadow::cuda::kBaseThreadNum];
    int warp_id = tid / 32;
    int lane = tid & (32 - 1);
    if (warp_id < out_num_rows*out_num_cols) {
      const dim_t i = static_cast<dim_t>(warp_id) % out_num_rows;  // i = row this thread computes
      const dim_t k = static_cast<dim_t>(warp_id) / out_num_rows;  // k = col this thread computes
      // Compute inner product of i-th row and k-th col
      DType sum = 0;
      for (CType col_id = rhs_indptr[k] + lane; col_id < rhs_indptr[k + 1]; col_id += 32) {
        sum += lhs_data[i * lhs_num_cols + rhs_indices[col_id]] * rhs_data[col_id];
      }
      vals[threadIdx.x] = sum; __syncwarp();

      // Parallel reduction in shared memory
      if (lane < 16) {vals[threadIdx.x] += vals[threadIdx.x+16];} __syncwarp();
      if (lane <  8) {vals[threadIdx.x] += vals[threadIdx.x+ 8];} __syncwarp();
      if (lane <  4) {vals[threadIdx.x] += vals[threadIdx.x+ 4];} __syncwarp();
      if (lane <  2) {vals[threadIdx.x] += vals[threadIdx.x+ 2];} __syncwarp();
      if (lane <  1) {vals[threadIdx.x] += vals[threadIdx.x+ 1];} __syncwarp();

      if (lane == 0) {
        out[i * out_num_cols + k] = vals[threadIdx.x];
      }
    }
  }
};

@anirudh2290
Copy link
Member

anirudh2290 commented Apr 2, 2018

dot(dns, csr) output a csr ndarray when cpu context is used, and dns ndarray when gpu context is used ? This should be at least well documented somewhere.

@haojin2
Copy link
Contributor Author

haojin2 commented Apr 2, 2018

@anirudh2290 Corresponding documentations will be added once the implementations and tests are complete.

@@ -235,13 +235,21 @@ inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs,
DispatchMode::kFComputeEx);
}
if (!dispatched && lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage &&
!param.transpose_a && !param.transpose_b) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we will add a stype hint argument which defaults to None?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Working on that now.

@eric-haibin-lin eric-haibin-lin self-assigned this Apr 3, 2018

namespace mxnet {
namespace op {

template<typename gpu>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like for dot, gpu kernels are in dot-inl.cuh. Let's move the implementation here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

MSHADOW_SGL_DBL_TYPE_SWITCH(csr_data.type_flag_, DType, { // data type
MSHADOW_IDX_TYPE_SWITCH(csr_indices.type_flag_, IType, { // indptr type
MSHADOW_IDX_TYPE_SWITCH(csr_indptr.type_flag_, CType, { // colidx type
/* Allocate workspace */
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this comment?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

for (CType col_id = rhs_indptr[k]; col_id < rhs_indptr[k + 1]; ++col_id) {
sum += lhs_data[i * lhs_num_cols + rhs_indices[col_id]] * rhs_data[col_id];
}
out[i*out_num_cols+k] = sum;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use tid directly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, done.

@haojin2 haojin2 force-pushed the dot_dns_csr_dns branch 2 times, most recently from f168837 to 387c66b Compare April 7, 2018 00:47
@haojin2 haojin2 changed the title [MXNET-263] [WIP] Support for dot(dns, csr) = dns and dot(dns, csr.T) = dns on GPU [MXNET-263] Support for dot(dns, csr) = dns and dot(dns, csr.T) = dns on GPU Apr 7, 2018
@haojin2 haojin2 changed the title [MXNET-263] Support for dot(dns, csr) = dns and dot(dns, csr.T) = dns on GPU [MXNET-263] [WIP] Support for dot(dns, csr) = dns and dot(dns, csr.T) = dns on GPU Apr 7, 2018
@haojin2 haojin2 changed the title [MXNET-263] [WIP] Support for dot(dns, csr) = dns and dot(dns, csr.T) = dns on GPU [MXNET-263] Support for dot(dns, csr) = dns and dot(dns, csr.T) = dns on GPU Apr 7, 2018
@eric-haibin-lin
Copy link
Member

Please fix test

test_operator_gpu.test_sparse_dot ... /work/runtime_functions.sh: line 389:     7 Segmentation fault      (core dumped) nosetests-2.7 --verbose tests/python/gpu

- dot(csr, default) = default
- dot(csr.T, default) = row_sparse
- dot(csr, row_sparse) = default
- dot(default, csr) = csr
- dot(default, csr) = csr on CPU only
- dot(default, csr) = dense on GPU only
Copy link
Member

@eric-haibin-lin eric-haibin-lin Apr 10, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you show the output storage with specific values of forward_stype_hint? e.g.

dot(csr, dense, trx_a=True) = row_sparse
dot(csr, dense, forward_stype_hint='default', tx_a=True) = default
dot(default, csr, forward_stype_hint='default') = default (GPU only)

What happens if someone uses dot(dense, dense, forward_stype_hint='csr')?

DMLC_DECLARE_PARAMETER(DotParam) {
DMLC_DECLARE_FIELD(transpose_a)
.describe("If true then transpose the first input before dot.")
.set_default(false);
DMLC_DECLARE_FIELD(transpose_b)
.describe("If true then transpose the second input before dot.")
.set_default(false);
DMLC_DECLARE_FIELD(forward_stype_hint)
.describe("Desired storage type of the forward output.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to be more detailed than this. Explicitly state that if no such combination is implemented, dense op is used.

if (!dispatched && lhs_stype == kDefaultStorage &&
rhs_stype == kDefaultStorage) {
// dns, dns -> dns
dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode,
DispatchMode::kFCompute);
target_stype = (param.forward_stype_hint.has_value())?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This line is quite long. Maybe cache the value of param.forward_stype_hint.has_value()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

DType* csc_data,
unsigned long long* csc_indices,
unsigned long long* csc_indptr,
unsigned long long* workspace,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename "workspace" to "col_counters"? Better document what this space is used for:
e.g. used to count the offset of column indices atomically

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

if (tid < num_rows) {
for (CType i = csr_indptr[tid]; i < csr_indptr[tid + 1]; ++i) {
// target column
IType target_col = csr_indices[i];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use const wherever applies

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


// if dot(dense, csr) = dns, transform to csc first
if (!transpose_b) {
// LOG(FATAL) << "dot(dns, csr) = dns not implemented yet";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove unused code please

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

MSHADOW_IDX_TYPE_SWITCH(csr_indices.type_flag_, IType, {
MSHADOW_IDX_TYPE_SWITCH(csr_indptr.type_flag_, CType, {
DType* csc_data_ptr = NULL;
unsigned long long* csc_indices_ptr = NULL;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe typedef to "AtomicIType"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, done

unsigned long long* csc_indices_ptr = NULL;
unsigned long long* csc_indptr_ptr = NULL;
unsigned long long* col_counters = NULL;
size_t ull_mem_size = sizeof(unsigned long long);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does ull mean?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh nvm unsigned long long.. Not obvious though

temp_storage_bytes += (ull_mem_size - (temp_storage_bytes % ull_mem_size));
Tensor<gpu, 1, char> workspace =
ctx.requested[0].get_space_typed<gpu, 1, char>(
Shape1(nnz*sizeof(DType) + nnz*ull_mem_size +
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we calculate nnz*ull_mem_size once and assign it to a meaningful local variable? That improves readability

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Stream<gpu>::GetStream(s));
// Reset values for col_counter, ready for the final transform
mxnet_op::Kernel<mxnet_op::set_zero, gpu>::Launch(
s, csr_cols+1, col_counters);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we rename "csr_cols" to "num_csr_cols"? It's not obvious from its name about whether it's a number, or a pointer to data, unless reading backward in code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

MSHADOW_SGL_DBL_TYPE_SWITCH(csr_data.type_flag_, DType, { // data type
MSHADOW_IDX_TYPE_SWITCH(csr_indices.type_flag_, IType, { // indptr type
MSHADOW_IDX_TYPE_SWITCH(csr_indptr.type_flag_, CType, { // colidx type
CType out_num_rows = ret->shape()[0];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like these two lines are duplicates of line 1044 and can be moved out side of if-else?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, done

@haojin2 haojin2 force-pushed the dot_dns_csr_dns branch 2 times, most recently from 52ebdbc to cb2671b Compare April 13, 2018 02:13
for (CType i = csr_indptr[tid]; i < csr_indptr[tid + 1]; ++i) {
// target column
const IType target_col = csr_indices[i];
const int target_offset = atomicAdd(&col_counters[target_col], 1);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this provides deterministic result.. The order of accumulation could be different across runs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this only affects the order of data within the column, when we are doing the final multiplication-accumulation the result should still be the same.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well...maybe u r right about this, but I guess then we'll have to add a sort after this to ensure the order.

test_infer_forward_stype(lhs_shape, (lhs_shape[0], rnd.randint(10, 20)),
lhs_d, rhs_d, True, False)
test_infer_forward_stype(lhs_shape, (rnd.randint(10, 20), lhs_shape[0]),
lhs_d, rhs_d, True, True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should think of a test to check if the result is determinist

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure

@haojin2 haojin2 changed the title [MXNET-263] Support for dot(dns, csr) = dns and dot(dns, csr.T) = dns on GPU [MXNET-263] [WIP] Support for dot(dns, csr) = dns and dot(dns, csr.T) = dns on GPU Apr 13, 2018
@haojin2 haojin2 force-pushed the dot_dns_csr_dns branch 11 times, most recently from a46bb74 to 66b197d Compare April 18, 2018 06:15

/*!
* \brief GPU Kernel of generation of transposed csr matrix
* \param tid global thread id
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove line 472, 473

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

- dot(default, csr) = csr
- otherwise, ``dot`` generates output with default storage
- dot(default, csr) = csr on CPU only
- dot(default, csr) = dense on GPU only
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it better to write "dot(default, csr, forward_stype_hint='default')"?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also dot(default, csr, transpose_b=True, forward_stype_hint='default') ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also just choose one of "default" or "dense".. Mixing them in the same line will be confusing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

rhs = rhs_nd.tostype(rhs_stype)
out = mx.nd.dot(lhs, rhs, forward_stype_hint=forward_stype,
transpose_a=trans_a, transpose_b=trans_b)
assert_almost_equal(out.tostype('default').asnumpy(), out_np, rtol=1e-4, atol=1e-5)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

out.asnumpy() should be fine

forward_stype_hint=forward_stype,
transpose_a=trans_a, transpose_b=trans_b)
location = {'lhs': lhs, 'rhs': rhs}
check_symbolic_forward(out, location, [out_np], rtol=1e-3, atol=1e-4)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also check_symbolic_backward?

@with_seed()
def test_sparse_dot_determinism():
def test_dot_determinism(lhs_stype, rhs_stype, lhs_density, rhs_density, transpose_a, transpose_b):
lhs_row = rnd.randint(200, 400)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

an int between 200 and 400 seems too large for unit test. Maybe just choose between 50 and 100?

@@ -23,6 +23,7 @@
*/

#include "./dot-inl.h"
#include <cub/cub.cuh>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed here?

MSHADOW_SGL_DBL_TYPE_SWITCH(csr_data.type_flag_, DType, { // data type
MSHADOW_IDX_TYPE_SWITCH(csr_indices.type_flag_, IType, { // indptr type
MSHADOW_IDX_TYPE_SWITCH(csr_indptr.type_flag_, CType, { // colidx type
const CType out_num_rows = ret->shape()[0];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shape is always dim_t

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


int num_bits = 1;
unsigned int a = num_csr_cols - 1;
while (a >>= 1) num_bits++;
Copy link
Member

@eric-haibin-lin eric-haibin-lin Apr 19, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we reuse the ilog2 function ? the variable "a" is quite unreadable

@haojin2 haojin2 force-pushed the dot_dns_csr_dns branch 2 times, most recently from c3e9ba8 to 4f025d6 Compare April 20, 2018 05:48
@haojin2 haojin2 changed the title [MXNET-263] [WIP] Support for dot(dns, csr) = dns and dot(dns, csr.T) = dns on GPU [MXNET-263] Support for dot(dns, csr) = dns and dot(dns, csr.T) = dns on GPU Apr 20, 2018
@haojin2 haojin2 force-pushed the dot_dns_csr_dns branch 2 times, most recently from 9f77910 to 4a4ac67 Compare April 22, 2018 07:28
Copy link
Member

@eric-haibin-lin eric-haibin-lin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also paste the benchmark script & result later? Thx

forward_stype=forward_stype,
transpose_a=trans_a, transpose_b=trans_b)
location = {'lhs': lhs, 'rhs': rhs}
check_symbolic_forward(out, location, [out_np], rtol=1e-3, atol=1e-4)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we test check_symbolic_backward for dot(dns, csr)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

dispatched = dispatch_fallback(out_attrs, dispatch_mode);
target_stype = (target_stype == kUndefinedStorage)? kDefaultStorage : target_stype;
dispatched = storage_type_assign(&out_stype, target_stype, dispatch_mode,
DispatchMode::kFComputeFallback);
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should also update InferStorageType for backward dot

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

* \brief GPU Impl of dot(dns, csr) = csr
*/
template<typename gpu>
inline void DotDnsCsrCsrImpl(const OpContext& ctx,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

template<>
inline void DotDnsCsrCsrImpl<gpu>(const OpContext& ctx,...)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@haojin2 haojin2 force-pushed the dot_dns_csr_dns branch 3 times, most recently from ccb9f5a to 880ae43 Compare April 24, 2018 00:37
- dot(default, csr) = csr
- otherwise, ``dot`` generates output with default storage
- dot(default, csr) = csr on CPU only
- dot(default, csr) = default on GPU only
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pls update doc

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

size_t csc_indptr_bytes = (num_csr_cols+1)*sizeof(CType);
size_t csc_data_bytes = nnz*sizeof(DType);
size_t scan_temp_storage_bytes = 0;
size_t temp_storage_bytes = SortByKeyWorkspaceSize<dim_t, dim_t, gpu>(nnz);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it should be SortByKeyWorkspaceSize<IType, IType, gpu>(nnz);? We might add 32bit int for idx dtype in the future to speedup / reduce memory

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@haojin2
Copy link
Contributor Author

haojin2 commented Apr 24, 2018

Benchmark results: [density of rhs csr(%)] [speedup rate]

  • dot(dense, csr):
    • without fallback time:
      1.00 % 0.11839042163642717
      0.50 % 0.2333432217062633
      0.10 % 1.1903050457360003
      0.05 % 2.358481865354682
      0.01 % 11.254730287959426
    • with fallback time:
      1.00 % 0.15756154617574628
      0.50 % 0.2846213797972692
      0.10 % 1.355322535347357
      0.05 % 2.649692301508364
      0.01 % 12.53094788814324
  • dot(dense, csr.T):
    • without fallback time:
      1.00 % 0.13831609996057787
      0.50 % 0.2716895153249566
      0.10 % 1.355685095559762
      0.05 % 2.704108996445601
      0.01 % 12.82104397524722
    • without fallback time:
      1.00 % 0.17843282886647807
      0.50 % 0.3281353223927983
      0.10 % 1.5357395689258746
      0.05 % 3.01124368707238
      0.01 % 14.151675772120354

@haojin2
Copy link
Contributor Author

haojin2 commented Apr 24, 2018

benchmark script:

import mxnet as mx
import numpy as np
from mxnet.test_utils import assert_almost_equal
import time

def measure_cost(repeat, f, *args, **kwargs):
    # start bench
    start = time.time()
    results = []
    for i in range(repeat):
        results.append(f(*args, **kwargs))
    for result in results:
        result.wait_to_read()
    end = time.time()
    diff = end - start
    return diff / repeat

def measure_fallback(repeat, a):
    # start bench
    start = time.time()
    results = []
    for i in range(repeat):
        results.append(a.tostype('default'))
    for result in results:
        result.wait_to_read()
    end = time.time()
    diff = end - start
    return diff / repeat

def main():
    shape_lhs = (256, 30000)
    shape_rhs = (30000, 30000)
    dns = np.random.uniform(size=shape_lhs)
    mx_dns = mx.nd.array(dns, ctx=mx.gpu())
    mx_dns_cpy = mx_dns.copy()
    for density in [0.01, 0.005, 0.001, 0.0005, 0.0001]:
        csr = scipy.sparse.random(shape_rhs[0], shape_rhs[1], density=density, format = 'csr', dtype=np.float32)
        mx_csr = mx.nd.sparse.csr_matrix((csr.data, csr.indices, csr.indptr), shape=shape_rhs, ctx=mx.gpu())
        mx_csr_dns = mx_csr.tostype('default')
        sparse_cost = 0.0
        dns_cost = 0.0
        mx.nd.waitall()
        #warmup
        check = mx.nd.dot(mx_dns, mx_csr, forward_stype='default')
        check_np = np.dot(dns, mx_csr_dns.asnumpy())
        assert_almost_equal(check.asnumpy(), check_np, atol=1e-5, rtol=1e-4)
        print(check.shape)
        mx.nd.waitall()
        for i in range(50):
            sparse_cost += measure_cost(1, mx.nd.dot, mx_dns, mx_csr, forward_stype='default')
            dns_cost += measure_fallback(1, mx_csr)
            dns_cost += measure_cost(1, mx.nd.dot, mx_dns, mx_csr_dns)
        print("%.2f %%" % (density*100), dns_cost / sparse_cost)
        sparse_cost = 0.0
        dns_cost = 0.0
        check = mx.nd.dot(mx_dns, mx_csr, transpose_b=True, forward_stype='default')
        check_np = np.dot(dns, mx_csr_dns.asnumpy().T)
        assert_almost_equal(check.asnumpy(), check_np, atol=1e-5, rtol=1e-4)
        print(check.shape)
        mx.nd.waitall()
        for i in range(50):
            sparse_cost += measure_cost(1, mx.nd.dot, mx_dns, mx_csr, transpose_b=True, forward_stype='default')
            dns_cost += measure_fallback(1, mx_csr)
            dns_cost += measure_cost(1, mx.nd.dot, mx_dns, mx_csr_dns, transpose_b=True)
        print("%.2f %%" % (density*100), dns_cost / sparse_cost)
if __name__ == "__main__":
    main()

@eric-haibin-lin eric-haibin-lin merged commit 8727cae into apache:master Apr 26, 2018
rahul003 pushed a commit to rahul003/mxnet that referenced this pull request Jun 4, 2018
… on GPU (apache#10371)

* add support for dot(dns, csr) = dns and dot(dns, csr.T) = dns on GPU

* add unit test for new op and forward_stype_hint parameter to dot

* update documentation for dot

* address code reviews

* fix flaky test_gluon:test_lambda through loosening the atol

* switch dot(dns, csr) case to a deterministic algorithm with unit test for determinism

* address code reviews and add backward
zheng-da pushed a commit to zheng-da/incubator-mxnet that referenced this pull request Jun 28, 2018
… on GPU (apache#10371)

* add support for dot(dns, csr) = dns and dot(dns, csr.T) = dns on GPU

* add unit test for new op and forward_stype_hint parameter to dot

* update documentation for dot

* address code reviews

* fix flaky test_gluon:test_lambda through loosening the atol

* switch dot(dns, csr) case to a deterministic algorithm with unit test for determinism

* address code reviews and add backward
@haojin2 haojin2 deleted the dot_dns_csr_dns branch April 15, 2019 22:21
@haojin2 haojin2 added the Sparse label Aug 12, 2019
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants