Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG: Mamba-Codestral-7B-v0.1 Internal Triton PTX codegen error: Ptx assembly aborted due to errors #213

Open
andretisch opened this issue Aug 21, 2024 · 0 comments
Labels
bug Something isn't working

Comments

@andretisch
Copy link

Python -VV

Python 3.10.12 (main, Jul 29 2024, 16:56:48) [GCC 11.4.0]

Pip Freeze

absl-py==2.1.0
accelerate==0.29.3
aiohttp==3.9.5
aiosignal==1.3.1
albucore==0.0.5
albumentations==1.4.0
ale-py==0.7.5
annotated-types==0.6.0
anyio==4.3.0
aqlm==1.1.6
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
astroid==3.2.2
asttokens==2.4.1
astunparse==1.6.3
async-lru==2.0.4
async-timeout==4.0.3
asyncio==3.4.3
attrs==23.2.0
Automat==20.2.0
autopep8==2.0.4
AutoROM==0.4.2
AutoROM.accept-rom-license==0.6.1
Babel==2.14.0
bcrypt==3.2.0
beautifulsoup4==4.12.3
bitsandbytes==0.43.1
bleach==6.1.0
blinker==1.4
cachetools==5.3.3
causal-conv1d==1.4.0
certifi==2024.2.2
cffi==1.16.0
chardet==4.0.0
charset-normalizer==3.3.2
click==8.0.3
cloud-init==24.1.3
cloudpickle==3.0.0
cmake==3.30.2
colorama==0.4.4
coloredlogs==15.0.1
comm==0.2.2
command-not-found==0.3
configobj==5.0.6
constantly==15.1.0
contourpy==1.2.1
craft-text-detector @ git+https://github.com/ria-com/craft-text-detector.git@0734fc81fbe9705cffa6d13c71d2ba240b20b422
cryptography==3.4.8
cycler==0.12.1
Cython==3.0.10
dataclasses-json==0.6.4
datasets==2.19.1
dbus-python==1.2.18
debugpy==1.8.1
decorator==4.4.2
defusedxml==0.7.1
dill==0.3.8
diskcache==5.6.3
distlib==0.3.8
distro==1.7.0
distro-info==1.1+ubuntu0.2
dnspython==2.1.0
docstring-to-markdown==0.15
docstring_parser==0.16
einops==0.8.0
exceptiongroup==1.2.1
executing==2.0.1
faiss-cpu==1.8.0
fastapi==0.112.0
fastjsonschema==2.19.1
filelock==3.14.0
fire==0.6.0
flake8==7.0.0
flatbuffers==24.3.25
fonttools==4.51.0
fqdn==1.5.1
frozenlist==1.4.1
fsspec==2024.3.1
gast==0.5.4
gdown==5.2.0
gevent==24.2.1
gigachain==0.1.16
gigachain-community==0.0.33.1
gigachain-core==0.1.43.1
gigachain-text-splitters==0.0.1
gigachat==0.1.23
gitdb==4.0.11
GitPython==3.1.18
google-auth==2.29.0
google-auth-oauthlib==1.2.0
google-pasta==0.2.0
gpg==1.16.0
greenlet==3.0.3
grpcio==1.62.0
gym==0.25.2
gym-notices==0.0.8
h11==0.14.0
h5py==3.10.0
httpcore==1.0.5
httplib2==0.20.2
httptools==0.6.1
httpx==0.27.0
huggingface-hub==0.24.5
humanfriendly==10.0
hyperlink==21.0.0
idna==3.6
imageio==2.34.1
imageio-ffmpeg==0.5.1
imgaug==0.4.0
importlib_metadata==8.0.0
importlib_resources==6.4.0
incremental==21.3.0
interegular==0.3.3
ipykernel==6.29.4
ipython==8.23.0
ipywidgets==8.1.2
isoduration==20.11.0
isort==5.13.2
jedi==0.19.1
jeepney==0.7.1
Jinja2==3.1.4
joblib==1.4.0
json5==0.9.25
jsonpatch==1.33
jsonpickle==3.0.4
jsonpointer==2.0
jsonschema==4.21.1
jsonschema-specifications==2023.12.1
jupyter==1.0.0
jupyter-console==6.6.3
jupyter-events==0.10.0
jupyter-lsp==2.2.5
jupyter_client==8.6.1
jupyter_core==5.7.2
jupyter_server==2.14.0
jupyter_server_terminals==0.5.3
jupyterlab==4.1.6
jupyterlab-language-pack-ru-RU==4.1.post2
jupyterlab-lsp==5.1.0
jupyterlab_pygments==0.3.0
jupyterlab_server==2.27.0
jupyterlab_widgets==3.0.10
jupyterthemes==0.20.0
keras==2.15.0
keyring==23.5.0
kiwisolver==1.4.5
langsmith==0.1.50
lapx==0.5.9
lark==1.2.1
launchpadlib==1.10.16
lazr.restfulclient==0.14.4
lazr.uri==1.0.6
lazy_loader==0.4
lesscpy==0.15.1
libclang==18.1.1
lightning-utilities==0.11.2
llvmlite==0.43.0
lm-format-enforcer==0.10.3
lxml==5.2.1
mamba-ssm==2.2.2
Markdown==3.6
markdown-it-py==3.0.0
MarkupSafe==2.1.5
marshmallow==3.21.1
matplotlib==3.9.0
matplotlib-inline==0.1.7
mccabe==0.7.0
mdurl==0.1.2
mistral_common==1.3.4
mistral_inference==1.3.1
mistune==3.0.2
ml-dtypes==0.2.0
modelhub-client @ git+https://github.com/ria-com/modelhub-client.git@d155a12e1c32a0a65b98c911f1e9e8786d02d93f
more-itertools==8.10.0
moviepy==1.0.3
mpmath==1.3.0
msgpack==1.0.8
multidict==6.0.5
multiprocess==0.70.16
mypy-extensions==1.0.0
namex==0.0.8
nbclient==0.10.0
nbconvert==7.16.3
nbformat==5.10.4
nest-asyncio==1.6.0
netifaces==0.11.0
networkx==3.3
ninja==1.11.1.1
notebook==7.1.3
notebook_shim==0.2.4
numba==0.60.0
numpy==1.26.4
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==9.1.0.70
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-ml-py==12.555.43
nvidia-nccl-cu12==2.20.5
nvidia-nvjitlink-cu12==12.5.40
nvidia-nvtx-cu12==12.1.105
oauthlib==3.2.0
onnx==1.15.0
onnx2pytorch==0.4.1
onnxruntime-gpu==1.18.1
openai==1.23.3
opencv-python==4.9.0.80
opencv-python-headless==4.9.0.80
opt-einsum==3.3.0
optree==0.11.0
orjson==3.10.1
outlines==0.0.46
overrides==7.7.0
packaging==24.0
pandas==2.2.2
pandocfilters==1.5.1
parso==0.8.4
pexpect==4.8.0
pillow==10.3.0
platformdirs==4.2.0
pluggy==1.5.0
ply==3.11
pretty-errors==1.2.25
proglog==0.1.10
prometheus-fastapi-instrumentator==7.0.0
prometheus_client==0.20.0
prompt-toolkit==3.0.43
protobuf==3.20.3
psutil==5.9.8
ptyprocess==0.7.0
pure-eval==0.2.2
py-cpuinfo==9.0.0
pyairports==2.1.1
pyarrow==16.1.0
pyarrow-hotfix==0.6
pyasn1==0.5.1
pyasn1-modules==0.3.0
pycocotools==2.0.7
pycodestyle==2.11.1
pycountry==24.6.1
pycparser==2.22
pydantic==2.7.1
pydantic_core==2.18.2
pydocstyle==6.3.0
pydot==2.0.0
pyflakes==3.2.0
Pygments==2.17.2
PyGObject==3.42.1
PyHamcrest==2.0.2
PyJWT==2.3.0
pylint==3.2.5
pyOpenSSL==21.0.0
pyparsing==3.1.2
pyrsistent==0.18.1
pyserial==3.5
PySocks==1.7.1
pyTelegramBotAPI==4.17.0
python-apt==2.4.0+ubuntu3
python-dateutil==2.9.0.post0
python-debian==0.1.43+ubuntu1.1
python-dotenv==1.0.1
python-json-logger==2.0.7
python-lsp-jsonrpc==1.1.2
python-lsp-server==1.11.0
python-magic==0.4.24
pytoolconfig==1.3.1
pytorch-lightning==1.8.6
PyTurboJPEG @ git+https://github.com/lilohuang/PyTurboJPEG.git@c9a4973ab48e1e3d421881f76d8b1b2a22669fd2
pytz==2024.1
PyVirtualDisplay==3.0
PyYAML==6.0.1
pyzmq==26.0.2
qtconsole==5.5.1
QtPy==2.4.1
qudida==0.0.4
ray==2.34.0
referencing==0.34.0
regex==2024.4.16
requests==2.32.2
requests-oauthlib==2.0.0
requests-toolbelt==0.9.1
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rich==13.7.1
rope==1.13.0
rpds-py==0.18.0
rsa==4.9
safetensors==0.4.3
scikit-image==0.23.2
scikit-learn==1.4.2
scipy==1.13.0
screen-resolution-extra==0.0.0
seaborn==0.13.2
SecretStorage==3.3.1
Send2Trash==1.8.3
sentence-transformers==2.7.0
sentencepiece==0.2.0
sentry-sdk==2.1.1
service-identity==18.1.0
shapely==2.0.4
simple_parsing==0.1.5
six==1.16.0
smmap==5.0.1
sniffio==1.3.1
snowballstemmer==2.2.0
sos==4.5.6
soupsieve==2.5
SQLAlchemy==2.0.29
ssh-import-id==5.11
stack-data==0.6.3
starlette==0.37.2
sympy==1.12
systemd-python==234
tenacity==8.2.3
tensorboard==2.15.2
tensorboard-data-server==0.7.2
tensorboardX==2.6.2.2
tensorflow==2.15.0.post1
tensorflow-estimator==2.15.0
tensorflow-io-gcs-filesystem==0.36.0
tensorrt==10.0.1
tensorrt-cu12==10.0.1
tensorrt-cu12-bindings==10.0.1
tensorrt-cu12-libs==10.0.1
termcolor==2.4.0
terminado==0.18.1
tf2onnx==1.16.1
thop==0.1.1.post2209072238
threadpoolctl==3.4.0
tifffile==2024.5.10
tiktoken==0.7.0
tinycss2==1.2.1
tokenizers==0.19.1
tomli==2.0.1
tomlkit==0.12.5
torch==2.4.0
torchaudio==2.4.0+cu124
torchmetrics==1.4.0
torchvision==0.19.0
tornado==6.4
tqdm==4.66.4
traitlets==5.14.3
transformers==4.44.0
triton==3.0.0
Twisted==22.1.0
typeguard==4.2.1
types-python-dateutil==2.9.0.20240316
typing-inspect==0.9.0
typing_extensions==4.11.0
tzdata==2024.1
ubuntu-drivers-common==0.0.0
ubuntu-pro-client==8001
ufw==0.36.1
ujson==5.10.0
ultralytics==8.2.19
unattended-upgrades==0.1
uri-template==1.3.0
urllib3==2.2.1
uvicorn==0.30.6
uvloop==0.19.0
virtualenv==20.25.3
vllm==0.5.4
vllm-flash-attn==2.6.1
wadllib==1.3.6
watchfiles==0.23.0
wcwidth==0.2.13
webcolors==1.13
webencodings==0.5.1
websocket-client==1.7.0
websockets==12.0
Werkzeug==3.0.2
whatthepatch==1.0.5
widgetsnbextension==4.0.10
wrapt==1.14.1
xformers==0.0.27.post2
xkit==0.0.0
xvfbwrapper==0.2.9
xxhash==3.4.1
yapf==0.40.2
yarl==1.9.4
zipp==1.0.0
zope.event==5.0
zope.interface==5.4.0

Reproduction Steps

from mistral_inference.main import *

demo('models/Mamba-Codestral-7B-v0.1')

Expected Behavior

Just generated output answer

Additional Context

Hi. I tried run demo function with model mistralai/Mamba-Codestral-7B-v0.1. I use torch 2.4.0 and CUDA 12.4 on Nvidia Tesla P40.
How I understand triton have some problem with bf16 on P40. But early I was loading llama in bf16.

---------------------------------------------------------------------------
CalledProcessError                        Traceback (most recent call last)
File ~/.local/lib/python3.10/site-packages/triton/backends/nvidia/compiler.py:292, in CUDABackend.make_cubin(src, metadata, opt, capability)
    291 try:
--> 292     subprocess.run(cmd, shell=True, check=True)
    293 except subprocess.CalledProcessError as e:

File /usr/lib/python3.10/subprocess.py:526, in run(input, capture_output, timeout, check, *popenargs, **kwargs)
    525     if check and retcode:
--> 526         raise CalledProcessError(retcode, process.args,
    527                                  output=stdout, stderr=stderr)
    528 return CompletedProcess(process.args, retcode, stdout, stderr)

CalledProcessError: Command '/home/andrew/.local/lib/python3.10/site-packages/triton/backends/nvidia/bin/ptxas -lineinfo -v --gpu-name=sm_61 /tmp/tmpxqf2tngx.ptx -o /tmp/tmpxqf2tngx.ptx.o 2> /tmp/tmpbh62k9ys.log' returned non-zero exit status 255.

During handling of the above exception, another exception occurred:

RuntimeError                              Traceback (most recent call last)
Cell In[1], line 4
      1 import torch
      2 from mistral_inference.main import *
----> 4 demo('models/Mamba-Codestral-7B-v0.1')

File ~/.local/lib/python3.10/site-packages/mistral_inference/main.py:185, in demo(model_path, max_tokens, temperature, lora_path)
    178     warnings.warn(
    179         "Batched generation is not correctly supported at the moment and therefore might lead to worse results "
    180         "as compared to non-batched generation. "
    181         "See https://github.com/state-spaces/mamba/issues/66#issuecomment-1862349718 for more information."
    182     )
    183     encoded_prompts = pad_and_convert_to_tensor(encoded_prompts, mistral_tokenizer.instruct_tokenizer.BOS)  # type: ignore[attr-defined]
--> 185 generated_tokens, _logprobs = generate_fn(
    186     encoded_prompts,
    187     model,  # type: ignore[arg-type]
    188     max_tokens=max_tokens,
    189     temperature=temperature,
    190     eos_id=tokenizer.eos_id,
    191 )
    193 generated_words = []
    194 for i, x in enumerate(generated_tokens):

File ~/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    113 @functools.wraps(func)
    114 def decorate_context(*args, **kwargs):
    115     with ctx_factory():
--> 116         return func(*args, **kwargs)

File ~/.local/lib/python3.10/site-packages/mistral_inference/generate.py:21, in generate_mamba(encoded_prompts, model, max_tokens, temperature, chunk_size, eos_id)
     10 @torch.inference_mode()
     11 def generate_mamba(
     12     encoded_prompts: List[List[int]],
   (...)
     18     eos_id: Optional[int] = None,
     19 ) -> Tuple[List[List[int]], List[List[float]]]:
     20     input_ids = torch.tensor(encoded_prompts, device=model.device)
---> 21     output = model.model.generate(
     22         input_ids=input_ids,
     23         max_length=input_ids.shape[-1] + max_tokens,
     24         cg=True,
     25         return_dict_in_generate=True,
     26         output_scores=True,
     27         enable_timing=False,
     28         eos_token_id=eos_id,
     29         temperature=temperature,
     30         top_p=0.8,
     31     )
     32     generated_tokens = output.sequences[:, input_ids.shape[-1] :].tolist()
     34     _logprobs: List[List[float]] = [[] for _ in range(len(generated_tokens))]

File ~/.local/lib/python3.10/site-packages/mamba_ssm/utils/generation.py:260, in GenerationMixin.generate(self, input_ids, max_length, top_k, top_p, min_p, temperature, return_dict_in_generate, output_scores, **kwargs)
    248 def generate(
    249     self,
    250     input_ids,
   (...)
    258     **kwargs,
    259 ):
--> 260     output = decode(
    261         input_ids, self, max_length, top_k=top_k, top_p=top_p, min_p = min_p, temperature=temperature, **kwargs
    262     )
    263     if not output_scores:
    264         output.scores = None

File ~/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    113 @functools.wraps(func)
    114 def decorate_context(*args, **kwargs):
    115     with ctx_factory():
--> 116         return func(*args, **kwargs)

File ~/.local/lib/python3.10/site-packages/mamba_ssm/utils/generation.py:160, in decode(input_ids, model, max_length, top_k, top_p, min_p, temperature, repetition_penalty, eos_token_id, teacher_outputs, vocab_size, cg, enable_timing, streamer)
    158 if not hasattr(model, "_decoding_cache"):
    159     model._decoding_cache = None
--> 160 model._decoding_cache = update_graph_cache(
    161     model,
    162     model._decoding_cache,
    163     batch_size,
    164     seqlen_og,
    165     max_length,
    166 )
    167 inference_params = model._decoding_cache.inference_params
    168 inference_params.reset(max_length, batch_size)

File ~/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    113 @functools.wraps(func)
    114 def decorate_context(*args, **kwargs):
    115     with ctx_factory():
--> 116         return func(*args, **kwargs)

File ~/.local/lib/python3.10/site-packages/mamba_ssm/utils/generation.py:321, in update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, decoding_seqlens, dtype, n_warmups)
    319 for decoding_seqlen in decoding_seqlens:
    320     if (batch_size, decoding_seqlen) not in cache.callables:
--> 321         cache.callables[batch_size, decoding_seqlen] = capture_graph(
    322             model,
    323             cache.inference_params,
    324             batch_size,
    325             max_seqlen,
    326             decoding_seqlen=decoding_seqlen,
    327             mempool=cache.mempool,
    328             n_warmups=n_warmups,
    329         )
    331 def dispatch(input_ids, position_ids, seqlen):
    332     batch_size, decoding_seqlen = input_ids.shape[:2]

File ~/.local/lib/python3.10/site-packages/mamba_ssm/utils/generation.py:355, in capture_graph(model, inference_params, batch_size, max_seqlen, decoding_seqlen, mempool, n_warmups)
    353 with torch.cuda.stream(s):
    354     for _ in range(n_warmups):
--> 355         logits = model(
    356             input_ids,
    357             position_ids=position_ids,
    358             inference_params=inference_params,
    359             num_last_tokens=decoding_seqlen,
    360         ).logits
    361     s.synchronize()
    362     # This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
    363     # which requires that graph launch and non-captured launch to not overlap (I think,
    364     # that's how I interpret the documentation). I'm not sure if this is required.

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File ~/.local/lib/python3.10/site-packages/mamba_ssm/models/mixer_seq_simple.py:279, in MambaLMHeadModel.forward(self, input_ids, position_ids, inference_params, num_last_tokens, **mixer_kwargs)
    274 def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0, **mixer_kwargs):
    275     """
    276     "position_ids" is just to be compatible with Transformer generation. We don't use it.
    277     num_last_tokens: if > 0, only return the logits for the last n tokens
    278     """
--> 279     hidden_states = self.backbone(input_ids, inference_params=inference_params, **mixer_kwargs)
    280     if num_last_tokens > 0:
    281         hidden_states = hidden_states[:, -num_last_tokens:]

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File ~/.local/lib/python3.10/site-packages/mamba_ssm/models/mixer_seq_simple.py:194, in MixerModel.forward(self, input_ids, inference_params, **mixer_kwargs)
    192 residual = None
    193 for layer in self.layers:
--> 194     hidden_states, residual = layer(
    195         hidden_states, residual, inference_params=inference_params, **mixer_kwargs
    196     )
    197 if not self.fused_add_norm:
    198     residual = (hidden_states + residual) if residual is not None else hidden_states

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File ~/.local/lib/python3.10/site-packages/mamba_ssm/modules/block.py:57, in Block.forward(self, hidden_states, residual, inference_params, **mixer_kwargs)
     55         residual = residual.to(torch.float32)
     56 else:
---> 57     hidden_states, residual = layer_norm_fn(
     58         hidden_states,
     59         self.norm.weight,
     60         self.norm.bias,
     61         residual=residual,
     62         prenorm=True,
     63         residual_in_fp32=self.residual_in_fp32,
     64         eps=self.norm.eps,
     65         is_rms_norm=isinstance(self.norm, RMSNorm)
     66     )
     67 hidden_states = self.mixer(hidden_states, inference_params=inference_params, **mixer_kwargs)
     69 if self.mlp is not None:

File ~/.local/lib/python3.10/site-packages/mamba_ssm/ops/triton/layer_norm.py:902, in layer_norm_fn(x, weight, bias, residual, x1, weight1, bias1, eps, dropout_p, rowscale, prenorm, residual_in_fp32, is_rms_norm, return_dropout_mask)
    886 def layer_norm_fn(
    887     x,
    888     weight,
   (...)
    900     return_dropout_mask=False,
    901 ):
--> 902     return LayerNormFn.apply(
    903         x,
    904         weight,
    905         bias,
    906         residual,
    907         x1,
    908         weight1,
    909         bias1,
    910         eps,
    911         dropout_p,
    912         rowscale,
    913         prenorm,
    914         residual_in_fp32,
    915         is_rms_norm,
    916         return_dropout_mask,
    917     )

File ~/.local/lib/python3.10/site-packages/torch/autograd/function.py:574, in Function.apply(cls, *args, **kwargs)
    571 if not torch._C._are_functorch_transforms_active():
    572     # See NOTE: [functorch vjp and autograd interaction]
    573     args = _functorch.utils.unwrap_dead_wrappers(args)
--> 574     return super().apply(*args, **kwargs)  # type: ignore[misc]
    576 if not is_setup_ctx_defined:
    577     raise RuntimeError(
    578         "In order to use an autograd.Function with functorch transforms "
    579         "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
    580         "staticmethod. For more details, please see "
    581         "https://pytorch.org/docs/main/notes/extending.func.html"
    582     )

File ~/.local/lib/python3.10/site-packages/mamba_ssm/ops/triton/layer_norm.py:775, in LayerNormFn.forward(ctx, x, weight, bias, residual, x1, weight1, bias1, eps, dropout_p, rowscale, prenorm, residual_in_fp32, is_rms_norm, return_dropout_mask)
    769     rowscale = rowscale.reshape(-1).contiguous()
    770 residual_dtype = (
    771     residual.dtype
    772     if residual is not None
    773     else (torch.float32 if residual_in_fp32 else None)
    774 )
--> 775 y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd(
    776     x,
    777     weight,
    778     bias,
    779     eps,
    780     residual,
    781     x1,
    782     weight1,
    783     bias1,
    784     dropout_p=dropout_p,
    785     rowscale=rowscale,
    786     residual_dtype=residual_dtype,
    787     is_rms_norm=is_rms_norm,
    788     return_dropout_mask=return_dropout_mask,
    789 )
    790 ctx.save_for_backward(
    791     residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd
    792 )
    793 ctx.x_shape_og = x_shape_og

File ~/.local/lib/python3.10/site-packages/mamba_ssm/ops/triton/layer_norm.py:369, in _layer_norm_fwd(x, weight, bias, eps, residual, x1, weight1, bias1, dropout_p, rowscale, out_dtype, residual_dtype, is_rms_norm, return_dropout_mask)
    367     raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
    368 with torch.cuda.device(x.device.index):
--> 369     _layer_norm_fwd_1pass_kernel[(M,)](
    370         x,
    371         y,
    372         weight,
    373         bias,
    374         residual,
    375         x1,
    376         weight1,
    377         bias1,
    378         y1,
    379         residual_out,
    380         rowscale,
    381         seeds,
    382         dropout_mask,
    383         mean,
    384         rstd,
    385         x.stride(0),
    386         y.stride(0),
    387         residual.stride(0) if residual is not None else 0,
    388         residual_out.stride(0) if residual_out is not None else 0,
    389         x1.stride(0) if x1 is not None else 0,
    390         y1.stride(0) if y1 is not None else 0,
    391         M,
    392         N,
    393         eps,
    394         dropout_p,
    395         is_rms_norm,
    396         BLOCK_N,
    397         residual is not None,
    398         residual_out is not None,
    399         bias is not None,
    400         dropout_p > 0.0,
    401         dropout_mask is not None,
    402         rowscale is not None,
    403     )
    404 # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
    405 if dropout_mask is not None and x1 is not None:

File ~/.local/lib/python3.10/site-packages/triton/runtime/jit.py:345, in KernelInterface.__getitem__.<locals>.<lambda>(*args, **kwargs)
    339 def __getitem__(self, grid) -> T:
    340     """
    341     A JIT function is launched with: fn[grid](*args, **kwargs).
    342     Hence JITFunction.__getitem__ returns a callable proxy that
    343     memorizes the grid.
    344     """
--> 345     return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)

File ~/.local/lib/python3.10/site-packages/triton/runtime/autotuner.py:156, in Autotuner.run(self, *args, **kwargs)
    154 pruned_configs = self.prune_configs(kwargs)
    155 bench_start = time.time()
--> 156 timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
    157 bench_end = time.time()
    158 self.bench_time = bench_end - bench_start

File ~/.local/lib/python3.10/site-packages/triton/runtime/autotuner.py:156, in <dictcomp>(.0)
    154 pruned_configs = self.prune_configs(kwargs)
    155 bench_start = time.time()
--> 156 timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
    157 bench_end = time.time()
    158 self.bench_time = bench_end - bench_start

File ~/.local/lib/python3.10/site-packages/triton/runtime/autotuner.py:133, in Autotuner._bench(self, config, *args, **meta)
    131             bench_res = do_bench_cudagraph(kernel_call, rep=self.num_reps, return_mode="median")
    132         return bench_res
--> 133     return do_bench(kernel_call, warmup=self.num_warmups, rep=self.num_reps, quantiles=(0.5, 0.2, 0.8))
    134 except (OutOfResources, CompileTimeAssertionFailure):
    135     return float("inf") if self.use_cuda_graph else [float("inf"), float("inf"), float("inf")]

File ~/.local/lib/python3.10/site-packages/triton/testing.py:103, in do_bench(fn, warmup, rep, grad_to_none, quantiles, fast_flush, return_mode)
    100 assert return_mode in ["min", "max", "mean", "median"]
    101 import torch
--> 103 fn()
    104 torch.cuda.synchronize()
    106 # We maintain a buffer of 256 MB that we clear
    107 # before each kernel call to make sure that the L2
    108 # doesn't contain any input data before the run

File ~/.local/lib/python3.10/site-packages/triton/runtime/autotuner.py:114, in Autotuner._bench.<locals>.kernel_call()
    112 self.pre_hook(args)
    113 try:
--> 114     self.fn.run(
    115         *args,
    116         **current,
    117     )
    118 except Exception as e:
    119     try:

File ~/.local/lib/python3.10/site-packages/triton/runtime/autotuner.py:338, in Heuristics.run(self, *args, **kwargs)
    336 for v, heur in self.values.items():
    337     kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs})
--> 338 return self.fn.run(*args, **kwargs)

File ~/.local/lib/python3.10/site-packages/triton/runtime/autotuner.py:338, in Heuristics.run(self, *args, **kwargs)
    336 for v, heur in self.values.items():
    337     kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs})
--> 338 return self.fn.run(*args, **kwargs)

File ~/.local/lib/python3.10/site-packages/triton/runtime/autotuner.py:338, in Heuristics.run(self, *args, **kwargs)
    336 for v, heur in self.values.items():
    337     kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs})
--> 338 return self.fn.run(*args, **kwargs)

File ~/.local/lib/python3.10/site-packages/triton/runtime/jit.py:662, in JITFunction.run(self, grid, warmup, *args, **kwargs)
    660     # compile the kernel
    661     src = self.ASTSource(self, signature, constants, configs[0])
--> 662     kernel = self.compile(
    663         src,
    664         target=target,
    665         options=options.__dict__,
    666     )
    667     self.cache[device][key] = kernel
    669 # Check that used global values have not changed.

File ~/.local/lib/python3.10/site-packages/triton/compiler/compiler.py:282, in compile(src, target, options)
    280 use_ttgir_loc = os.environ.get("USE_TTGIR_LOC", "0") == "1"
    281 for ext, compile_ir in list(stages.items())[first_stage:]:
--> 282     next_module = compile_ir(module, metadata)
    283     ir_filename = f"{src.name}.{ext}"
    284     metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename)

File ~/.local/lib/python3.10/site-packages/triton/backends/nvidia/compiler.py:320, in CUDABackend.add_stages.<locals>.<lambda>(src, metadata)
    318 stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, self.capability)
    319 stages["ptx"] = lambda src, metadata: self.make_ptx(src, metadata, options, self.capability)
--> 320 stages["cubin"] = lambda src, metadata: self.make_cubin(src, metadata, options, self.capability)

File ~/.local/lib/python3.10/site-packages/triton/backends/nvidia/compiler.py:297, in CUDABackend.make_cubin(src, metadata, opt, capability)
    295     log = log_file.read()
    296 if e.returncode == 255:
--> 297     raise RuntimeError(f'Internal Triton PTX codegen error: \n{log}')
    298 elif e.returncode == 128 + signal.SIGSEGV:
    299     raise RuntimeError(
    300         f'Please run `ptxas {fsrc.name}` to confirm that this is a bug in `ptxas`\n{log}')

RuntimeError: Internal Triton PTX codegen error: 
ptxas /tmp/tmpxqf2tngx.ptx, line 579; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/tmpxqf2tngx.ptx, line 579; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/tmpxqf2tngx.ptx, line 582; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/tmpxqf2tngx.ptx, line 582; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/tmpxqf2tngx.ptx, line 585; error   : Feature '.bf16' requires ......

### Suggested Solutions

_No response_
@andretisch andretisch added the bug Something isn't working label Aug 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant