Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Fit API] update estimator #14849

Merged
merged 5 commits into from
May 2, 2019
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions python/mxnet/gluon/contrib/estimator/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,17 @@ def _check_metrics(self, metrics):
return metrics

def _check_context(self, context):
# handle context
if isinstance(context, Context):
context = [context]
elif isinstance(context, list) and all([isinstance(c, Context) for c in context]):
context = context
elif not context:
if context:
# check context values, only accept Context or a list of Context
if isinstance(context, Context):
context = [context]
elif isinstance(context, list) and all([isinstance(c, Context) for c in context]):
context = context
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we check for the GPU device index too? also, try querying num_gpus() only once.

Copy link
Member Author

@roywei roywei Apr 30, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@szha thanks! I m now asserting context must be in available_context which is [cpu(), gpu(0), ...., gpu(num_gpus-1)]. added unit test

Copy link
Member

@szha szha May 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cpu(65536) is actually valid in mxnet regardless of the number of physical CPUs, whereas cpu() refers to cpu(0). For GPU, the check is fine, but for CPU, we might need to do a more general check.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about:

available_context = [mx.gpu(i) for i in range(num_gpus)]
assert ctx  in available_context or str(ctx).startswith('cpu')

else:
raise ValueError("context must be a Context or a list of Context, "
"refer to mxnet.Context:{}".format(context))
else:
# provide default context
if num_gpus() > 0:
# only use 1 GPU by default
if num_gpus() > 1:
Expand All @@ -103,9 +108,6 @@ def _check_context(self, context):
context = [gpu(0)]
else:
context = [cpu()]
else:
raise ValueError("context must be a Context or a list of Context, "
"refer to mxnet.Context:{}".format(context))
return context

def _initialize(self, initializer):
Expand Down Expand Up @@ -167,7 +169,8 @@ def prepare_loss_and_metrics(self):
self.train_metrics = [Accuracy()]
self.val_metrics = []
for loss in self.loss:
self.train_metrics.append(Loss(''.join([i for i in loss.name if not i.isdigit()])))
# remove trailing numbers from loss name to avoid confusion
self.train_metrics.append(Loss(loss.name.rstrip('1234567890')))
for metric in self.train_metrics:
val_metric = copy.deepcopy(metric)
metric.name = "Train " + metric.name
Expand Down