Skip to content

Conversation

@jaketae
Copy link
Member

@jaketae jaketae commented Aug 4, 2021

This PR addresses #44. Specifically, it implements the following activation functions:

  • LiGLU
  • GEGLU
  • ReGLU
  • SwiGLU

Nomenclature-wise, Bilinear seems to be synonymous with LiGLU.

@jaketae
Copy link
Member Author

jaketae commented Aug 4, 2021

@stas00 I see that the original Megatron codebase has fused activation functions that use JIT. Should we also do this for experimental activation functions?

@stas00
Copy link
Contributor

stas00 commented Aug 4, 2021

Yes, of course! If it works that is.

@jaketae
Copy link
Member Author

jaketae commented Aug 5, 2021

I think JIT doesn't like None as the default argument for bias. My understanding is that bias terms will be added in previous layers before the activation function, so I don't see why we would need it (I was following GPT-Neo's codebase as reference). I'll get rid of the bias argument and see if I can get JIT working.

@jaketae jaketae self-assigned this Aug 5, 2021
@jaketae jaketae changed the title [WIP] Add GLU variants Add GLU variants Aug 5, 2021
@jaketae jaketae changed the title Add GLU variants [WIP] Add GLU variants Aug 5, 2021
@jaketae
Copy link
Member Author

jaketae commented Aug 5, 2021

@stas00 Something funky is happening on my local, and I'm getting an error with reglu (the rest work fine). Trace:

RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
RuntimeError: vector

I ran the code on Colab and verified that they all work fine. Should we be writing unit tests of any sort to ensure that modules work as intended without error?

@stas00
Copy link
Contributor

stas00 commented Aug 5, 2021

This looks like a truncated traceback, no? From googling this error typically there is vector::something

How do I reproduce it?

And yes, we absolutely need to start writing tests, as Meg didn't seem to have any. So if you're inspired start adding the tests under tests and overtime we will expand it and also need to setup up a CI.

@jaketae
Copy link
Member Author

jaketae commented Aug 5, 2021

Surprisingly, that's all the trace shows me. Below is what I did to cause the error on my local (macOS, Python 3.7.3).

>>> from megatron.model.activations import liglu, geglu, reglu, swiglu
>>> import torch
>>> x = torch.randn(8, 100, 768)
>>> reglu(x)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/jaketae/Documents/Dev/GitHub/Megatron-DeepSpeed/venv/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
RuntimeError: vector

>>> torch.__version__
'1.9.0'

Two more observations:

  • I was not able to reproduce the error on Colab
  • When I moved activations.py to an entirely separate directly and ran the same sequence on the same virtual environment, the error did not occur

We will need testing to iron out messy details like this, and I could certainly try writing some very basic ones for the next few days, but unless you can reproduce the error on your local I wouldn't be too concerned... or should we?

@stas00
Copy link
Contributor

stas00 commented Aug 5, 2021

I get the full trace on my machine:

python -c "from megatron.model.activations import liglu, geglu, reglu, swiglu; import torch; print(torch.__version__); \
reglu(torch.randn(8, 100, 768))"
1.9.0
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/home/stas/anaconda3/envs/py38-pt19/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
RuntimeError: vector::_M_range_check: __n (which is 18446744073709551615) >= this->size() (which is 3)

the other 3 produce no error.

@jaketae
Copy link
Member Author

jaketae commented Aug 5, 2021

Apparently JIT doesn't like negative indexing (source). I replaced dim=-1 with dim=2 and it works fine now (it also passed a mini unit test). Can you confirm? Thanks!

@stas00
Copy link
Contributor

stas00 commented Aug 5, 2021

yeah, that fixed it. Interesting that dim=-1 is 18446744073709551615

@stas00
Copy link
Contributor

stas00 commented Aug 5, 2021

Is it always a dim 3 vector? It will fail with dim 4 vector for example

@jaketae
Copy link
Member Author

jaketae commented Aug 5, 2021

It's odd that the error is still there, my source is from 2019. Is it safe to assume that all tensors will be three-dimensional?


EDIT Haha we had the same question. A remedy would be to use x.ndim. If we can safely assume that the last dimension will always be the feature dimension, I think we can do x.ndim - 1 or something of that sort. Does that sound good?

@stas00
Copy link
Contributor

stas00 commented Aug 5, 2021

yes, I just tested x1, x2 = x.chunk(2, dim=(x.ndim-1)) as proposed in the thread you linked to and it works. Sounds like the most generic approach.

@stas00
Copy link
Contributor

stas00 commented Aug 5, 2021

supposedly this bug has been fixed recently in pytorch/pytorch#25135 (comment), so probably pt-1.10.0, but we need to support older pytorch, so let's leave a comment there.

# dim=-1 breaks in pt<1.10
x1, x2 = x.chunk(2, dim=(x.ndim-1))

@jaketae
Copy link
Member Author

jaketae commented Aug 5, 2021

Sounds good! Thanks for testing the code on your end.

@jaketae jaketae changed the title [WIP] Add GLU variants Add GLU variants Aug 5, 2021
Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
Copy link
Contributor

@stas00 stas00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add basic tests, and looking great otherwise. Thank you, @jaketae

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
@jaketae jaketae mentioned this pull request Aug 6, 2021
@jaketae
Copy link
Member Author

jaketae commented Aug 6, 2021

@stas00 I wrote some simple tests. Admittedly I don't have a lot of experience writing test code, and I wasn't sure if the way I tested the operations make sense (i.e. the test somewhat regurgitates function definitions). I also took a monkey testing approach and used random inputs, with seeds set at the beginning of the run for reproducibility. Let me know what you think. Thanks!

@stas00
Copy link
Contributor

stas00 commented Aug 8, 2021

Looks great, @jaketae

We will gradually improve the test suite, so yours is a good start.

Let's merge this one

@stas00 stas00 mentioned this pull request Aug 8, 2021
3 tasks
@jaketae jaketae merged commit effb2fb into main Aug 8, 2021
@jaketae
Copy link
Member Author

jaketae commented Aug 8, 2021

Thanks for the feedback and review @stas00!

@stas00 stas00 deleted the activations branch August 8, 2021 16:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants