diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index ab1f9615b3..0c1ae6846e 100755 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -610,7 +610,7 @@ def compute_valid_sequence_lengths_tensor(input_tokens): attn_mask = input_tokens['attention_mask'] return torch.sum(attn_mask, dim=1) valid_sequence_lengths = compute_valid_sequence_lengths_tensor(input_tokens).to(args.device) - setattr(generation_config, 'valid_sequence_lengths', valid_sequence_lengths) + generation_config.valid_sequence_lengths = valid_sequence_lengths else: input_tokens = tokenizer.batch_encode_plus(input_sentences, return_tensors="pt", padding=True) encode_duration = time.perf_counter() - encode_t0 diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index 0a860f9bc9..ef8bccbec2 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -601,6 +601,7 @@ def setup_generation_config(args, model, assistant_model, tokenizer): generation_config.flash_attention_causal_mask = args.flash_attention_causal_mask generation_config.flash_attention_fast_softmax = args.flash_attention_fast_softmax generation_config.trust_remote_code = args.trust_remote_code + setattr(generation_config, 'valid_sequence_lengths', None) return generation_config