-
Notifications
You must be signed in to change notification settings - Fork 209
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ignore_index and label to jsd and fl-jsd #306
Conversation
src/liger_kernel/ops/jsd.py
Outdated
tl.store(dX_ptr + offsets, dX, mask=mask) | ||
|
||
|
||
MAX_FUSED_SIZE = 65536 | ||
|
||
|
||
def jsd_forward(_input, target, beta): | ||
def jsd_forward(_input, target, label, beta, ignore_index, has_label): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I might be wrong -- if i'm understanding it correctly, we currently have an intrinsic assumption that the label is shifted already. It would be helpful to specify this requirement and provide an example of what kind of input we'll expect in this case 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense. I added some examples in transformers files, and renamed it to shift_labels
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
overall LGTM just a very minor suggestion!
src/liger_kernel/ops/jsd.py
Outdated
beta, | ||
n_rows, | ||
n_non_ignore, | ||
ignore_index, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this could be a constexpr
Summary
Resolve #277.
Testing Done
make test
to ensure correctnessmake checkstyle
to ensure code stylemake test-convergence
to ensure convergence