-
Notifications
You must be signed in to change notification settings - Fork 88
/
test_llama_mlp.py
116 lines (101 loc) · 3.67 KB
/
test_llama_mlp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
# SPDX-License-Identifier: Apache-2.0
import torch
import pytest
from loguru import logger
import os
import ttnn
from models.demos.llama3.tt.llama_mlp import TtLlamaMLP
from models.demos.llama3.tt.model_config import TtModelArgs
from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.model import FeedForward
from models.utility_functions import (
comp_pcc,
comp_allclose,
)
from models.utility_functions import skip_for_grayskull
@torch.no_grad()
@skip_for_grayskull("Requires wormhole_b0 to run")
@pytest.mark.parametrize(
"mesh_device",
[
{"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get(
os.environ.get("FAKE_DEVICE"), len(ttnn.get_device_ids())
)
],
indirect=True,
)
@pytest.mark.parametrize(
"seq_len",
(
64 * 1024,
32 * 1024,
32,
),
)
@pytest.mark.parametrize(
"batch_size",
(1,),
)
def test_llama_mlp_inference(seq_len, batch_size, mesh_device, use_program_cache, reset_seeds, ensure_gc):
dtype = ttnn.bfloat8_b
mode = "decode" if seq_len <= 32 else "prefill"
mesh_device.enable_async(True)
model_args = TtModelArgs(mesh_device, max_batch_size=batch_size, max_seq_len=128)
model_args.n_layers = 1
state_dict = model_args.load_state_dict()
# Ref model needs partial state dict, but our models use full state dict keys as cached weight names
first_layer_prefix = model_args.get_state_dict_prefix("TtLlamaMLP", 0)
partial_state_dict = {
k[len(first_layer_prefix) + 1 :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix))
}
model_args.WEIGHTS_DTYPE = dtype
reference_model = FeedForward(
dim=model_args.dim,
hidden_dim=4 * model_args.dim,
multiple_of=model_args.multiple_of,
ffn_dim_multiplier=model_args.ffn_dim_multiplier,
)
reference_model.load_state_dict(partial_state_dict)
tt_model = TtLlamaMLP(
mesh_device=mesh_device,
args=model_args,
state_dict=state_dict,
weight_cache_path=model_args.weight_cache_path(dtype),
layer_num=0,
dtype=dtype,
model_config=model_args.get_model_config(),
)
torch_input = torch.randn(1, 1, seq_len, model_args.dim)
reference_output = reference_model(torch_input)
tt_input = ttnn.from_torch(
torch_input,
device=mesh_device,
mesh_mapper=ttnn.ShardTensor2dMesh(
mesh_device, dims=(None, 3) if model_args.is_galaxy else (None, None), mesh_shape=model_args.cluster_shape
), # When both dims are None, the mapper used is `ReplicateTensorToMesh`
dtype=ttnn.bfloat8_b,
memory_config=(
tt_model.model_config["MLP_ACT_MEMCFG"]
if model_args.is_galaxy
else model_args.model_config["SHARDED_MLP_INPUT_MEMCFG"]
)
if mode == "decode"
else ttnn.DRAM_MEMORY_CONFIG,
layout=ttnn.TILE_LAYOUT,
)
logger.info("Run Llama_MLP")
tt_output = tt_model(tt_input, mode)
tt_output_torch = ttnn.to_torch(
tt_output,
mesh_composer=ttnn.ConcatMesh2dToTensor(mesh_device, dims=(1, 3), mesh_shape=model_args.cluster_shape),
)
tt_output_torch = tt_output_torch[:, :1, :, :]
pcc_required = 0.99
passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required)
logger.info(comp_allclose(reference_output, tt_output_torch))
logger.info(f"PCC: {pcc_message}")
if passing:
logger.info("Llama_MLP Passed!")
else:
logger.warning("Llama_MLP Failed!")
assert passing, f"Llama_MLP output does not meet PCC requirement {pcc_required}: {pcc_message}."