Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

float8 Expected trailing dimension of mat1 to be divisible by 16 but got mat1 shape #571

Open
vkuzo opened this issue Jul 30, 2024 · 0 comments
Labels

Comments

@vkuzo
Copy link
Contributor

vkuzo commented Jul 30, 2024

from @msaroufim

I wrote a toy training loop to get something going with fp8 and ran into this padding related issue. I managed to solve it by just replacing a single line in my code by texts = ["Example text input 1 bla bla bla bla bla bla bla bla bla.", "Example text input 2.", "Example text input 3."] but it took me about 10 min to hunt down. I figure this is some performance related assert for tensor cores in which case padding feels like it makes sense

After that I now have a functioning hello world example with the loss going down

Epoch 1, Step 1, Loss: 8.910361289978027
Epoch 1, Step 2, Loss: 4.616391658782959
Epoch 2, Step 1, Loss: 2.377967119216919
Epoch 2, Step 2, Loss: 1.4298633337020874
Epoch 3, Step 1, Loss: 1.5666098594665527
Epoch 3, Step 2, Loss: 0.8038766384124756

Error

ao) [[email protected] ~/float8_experimental/test (main)]$ HF_TOKEN="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" python fp8.py 
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.32it/s]
[rank0]: Traceback (most recent call last):
[rank0]:   File "/home/marksaroufim/float8_experimental/test/fp8.py", line 62, in <module>
[rank0]:     loss.backward()
[rank0]:   File "/home/marksaroufim/anaconda3/envs/ao/lib/python3.10/site-packages/torch/_tensor.py", line 521, in backward
[rank0]:     torch.autograd.backward(
[rank0]:   File "/home/marksaroufim/anaconda3/envs/ao/lib/python3.10/site-packages/torch/autograd/__init__.py", line 289, in backward
[rank0]:     _engine_run_backward(
[rank0]:   File "/home/marksaroufim/anaconda3/envs/ao/lib/python3.10/site-packages/torch/autograd/graph.py", line 768, in _engine_run_backward
[rank0]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank0]:   File "/home/marksaroufim/float8_experimental/float8_experimental/float8_tensor.py", line 297, in __torch_dispatch__
[rank0]:     return FLOAT8_OPS_TABLE[func](func, args, kwargs)
[rank0]:   File "/home/marksaroufim/float8_experimental/float8_experimental/float8_ops.py", line 151, in float8_mm
[rank0]:     tensor_out, amax = addmm_float8_unwrapped(
[rank0]:   File "/home/marksaroufim/float8_experimental/float8_experimental/float8_python_api.py", line 55, in addmm_float8_unwrapped
[rank0]:     output, output_amax = torch._scaled_mm(
[rank0]: RuntimeError: Expected trailing dimension of mat1 to be divisible by 16 but got mat1 shape: (32000x14.
[rank0]:[W612 11:03:16.562113367 ProcessGroupNCCL.cpp:1158] Warning: WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL operations have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present,  but this warning has only been added since PyTorch 2.4 (function operator())
(ao) [[email protected] ~/float8_experimental/test (main)]$ 

Code

import torch
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf").to("cuda:7")

# Convert all torch.nn.Linear modules to Float8DynamicLinear
from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
swap_linear_with_float8_linear(model, Float8DynamicLinear)

# Wrap model with Fully Sharded Data Parallel (FSDP)
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
import os
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
os.environ['WORLD_SIZE'] = '1'
os.environ['RANK'] = '0'

dist.init_process_group(backend='nccl', init_method='env://')

# model = FSDP(model, use_orig_params=True)

# optionally compile the model
# model = torch.compile(model)

# Prepare your dataset and dataloader (customize this part as needed)
class TextDataset(torch.utils.data.Dataset):
    def __init__(self, texts, tokenizer):
        self.encodings = tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=512)

    def __getitem__(self, idx):
        return {key: val[idx] for key, val in self.encodings.items()}

    def __len__(self):
        return len(self.encodings.input_ids)

# Example text data
texts = ["Example text input 1.", "Example text input 2.", "Example text input 3."]
dataset = TextDataset(texts, tokenizer)
dataloader = DataLoader(dataset, batch_size=2)

# Set up the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)

# Training loop
model.train()
for epoch in range(3):  # Loop over the dataset multiple times
    for i, batch in enumerate(dataloader):
        inputs = {k: v.to(model.device) for k, v in batch.items()}
        
        # Forward pass
        outputs = model(**inputs, labels=inputs['input_ids'])
        loss = outputs.loss
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        print(f'Epoch {epoch + 1}, Step {i + 1}, Loss: {loss.item()}')

# Save the fine-tuned model
model.save_pretrained("./fine_tuned_model")

print("Training complete!")

copied from pytorch-labs/float8_experimental#279

@vkuzo vkuzo added the float8 label Jul 30, 2024
yanbing-j pushed a commit to yanbing-j/ao that referenced this issue Dec 9, 2024
yanbing-j pushed a commit to yanbing-j/ao that referenced this issue Dec 9, 2024
* code beautification

* code beautification, move functions together

* make --device fast the default (pytorch#515)

* make --device fast the default

* Update iOS.md (pytorch#517)

* Update iOS.md

* Update iOS.md

* Pip to pip3 (pytorch#504)

* remove macos-12 test

* pip to pip3

* break aoti CI jobs separately (pytorch#500)

* init

* fixes

* more fixes

* fixes

* fix

* fix

* bug fix

* add objcopy update

* suppress int8

* undefined variable

---------

Co-authored-by: Michael Gschwind <[email protected]>

* Support llama3 in chat in run.cpp  (pytorch#486)

* refactor chat runner in preparation for llama3

* add sketch for llama3 prompt template and move to returning tokens

* fix tiktoken

* fixes to chat

* add default llama_ver

* Add tests for quantize json, add cuda device specification and precision to cuda.json (pytorch#519)

* remove code for no KV Cache path (pytorch#527)

* Update ADVANCED-USERS.md (pytorch#529)

Update Advanced Users description to reflect changes in the repo since the description was initially created.

* runner-aoti on cuda (pytorch#531)

* runner-aoti on cuda

* transfer results back to CPU

* transfer results back to CPU

* runner-aoti on cuda

* Update runner_build.md (pytorch#530)

Update description of runner and build process in runner_build.md

* clean up runner code a little (pytorch#532)

* clean up runner code a little

* update

* update

* pull out generate loop in chat

* updates

* edit docs

* typo

* move int8 linear class and function into qops.py (pytorch#534)

* add dtype tests for runner-aoti + runner-et (pytorch#539)

* add dtype tests for runner-aoti + runner-et

* typo

* Quantized embedding (pytorch#536)

* move int8 linear class and function into qops.py

* move Quantized Embedding to qops.py

* Move Linear int4 to qops (pytorch#537)

* move int8 linear class and function into qops.py

* move Quantized Embedding to qops.py

* move int4 linear to qops

* Revert "add dtype tests for runner-aoti + runner-et (pytorch#539)" (pytorch#548)

This reverts commit a7a24577a65be67ac9ae4dc05452f35d9c49e5d1.

* fix generate for llama3 (pytorch#538)

* fix generate for llama3

* switch more things to C

* remove C++ header

* add delegation visualization instructions (pytorch#551)

* Add dtype runner aoti (pytorch#552)

* add dtype tests for runner-aoti + runner-et

* typo

* add dtype test runner-aoti

* test sdpa with fp16 (pytorch#553)

* test sdpa with fp16

* kv cache fp32

* typo

* update (pytorch#560)

* Only support newest versions of lm-eval (pytorch#556)

Summary:
remove support for lm-eval 0.3 to reduce the options we have

Test Plan:
CI

Reviewers:

Subscribers:

Tasks:

Tags:

* split cpu eval CI by dtype (pytorch#554)

* split cpu eval CI by dtype

* fix

* differentiate names with checks

* keep one name the same as old

* fix

* Removing duplicate HF issue message from README (pytorch#559)

Co-authored-by: Michael Gschwind <[email protected]>

* doc updates (pytorch#567)

* Add VM-safe MPS check

---------

Co-authored-by: Anthony Shoumikhin <[email protected]>
Co-authored-by: metascroy <[email protected]>
Co-authored-by: Nikita Shulga <[email protected]>
Co-authored-by: lucylq <[email protected]>
Co-authored-by: Jerry Zhang <[email protected]>
Co-authored-by: Jack-Khuu <[email protected]>

* add unpacking support (pytorch#525)

* add unpacking support

* fix typos and linter

* perform parallel prefill when possible (pytorch#568)

* perform parallel prefill when possible

* typo

* disable hack

* remove print

* remove debug messages which prevent export

* fixes

* stream results in generate.py (pytorch#571)

* remove logging interfering with export

---------

Co-authored-by: Anthony Shoumikhin <[email protected]>
Co-authored-by: metascroy <[email protected]>
Co-authored-by: Nikita Shulga <[email protected]>
Co-authored-by: lucylq <[email protected]>
Co-authored-by: Jerry Zhang <[email protected]>
Co-authored-by: Jack-Khuu <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant