diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index ee548bc1c8..cdfdd614e9 100755 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -462,7 +462,7 @@ def generate(size=None, reduce_recompile=False): def compute_valid_sequence_lengths_tensor(input_tokens): attn_mask = input_tokens["attention_mask"] - return torch.sum(attn_mask, dim=1) + return torch.sum(attn_mask, dim=1, dtype=torch.int32) valid_sequence_lengths = compute_valid_sequence_lengths_tensor(input_tokens).to(args.device) generation_config.valid_sequence_lengths = valid_sequence_lengths