-
Notifications
You must be signed in to change notification settings - Fork 6.8k
randn operator for symbol and NDarray API #12775
Conversation
@ChaiBapchya Can you check the lint failures on CI ? |
@ChaiBapchya - Thanks for your contributions. Please add more detailed description from the template. |
@anirudhacharya - Can you please take a look at this PR? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I agree with @sandeep-krishnamurthy . the method definition is inconsistent with other similar sampling functions of the symbol API( line 74 in the same file, for example). Also please add tests for this.
Also if you are adding this operator in the symbol API, can you implement the operator |
@ChaiBapchya ping! |
@@ -152,7 +152,7 @@ def normal(loc=0, scale=1, shape=_Null, dtype=_Null, ctx=None, out=None, **kwarg | |||
[loc, scale], shape, dtype, ctx, out, kwargs) | |||
|
|||
|
|||
def randn(*shape, **kwargs): | |||
def randn(loc=0, scale=1, shape=_Null, dtype=_Null, ctx=None, out=None, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems to be a breaking change: For example mx.nd.random.randn(2, 3, loc=5, scale=1) will fail now. We should add this to APIs good to break for 2.0 #9686 and postpone this.
Closing, will reopen for 2.0 |
Description
Added the Missing randn operator for symbol API.
Fixes inconsistent function signature of ndarray.randn
Fixes #12755 and #12801