Skip to content

Commit e5f8de5

Browse files
guangy10Guang Yang
authored andcommitted
fix seq_len dim for models using hybrid cache
1 parent a1a1968 commit e5f8de5

File tree

5 files changed

+43
-27
lines changed

5 files changed

+43
-27
lines changed

install_dev.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def install_dep_from_source():
3434
"-m",
3535
"pip",
3636
"install",
37-
"git+https://github.com/huggingface/transformers@ea013348737fbd0efdefa38f9cad30443a810fd3#egg=transformers",
37+
"git+https://github.com/huggingface/transformers@37367c7d9fd23401c26e79a2b26253ab2d1b7236#egg=transformers",
3838
]
3939
)
4040
subprocess.check_call(

optimum/executorch/modeling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -687,18 +687,18 @@ def generate(
687687
cache_position=torch.tensor([i], dtype=torch.long, device=self.device),
688688
)
689689
self.stats.on_sampling_end()
690+
next_token = torch.argmax(logits, dim=-1).item()
690691
else:
691692
self.stats.on_sampling_begin()
692693
logits = self.forward(
693694
input_ids=torch.tensor(prompt_tokens, dtype=torch.long, device=self.device).unsqueeze(0),
694695
cache_position=torch.arange(len(prompt_tokens), dtype=torch.long, device=self.device),
695696
)
696697
self.stats.on_sampling_end()
697-
698+
next_token = torch.argmax(logits, dim=-1)[0, -1].item()
698699
self.stats.on_prompt_eval_end()
699700
first_token_generated = False
700701

701-
next_token = torch.argmax(logits, dim=-1)[0, -1].item()
702702
generated_tokens = prompt_tokens + [next_token]
703703

704704
while len(generated_tokens) < max_seq_len:

optimum/exporters/executorch/integrations.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import logging
1516
from typing import Dict, Optional
1617

1718
import torch
@@ -44,6 +45,7 @@ def __init__(self, model, use_custom_kv_cache=False):
4445
self.config = model.config
4546
self.use_custom_kv_cache = use_custom_kv_cache
4647
self.metadata = save_config_to_constant_methods(model.config, model.generation_config)
48+
logging.info(f"Metadata to be recorded in PTE: {self.metadata}")
4749

4850
def export(
4951
self,
@@ -58,21 +60,30 @@ def export(
5860
)
5961

6062
max_batch_size = 1
61-
max_cache_len = 4094
6263
seq_length = 3 # Make the sequence length dim to be dynamic in orfer to leverage parallel prefill in ExecuTorch runtime.
6364
example_input_ids = input_ids if input_ids is not None else torch.zeros((1, seq_length), dtype=torch.long)
6465
example_cache_position = (
6566
cache_position if cache_position is not None else torch.arange(seq_length, dtype=torch.long)
6667
)
6768
seq_len_dim = torch.export.Dim(
68-
"seq_length_dim", max=min(self.metadata["get_max_seq_len"], max_cache_len) - 1
69+
"seq_length_dim",
70+
max=min(
71+
self.metadata.get("get_max_seq_len"),
72+
self.metadata.get("sliding_window", float("inf")),
73+
)
74+
- 1,
6975
)
7076
dynamic_shapes = {"input_ids": {1: seq_len_dim}, "cache_position": {0: seq_len_dim}}
7177
strict = parse(torch.__version__) != parse(
7278
"2.7.0"
7379
) # Due to bug https://github.com/pytorch/pytorch/issues/150994
7480

75-
exportable_module = TorchExportableModuleForDecoderOnlyLM(self.model, max_batch_size, max_cache_len)
81+
exportable_module = TorchExportableModuleForDecoderOnlyLM(
82+
self.model,
83+
max_batch_size=max_batch_size,
84+
max_cache_len=self.metadata.get("get_max_seq_len"),
85+
)
86+
7687
if self.use_custom_kv_cache:
7788
from optimum.executorch.attentions.custom_kv_cache import (
7889
replace_with_et_custom_kv_cache,

optimum/exporters/executorch/utils.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,16 +42,14 @@ def save_config_to_constant_methods(
4242
"get_vocab_size": getattr(config, "vocab_size", None),
4343
"get_max_batch_size": 1,
4444
"get_max_seq_len": getattr(config, "max_position_embeddings", None),
45+
"use_kv_cache": getattr(generation_config, "use_cache", None),
46+
"sliding_window": getattr(config, "sliding_window", None),
4547
"decoder_start_token_id": getattr(config, "decoder_start_token_id", None),
4648
"use_sdpa_with_kv_cache": "custom_sdpa" in config._attn_implementation,
4749
}
4850

4951
# Safely access fields from generation_config if it exists
5052
if generation_config is not None:
51-
# Get use_cache with default value
52-
use_cache = getattr(generation_config, "use_cache", None)
53-
metadata["use_kv_cache"] = use_cache
54-
5553
# Check for cache_config and its attributes
5654
cache_config = getattr(generation_config, "cache_config", None)
5755
if cache_config is not None:

tests/models/test_modeling_phi4.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
import pytest
2222
import torchao
23-
from executorch import version as executorch_version
23+
import transformers
2424
from executorch.extension.pybindings.portable_lib import ExecuTorchModule
2525
from packaging.version import parse
2626
from transformers import AutoConfig, AutoTokenizer
@@ -43,41 +43,49 @@ def __init__(self, *args, **kwargs):
4343
@slow
4444
@pytest.mark.run_slow
4545
@pytest.mark.skipif(
46-
is_ci,
47-
reason="Test Phi-4-mini (3.8B) will require runner to be configured with larger RAM",
46+
parse(transformers.__version__) < parse("4.52.0") or parse(torchao.__version__) < parse("0.11.0"),
47+
reason="Only available on transformers >= 4.52.0 and torchao >= 0.11.0",
4848
)
49-
def test_phi4_text_generation(self):
49+
def test_phi4_text_generation_with_custom_sdpa_and_kv_cache_8da4w_8we(self):
5050
model_id = "microsoft/Phi-4-mini-instruct"
5151
config = AutoConfig.from_pretrained(model_id)
5252
# NOTE: To make the model exportable we need to set the rope scaling to default to avoid hitting
5353
# the data-dependent control flow in _longrope_frequency_update. Alternatively, we can rewrite
5454
# that function to avoid the data-dependent control flow.
5555
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
5656
config.rope_scaling["type"] = "default"
57-
model = ExecuTorchModelForCausalLM.from_pretrained(model_id, recipe="xnnpack", config=config)
57+
model = ExecuTorchModelForCausalLM.from_pretrained(
58+
model_id,
59+
recipe="xnnpack",
60+
config=config,
61+
attn_implementation="custom_sdpa",
62+
use_custom_kv_cache=True,
63+
**{"qlinear": True, "qembeeding": True},
64+
)
5865
self.assertIsInstance(model, ExecuTorchModelForCausalLM)
5966
self.assertIsInstance(model.model, ExecuTorchModule)
6067

6168
tokenizer = AutoTokenizer.from_pretrained(model_id)
6269
generated_text = model.text_generation(
6370
tokenizer=tokenizer,
6471
prompt="My favourite condiment is ",
65-
max_seq_len=32,
72+
max_seq_len=64,
6673
)
6774
logging.info(f"\nGenerated text:\n\t{generated_text}")
68-
generated_tokens = tokenizer(generated_text, return_tensors="pt").input_ids
6975

70-
# Free memory before loading eager for quality check
71-
del model
72-
del tokenizer
73-
gc.collect()
76+
if not is_ci:
77+
generated_tokens = tokenizer(generated_text, return_tensors="pt").input_ids
78+
79+
# Free memory before loading eager for quality check
80+
del model
81+
del tokenizer
82+
gc.collect()
7483

75-
self.assertTrue(check_causal_lm_output_quality(model_id, generated_tokens))
84+
self.assertTrue(check_causal_lm_output_quality(model_id, generated_tokens))
7685

7786
@slow
7887
@pytest.mark.run_slow
79-
@pytest.mark.skipif(
80-
parse(executorch_version.__version__) > parse("0.6.0"),
88+
@pytest.mark.skip(
8189
reason="Require cache_position support in executorch runtime. Re-enable when available.",
8290
)
8391
def test_phi4_text_generation_with_quantized_pte_from_hub(self):
@@ -119,9 +127,8 @@ def test_phi4_text_generation_with_quantized_pte_from_hub(self):
119127

120128
@slow
121129
@pytest.mark.run_slow
122-
@pytest.mark.skipif(
123-
parse(torchao.__version__) < parse("0.11.0.dev0"),
124-
reason="Only available on torchao >= 0.11.0.dev0",
130+
@pytest.mark.skip(
131+
reason="Require cache_position support in executorch runtime. Re-enable when available.",
125132
)
126133
def test_phi4_text_generation_with_quantized_ckp(self):
127134
model_id = "pytorch/Phi-4-mini-instruct-8da4w"

0 commit comments

Comments
 (0)