Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix non-determinism of dot(csr.T, dns) = dns with tests #11825

Merged
merged 2 commits into from
Jul 26, 2018

Conversation

haojin2
Copy link
Contributor

@haojin2 haojin2 commented Jul 19, 2018

Description

Fix for #10709

Checklist

Essentials

  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • Check the API doc at http://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-$PR_ID/$BUILD_ID/index.html
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • Change dot(csr.T, dns) = dns to a deterministic algorithm
  • Test for determinism of dot(csr.T, dns) = dns

Comments

New determinism test passed 10000 times on local machine.
Correctness check, which takes longer time to run (~5s/trial), have already passed more than 10000 times at this moment.

@haojin2
Copy link
Contributor Author

haojin2 commented Jul 19, 2018

@eric-haibin-lin Please give a review when you have time, thanks!

@@ -573,12 +581,113 @@ inline void DotCsrDnsDnsImpl(const OpContext& ctx,
data_out.dptr<DType>(), data_l.dptr<DType>(), indptr_l.dptr<IType>(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since DotCsrTransDnsDnsWarpKernel, DotCsrTransDnsDnsThreadBlockKernel, and DotCsrTransDnsDnsWarpBlockKernel are nondeterministic and cannot be used, we might as well just remove them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure I can do that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@haojin2 haojin2 force-pushed the fix_dot_determinism branch 2 times, most recently from f63c106 to 498b150 Compare July 19, 2018 19:14
@haojin2 haojin2 changed the title Fix undeterminism of dot(csr.T, dns) = dns with tests Fix non-determinism of dot(csr.T, dns) = dns with tests Jul 19, 2018
Tensor<gpu, 1, char> temp_storage(temp_storage_ptr, Shape1(temp_storage_bytes), s);

int num_bits = log2i(num_csr_cols - 1);
SortByKey(csc_cols, original_idx, true, &temp_storage, 0, num_bits);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: should the argument be num_bits -1 based on the SortByKey function signature?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pay attention to the num_csr_cols - 1 above.

@@ -35,6 +35,13 @@
namespace mxnet {
namespace op {

// Returns integer log2(a) rounded up
inline int log2i(size_t a) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel we should use common utility function with more efficient implementation for such type of computation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only used in this file, plus, pay attention to the mismatch between types of input and output.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although this is only used in this file for now, this function is too generic and very likely to be used in the future by other developers. Making it a generic utility will allow us to change its implementation in the future without breaking existing code. It's output can be cast based on our need. I don't see it as a reason to have a separate function to do that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO a generic utility should be a templated one to support all data types instead of something like this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did a simple search and found log2i (or with slightly different names) function appeared at least three times in this module. See: indexing_op.h, pooled_storage_manner.h and intrinsics.cuh. It would be nice to consolidate them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For normal usages where the inputs and outputs have the same type there're std implementations of this. On the other hand this is already a very optimized code to perform such functionality.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That one is operating on unsigned int and I'm operating on size_t here. Actually you can see that the code is the same, as there's no other usages, putting them in corresponding files that uses them makes more sense.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Taking log2 integer does sound like a common utility. We always need to calculate this before calling SortByKey . Can we have log2i(size_t) and log2i(unsigned int) in common/utils.h and replace other occurrences?

// Returns integer log2(a) rounded up
inline int log2i(size_t a) {
int k = 1;
while (a >>= 1) k++;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use ++k instead

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Member

@eric-haibin-lin eric-haibin-lin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the performance impact of this change?

@@ -35,6 +35,13 @@
namespace mxnet {
namespace op {

// Returns integer log2(a) rounded up
inline int log2i(size_t a) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Taking log2 integer does sound like a common utility. We always need to calculate this before calling SortByKey . Can we have log2i(size_t) and log2i(unsigned int) in common/utils.h and replace other occurrences?

@haojin2
Copy link
Contributor Author

haojin2 commented Jul 22, 2018

@eric-haibin-lin

@@ -663,6 +663,18 @@ constexpr size_t MaxIntegerValue<mshadow::half::half_t>() {
return size_t(2) << 10;
}

MSHADOW_XINLINE int ilog2ul(size_t a) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. Please also refactor other places using the same utility as we discussed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -510,6 +398,7 @@ inline void DotCsrDnsDnsImpl(const OpContext& ctx,
return;
}

using namespace mshadow;
using mshadow::cuda::kBaseThreadNum;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may not need the mshadow directive since you added the line above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not used now, deleted.

@apeforest
Copy link
Contributor

LGTM and thanks for the refactoring.

@haojin2
Copy link
Contributor Author

haojin2 commented Jul 24, 2018

Benchmark script:

import mxnet as mx
import sys
import os
import scipy
import numpy as np
from mxnet.test_utils import assert_almost_equal
import time

def measure_cost(repeat, f, *args, **kwargs):
    # start bench
    start = time.time()
    results = []
    for i in range(repeat):
        results.append(f(*args, **kwargs))
    for result in results:
        result.wait_to_read()
    end = time.time()
    diff = end - start
    return diff / repeat

def measure_fallback(repeat, a):
    # start bench
    start = time.time()
    results = []
    for i in range(repeat):
        results.append(a.tostype('default'))
    for result in results:
        result.wait_to_read()
    end = time.time()
    diff = end - start
    return diff / repeat

def main():
    shape_dns = (1000000, 512)
    shape_csr = (1000000, 128)
    dns = np.random.uniform(size=shape_dns)
    mx_dns = mx.nd.array(dns, ctx=mx.gpu())
    mx_dns_cpy = mx_dns.copy()
    for density in [0.01, 0.005, 0.001, 0.0005, 0.0001]:
        csr = scipy.sparse.random(shape_csr[0], shape_csr[1], density=density, format = 'csr', dtype=np.float32)
        mx_csr = mx.nd.sparse.csr_matrix((csr.data, csr.indices, csr.indptr), shape=shape_csr, ctx=mx.gpu())
        mx_csr_dns = mx_csr.tostype('default')
        sparse_cost = 0.0
        dns_cost = 0.0
        mx.nd.waitall()
        #warmup
        check = mx.nd.dot(mx_csr, mx_dns, transpose_a=True, forward_stype='default')
        check_np = np.dot(mx_csr_dns.asnumpy().T, dns)
        assert_almost_equal(check.asnumpy(), check_np, atol=1e-5, rtol=1e-4)
        print(check.shape)
        mx.nd.waitall()
        for i in range(50):
            sparse_cost += measure_cost(1, mx.nd.dot, mx_csr, mx_dns, transpose_a=True, forward_stype='default')
            dns_cost += measure_cost(1, mx.nd.dot, mx_csr_dns, mx_dns, transpose_a=True)
        print("%.2f %%" % (density*100), dns_cost / sparse_cost)


if __name__ == "__main__":
    main()

Benchmark results on single K80 GPU (all speedup values are based on comparison to the same dense dot computation dot( (1M, 128).T, (1M, 512) ) ):
(density of CSR matrix - speedup)
1.00% - 4.718009661244505
0.50% - 9.146928215069966
0.10% - 32.65680889940526
0.05% - 47.22710560215928
0.01% - 73.19199774617466
@eric-haibin-lin
Note: the previous implementation was not able to handle inputs at this level of sizes so there's no comparison here. Will post a reference comparison with the previous implementation on smaller inputs below.

@haojin2
Copy link
Contributor Author

haojin2 commented Jul 24, 2018

Comparison of performance between before and after on smaller workloads (dot( (100k, 128).T , (100k, 512) )):
(density of csr - speedup compared to dense version before the change / speedup compared to dense version after the change)
1.00 % - 0.11684326423581441 / 3.7605456307819147
0.50 % - 0.13864958837387978 / 4.902107295777013
0.10 % - 0.18179564999182662 / 7.328729395201502
0.05 % - 0.19055669842279166 / 7.820138355111453
0.01 % - 0.20276938703209807 / 8.232172552569692

Copy link
Member

@eric-haibin-lin eric-haibin-lin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice speed up

@haojin2
Copy link
Contributor Author

haojin2 commented Jul 25, 2018

@eric-haibin-lin File conflicts resolved, should be ready for merge once the build passes.

@eric-haibin-lin
Copy link
Member

Great fix!

@eric-haibin-lin eric-haibin-lin merged commit 302aae3 into apache:master Jul 26, 2018
@haojin2 haojin2 deleted the fix_dot_determinism branch July 26, 2018 01:43
XinYao1994 pushed a commit to XinYao1994/incubator-mxnet that referenced this pull request Aug 29, 2018
* fix undeterminism of dot(csr.T, dns) = dns with tests

* address code reviews
@haojin2 haojin2 added the Sparse label Aug 12, 2019
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants