Skip to content

Commit

Permalink
optimize tensor creation
Browse files Browse the repository at this point in the history
  • Loading branch information
hoesler authored and rlouf committed Jan 2, 2025
1 parent fddfc8f commit 2f0740e
Showing 1 changed file with 4 additions and 7 deletions.
11 changes: 4 additions & 7 deletions outlines/processors/structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,22 +103,19 @@ def process_logits(

sequence_states.append(self._guide_states[curr_state_key])

mask = torch.ones_like(logits, dtype=torch.bool)

allowed_tokens_batch = []
batch_indices = []
for i, guide_state in enumerate(sequence_states):
allowed_tokens = self.guide.get_next_instruction(guide_state).tokens.to(
mask.device
)
allowed_tokens = self.guide.get_next_instruction(guide_state).tokens
allowed_tokens_batch.append(allowed_tokens)
batch_indices.append(
torch.full_like(allowed_tokens, i)
) # Store batch index for each allowed token

allowed_tokens_concat = torch.cat(allowed_tokens_batch)
batch_indices_concat = torch.cat(batch_indices)
allowed_tokens_concat = torch.cat(allowed_tokens_batch).to(logits.device)
batch_indices_concat = torch.cat(batch_indices).to(logits.device)

mask = torch.ones_like(logits, dtype=torch.bool)
mask[batch_indices_concat, allowed_tokens_concat] = False
logits.masked_fill_(mask, float("-inf"))

Expand Down

0 comments on commit 2f0740e

Please sign in to comment.