diff --git a/examples/run_qwen2_5_vl_eagle3_evaluation.sh b/examples/run_qwen2_5_vl_eagle3_evaluation.sh new file mode 100644 index 000000000..4a5f75085 --- /dev/null +++ b/examples/run_qwen2_5_vl_eagle3_evaluation.sh @@ -0,0 +1,29 @@ +#!/bin/bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) + +# support tp1 evaluate eagle3 for qwen2.5-vl-7b-instruct +NUM_GPUS=${1:-1} +$CHECKPOINT_PATH= + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/eval_eagle3.py \ + --target-model-path Qwen/Qwen2.5-VL-7B-Instruct \ + --draft-model-config $ROOT_DIR/configs/qwen2-5-vl-7b-eagle3.json \ + --checkpoint-path $CHECKPOINT_PATH \ + --eval-data-path $ROOT_DIR/cache/dataset/allava4v_train.jsonl \ + --max-length 8192 \ + --dist-timeout 360 \ + --chat-template qwen2-vl \ + --attention-backend sdpa \ + --cache-dir $ROOT_DIR/cache \ + --embedding-key model.embed_tokens.weight \ + --tp-size 1 \ + --batch-size 1 \ + --is-vlm \ + --min-pixels 50176 \ + --max-pixels 802816 \ + --verbose \ No newline at end of file diff --git a/scripts/eval_eagle3.py b/scripts/eval_eagle3.py new file mode 100644 index 000000000..9ad4a8c5e --- /dev/null +++ b/scripts/eval_eagle3.py @@ -0,0 +1,298 @@ +import argparse +import hashlib +import math +import os +import time +from collections import defaultdict + +import torch +import torch.distributed as dist +from accelerate.utils import set_seed +from datasets import load_dataset +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import MixedPrecision, ShardingStrategy, StateDictType +from tqdm import tqdm +from transformers import AutoConfig, AutoModelForCausalLM, AutoProcessor, AutoTokenizer + +from specforge import ( + AutoDraftModelConfig, + AutoEagle3DraftModel, + OnlineEagle3Model, + QwenVLOnlineEagle3Model, +) +from specforge.data import ( + build_eagle3_dataset, + generate_vocab_mapping_file, + prepare_dp_dataloaders, +) +from specforge.distributed import ( + destroy_distributed, + get_dp_group, + get_tp_device_mesh, + init_distributed, +) +from specforge.evaluator import ( + QwenVLEagle3Evaluator +) +from specforge.optimizer import BF16Optimizer +from specforge.tracker import create_tracker, get_tracker_class +from specforge.utils import ( + create_draft_config_from_target, + get_last_checkpoint, + print_on_rank0, + print_with_rank, + rank_0_priority, +) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Evaluate Eagle3 model") + + # add model-related arguments + parser.add_argument("--target-model-path", type=str, required=True) + parser.add_argument( + "--draft-model-config", + type=str, + required=False, + help="Draft model config path. If not provided, will auto-generate from target model.", + ) + parser.add_argument( + "--embedding-key", + type=str, + default="model.embed_tokens.weight", + help="The key of the embedding weight to load from the target model", + ) + parser.add_argument( + "--is-vlm", action="store_true", help="Whether the target model is a VLM" + ) + + # add evaluation-related arguments + parser.add_argument("--eval-data-path", type=str, required=True) + parser.add_argument("--batch-size", type=int, default=1) + parser.add_argument("--max-length", type=int, default=2048) + parser.add_argument("--ttt-length", type=int, default=7) + + # data processing type + parser.add_argument("--chat-template", type=str, default="llama3") + parser.add_argument( + "--is-preformatted", + action="store_true", + help="Whether the input data is preformatted text with the chat template already applied to the conversation messages.", + ) + + # distributed evaluation + parser.add_argument("--tp-size", type=int, default=1) + parser.add_argument("--dp-size", type=int, default=1) + + # other args + parser.add_argument("--cache-key", type=str, default=None) + parser.add_argument("--cache-dir", type=str, default="./cache") + parser.add_argument("--checkpoint-path", type=str, required=True, help="Path to the trained model checkpoint") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument( + "--dist-timeout", + type=int, + default=20, + help="Timeout for collective communication in minutes", + ) + parser.add_argument("--attention-backend", type=str, default="flex_attention") + + # vlm related args + parser.add_argument( + "--min-pixels", type=int, default=50176 + ) # 64 * 28 * 28 for qwen2.5-vl + parser.add_argument( + "--max-pixels", type=int, default=802816 + ) # 1024 * 28 * 28 for qwen2.5-vl + + parser.add_argument("--build-dataset-num-proc", type=int, default=8) + parser.add_argument("--verbose", action="store_true") + + args = parser.parse_args() + return parser, args + + +def main(): + # initialize + parser, args = parse_args() + set_seed(args.seed) + init_distributed(timeout=args.dist_timeout, tp_size=args.tp_size) + print_with_rank("Initialized distributed environment") + args.dp_size = dist.get_world_size() // args.tp_size + + # Handle draft model config + if args.draft_model_config is None: + # Auto-generate and save config file + auto_config_path = create_draft_config_from_target( + target_model_path=args.target_model_path, cache_dir=args.cache_dir + ) + draft_model_config = AutoDraftModelConfig.from_file(auto_config_path) + else: + # Use provided config file + draft_model_config = AutoDraftModelConfig.from_file(args.draft_model_config) + + # build target model + if args.is_vlm and draft_model_config.target_model_type == "qwen2_5_vl": + from transformers import Qwen2_5_VLForConditionalGeneration + + target_model = ( + Qwen2_5_VLForConditionalGeneration.from_pretrained( + pretrained_model_name_or_path=args.target_model_path, + torch_dtype=torch.bfloat16, + ) + .eval() + .cuda() + ) + else: + target_model = ( + AutoModelForCausalLM.from_pretrained( + pretrained_model_name_or_path=args.target_model_path, + torch_dtype=torch.bfloat16, + cache_dir=args.cache_dir, + ) + .eval() + .cuda() + ) + + for p in target_model.parameters(): + p.requires_grad = False + + print_with_rank("Initialized target model") + + # load trained draft model + draft_model = AutoEagle3DraftModel.from_pretrained( + args.checkpoint_path, + attention_backend=args.attention_backend, + torch_dtype=torch.bfloat16, + ).cuda() + + # load embedding + draft_model.load_embedding(args.target_model_path, embedding_key=args.embedding_key) + draft_model.freeze_embedding() + print_with_rank("Loaded trained draft model") + + # build dataloaders + tokenizer = AutoTokenizer.from_pretrained(args.target_model_path) + if args.is_vlm: + processor = AutoProcessor.from_pretrained( + args.target_model_path, + min_pixels=args.min_pixels, + max_pixels=args.max_pixels, + ) + else: + processor = None + + # convert to dataloader + cache_params_string = ( + f"{args.eval_data_path}-" + f"{args.max_length}-" + f"{args.chat_template}-" + f"{args.target_model_path}" + ) + cache_key = hashlib.md5(cache_params_string.encode()).hexdigest() + + eval_dataset = load_dataset("json", data_files=args.eval_data_path)["train"] + + with rank_0_priority(): + eval_eagle3_dataset = build_eagle3_dataset( + dataset=eval_dataset, + tokenizer=tokenizer, + chat_template=args.chat_template, + max_length=args.max_length, + cache_dir=os.path.join(args.cache_dir, "processed_dataset"), + cache_key=cache_key, + is_vlm=args.is_vlm, + is_preformatted=args.is_preformatted, + processor=processor, + num_proc=args.build_dataset_num_proc, + ) + + eval_dataloader = prepare_dp_dataloaders( + eval_eagle3_dataset, + args.batch_size, + num_workers=4, + shuffle=False, + process_group=get_dp_group(), + is_vlm=args.is_vlm, + ) + print_with_rank("Initialized eval dataloader") + + # build Eagle3 model + if args.is_vlm and draft_model_config.target_model_type == "qwen2_5_vl": + eagle3_model = QwenVLOnlineEagle3Model( + target_model=target_model, + draft_model=draft_model, + processor=processor, + length=args.ttt_length, + attention_backend=args.attention_backend, + ) + else: + eagle3_model = OnlineEagle3Model( + target_model=target_model, + draft_model=draft_model, + length=args.ttt_length, + attention_backend=args.attention_backend, + ) + + # run evaluation + draft_model.eval() + + evaluator = QwenVLEagle3Evaluator(eagle3_model) + + eval_accept_length = [] + + if dist.get_rank() == 0: + progress_bar = tqdm(eval_dataloader, desc="Evaluating", leave=True) + else: + progress_bar = eval_dataloader + + last_time = time.time() + total_samples = 0 + + for data in progress_bar: + if args.is_vlm: + with torch.no_grad(): + accept_length = evaluator.evaluation( + input_ids=data["input_ids"].cuda(), + attention_mask=data["attention_mask"].cuda(), + loss_mask=data["loss_mask"].cuda(), + pixel_values=data["pixel_values"].cuda(), + image_grid_thw=data["image_grid_thw"].cuda(), + ) + + eval_accept_length.append(accept_length) + total_samples += data["input_ids"].shape[0] + + if args.verbose: + print(f"[{dist.get_rank()}] time={(time.time() - last_time):.3f}s shape={data['input_ids'].shape}") + last_time = time.time() + + if dist.get_rank() == 0: + avg_accept = sum(eval_accept_length) / len(eval_accept_length) + progress_bar.set_postfix({"accept_len": f"{avg_accept:.2f}", "samples": total_samples}) + + # Synchronize and collect results from all devices + local_accept_tensor = torch.tensor(eval_accept_length, dtype=torch.float32, device="cuda") + total_accept_tensor = torch.zeros_like(local_accept_tensor) + dist.all_reduce(local_accept_tensor, op=dist.ReduceOp.SUM) + + total_samples_tensor = torch.tensor(total_samples, dtype=torch.float32, device="cuda") + dist.all_reduce(total_samples_tensor, op=dist.ReduceOp.SUM) + + # calculate accept length + overall_accept_length = (local_accept_tensor.sum() / total_samples_tensor).item() + + # Only rank 0 prints the results + if dist.get_rank() == 0: + print("\n" + "="*70) + print("EVALUATION RESULTS — Accept Length") + print("="*70) + print(f"Average Accept Length (across all samples & devices): {overall_accept_length:.4f}") + print(f"Total samples evaluated: {int(total_samples_tensor.item())}") + print("="*70) + + destroy_distributed() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/specforge/core/eagle3.py b/specforge/core/eagle3.py index b6fcfe949..49012a321 100644 --- a/specforge/core/eagle3.py +++ b/specforge/core/eagle3.py @@ -163,7 +163,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, - use_cache=True, + use_cache=False, ) # update hidden states for next step @@ -537,7 +537,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, - use_cache=True, + use_cache=False, ) # update hidden states for next step @@ -608,4 +608,4 @@ def _compute_target_p(target, t2d, loss_mask): def _compute_metric_acc(logits, target_p, position_mask, loss_mask): return ( (logits.argmax(-1) == target_p.argmax(-1)) * position_mask.squeeze(-1) - ).sum() / loss_mask.sum().clamp_min(1e-6) + ).sum() / loss_mask.sum().clamp_min(1e-6) \ No newline at end of file diff --git a/specforge/evaluator/__init__.py b/specforge/evaluator/__init__.py new file mode 100644 index 000000000..59244c5a2 --- /dev/null +++ b/specforge/evaluator/__init__.py @@ -0,0 +1 @@ +from .qwenvl_evaluator import QwenVLEagle3Evaluator \ No newline at end of file diff --git a/specforge/evaluator/qwenvl_evaluator.py b/specforge/evaluator/qwenvl_evaluator.py new file mode 100644 index 000000000..d3186e529 --- /dev/null +++ b/specforge/evaluator/qwenvl_evaluator.py @@ -0,0 +1,287 @@ +import torch +from typing import List, Optional, Tuple + +from specforge.core.eagle3 import Eagle3Model +from specforge.modeling.draft.llama3_eagle import _make_causal_mask +from transformers.cache_utils import DynamicCache + +class QwenVLEagle3Evaluator(): + def __init__( + self, + eagle3_model: Eagle3Model, + draft_length: int = 7, + ): + """Evaluator for Qwen-VL Eagle3 model performance assessment. + + Args: + eagle3_model: The Eagle3 model instance containing draft and target models + draft_length: Number of tokens to generate in draft phase + """ + self.eagle3_model = eagle3_model + self.draft_model = eagle3_model.draft_model + self.target_model = eagle3_model.target_model + self.draft_length = draft_length + self.attention_backend = eagle3_model.attention_backend + + def draft_token_valid( + self, + draft_ids: torch.Tensor, + target_ids: torch.Tensor + ) -> torch.Tensor: + """Calculate the average acceptance length between draft and target sequences. + + Args: + draft_ids: Token IDs generated by draft model [batch_size, seq_len] + target_ids: Ground truth token IDs [batch_size, seq_len] + + Returns: + Average number of consecutive matching tokens from the beginning + """ + bsz = draft_ids.shape[0] + acc_len = torch.zeros(bsz, dtype=torch.long, device=draft_ids.device) + + # Check token-by-token matching from the start + for i in range(self.draft_length): + mask = (draft_ids[:, i] == target_ids[:, i]) + still_valid = (acc_len == i) # Only increment if previous tokens matched + can_increment = mask & still_valid + acc_len[can_increment] += 1 + + return acc_len.float().mean() + + def draft_generate( + self, + input_ids: torch.Tensor, + hidden_states: torch.Tensor, + cache_hidden: List[List], + past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + pixel_values: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.Tensor] = None, + seq_length_with_past: int = 0, + **kwargs, + ) -> torch.Tensor: + """Generate candidate token sequence using draft model (autoregressive decoding). + + Performs autoregressive decoding with the draft model to generate candidate token + sequences for speculative decoding. + + Args: + input_ids: Input token IDs [batch_size, seq_len] + hidden_states: Hidden states from target model for draft initialization + cache_hidden: KV cache management list for storing past key-values + past_key_values: Past key-value cache for incremental decoding + pixel_values: Image pixel values for vision-language tasks + image_grid_thw: Spatiotemporal dimensions of image grids + seq_length_with_past: Total sequence length including history for position encoding + **kwargs: Additional optional arguments + + Returns: + Generated candidate token sequence [batch_size, draft_length-1] + + Notes: + - Uses greedy decoding strategy for candidate generation + - Supports two attention backends: sdpa and flex_attention + - Uses visual embedding for first decode, text embedding for subsequent steps + - Maps generated tokens to target model vocabulary via d2t mapping + """ + + # Initialize generation sequence by cloning input IDs + generated_ids = input_ids.clone() + # Extract current step hidden states (excluding history) + hidden_states = hidden_states[:, seq_length_with_past:, :] + # Get current token IDs to process + current_ids = input_ids[:, seq_length_with_past:] + bsz, seq_len = current_ids.shape + + # Initialize cache management parameters based on attention backend + if self.attention_backend == "sdpa": + lck = len(cache_hidden[0]) # Current cache length + elif self.attention_backend == "flex_attention": + past_seen_tokens = seq_len + ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + + # Autoregressive decoding loop: generate draft_length-1 tokens + for step in range(self.draft_length - 1): + bsz, seq_len = current_ids.shape + + # Step 1: Prepare input embeddings + if seq_length_with_past == 0: + # First decode (prefill phase): use visual embedding if available + inputs_embeds = self.eagle3_model._get_input_embeds( + current_ids, pixel_values, image_grid_thw + ).to(hidden_states.dtype) + else: + # Subsequent decodes: use text embedding only + inputs_embeds = self.draft_model.embed_input_ids(current_ids).to(hidden_states.dtype) + + # Step 2: Create causal attention mask + if self.attention_backend == "sdpa": + causal_mask = _make_causal_mask( + input_ids_shape=current_ids.shape, + dtype=hidden_states.dtype, + device=hidden_states.device, + past_key_values_length=seq_length_with_past + ) + elif self.attention_backend == "flex_attention": + causal_mask = torch.ones( + (bsz, seq_len + seq_length_with_past), + dtype=torch.bool, + device=hidden_states.device, + ) + + # Step 3: Create position IDs for rotary encoding + position_ids, _ = self.target_model.model.get_rope_index( + generated_ids, + image_grid_thw, + None, + second_per_grid_ts=None, + attention_mask=None, + ) + # Take only position IDs for current sequence + position_ids = position_ids[:, :, -seq_len:] + + # Step 4: Forward pass through draft model backbone + hidden_states_out = self.draft_model.backbone( + input_embeds=inputs_embeds, + hidden_states=hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + cache_hidden=cache_hidden, # Enable KV cache management + use_cache=True, + ) + + # Step 5: Update hidden states (take last token output) + hidden_states = hidden_states_out[:, -1:, :] + + # Step 6: Compute logits and select next token + logits = self.draft_model.compute_logits(hidden_states) + next_token_logits = logits[:, -1, :] # [batch_size, vocab_size] + + # Greedy decoding: select token with highest probability + next_token_ids = torch.argmax(next_token_logits, dim=-1, keepdim=True) + # Map to target model vocabulary space + next_token_ids = next_token_ids + self.draft_model.d2t[next_token_ids] + + # Step 7: Update generated sequence + generated_ids = torch.cat([generated_ids, next_token_ids], dim=1) + current_ids = next_token_ids + seq_length_with_past += seq_len + + # Step 8: Clean up KV cache (trim to valid length) + if self.attention_backend == "sdpa": + cache_hidden[0] = cache_hidden[0][:lck + 1] + cache_hidden[1] = cache_hidden[1][:lck + 1] + elif self.attention_backend == "flex_attention": + past_key_values.crop(past_seen_tokens) + + # Return generated candidate sequence (exclude input, keep draft portion only) + return generated_ids[:, -(self.draft_length-1):] + + def evaluation( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + loss_mask: torch.Tensor, + past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_ids: Optional[torch.Tensor] = None, + pixel_values: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.Tensor] = None, + ) -> float: + """Evaluate draft model performance by calculating average acceptance length. + + Simulates speculative decoding process to evaluate how many candidate tokens + generated by the draft model are accepted by the target sequence, measuring + draft model accuracy and acceleration potential. + + Args: + input_ids: Input token IDs [batch_size, seq_len] + attention_mask: Attention mask identifying valid token positions + loss_mask: Loss calculation mask identifying tokens for loss computation + past_key_values: Past key-value cache for incremental decoding + position_ids: Position encoding IDs + pixel_values: Image pixel values for multimodal tasks + image_grid_thw: Spatiotemporal dimensions of image grids + + Returns: + Average acceptance length indicating how many draft tokens are accepted by target + + Notes: + - Currently supports only batch_size=1 evaluation + - Uses sliding window approach over entire sequence + - Calculates acceptance length by comparing draft tokens with target tokens + - Supports two attention backends: sdpa and flex_attention + """ + + # Step 0: Data preparation and validation + bsz, _ = input_ids.shape + assert bsz == 1 # Currently only supports batch_size=1 evaluation + + # Step 1: Prepare hidden states using target model + hidden_states, _, _, input_ids = self.eagle3_model._prepare_data( + input_ids, attention_mask, loss_mask, pixel_values, image_grid_thw + ) + # Project target model hidden states to draft model space + hidden_states = self.draft_model.project_hidden_states(hidden_states) + + # Determine decoding start position (first token to predict) + start_pos = torch.argmax(loss_mask.squeeze(-1), dim=1, keepdim=True)[-1].item() + start_idx = 1 # Sliding window start index + + # Step 2: Initialize cache based on attention backend + if self.attention_backend == "sdpa": + cache_hidden = [[], []] # First two for attention, rest for fp8 k_cache k_scale for indexer + past_key_values = None + elif self.attention_backend == "flex_attention": + cache_hidden = None + past_key_values = DynamicCache() + + seq_length_with_past = 0 # Processed sequence length (including history) + + # Step 3: Initialize evaluation metrics + round_num = 0 # Evaluation round counter + avg_accept_length = 0 # Cumulative acceptance length + + # Step 4: Sliding window evaluation loop + # Iterate through entire sequence, moving draft_length tokens each time + while start_pos + self.draft_length + start_idx - 1 < len(input_ids[0]): + # Get target token sequence for current window + target_ids = input_ids[:, start_pos + start_idx - 1:start_pos + self.draft_length + start_idx - 1] + + # Prepare hidden states and input IDs for current decoding + cur_hidden_states = hidden_states[:, :start_pos + start_idx, :].clone() + cur_input_ids = input_ids[:, :start_pos + start_idx].clone() + + # Step 5: Generate candidate token sequence using draft model + draft_ids = self.draft_generate( + input_ids=cur_input_ids, + hidden_states=cur_hidden_states, + cache_hidden=cache_hidden, + past_key_values=past_key_values, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + seq_length_with_past=seq_length_with_past + ) + + # Step 6: Construct complete draft sequence (include last token from current input) + draft_ids = torch.cat([cur_input_ids[:, -1:], draft_ids], dim=-1) + + # Update cached sequence length (draft_generate only caches input, not decoded draft tokens) + seq_length_with_past = start_pos + start_idx + + # Step 7: Calculate acceptance length for current window + accept_length = self.draft_token_valid(draft_ids, target_ids) + avg_accept_length += accept_length + round_num += 1 + + # Step 8: Move sliding window + start_idx += self.draft_length + + # Step 9: Calculate average acceptance length + if round_num > 0: + avg_accept_length /= round_num + else: + avg_accept_length = 0.0 + + return avg_accept_length \ No newline at end of file diff --git a/specforge/modeling/draft/llama3_eagle.py b/specforge/modeling/draft/llama3_eagle.py index 2b7daf9d5..e842069f3 100644 --- a/specforge/modeling/draft/llama3_eagle.py +++ b/specforge/modeling/draft/llama3_eagle.py @@ -681,46 +681,27 @@ def forward( key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - cache_hidden[0] = cache_hidden[0] + [key_states] - cache_hidden[1] = cache_hidden[1] + [value_states] + if use_cache: + cache_hidden[0] = cache_hidden[0] + [key_states] + cache_hidden[1] = cache_hidden[1] + [value_states] + else: + cache_hidden[0] = [key_states] + cache_hidden[1] = [value_states] cache_k = cache_hidden[0] cache_v = cache_hidden[1] - k0 = cache_k[0] - v0 = cache_v[0] - - # causal - attn_weights = torch.matmul(query_states, k0.transpose(2, 3)) / math.sqrt( - self.head_dim - ) - lck = len(cache_k) - + k_all = torch.cat(cache_k, dim=2) + attn_weights = torch.matmul(query_states, k_all.transpose(2, 3)) / math.sqrt(self.head_dim) attn_weights = attn_weights + attention_mask - for i in range(1, lck): - ki = cache_k[i] - qi = query_states - kiq = ki - - attn_weightsi = (qi * kiq).sum(-1) / math.sqrt(self.head_dim) - attn_weights = torch.cat( - (attn_weights, attn_weightsi[..., None]), dim=-1 - ) - # upcast attention to fp32 attn_weights = nn.functional.softmax( attn_weights, dim=-1, dtype=torch.float32 ).to(query_states.dtype) - attn_weights0 = attn_weights[..., :q_len] - - attn_output = torch.matmul(attn_weights0, v0) - for i in range(1, lck): - vi = cache_v[i] - attn_weightsi = attn_weights[..., q_len + i - 1] - attn_outputi = attn_weightsi[..., None] * vi - attn_output = attn_output + attn_outputi + v_all = torch.cat(cache_v, dim=2) + attn_output = torch.matmul(attn_weights, v_all) # (bzs, head_num, sql, head_dim) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.head_dim * self.num_heads) @@ -1103,5 +1084,5 @@ def backbone( position_ids=position_ids, past_key_values=past_key_values, output_attentions=False, - use_cache=False, - ) + use_cache=use_cache, # training use_cache = False / draft decoding use_cache = True + ) \ No newline at end of file