Skip to content

Commit

Permalink
Added cpu support for llama generate.py/eval.py (pytorch#1307)
Browse files Browse the repository at this point in the history
  • Loading branch information
jainapurva authored and sunjiweiswift committed Nov 25, 2024
1 parent d0f5538 commit 956b8a1
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 11 deletions.
11 changes: 3 additions & 8 deletions torchao/_models/llama/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,22 @@
from generate import (
_load_model,
device_sync,

)
from torchao.quantization.quant_api import (
from torchao.quantization import (
quantize_,
int4_weight_only,
int8_weight_only,
int8_dynamic_activation_int8_weight,
fpx_weight_only,
uintx_weight_only,
unwrap_tensor_subclass,
float8_weight_only,
float8_dynamic_activation_float8_weight,
float8_static_activation_float8_weight,
)
from torchao._models._eval import TransformerEvalWrapper, InputRecorder
from torchao._models.llama.model import prepare_inputs_for_model
from torchao.quantization.granularity import PerRow, PerTensor

from torchao.quantization import PerRow, PerTensor
from tokenizer import get_tokenizer
import time
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, unwrap_tensor_subclass

def run_evaluation(
checkpoint_path: Path,
Expand Down
9 changes: 6 additions & 3 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tenso
def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, callback=lambda _: _, **sampling_kwargs):
new_tokens, new_probs = [], []
for i in range(num_new_tokens):
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH):
next_token, next_prob = decode_one_token(
model, cur_token, input_pos, **sampling_kwargs
)
Expand Down Expand Up @@ -351,6 +351,8 @@ def main(
torch.cuda.memory._record_memory_history(True,trace_alloc_max_entries=250000, trace_alloc_record_context=True)
elif device == "xpu":
torch.xpu.memory._record_memory_history(True,trace_alloc_max_entries=250000, trace_alloc_record_context=True)
else:
print("Memory profiling only works on CUDA or XPU devices")
aggregate_metrics = {
'tokens_per_sec': [],
}
Expand All @@ -359,7 +361,7 @@ def main(
for i in range(start, num_samples):
if i==0:
if device == "cuda":
torch.cuda.reset_peak_memory_stats()
torch.cuda.reset_peak_memory_stats() # MKG
elif device == "xpu":
torch.xpu.reset_peak_memory_stats()
device_sync(device=device) # MKG
Expand Down Expand Up @@ -433,6 +435,8 @@ def callback(x):
snapshot = torch.cuda.memory._snapshot()
elif device == "xpu":
snapshot = torch.xpu.memory._snapshot()
else:
print("Memory profiling only works on CUDA or XPU devices")
with open(f"{memory_profile}.pickle", 'wb') as f:
from pickle import dump
dump(snapshot, f)
Expand All @@ -441,7 +445,6 @@ def callback(x):
"python pytorch/torch/cuda/_memory_viz.py trace_plot <pickle file> -o <desired output name>.html"
)
break

print("==========")

tokpersec = torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item()
Expand Down

0 comments on commit 956b8a1

Please sign in to comment.