diff --git a/python/mxnet/symbol/contrib.py b/python/mxnet/symbol/contrib.py index d54267fc02ff..a83227a0261c 100644 --- a/python/mxnet/symbol/contrib.py +++ b/python/mxnet/symbol/contrib.py @@ -73,17 +73,14 @@ def rand_zipfian(true_classes, num_sampled, range_max): Examples -------- - >>> true_cls = mx.nd.array([3]) - >>> samples, exp_count_true, exp_count_sample = mx.nd.contrib.rand_zipfian(true_cls, 4, 5) - >>> samples - [1 3 3 3] - - >>> exp_count_true - [ 0.12453879] - - >>> exp_count_sample - [ 0.22629439 0.12453879 0.12453879 0.12453879] - + >>> true_cls = mx.sym.Variable('true_cls') + >>> samples, exp_count_true, exp_count_sample = mx.sym.contrib.rand_zipfian(true_cls, 4, 5) + >>> samples.eval(true_cls=mx.nd.array([3]))[0].asnumpy() + array([1, 3, 3, 3]) + >>> exp_count_true.eval(true_cls=mx.nd.array([3]))[0].asnumpy() + array([0.12453879]) + >>> exp_count_sample.eval(true_cls=mx.nd.array([3]))[0].asnumpy() + array([0.22629439, 0.12453879, 0.12453879, 0.12453879]) """ assert(isinstance(true_classes, Symbol)), "unexpected type %s" % type(true_classes) log_range = math.log(range_max + 1)