Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added cpu support for llama generate.py/eval.py #1307

Merged
merged 2 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 3 additions & 8 deletions torchao/_models/llama/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,22 @@
from generate import (
_load_model,
device_sync,

)
from torchao.quantization.quant_api import (
from torchao.quantization import (
quantize_,
int4_weight_only,
int8_weight_only,
int8_dynamic_activation_int8_weight,
fpx_weight_only,
uintx_weight_only,
unwrap_tensor_subclass,
float8_weight_only,
float8_dynamic_activation_float8_weight,
float8_static_activation_float8_weight,
)
from torchao._models._eval import TransformerEvalWrapper, InputRecorder
from torchao._models.llama.model import prepare_inputs_for_model
from torchao.quantization.granularity import PerRow, PerTensor

from torchao.quantization import PerRow, PerTensor
from tokenizer import get_tokenizer
import time
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, unwrap_tensor_subclass

def run_evaluation(
checkpoint_path: Path,
Expand Down
31 changes: 19 additions & 12 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tenso
def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, callback=lambda _: _, **sampling_kwargs):
new_tokens, new_probs = [], []
for i in range(num_new_tokens):
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH):
next_token, next_prob = decode_one_token(
model, cur_token, input_pos, **sampling_kwargs
)
Expand Down Expand Up @@ -347,15 +347,19 @@ def main(
prefill = torch.compile(prefill, fullgraph=True, dynamic=True)

if memory_profile:
torch.cuda.memory._record_memory_history(True,trace_alloc_max_entries=250000, trace_alloc_record_context=True)
if device != "cuda":
print("Memory profiling only works on CUDA")
else:
torch.cuda.memory._record_memory_history(True,trace_alloc_max_entries=250000, trace_alloc_record_context=True)
aggregate_metrics = {
'tokens_per_sec': [],
}
start = -1 if compile else 0

for i in range(start, num_samples):
if i==0:
torch.cuda.reset_peak_memory_stats()
if device == "cuda":
torch.cuda.reset_peak_memory_stats() # MKG
device_sync(device=device) # MKG
if i >= 0 and interactive:
prompt = input("What is your prompt? ")
Expand Down Expand Up @@ -423,15 +427,18 @@ def callback(x):
print(f"Bandwidth achieved: {model_size * tokens_sec:.02f} GB/s")

if memory_profile and i==0:
snapshot = torch.cuda.memory._snapshot()
with open(f"{memory_profile}.pickle", 'wb') as f:
from pickle import dump
dump(snapshot, f)
print(
f"\nmemory profile {memory_profile}.pickle saved, to convert that to a usable file, use",
"python pytorch/torch/cuda/_memory_viz.py trace_plot <pickle file> -o <desired output name>.html"
)
break
if device != "cuda":
print("Memory profiling only works on CUDA")
else:
snapshot = torch.cuda.memory._snapshot()
with open(f"{memory_profile}.pickle", 'wb') as f:
from pickle import dump
dump(snapshot, f)
print(
f"\nmemory profile {memory_profile}.pickle saved, to convert that to a usable file, use",
"python pytorch/torch/cuda/_memory_viz.py trace_plot <pickle file> -o <desired output name>.html"
)
break

print("==========")

Expand Down
Loading