diff --git a/models/language_model.py b/models/language_model.py index 668330e4..7dff478d 100644 --- a/models/language_model.py +++ b/models/language_model.py @@ -113,62 +113,78 @@ def __init__(self, cfg): print("Warning: scaled dot product attention not available, using standard attention in LM.") def forward(self, x, cos, sin, attention_mask=None, kv_cache=None): - B, T, C = x.size() + B, T_curr, C = x.size() # T_curr is the sequence length of the current input x - q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) # (B, n_heads, T, head_dim) + q_curr = self.q_proj(x).view(B, T_curr, self.n_heads, self.head_dim).transpose(1, 2) # (B, n_heads, T_curr, head_dim) + k_curr = self.k_proj(x).view(B, T_curr, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B, n_kv_heads, T_curr, head_dim) + v_curr = self.v_proj(x).view(B, T_curr, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B, n_kv_heads, T_curr, head_dim) + + # Apply rotary embeddings to the current q and k + q, k_rotated = apply_rotary_pos_embd(q_curr, k_curr, cos, sin) # Check if we can use cached keys and values if kv_cache is not None and kv_cache['key'] is not None: - k = kv_cache['key'] # (B, n_kv_heads, T_cached, head_dim) - v = kv_cache['value'] # (B, n_kv_heads, T_cached, head_dim) - # Compute keys and values for the new token only - new_k = self.k_proj(x[:, -1:, :]).view(B, 1, self.n_kv_heads, self.head_dim).transpose(1, 2) - new_v = self.v_proj(x[:, -1:, :]).view(B, 1, self.n_kv_heads, self.head_dim).transpose(1, 2) - # Append new keys and values to cache - k = torch.cat([k, new_k], dim=2) - v = torch.cat([v, new_v], dim=2) + # Concatenate with cached K, V + # k_rotated and v_curr are for the new token(s) + k_past = kv_cache['key'] + v_past = kv_cache['value'] + k = torch.cat([k_past, k_rotated], dim=2) + v = torch.cat([v_past, v_curr], dim=2) else: - # Compute keys and values for all tokens - k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B, n_kv_heads, T, head_dim) - v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B, n_kv_heads, T, head_dim) - - # Update KV cache + # No cache, this is the first pass (prefill) + k = k_rotated + v = v_curr + new_kv_cache = {'key': k, 'value': v} - # Use precomputed positional embeddings - q, k = apply_rotary_pos_embd(q, k, cos, sin) - - k = k.repeat_interleave(self.n_kv_groups, dim=1) - v = v.repeat_interleave(self.n_kv_groups, dim=1) + # Repeat K, V for Grouped Query Attention + k_exp = k.repeat_interleave(self.n_kv_groups, dim=1) # (B, n_heads, T_kv, head_dim) + v_exp = v.repeat_interleave(self.n_kv_groups, dim=1) # (B, n_heads, T_kv, head_dim) + + T_kv = k_exp.size(2) # Total sequence length of keys/values - # Process attention mask if provided + # Prepare attention mask for SDPA or manual path + # attention_mask is (B, T_kv_total_length), 1 for attend, 0 for pad + additive_attn_mask = None if attention_mask is not None: - attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # [B, 1, 1, T] - padding_mask = (attention_mask == 0).transpose(-1, -2) - attention_mask = (1.0 - attention_mask) * torch.finfo(q.dtype).min + # The current `attention_mask` parameter is assumed to be `[B, total_sequence_length_kv]` + # Let's make it `[B, 1, 1, T_kv]` for SDPA. + mask_for_keys = attention_mask[:, :T_kv] # Ensure mask matches key length [B, T_kv] + additive_attn_mask = (1.0 - mask_for_keys.unsqueeze(1).unsqueeze(2).float()) * torch.finfo(q.dtype).min + # This additive_attn_mask shape is [B, 1, 1, T_kv] if self.sdpa: + is_causal_sdpa = (kv_cache is None and T_curr > 1) # True only for prefill of a sequence + # Not for single token decode, even if T_curr=1 then + + # When T_curr=1 (decode) and T_kv > 1, is_causal_sdpa is False. + # additive_attn_mask [B,1,1,T_kv] will mask out padded KV elements. + # Attention is for q (1 token) to all KVs. No further causal masking needed within this step by SDPA. y = torch.nn.functional.scaled_dot_product_attention( - q, k, v, - attn_mask=attention_mask, + q, k_exp, v_exp, + attn_mask=additive_attn_mask, dropout_p=self.dropout if self.training else 0.0, - is_causal=True # LM attention is causal (masked) + is_causal=is_causal_sdpa ) else: - attn = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim) - causal_mask = torch.tril(torch.ones(T, T, device=x.device)).view(1, 1, T, T) - attn = attn.masked_fill(causal_mask == 0, float('-inf')) - if attention_mask is not None: - attn = attn + attention_mask + # Manual attention implementation + attn = torch.matmul(q, k_exp.transpose(2, 3)) / math.sqrt(self.head_dim) # (B, n_heads, T_curr, T_kv) + + # Causal mask for prefill where T_curr == T_kv and T_curr > 1 + if kv_cache is None and T_curr > 1 and T_curr == T_kv: + # This creates a lower triangular mask for square attention (prefill) + causal_mask_val = torch.tril(torch.ones(T_curr, T_curr, device=x.device, dtype=torch.bool)).view(1, 1, T_curr, T_curr) + attn = attn.masked_fill(~causal_mask_val, float('-inf')) + + if additive_attn_mask is not None: # Additive padding mask + # additive_attn_mask is [B,1,1,T_kv], needs to be broadcast to [B, n_heads, T_curr, T_kv] + attn = attn + additive_attn_mask attn = F.softmax(attn, dim=-1) attn = self.attn_dropout(attn) - y = attn @ v + y = attn @ v_exp - if attention_mask is not None: - y = y.masked_fill(padding_mask, 0.0) # Zero out the padded positions in the output - - y = y.transpose(1, 2).contiguous().view(B, T, C) + y = y.transpose(1, 2).contiguous().view(B, T_curr, C) y = self.out_proj(y) y = self.resid_dropout(y) @@ -245,58 +261,76 @@ def _init_weights(self, module): elif isinstance(module, RMSNorm): module.weight.data.fill_(1.0) - def forward(self, x, attention_mask=None, kv_cache=None): + def forward(self, x, attention_mask=None, kv_cache=None, start_pos=0): if self.lm_use_tokens: - x = self.token_embedding(x) # Only embed the inputs when using tokens - - B , T, _ = x.size() + x = self.token_embedding(x) + + # T_curr is the length of the current input sequence + B, T_curr, _ = x.size() - # Note: You could also cache these input embeddings if you want to avoid recomputing them - position_ids = torch.arange(T, device=x.device).unsqueeze(0).expand(B, -1) # Create position ids [0, 1, 2, ..., seq_len-1] - cos, sin = self.rotary_embd(position_ids) # Get rotary position embeddings + # Create position_ids for the current sequence based on start_pos + current_position_ids = torch.arange(start_pos, start_pos + T_curr, device=x.device).unsqueeze(0).expand(B, -1) + cos, sin = self.rotary_embd(current_position_ids) # Get rotary position embeddings for current tokens # Initialize new KV cache if none provided + new_kv_cache_list = [] if kv_cache is None: kv_cache = [None] * len(self.blocks) - new_kv_cache = [] for i, block in enumerate(self.blocks): x, block_kv_cache = block(x, cos, sin, attention_mask, kv_cache[i]) - new_kv_cache.append(block_kv_cache) + new_kv_cache_list.append(block_kv_cache) x = self.norm(x) - if self.lm_use_tokens: - x = self.head(x) # Compute logits if we are using tokens, otherwise stay in the embedding space + # Compute logits if we are using tokens, otherwise stay in the embedding space + if self.lm_use_tokens: + x = self.head(x) - return x, new_kv_cache + return x, new_kv_cache_list - @torch.no_grad() + @torch.inference_mode() def generate(self, inputs, max_new_tokens=20): # Add batch dimension if needed if inputs.dim() == 1: inputs = inputs.unsqueeze(0) - - generated = inputs.clone() - - for _ in range(max_new_tokens): - # Forward pass through the model - outputs = self.forward(generated) - last_output = outputs[:, -1, :] - + generated_outputs = inputs.clone() + + prompt_output, kv_cache_list = self.forward( + generated_outputs, + attention_mask=None, + kv_cache=None, + start_pos=0 + ) + last_output = prompt_output[:, -1, :] + + # Decode Phase with KV cache + for i in range(max_new_tokens): if self.lm_use_tokens: # Now the model outputs logits - next_token = torch.argmax(last_output, dim=-1, keepdim=True) - generated = torch.cat((generated, next_token), dim=-1) + next_output = torch.argmax(last_output, dim=-1, keepdim=True) else: # Now the model outputs embeddings - next_token_embedding = last_output.unsqueeze(1) # Shape: [batch_size, 1, hidden_dim] - generated = torch.cat((generated, next_token_embedding), dim=1) + next_output = last_output.unsqueeze(1) + + generated_outputs = torch.cat((generated_outputs, next_output), dim=1) - #Note: You could enable the generation to break earlier than max_new_tokens when it detects a eos token, but this does not work in batched generation (output tensors need to have the same size) + # The token being processed is `next_token`. Its position is `generated_outputs.size(1) - 1`. + current_token_start_pos = generated_outputs.size(1) - 1 + + if i == max_new_tokens - 1: + break + + decode_step_output, kv_cache_list = self.forward( + next_output, + attention_mask=None, + kv_cache=kv_cache_list, + start_pos=current_token_start_pos + ) + last_output = decode_step_output[:, -1, :] - return generated + return generated_outputs # Load the model from a pretrained HuggingFace model (we don't want to have to train the Language Backbone from scratch) @classmethod diff --git a/models/vision_language_model.py b/models/vision_language_model.py index 545545c5..67254137 100644 --- a/models/vision_language_model.py +++ b/models/vision_language_model.py @@ -59,59 +59,81 @@ def forward(self, input_ids, image, attention_mask=None, targets=None): return logits, loss - @torch.no_grad() - def generate(self, input_ids, image, attention_mask=None, max_new_tokens=5, top_k=50, top_p=0.9, temperature=0.5, greedy=False): - # Process image through vision encoder and projection - image_embd = self.vision_encoder(image) - image_embd = self.MP(image_embd) - - # Embed initial tokens - token_embd = self.decoder.token_embedding(input_ids) - - # Concatenate image embeddings with token embeddings - combined_embd = torch.cat((image_embd, token_embd), dim=1) + @torch.inference_mode() + def generate(self, input_ids, image, max_new_tokens=5, top_k=50, top_p=0.9, temperature=0.5, greedy=False): + # input_ids: [B, T_prompt_text] + # image: [B, C, H, W] + # max_new_tokens: int, maximum number of tokens to generate - batch_size = image_embd.size(0) - img_seq_len = image_embd.size(1) - # Adjust attention mask to account for image tokens - if attention_mask is not None: - # Create mask of 1s for image tokens (all image tokens should be attended to) - image_attention_mask = torch.ones((batch_size, img_seq_len), device=attention_mask.device, dtype=attention_mask.dtype) - attention_mask = torch.cat((image_attention_mask, attention_mask), dim=1) + B = image.size(0) + + # 1. Process image + image_embd = self.vision_encoder(image) # [B, T_img, D_model] + image_embd = self.MP(image_embd) # [B, T_img, D_lm] + + # 2. Embed initial text prompt tokens + prompt_token_embeds = self.decoder.token_embedding(input_ids) # [B, T_prompt_text, D_lm] + + # 3. Combine image and text prompt embeddings for prefill + initial_combined_embeds = torch.cat((image_embd, prompt_token_embeds), dim=1) # [B, T_img + T_prompt_text, D_lm] + current_total_seq_len = initial_combined_embeds.size(1) - # Initialize KV cache: List to store key and value tensors for each block - kv_cache = [None] * len(self.decoder.blocks) + # --- Multimodal Prefill Phase --- + prefill_output, kv_cache_list = self.decoder.forward( + initial_combined_embeds, + attention_mask=None, + kv_cache=None, + start_pos=0 + ) - # Generate tokens one by one - outputs = combined_embd - generated_tokens = torch.zeros((batch_size, max_new_tokens), device=input_ids.device, dtype=input_ids.dtype) + last_token_output_from_prefill = prefill_output[:, -1, :] - for i in range(max_new_tokens): - # Pass KV cache to decoder for efficient generation - model_out, kv_cache = self.decoder(outputs, attention_mask, kv_cache=kv_cache) + if not self.decoder.lm_use_tokens: + current_logits = self.decoder.head(last_token_output_from_prefill) + else: + current_logits = last_token_output_from_prefill + + # Store newly generated token IDs + newly_generated_ids_list = [] + + # --- Decode Phase with Sampling tokens one by one --- + for _ in range(max_new_tokens): + if greedy: + next_token_id = torch.argmax(current_logits, dim=-1, keepdim=True) + else: + filtered_logits = top_k_top_p_filtering(current_logits, top_k=top_k, top_p=top_p) + probs = torch.softmax(filtered_logits / temperature, dim=-1) + next_token_id = torch.multinomial(probs, num_samples=1) + + newly_generated_ids_list.append(next_token_id) - last_token_logits = model_out[:, -1, :] + # Embed the newly generated token + next_token_embed = self.decoder.token_embedding(next_token_id) # [B, 1, D_lm] + + # The start_pos for the new token is the current total sequence length *before* adding this new token + current_token_start_pos = current_total_seq_len + current_total_seq_len += 1 + + # Call decoder.forward with the new token's embedding and the updated KV cache + decode_step_output, kv_cache_list = self.decoder.forward( + next_token_embed, + attention_mask=None, # Autoregressive, so no explicit mask beyond KV cache structure + kv_cache=kv_cache_list, # Pass the updated cache + start_pos=current_token_start_pos + ) + + last_token_output = decode_step_output[:, -1, :] # Apply head to get logits (if model is in embedding mode) if not self.decoder.lm_use_tokens: - last_token_logits = self.decoder.head(last_token_logits) - if greedy: - next_token = torch.argmax(last_token_logits, dim=-1, keepdim=True) + current_logits = self.decoder.head(last_token_output) else: - filtered_logits = top_k_top_p_filtering(last_token_logits, top_k=top_k, top_p=top_p) - probs = torch.softmax(filtered_logits/temperature, dim=-1) - next_token = torch.multinomial(probs, num_samples=1) - - generated_tokens[:, i] = next_token.squeeze(-1) - - # Convert to embedding and append - next_embd = self.decoder.token_embedding(next_token) - outputs = torch.cat((outputs, next_embd), dim=1) - - if attention_mask is not None: - attention_mask = torch.cat((attention_mask, torch.ones((batch_size, 1), device=attention_mask.device)), dim=1) + current_logits = last_token_output - return generated_tokens + if not newly_generated_ids_list: # Handle case where max_new_tokens might be 0 + return torch.empty((B,0), dtype=torch.long, device=input_ids.device) + + return torch.cat(newly_generated_ids_list, dim=1) @classmethod def from_pretrained(