Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix profiler check #14677

Merged
merged 10 commits into from
Apr 12, 2019
6 changes: 6 additions & 0 deletions src/profiler/profiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,12 @@ struct ProfileCounter : public ProfileObject {
return IncrementValue(static_cast<uint64_t>(v));
}
}

inline bool operator >=(int64_t v) {
CHECK_GE(v, 0);
return value_ >= static_cast<uint64_t>(v);
}

/*! \brief operator: object = v */
inline ProfileCounter& operator = (uint64_t v) {
SetValue(v);
Expand Down
6 changes: 5 additions & 1 deletion src/profiler/storage_profiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,11 @@ class DeviceStorageProfiler {
Init(); // In case of bug which tries to free first
const size_t idx = prof->DeviceIndex(handle.ctx.dev_type, handle.ctx.dev_id);
CHECK_LT(idx, mem_counters_.size()) << "Invalid device index: " << idx;
*mem_counters_[idx] -= handle.size;
if (*mem_counters_[idx] >= handle.size) {
*mem_counters_[idx] -= handle.size;
} else {
*mem_counters_[idx] = 0;
}
}
}
}
Expand Down
41 changes: 41 additions & 0 deletions tests/python/unittest/test_exc_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from mxnet.gluon import nn
from mxnet.base import MXNetError
from mxnet.test_utils import assert_exception, default_context, set_default_context
from mxnet.gluon.data.vision import transforms
from nose.tools import assert_raises

@with_seed()
Expand Down Expand Up @@ -165,6 +166,46 @@ def test_multiple_waitalls():
assert caught, "No exception thrown"
mx.nd.waitall()

@with_seed()
def test_exc_profiler():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems a bit heavy for a unittest, I think passing some nd.ones() through a simple dense layer would do the same right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have simplified it. Please take a look

def run_training_iteration(data, label):
data = data.as_in_context(ctx)
label = label.as_in_context(ctx)
with mx.autograd.record():
output = net(data)
loss = softmax_cross_entropy(output, label)
loss.backward()

trainer.step(data.shape[0])

net = gluon.nn.HybridSequential()
with net.name_scope():
net.add(gluon.nn.Conv2D(channels=20, kernel_size=5, activation='relu'))
net.add(gluon.nn.MaxPool2D(pool_size=2, strides=2))
net.add(gluon.nn.Conv2D(channels=50, kernel_size=5, activation='relu'))
net.add(gluon.nn.MaxPool2D(pool_size=2, strides=2))
net.add(gluon.nn.Flatten())
net.add(gluon.nn.Dense(512, activation="relu"))
net.add(gluon.nn.Dense(10))

train_data = gluon.data.DataLoader(gluon.data.vision.MNIST(train=True).transform_first(transforms.ToTensor()),
batch_size=64, shuffle=True)

ctx = default_context()

net.collect_params().initialize(mx.init.Xavier(), ctx=ctx)
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.1})
softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()

itr = iter(train_data)
run_training_iteration(*next(itr))

data, label = next(itr)

mx.profiler.set_state("run")
run_training_iteration(*next(itr))
mx.nd.waitall()
mx.profiler.set_state("stop")


if __name__ == '__main__':
Expand Down