Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Lookahead decoding is nondeterministic and wrong after the first call to runner.generate #2263

Open
3 of 4 tasks
tloen opened this issue Sep 27, 2024 · 2 comments
Open
3 of 4 tasks
Labels
bug Something isn't working triaged Issue has been triaged by maintainers

Comments

@tloen
Copy link

tloen commented Sep 27, 2024

System Info

  • x86_64
  • 2TB RAM
  • 8xH100
  • TensorRT-LLM main @ 40274aa
  • Cuda 12.5

Who can help?

@kaiyux @byshiue

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Inside examples/run.py, add a for loop to the generation.

for _ in range(3): # THIS IS THE ONLY CHANGE
        with torch.no_grad():
            outputs = runner.generate(
                batch_input_ids=decoder_input_ids
                if is_enc_dec else batch_input_ids,
                encoder_input_ids=encoder_input_ids if is_enc_dec else None,
                encoder_input_features=encoder_input_features
                if is_enc_dec else None,
                encoder_output_lengths=encoder_output_lengths
                if is_enc_dec else None,
                max_new_tokens=args.max_output_len,
                max_attention_window_size=args.max_attention_window_size,
                sink_token_length=args.sink_token_length,
                end_id=end_id,
                pad_id=pad_id,
                temperature=args.temperature,
                top_k=args.top_k,
                top_p=args.top_p,
                num_beams=args.num_beams,
                length_penalty=args.length_penalty,
                early_stopping=args.early_stopping,
                repetition_penalty=args.repetition_penalty,
                presence_penalty=args.presence_penalty,
                frequency_penalty=args.frequency_penalty,
                stop_words_list=stop_words_list,
                bad_words_list=bad_words_list,
                output_cum_log_probs=(args.output_cum_log_probs_npy != None),
                output_log_probs=(args.output_log_probs_npy != None),
                random_seed=args.random_seed,
                lora_uids=args.lora_task_uids,
                prompt_table=args.prompt_table_path,
                prompt_tasks=args.prompt_tasks,
                streaming=args.streaming,
                output_sequence_lengths=True,
                no_repeat_ngram_size=args.no_repeat_ngram_size,
                return_dict=True,
                medusa_choices=args.medusa_choices,
                return_all_generated_tokens=args.return_all_generated_tokens,
                input_token_extra_ids=input_token_extra_ids)
            torch.cuda.synchronize()

        if args.streaming:
            for curr_outputs in throttle_generator(outputs,
                                                   args.streaming_interval):
                if runtime_rank == 0:
                    output_ids = curr_outputs['output_ids']
                    sequence_lengths = curr_outputs['sequence_lengths']
                    cum_log_probs = None
                    log_probs = None
                    if args.output_cum_log_probs_npy != None:
                        cum_log_probs = outputs['cum_log_probs']
                    if args.output_log_probs_npy != None:
                        log_probs = outputs['log_probs']
                    print_output(
                        tokenizer,
                        output_ids,
                        input_lengths,
                        sequence_lengths,
                        output_csv=args.output_csv,
                        output_npy=args.output_npy,
                        cum_log_probs=cum_log_probs,
                        log_probs=log_probs,
                        output_cum_log_probs_npy=args.output_cum_log_probs_npy,
                        output_log_probs_npy=args.output_log_probs_npy)
        else:
            if runtime_rank == 0:
                output_ids = outputs['output_ids']
                sequence_lengths = outputs['sequence_lengths']
                context_logits = None
                generation_logits = None
                cum_log_probs = None
                log_probs = None
                if runner.gather_context_logits:
                    context_logits = outputs['context_logits']
                if runner.gather_generation_logits:
                    generation_logits = outputs['generation_logits']
                if args.output_cum_log_probs_npy != None:
                    cum_log_probs = outputs['cum_log_probs']
                if args.output_log_probs_npy != None:
                    log_probs = outputs['log_probs']
                print_output(tokenizer,
                             output_ids,
                             input_lengths,
                             sequence_lengths,
                             output_csv=args.output_csv,
                             output_npy=args.output_npy,
                             context_logits=context_logits,
                             generation_logits=generation_logits,
                             output_logits_npy=args.output_logits_npy,
                             cum_log_probs=cum_log_probs,
                             log_probs=log_probs,
                             output_cum_log_probs_npy=args.output_cum_log_probs_npy,
                             output_log_probs_npy=args.output_log_probs_npy)
python run.py \
    --max_output_len=50 \
    --lookahead_config='[2,2,1]' \
    --tokenizer_dir=[DIR] \
    --engine_dir=[DIR]

Expected behavior

Input [Text 0]: "<|begin▁of▁sentence|>You are a diligent AI assistant that follows commands exactly.
### Instruction:
Please say "1" a thousand times.
### Response:
1, 1, 1, 1, 1,"
Output [Text 0 Beam 0]: " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1"
Input [Text 0]: "<|begin▁of▁sentence|>You are a diligent AI assistant that follows commands exactly.
### Instruction:
Please say "1" a thousand times.
### Response:
1, 1, 1, 1, 1,"
Output [Text 0 Beam 0]: " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1"
Input [Text 0]: "<|begin▁of▁sentence|>You are a diligent AI assistant that follows commands exactly.
### Instruction:
Please say "1" a thousand times.
### Response:
1, 1, 1, 1, 1,"
Output [Text 0 Beam 0]: " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1"

actual behavior

Nondeterminism and incorrect responses after first iteration.

Input [Text 0]: "<|begin▁of▁sentence|>You are a diligent AI assistant that follows commands exactly.
### Instruction:
Please say "1" a thousand times.
### Response:
1, 1, 1, 1, 1,"
Output [Text 0 Beam 0]: " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1"
Input [Text 0]: "<|begin▁of▁sentence|>You are a diligent AI assistant that follows commands exactly.
### Instruction:
Please say "1" a thousand times.
### Response:
1, 1, 1, 1, 1,"
Output [Text 0 Beam 0]: " 1, 1, 1, 1, 11 1111111111111111111111111111111111"
Input [Text 0]: "<|begin▁of▁sentence|>You are a diligent AI assistant that follows commands exactly.
### Instruction:
Please say "1" a thousand times.
### Response:
1, 1, 1, 1, 1,"
Output [Text 0 Beam 0]: " 1, 1, 1, 1, 1111111111111111111111111111111111111"

additional notes

Model is Llama architecture.
max_draft_len is 107.
Error doesn't happen when number of verification branches is zero or window size is 1.

@tloen tloen added the bug Something isn't working label Sep 27, 2024
@DanBlanaru DanBlanaru added the triaged Issue has been triaged by maintainers label Oct 4, 2024
@davidmlw
Copy link

davidmlw commented Oct 8, 2024

Thank you very much! The bug has been fixed recently, and will be released soon

@kaiyux
Copy link
Member

kaiyux commented Oct 15, 2024

Hi @tloen , the issue should be addressed after this PR, can you please try and see if that solves the problem? Feel free to let us know if there are any more questions, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

4 participants