diff --git a/torchchat/generate.py b/torchchat/generate.py index ad933687d..7f37386ac 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -1192,6 +1192,8 @@ def callback(x, *, done_generating=False): max_seq_length=max_seq_length, attention_backend=self.builder_args.attention_backend, ) + if generator_args.chat_mode: + start_pos += encoded.size(0) for token_tensor, metrics in generator_func: if token_tensor is not None: start_pos += token_tensor.size(0)