Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 98 additions & 64 deletions models/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,62 +113,78 @@ def __init__(self, cfg):
print("Warning: scaled dot product attention not available, using standard attention in LM.")

def forward(self, x, cos, sin, attention_mask=None, kv_cache=None):
B, T, C = x.size()
B, T_curr, C = x.size() # T_curr is the sequence length of the current input x

q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) # (B, n_heads, T, head_dim)
q_curr = self.q_proj(x).view(B, T_curr, self.n_heads, self.head_dim).transpose(1, 2) # (B, n_heads, T_curr, head_dim)
k_curr = self.k_proj(x).view(B, T_curr, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B, n_kv_heads, T_curr, head_dim)
v_curr = self.v_proj(x).view(B, T_curr, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B, n_kv_heads, T_curr, head_dim)

# Apply rotary embeddings to the current q and k
q, k_rotated = apply_rotary_pos_embd(q_curr, k_curr, cos, sin)

# Check if we can use cached keys and values
if kv_cache is not None and kv_cache['key'] is not None:
k = kv_cache['key'] # (B, n_kv_heads, T_cached, head_dim)
v = kv_cache['value'] # (B, n_kv_heads, T_cached, head_dim)
# Compute keys and values for the new token only
new_k = self.k_proj(x[:, -1:, :]).view(B, 1, self.n_kv_heads, self.head_dim).transpose(1, 2)
new_v = self.v_proj(x[:, -1:, :]).view(B, 1, self.n_kv_heads, self.head_dim).transpose(1, 2)
# Append new keys and values to cache
k = torch.cat([k, new_k], dim=2)
v = torch.cat([v, new_v], dim=2)
# Concatenate with cached K, V
# k_rotated and v_curr are for the new token(s)
k_past = kv_cache['key']
v_past = kv_cache['value']
k = torch.cat([k_past, k_rotated], dim=2)
v = torch.cat([v_past, v_curr], dim=2)
else:
# Compute keys and values for all tokens
k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B, n_kv_heads, T, head_dim)
v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B, n_kv_heads, T, head_dim)

# Update KV cache
# No cache, this is the first pass (prefill)
k = k_rotated
v = v_curr

new_kv_cache = {'key': k, 'value': v}

# Use precomputed positional embeddings
q, k = apply_rotary_pos_embd(q, k, cos, sin)

k = k.repeat_interleave(self.n_kv_groups, dim=1)
v = v.repeat_interleave(self.n_kv_groups, dim=1)
# Repeat K, V for Grouped Query Attention
k_exp = k.repeat_interleave(self.n_kv_groups, dim=1) # (B, n_heads, T_kv, head_dim)
v_exp = v.repeat_interleave(self.n_kv_groups, dim=1) # (B, n_heads, T_kv, head_dim)

T_kv = k_exp.size(2) # Total sequence length of keys/values

# Process attention mask if provided
# Prepare attention mask for SDPA or manual path
# attention_mask is (B, T_kv_total_length), 1 for attend, 0 for pad
additive_attn_mask = None
if attention_mask is not None:
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # [B, 1, 1, T]
padding_mask = (attention_mask == 0).transpose(-1, -2)
attention_mask = (1.0 - attention_mask) * torch.finfo(q.dtype).min
# The current `attention_mask` parameter is assumed to be `[B, total_sequence_length_kv]`
# Let's make it `[B, 1, 1, T_kv]` for SDPA.
mask_for_keys = attention_mask[:, :T_kv] # Ensure mask matches key length [B, T_kv]
additive_attn_mask = (1.0 - mask_for_keys.unsqueeze(1).unsqueeze(2).float()) * torch.finfo(q.dtype).min
# This additive_attn_mask shape is [B, 1, 1, T_kv]

if self.sdpa:
is_causal_sdpa = (kv_cache is None and T_curr > 1) # True only for prefill of a sequence
# Not for single token decode, even if T_curr=1 then

# When T_curr=1 (decode) and T_kv > 1, is_causal_sdpa is False.
# additive_attn_mask [B,1,1,T_kv] will mask out padded KV elements.
# Attention is for q (1 token) to all KVs. No further causal masking needed within this step by SDPA.
y = torch.nn.functional.scaled_dot_product_attention(
q, k, v,
attn_mask=attention_mask,
q, k_exp, v_exp,
attn_mask=additive_attn_mask,
dropout_p=self.dropout if self.training else 0.0,
is_causal=True # LM attention is causal (masked)
is_causal=is_causal_sdpa
)
else:
attn = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)
causal_mask = torch.tril(torch.ones(T, T, device=x.device)).view(1, 1, T, T)
attn = attn.masked_fill(causal_mask == 0, float('-inf'))
if attention_mask is not None:
attn = attn + attention_mask
# Manual attention implementation
attn = torch.matmul(q, k_exp.transpose(2, 3)) / math.sqrt(self.head_dim) # (B, n_heads, T_curr, T_kv)

# Causal mask for prefill where T_curr == T_kv and T_curr > 1
if kv_cache is None and T_curr > 1 and T_curr == T_kv:
# This creates a lower triangular mask for square attention (prefill)
causal_mask_val = torch.tril(torch.ones(T_curr, T_curr, device=x.device, dtype=torch.bool)).view(1, 1, T_curr, T_curr)
attn = attn.masked_fill(~causal_mask_val, float('-inf'))

if additive_attn_mask is not None: # Additive padding mask
# additive_attn_mask is [B,1,1,T_kv], needs to be broadcast to [B, n_heads, T_curr, T_kv]
attn = attn + additive_attn_mask

attn = F.softmax(attn, dim=-1)
attn = self.attn_dropout(attn)
y = attn @ v
y = attn @ v_exp

if attention_mask is not None:
y = y.masked_fill(padding_mask, 0.0) # Zero out the padded positions in the output

y = y.transpose(1, 2).contiguous().view(B, T, C)
y = y.transpose(1, 2).contiguous().view(B, T_curr, C)
y = self.out_proj(y)
y = self.resid_dropout(y)

Expand Down Expand Up @@ -245,58 +261,76 @@ def _init_weights(self, module):
elif isinstance(module, RMSNorm):
module.weight.data.fill_(1.0)

def forward(self, x, attention_mask=None, kv_cache=None):
def forward(self, x, attention_mask=None, kv_cache=None, start_pos=0):
if self.lm_use_tokens:
x = self.token_embedding(x) # Only embed the inputs when using tokens

B , T, _ = x.size()
x = self.token_embedding(x)

# T_curr is the length of the current input sequence
B, T_curr, _ = x.size()

# Note: You could also cache these input embeddings if you want to avoid recomputing them
position_ids = torch.arange(T, device=x.device).unsqueeze(0).expand(B, -1) # Create position ids [0, 1, 2, ..., seq_len-1]
cos, sin = self.rotary_embd(position_ids) # Get rotary position embeddings
# Create position_ids for the current sequence based on start_pos
current_position_ids = torch.arange(start_pos, start_pos + T_curr, device=x.device).unsqueeze(0).expand(B, -1)
cos, sin = self.rotary_embd(current_position_ids) # Get rotary position embeddings for current tokens

# Initialize new KV cache if none provided
new_kv_cache_list = []
if kv_cache is None:
kv_cache = [None] * len(self.blocks)
new_kv_cache = []

for i, block in enumerate(self.blocks):
x, block_kv_cache = block(x, cos, sin, attention_mask, kv_cache[i])
new_kv_cache.append(block_kv_cache)
new_kv_cache_list.append(block_kv_cache)

x = self.norm(x)

if self.lm_use_tokens:
x = self.head(x) # Compute logits if we are using tokens, otherwise stay in the embedding space
# Compute logits if we are using tokens, otherwise stay in the embedding space
if self.lm_use_tokens:
x = self.head(x)

return x, new_kv_cache
return x, new_kv_cache_list


@torch.no_grad()
@torch.inference_mode()
def generate(self, inputs, max_new_tokens=20):
# Add batch dimension if needed
if inputs.dim() == 1:
inputs = inputs.unsqueeze(0)

generated = inputs.clone()

for _ in range(max_new_tokens):
# Forward pass through the model
outputs = self.forward(generated)
last_output = outputs[:, -1, :]

generated_outputs = inputs.clone()

prompt_output, kv_cache_list = self.forward(
generated_outputs,
attention_mask=None,
kv_cache=None,
start_pos=0
)
last_output = prompt_output[:, -1, :]

# Decode Phase with KV cache
for i in range(max_new_tokens):
if self.lm_use_tokens:
# Now the model outputs logits
next_token = torch.argmax(last_output, dim=-1, keepdim=True)
generated = torch.cat((generated, next_token), dim=-1)
next_output = torch.argmax(last_output, dim=-1, keepdim=True)
else:
# Now the model outputs embeddings
next_token_embedding = last_output.unsqueeze(1) # Shape: [batch_size, 1, hidden_dim]
generated = torch.cat((generated, next_token_embedding), dim=1)
next_output = last_output.unsqueeze(1)

generated_outputs = torch.cat((generated_outputs, next_output), dim=1)

#Note: You could enable the generation to break earlier than max_new_tokens when it detects a eos token, but this does not work in batched generation (output tensors need to have the same size)
# The token being processed is `next_token`. Its position is `generated_outputs.size(1) - 1`.
current_token_start_pos = generated_outputs.size(1) - 1

if i == max_new_tokens - 1:
break

decode_step_output, kv_cache_list = self.forward(
next_output,
attention_mask=None,
kv_cache=kv_cache_list,
start_pos=current_token_start_pos
)
last_output = decode_step_output[:, -1, :]

return generated
return generated_outputs

# Load the model from a pretrained HuggingFace model (we don't want to have to train the Language Backbone from scratch)
@classmethod
Expand Down
108 changes: 65 additions & 43 deletions models/vision_language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,59 +59,81 @@ def forward(self, input_ids, image, attention_mask=None, targets=None):

return logits, loss

@torch.no_grad()
def generate(self, input_ids, image, attention_mask=None, max_new_tokens=5, top_k=50, top_p=0.9, temperature=0.5, greedy=False):
# Process image through vision encoder and projection
image_embd = self.vision_encoder(image)
image_embd = self.MP(image_embd)

# Embed initial tokens
token_embd = self.decoder.token_embedding(input_ids)

# Concatenate image embeddings with token embeddings
combined_embd = torch.cat((image_embd, token_embd), dim=1)
@torch.inference_mode()
def generate(self, input_ids, image, max_new_tokens=5, top_k=50, top_p=0.9, temperature=0.5, greedy=False):
# input_ids: [B, T_prompt_text]
# image: [B, C, H, W]
# max_new_tokens: int, maximum number of tokens to generate

batch_size = image_embd.size(0)
img_seq_len = image_embd.size(1)
# Adjust attention mask to account for image tokens
if attention_mask is not None:
# Create mask of 1s for image tokens (all image tokens should be attended to)
image_attention_mask = torch.ones((batch_size, img_seq_len), device=attention_mask.device, dtype=attention_mask.dtype)
attention_mask = torch.cat((image_attention_mask, attention_mask), dim=1)
B = image.size(0)

# 1. Process image
image_embd = self.vision_encoder(image) # [B, T_img, D_model]
image_embd = self.MP(image_embd) # [B, T_img, D_lm]

# 2. Embed initial text prompt tokens
prompt_token_embeds = self.decoder.token_embedding(input_ids) # [B, T_prompt_text, D_lm]

# 3. Combine image and text prompt embeddings for prefill
initial_combined_embeds = torch.cat((image_embd, prompt_token_embeds), dim=1) # [B, T_img + T_prompt_text, D_lm]
current_total_seq_len = initial_combined_embeds.size(1)

# Initialize KV cache: List to store key and value tensors for each block
kv_cache = [None] * len(self.decoder.blocks)
# --- Multimodal Prefill Phase ---
prefill_output, kv_cache_list = self.decoder.forward(
initial_combined_embeds,
attention_mask=None,
kv_cache=None,
start_pos=0
)

# Generate tokens one by one
outputs = combined_embd
generated_tokens = torch.zeros((batch_size, max_new_tokens), device=input_ids.device, dtype=input_ids.dtype)
last_token_output_from_prefill = prefill_output[:, -1, :]

for i in range(max_new_tokens):
# Pass KV cache to decoder for efficient generation
model_out, kv_cache = self.decoder(outputs, attention_mask, kv_cache=kv_cache)
if not self.decoder.lm_use_tokens:
current_logits = self.decoder.head(last_token_output_from_prefill)
else:
current_logits = last_token_output_from_prefill

# Store newly generated token IDs
newly_generated_ids_list = []

# --- Decode Phase with Sampling tokens one by one ---
for _ in range(max_new_tokens):
if greedy:
next_token_id = torch.argmax(current_logits, dim=-1, keepdim=True)
else:
filtered_logits = top_k_top_p_filtering(current_logits, top_k=top_k, top_p=top_p)
probs = torch.softmax(filtered_logits / temperature, dim=-1)
next_token_id = torch.multinomial(probs, num_samples=1)

newly_generated_ids_list.append(next_token_id)

last_token_logits = model_out[:, -1, :]
# Embed the newly generated token
next_token_embed = self.decoder.token_embedding(next_token_id) # [B, 1, D_lm]

# The start_pos for the new token is the current total sequence length *before* adding this new token
current_token_start_pos = current_total_seq_len
current_total_seq_len += 1

# Call decoder.forward with the new token's embedding and the updated KV cache
decode_step_output, kv_cache_list = self.decoder.forward(
next_token_embed,
attention_mask=None, # Autoregressive, so no explicit mask beyond KV cache structure
kv_cache=kv_cache_list, # Pass the updated cache
start_pos=current_token_start_pos
)

last_token_output = decode_step_output[:, -1, :]

# Apply head to get logits (if model is in embedding mode)
if not self.decoder.lm_use_tokens:
last_token_logits = self.decoder.head(last_token_logits)
if greedy:
next_token = torch.argmax(last_token_logits, dim=-1, keepdim=True)
current_logits = self.decoder.head(last_token_output)
else:
filtered_logits = top_k_top_p_filtering(last_token_logits, top_k=top_k, top_p=top_p)
probs = torch.softmax(filtered_logits/temperature, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)

generated_tokens[:, i] = next_token.squeeze(-1)

# Convert to embedding and append
next_embd = self.decoder.token_embedding(next_token)
outputs = torch.cat((outputs, next_embd), dim=1)

if attention_mask is not None:
attention_mask = torch.cat((attention_mask, torch.ones((batch_size, 1), device=attention_mask.device)), dim=1)
current_logits = last_token_output

return generated_tokens
if not newly_generated_ids_list: # Handle case where max_new_tokens might be 0
return torch.empty((B,0), dtype=torch.long, device=input_ids.device)

return torch.cat(newly_generated_ids_list, dim=1)

@classmethod
def from_pretrained(
Expand Down