Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Numpy] Infrastructure for implementing constraint check #16868

Open
xidulu opened this issue Nov 20, 2019 · 2 comments
Open

[Numpy] Infrastructure for implementing constraint check #16868

xidulu opened this issue Nov 20, 2019 · 2 comments

Comments

@xidulu
Copy link
Contributor

xidulu commented Nov 20, 2019

Description

Checking the validity of parameters is crucial many operators, especially in distribution related Ops. (see references for the implementations of torch and tensorflow)

I implement an operator called npx.constraint_check, which takes a boolean tensor and an error message as input and then checks if all the elements are true, if not, raises exception with given message. It will return a scalar tensor array(True) if none of the elements is false.

However, this Op fails in symbolic mode, as the output of this Op is neither returned to users nor used as the input for other Ops, causing the engine to completely ignore the check Op. In short, exception is not raised.

@leezu provides a workaround like this:

class WrapperHybridBlock(HybridBlock):
  def __init__(self):
    super(WrapperHybridBlock, self).__init__()

  def __call__(self, *args):
    tmp = super(WrapperHybridBlock, self).__call__(args)
    tmp[-1].wait_to_read()
    return tmp[:-1] if len(tmp) > 2 else tmp[0]

class foo(WrapperHybridBlock):
  def __init__(self):
    super(foo, self).__init__()

  def hybrid_forward(self, F, arg):
    low = arg[0]
    high = arg[1]
    flag = F.npx.constraint_check(high > low, "error!")
    actual_output = F.np.random.uniform(low, high)
    return (actual_output, flag)

test_foo = foo()
low = np.ones((4,4))
high = np.ones((4,4)) - 2
test_foo.hybridize()
print(test_foo(low, high))

This approach works well, exception got thrown out as expected.

However, this method is not convenient when the constraint_check is buried in function called inside the hybrid_forward, e.g:

def func1(F, args):
    F.npx.constraint_check(args)
    return F.npx.other_func(args)

def func2(F, args):
    t = func1(F, args)
    return F.npx.other_func(t)

def hybrid_forward(self, F, args);
    return func2(F, args)

In such case, it becomes quite difficult to manually turn the flag tensor into a return value.


Another solution could be the cond op in control flow, which unfortunately, seems to be out of maintenance.

I believe, the simplest way is to have some kinds of mechanism that force MXNet to evaluate a particular OP.

I'm open to other solutions . (It would be great to implement this feature using only the existing infrastructure)

References

@xidulu
Copy link
Contributor Author

xidulu commented Nov 20, 2019

@sxjscience Any suggestions?

@leezu
Copy link
Contributor

leezu commented Nov 20, 2019

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

No branches or pull requests

2 participants