Skip to content

Commit

Permalink
Fix profiler check (apache#14677)
Browse files Browse the repository at this point in the history
* Relax constexpr restriction

* Image classifcation mkldnn

* Check mem profiler greater than 0

* Revert "Relax constexpr restriction"

This reverts commit 5016170.

* Revert "Image classifcation mkldnn"

This reverts commit 30bfab2.

* Add test for profiler

* Simplify test
  • Loading branch information
anirudh2290 authored and larroy committed Apr 15, 2019
1 parent b54423c commit 2c0f259
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 1 deletion.
6 changes: 6 additions & 0 deletions src/profiler/profiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,12 @@ struct ProfileCounter : public ProfileObject {
return IncrementValue(static_cast<uint64_t>(v));
}
}

inline bool operator >=(int64_t v) {
CHECK_GE(v, 0);
return value_ >= static_cast<uint64_t>(v);
}

/*! \brief operator: object = v */
inline ProfileCounter& operator = (uint64_t v) {
SetValue(v);
Expand Down
6 changes: 5 additions & 1 deletion src/profiler/storage_profiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,11 @@ class DeviceStorageProfiler {
Init(); // In case of bug which tries to free first
const size_t idx = prof->DeviceIndex(handle.ctx.dev_type, handle.ctx.dev_id);
CHECK_LT(idx, mem_counters_.size()) << "Invalid device index: " << idx;
*mem_counters_[idx] -= handle.size;
if (*mem_counters_[idx] >= handle.size) {
*mem_counters_[idx] -= handle.size;
} else {
*mem_counters_[idx] = 0;
}
}
}
}
Expand Down
16 changes: 16 additions & 0 deletions tests/python/unittest/test_exc_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,22 @@ def test_multiple_waitalls():
assert caught, "No exception thrown"
mx.nd.waitall()

@with_seed()
def test_exc_profiler():
def run_training_iteration(data):
output = net(data)

net = gluon.nn.HybridSequential()
with net.name_scope():
net.add(gluon.nn.Dense(10))

ctx = default_context()
net.collect_params().initialize(mx.init.Xavier(), ctx=ctx)
data = mx.nd.ones((3, 4))
mx.profiler.set_state("run")
run_training_iteration(data)
mx.nd.waitall()
mx.profiler.set_state("stop")


if __name__ == '__main__':
Expand Down

0 comments on commit 2c0f259

Please sign in to comment.