Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 30 additions & 13 deletions unsloth_zoo/rl_replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,16 @@ def selective_log_softmax(logits, index):
# More memory efficient by chunking on (bsz+qlen) dimension
# Exactly equivalent to the above
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
def chunked_selective_log_softmax(logits, index):
def chunked_selective_log_softmax(logits, index, temperature: float = 1.0):
# Split into 4 chunks only
chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
all_per_token_logps = []
# Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
chunk_logits = chunk_logits.to(torch.float32)
if temperature != 1.0:
chunk_logits = chunk_logits / temperature
selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
per_token_logps = selected_logits - logsumexp_values
Expand Down Expand Up @@ -940,6 +942,17 @@ def efficient_log_softmax(hidden_states, lm_head, index, chunks=32,

new_hidden_states_chunk = new_hidden_states_chunk[:, -(logits_to_keep + max_left_pad + 1): , :]
new_hidden_states_chunk = new_hidden_states_chunk[:, :-1, :]
logprobs_chunk = efficient_log_softmax(
new_hidden_states_chunk,
lm_head,
completion_ids,
chunks=input_ids_chunk.shape[0]*multiplier,
logit_scale_multiply=logit_scale_multiply,
logit_scale_divide=logit_scale_divide,
logit_softcapping=logit_softcapping,
temperature=temperature,
batch_size = B
)
else:
new_hidden_states_chunk = unwrapped_model(
input_ids = input_ids_chunk,
Expand All @@ -952,18 +965,22 @@ def efficient_log_softmax(hidden_states, lm_head, index, chunks=32,
).logits

new_hidden_states_chunk = new_hidden_states_chunk[:, :-1, :]

logprobs_chunk = efficient_log_softmax(
new_hidden_states_chunk,
lm_head,
completion_ids,
chunks=input_ids_chunk.shape[0]*multiplier,
logit_scale_multiply=logit_scale_multiply,
logit_scale_divide=logit_scale_divide,
logit_softcapping=logit_softcapping,
temperature=temperature,
batch_size = B
)
# Guard: check if model returned hidden states or logits
if new_hidden_states_chunk.shape[-1] == lm_head.shape[1]:
logprobs_chunk = efficient_log_softmax(
new_hidden_states_chunk,
lm_head,
completion_ids,
chunks=input_ids_chunk.shape[0]*multiplier,
logit_scale_multiply=logit_scale_multiply,
logit_scale_divide=logit_scale_divide,
logit_softcapping=logit_softcapping,
temperature=temperature,
batch_size = B
)
else:
# Model returned logits directly - scaling/softcapping already applied by model forward
logprobs_chunk = chunked_selective_log_softmax(new_hidden_states_chunk, completion_ids, temperature)
#This is needed to avoid race conditions with GPT OSS offload_embbed=True
#However, it seems that this line does not slow down or disrupt models.
device_synchronize()
Expand Down
8 changes: 8 additions & 0 deletions unsloth_zoo/vllm_lora_worker_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,14 @@ def set_active_adapters(self, requests: Set[Any],
if mapping is not None:
self._adapter_manager.set_adapter_mapping(mapping)

def supports_tower_connector_lora(self) -> bool:
manager = getattr(self, '_adapter_manager', None)
if manager is None:
return False
return (
getattr(manager, 'supports_mm', False)
and getattr(manager, 'supports_tower_connector_lora', False)
)

def _apply_adapters(self, adapter_requests: Set[Any]) -> None:
if apply_adapters_worker:
Expand Down
Loading