Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
1073 commits
Select commit Hold shift + click to select a range
beecad0
Update llama.py
danielhanchen Mar 22, 2025
e716e15
Update llama.py
danielhanchen Mar 22, 2025
62e4ae5
Update llama.py
danielhanchen Mar 22, 2025
1eba050
Update llama.py
danielhanchen Mar 22, 2025
a015c38
Update llama.py
danielhanchen Mar 22, 2025
0c995e8
Update llama.py
danielhanchen Mar 22, 2025
5ba0878
Update llama.py
danielhanchen Mar 22, 2025
cd5e195
Update llama.py
danielhanchen Mar 22, 2025
855695d
Update llama.py
danielhanchen Mar 22, 2025
4d7e3a1
Update vision.py
danielhanchen Mar 22, 2025
33194f1
HF Transfer
danielhanchen Mar 22, 2025
ef71732
fix(utils): add missing importlib import to fix NameError (#2134)
naliazheli Mar 22, 2025
1d7b570
Add QLoRA Train and Merge16bit Test (#2130)
jeromeku Mar 22, 2025
167b482
Update pyproject.toml
danielhanchen Mar 22, 2025
d28b929
Merge branch 'main' into nightly
danielhanchen Mar 22, 2025
2064656
Merge branch 'main' into nightly
danielhanchen Mar 22, 2025
3fdfff8
Update vision.py
danielhanchen Mar 26, 2025
172fe3c
Update vision.py
danielhanchen Mar 26, 2025
da6ad9f
Update vision.py
danielhanchen Mar 26, 2025
781887f
Update vision.py
danielhanchen Mar 26, 2025
fce9e82
Update loader.py
danielhanchen Mar 26, 2025
9ceabbe
Update loader.py
danielhanchen Mar 26, 2025
87dc533
Revert
danielhanchen Mar 26, 2025
cafd05e
Update vision.py
danielhanchen Mar 26, 2025
6ebcae0
Update vision.py
danielhanchen Mar 26, 2025
9f34d47
Update vision.py
danielhanchen Mar 26, 2025
26b0c83
Update vision.py
danielhanchen Mar 26, 2025
f9dd304
Update vision.py
danielhanchen Mar 26, 2025
10cfe62
Bug fix
danielhanchen Mar 26, 2025
bfa1b9f
Update mapper.py
danielhanchen Mar 26, 2025
b3c2975
check SDPA for Mistral 3, Pixtral
danielhanchen Mar 26, 2025
75ce106
Update vision.py
danielhanchen Mar 26, 2025
86c6060
Versioning
danielhanchen Mar 26, 2025
d4c0550
Update rl_replacements.py
danielhanchen Mar 26, 2025
c493054
Merge branch 'main' into nightly
danielhanchen Mar 26, 2025
0b2b903
Update README.md
jackswl Mar 26, 2025
f6dfa80
add model registry
jeromeku Mar 28, 2025
a5e7b3a
move hf hub utils to unsloth/utils
jeromeku Mar 28, 2025
dc8f34e
refactor global model info dicts to dataclasses
jeromeku Mar 30, 2025
7cd2763
fix dataclass init
jeromeku Mar 30, 2025
9899a72
fix llama registration
jeromeku Mar 30, 2025
310c598
remove deprecated key function
jeromeku Mar 30, 2025
e70d035
start registry reog
jeromeku Mar 30, 2025
de1fe25
add llama vision
jeromeku Mar 30, 2025
7e2207c
quant types -> Enum
jeromeku Mar 30, 2025
c3a1aff
remap literal quant types to QuantType Enum
jeromeku Mar 30, 2025
03de6df
add llama model registration
jeromeku Mar 30, 2025
fa95aa0
fix quant tag mapping
jeromeku Mar 30, 2025
fdafa78
add qwen2.5 models to registry
jeromeku Mar 31, 2025
6049310
add option to include original model in registry
jeromeku Mar 31, 2025
8dc3d66
handle quant types per model size
jeromeku Mar 31, 2025
1237075
separate registration of base and instruct llama3.2
jeromeku Mar 31, 2025
baab018
add QwenQVQ to registry
jeromeku Mar 31, 2025
6b08fc3
add gemma3 to registry
jeromeku Mar 31, 2025
44e227b
add phi
jeromeku Mar 31, 2025
d633179
add deepseek v3
jeromeku Mar 31, 2025
0755b45
add deepseek r1 base
jeromeku Mar 31, 2025
17358e6
add deepseek r1 zero
jeromeku Mar 31, 2025
975d263
add deepseek distill llama
jeromeku Mar 31, 2025
229ae10
add deepseek distill models
jeromeku Mar 31, 2025
6439e88
remove redundant code when constructing model names
jeromeku Mar 31, 2025
4e1df71
add mistral small to registry
jeromeku Mar 31, 2025
6d4ede4
rename model registration methods
jeromeku Apr 1, 2025
a774726
rename deepseek registration methods
jeromeku Apr 1, 2025
a2a4366
refactor naming for mistral and phi
jeromeku Apr 1, 2025
02fbb87
add global register models
jeromeku Apr 1, 2025
7fbde42
refactor model registration tests for new registry apis
jeromeku Apr 1, 2025
a2d3ad9
add model search method
jeromeku Apr 1, 2025
13a1126
remove deprecated registration api
jeromeku Apr 1, 2025
4840a32
add quant type test
jeromeku Apr 1, 2025
7d64639
add registry readme
jeromeku Apr 1, 2025
12b0d32
make llama registration more specific
jeromeku Apr 1, 2025
ea75001
clear registry when executing individual model registration file
jeromeku Apr 1, 2025
d854070
more registry readme updates
jeromeku Apr 1, 2025
2a4a274
Merge branch 'main' into nightly
danielhanchen Apr 1, 2025
03ab51d
Merge pull request #2255 from jeromeku/registry-refactor
shimmyshimmer Apr 2, 2025
0c95691
Merge pull request #2119 from jackswl/patch-1
shimmyshimmer Apr 3, 2025
0c1b3ff
Update _auto_install.py
danielhanchen Apr 5, 2025
d5e1880
Llama4
danielhanchen Apr 6, 2025
56f14e8
Merge branch 'main' into nightly
danielhanchen Apr 10, 2025
1d10f0e
Merge branch 'main' into nightly
danielhanchen Apr 20, 2025
cc2c02c
Merge branch 'main' into nightly
danielhanchen Apr 26, 2025
6dbf677
Merge branch 'main' into nightly
danielhanchen Apr 30, 2025
98177a0
Update synthetic.py
danielhanchen Apr 30, 2025
c217c75
Update synthetic.py
danielhanchen Apr 30, 2025
49d610e
Update synthetic.py
danielhanchen Apr 30, 2025
63698fc
Update synthetic.py
danielhanchen Apr 30, 2025
5b138c7
Update synthetic.py
danielhanchen Apr 30, 2025
c5d632a
Update synthetic.py
danielhanchen Apr 30, 2025
d25f93c
Update synthetic.py
danielhanchen Apr 30, 2025
ad45d26
Update synthetic.py
danielhanchen Apr 30, 2025
4874c72
Update synthetic.py
danielhanchen Apr 30, 2025
95f595a
Update synthetic.py
danielhanchen Apr 30, 2025
de0dbc6
Update synthetic.py
danielhanchen Apr 30, 2025
0ea2279
Synthetic data
danielhanchen Apr 30, 2025
3329c77
Merge branch 'main' into nightly
danielhanchen Apr 30, 2025
d1845c7
Update mapper.py
danielhanchen Apr 30, 2025
64d21f8
Xet and Synthetic
danielhanchen Apr 30, 2025
f522381
Update synthetic.py
danielhanchen Apr 30, 2025
9687fb3
Update loader.py
danielhanchen Apr 30, 2025
0d323a3
Update synthetic.py
danielhanchen Apr 30, 2025
c49d5ff
Update synthetic.py
danielhanchen Apr 30, 2025
c48079b
Update synthetic.py
danielhanchen Apr 30, 2025
1dd6034
Update synthetic.py
danielhanchen Apr 30, 2025
9ae987c
Update synthetic.py
danielhanchen Apr 30, 2025
ccf7065
Update synthetic.py
danielhanchen Apr 30, 2025
376cb9a
Update synthetic.py
danielhanchen Apr 30, 2025
9827a68
Update synthetic.py
danielhanchen Apr 30, 2025
9e6b59e
Update synthetic.py
danielhanchen Apr 30, 2025
fd9f3dc
Update synthetic.py
danielhanchen Apr 30, 2025
3f346e7
Update synthetic.py
danielhanchen Apr 30, 2025
74f42ba
Update synthetic.py
danielhanchen Apr 30, 2025
6dc3383
Update synthetic.py
danielhanchen Apr 30, 2025
7e3849f
Update synthetic.py
danielhanchen Apr 30, 2025
afcbb2c
Update synthetic.py
danielhanchen Apr 30, 2025
49b3343
Update synthetic.py
danielhanchen Apr 30, 2025
f3475b4
Update synthetic.py
danielhanchen Apr 30, 2025
7d5a8b3
Update synthetic.py
danielhanchen Apr 30, 2025
c50c039
Update synthetic.py
danielhanchen Apr 30, 2025
e85e987
Update synthetic.py
danielhanchen Apr 30, 2025
270f02f
Update synthetic.py
danielhanchen Apr 30, 2025
5a05158
Update synthetic.py
danielhanchen Apr 30, 2025
a536173
Update synthetic.py
danielhanchen Apr 30, 2025
90783f7
Update synthetic.py
danielhanchen Apr 30, 2025
eb37b78
Update synthetic.py
danielhanchen Apr 30, 2025
ecdd496
Update synthetic.py
danielhanchen Apr 30, 2025
050306f
Update synthetic.py
danielhanchen Apr 30, 2025
b7ac229
Update pyproject.toml
danielhanchen May 1, 2025
0ee8529
Delete .gitignore
danielhanchen May 1, 2025
bd36d00
Merge branch 'main' into nightly
danielhanchen May 1, 2025
be60490
Update synthetic.py
danielhanchen May 1, 2025
b645497
Update synthetic.py
danielhanchen May 1, 2025
3840255
Update synthetic.py
danielhanchen May 1, 2025
4c4f194
Update synthetic.py
danielhanchen May 1, 2025
5dd52bf
Update synthetic.py
danielhanchen May 1, 2025
8f28047
Update synthetic.py
danielhanchen May 1, 2025
791bfdd
Update synthetic.py
danielhanchen May 1, 2025
e319c96
Update synthetic.py
danielhanchen May 1, 2025
b4798c5
Update synthetic.py
danielhanchen May 1, 2025
b32e2f9
Update synthetic.py
danielhanchen May 1, 2025
1e7ca2f
Update synthetic.py
danielhanchen May 1, 2025
dbd3089
Update synthetic.py
danielhanchen May 1, 2025
6f2d524
Update synthetic.py
danielhanchen May 1, 2025
f8db408
Update synthetic.py
danielhanchen May 1, 2025
cd170e2
Update synthetic.py
danielhanchen May 1, 2025
c874a24
Update synthetic.py
danielhanchen May 1, 2025
152cde6
Update synthetic.py
danielhanchen May 1, 2025
984ca31
Update _utils.py
danielhanchen May 1, 2025
95bc443
Update pyproject.toml
danielhanchen May 1, 2025
d11d060
Update synthetic.py
danielhanchen May 1, 2025
4f3fe1b
Update synthetic.py
danielhanchen May 1, 2025
8154746
Merge branch 'main' into nightly
danielhanchen May 13, 2025
cb02396
Update synthetic.py
danielhanchen May 13, 2025
8ae377a
Update synthetic.py
danielhanchen May 13, 2025
0f7a8b8
Merge branch 'main' into nightly
danielhanchen May 15, 2025
50860ef
Merge branch 'main' into nightly
danielhanchen May 15, 2025
6304676
Update chat_templates.py
danielhanchen May 15, 2025
70c13c4
Seasame force float16 / float32
danielhanchen May 15, 2025
40d8b88
Fix Seasame
danielhanchen May 15, 2025
5684e86
Update loader.py
danielhanchen May 15, 2025
6b6521a
Update vision.py
danielhanchen May 15, 2025
8de07a1
Update vision.py
danielhanchen May 15, 2025
9a7bc91
Update vision.py
danielhanchen May 15, 2025
7502614
Update loader.py
danielhanchen May 15, 2025
636aa9b
is_multimodal
danielhanchen May 15, 2025
fcb3aa7
Update loader.py
danielhanchen May 15, 2025
3aa8a91
Update loader.py
danielhanchen May 15, 2025
c96d7b1
Update loader.py
danielhanchen May 15, 2025
45a85eb
Update loader.py
danielhanchen May 15, 2025
f8f4589
Update vision.py
danielhanchen May 15, 2025
a8c5b6f
Update vision.py
danielhanchen May 15, 2025
8a5b99d
Update vision.py
danielhanchen May 15, 2025
1bb1174
UNSLOTH_DISABLE_STATIC_GENERATION
danielhanchen May 15, 2025
ba6fd2f
Update vision.py
danielhanchen May 15, 2025
8b8ccff
Auto vision detection
danielhanchen May 15, 2025
c504076
Sesame
danielhanchen May 15, 2025
1b142f4
Whisper
danielhanchen May 15, 2025
1ba3128
Update loader.py
danielhanchen May 15, 2025
01f50b0
Update loader.py
danielhanchen May 15, 2025
a0df20a
Update loader.py
danielhanchen May 15, 2025
86b9155
Merge branch 'main' into nightly
danielhanchen May 17, 2025
81c46ec
Update mapper.py
danielhanchen May 17, 2025
65674db
Update vision.py
danielhanchen May 17, 2025
fafb278
Update vision.py
danielhanchen May 17, 2025
48cfac6
Update vision.py
danielhanchen May 17, 2025
424d329
Update vision.py
danielhanchen May 17, 2025
edb9e83
Update vision.py
danielhanchen May 17, 2025
b7fde1c
Update vision.py
danielhanchen May 17, 2025
7650061
Update loader.py
danielhanchen May 17, 2025
3df72b9
Update loader.py
danielhanchen May 17, 2025
04a19ab
Update loader.py
danielhanchen May 17, 2025
6a894cf
Update loader.py
danielhanchen May 17, 2025
caecfe3
Merge branch 'main' into nightly
danielhanchen May 17, 2025
4d28a74
Update _utils.py
danielhanchen May 17, 2025
c2f438d
Merge branch 'main' into nightly
danielhanchen May 22, 2025
b2a1966
Merge branch 'main' into nightly
danielhanchen May 28, 2025
a5e7ca3
Merge branch 'main' into nightly
danielhanchen May 28, 2025
6db6cc6
Update rl.py
danielhanchen May 28, 2025
6ac005e
versioning
danielhanchen May 28, 2025
332b35a
Update rl.py
danielhanchen May 28, 2025
f456e25
Update rl.py
danielhanchen May 28, 2025
74e65cd
Update rl.py
danielhanchen May 28, 2025
c2782a5
Update rl.py
danielhanchen May 28, 2025
dbf185a
Update rl.py
danielhanchen May 28, 2025
1295e09
logging
danielhanchen May 28, 2025
2798e76
Update pyproject.toml
danielhanchen May 28, 2025
b3e37bc
Update rl.py
danielhanchen May 28, 2025
ab03b61
Merge branch 'main' into nightly
danielhanchen May 28, 2025
07cc9b0
Merge branch 'main' into nightly
danielhanchen May 28, 2025
2ca86b4
versioning
danielhanchen May 28, 2025
c872792
Update rl.py
danielhanchen May 28, 2025
c42f136
Update rl.py
danielhanchen May 28, 2025
8e36979
Merge branch 'main' into nightly
danielhanchen May 29, 2025
ac82936
Merge branch 'main' into nightly
danielhanchen Jun 3, 2025
c688bdd
Merge branch 'main' into nightly
danielhanchen Jun 6, 2025
166f536
Merge branch 'main' into nightly
danielhanchen Jun 22, 2025
b81018c
Merge branch 'main' into nightly
danielhanchen Jun 22, 2025
4ab232d
Update rl_replacements.py
danielhanchen Jun 22, 2025
baf8ff4
Update rl_replacements.py
danielhanchen Jun 22, 2025
ee99269
Update rl.py
danielhanchen Jun 22, 2025
794682f
Update rl_replacements.py
danielhanchen Jun 22, 2025
590bd55
Update rl_replacements.py
danielhanchen Jun 22, 2025
060f442
logits / temperature
danielhanchen Jun 22, 2025
12fcb87
Update rl_replacements.py
danielhanchen Jun 22, 2025
b02203c
Update pyproject.toml
danielhanchen Jun 22, 2025
6d5c231
Update rl_replacements.py
danielhanchen Jun 22, 2025
d5509ce
Update rl_replacements.py
danielhanchen Jun 22, 2025
b04cde9
Merge branch 'main' into nightly
danielhanchen Jun 24, 2025
aa2c4a5
Merge branch 'main' of https://github.com/unslothai/unsloth into nightly
danielhanchen Jun 24, 2025
83284d2
Merge branch 'main' into nightly
danielhanchen Jun 24, 2025
b9888c4
Debugging only
danielhanchen Jun 25, 2025
09fed61
Update llama.py
danielhanchen Jun 25, 2025
6a0ac38
Update llama.py
danielhanchen Jun 25, 2025
caa0066
Merge branch 'main' into nightly
danielhanchen Jun 25, 2025
27bce12
Update rl_replacements.py
danielhanchen Jun 26, 2025
7c791d0
Update rl_replacements.py
danielhanchen Jun 26, 2025
de150dc
Update rl_replacements.py
danielhanchen Jun 26, 2025
e87d99c
Update rl_replacements.py
danielhanchen Jun 26, 2025
ae16859
Update rl_replacements.py
danielhanchen Jun 26, 2025
9bc76ee
Generic efficient GRPO
danielhanchen Jun 26, 2025
4843ed7
Update rl_replacements.py
danielhanchen Jun 26, 2025
1c9e4b3
Update rl_replacements.py
danielhanchen Jun 26, 2025
e13fd44
Remove debugging
danielhanchen Jun 26, 2025
22e31f7
Update rl_replacements.py
danielhanchen Jun 26, 2025
74bed43
Update rl_replacements.py
danielhanchen Jun 26, 2025
b756aed
Update vision.py
danielhanchen Jun 26, 2025
2efd4a7
Merge branch 'main' into nightly
danielhanchen Jun 26, 2025
940b7d1
Update llama.py
danielhanchen Jun 26, 2025
5fc2a12
Update rl_replacements.py
danielhanchen Jun 26, 2025
10b3b83
versioning
danielhanchen Jun 26, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ triton = [
]

huggingface = [
"unsloth_zoo>=2025.6.4",
"unsloth_zoo>=2025.6.5",
"packaging",
"tyro",
"transformers>=4.51.3,!=4.47.0,!=4.52.0,!=4.52.1,!=4.52.2",
"transformers>=4.51.3,!=4.47.0,!=4.52.0,!=4.52.1,!=4.52.2,!=4.52.3",
"datasets>=3.4.1",
"sentencepiece>=0.2.0",
"tqdm",
Expand Down Expand Up @@ -381,10 +381,10 @@ colab-ampere-torch220 = [
"flash-attn>=2.6.3",
]
colab-new = [
"unsloth_zoo>=2025.6.4",
"unsloth_zoo>=2025.6.5",
"packaging",
"tyro",
"transformers>=4.51.3,!=4.47.0,!=4.52.0,!=4.52.1,!=4.52.2",
"transformers>=4.51.3,!=4.47.0,!=4.52.0,!=4.52.1,!=4.52.2,!=4.52.3",
"datasets>=3.4.1",
"sentencepiece>=0.2.0",
"tqdm",
Expand Down
2 changes: 1 addition & 1 deletion unsloth/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = "2025.6.5"
__version__ = "2025.6.6"

__all__ = [
"SUPPORTS_BFLOAT16",
Expand Down
75 changes: 47 additions & 28 deletions unsloth/models/rl_replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,10 +247,10 @@ def _move_model_to_vllm(self, *args, **kwargs): return None

# Edit _get_per_token_logps to handle mixed precision
def grpo_trainer__get_per_token_logps(function_name, function):
if function_name != "_get_per_token_logps": return function
if function_name != "_get_per_token_logps": return function

def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep, calc_logprob_flag = None):
if os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0' and not calc_logprob_flag:
def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
if True: # os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0':
return None # Unsloth efficient GRPO
# Otherwise, calculate normally:
if not hasattr(self, '_autocast_dtype'):
Expand All @@ -260,9 +260,13 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep,
os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1"
with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype):
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
hidden_states = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
#logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
return hidden_states
logits = model(
input_ids = input_ids,
attention_mask = attention_mask,
logits_to_keep = logits_to_keep + 1,
).logits
# logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
return logits
# input_ids = input_ids[:, -logits_to_keep:]
# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
# See https://github.com/huggingface/trl/issues/2770
Expand Down Expand Up @@ -331,19 +335,24 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch
# per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
# per_token_loss = -(per_token_loss - self.beta * per_token_kl)
# loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
if "old_per_token_logps" in inputs.keys():
old_hidden_states = inputs["old_per_token_logps"]
else:
old_hidden_states = None

old_hidden_states = inputs.get("old_per_token_logps", None)
input_ids = input_ids[:, -logits_to_keep:]

# Get logit softcapping and logit scale
logit_softcapping = getattr(model.config, "final_logit_softcapping", 0) # Gemma
if logit_softcapping is None: logit_softcapping = 0
logit_scale_multiply = getattr(model.config, "logit_scale", 0) # Cohere
if logit_scale_multiply is None: logit_scale_multiply = 0
logit_scale_divide = getattr(model.config, "logits_scaling", 0) # Granite
if logit_scale_divide is None: logit_scale_divide = 0


if per_token_logps is not None:

if ref_per_token_logps is not None:
ref_per_token_logps = ref_per_token_logps[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred

per_token_logps = per_token_logps[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred

loss, completion_length, mean_kl = grpo_compute_loss_slow(
ref_per_token_logps,
per_token_logps,
Expand All @@ -358,43 +367,53 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch
max_completion_length = self.args.max_completion_length,
delta = self.args.delta,
temperature = self.args.temperature,
logit_softcapping = logit_softcapping,
logit_scale_multiply = logit_scale_multiply,
logit_scale_divide = logit_scale_divide,
)
else:
if hasattr(self.args, "loss_type"):
loss, completion_length, mean_kl = grpo_accumulated_loss(
self,
_input_ids,
logits_to_keep,
completion_mask,
advantages,
old_hidden_states,
trainer = self,
input_ids = _input_ids,
logits_to_keep = logits_to_keep,
completion_mask = completion_mask,
advantages = advantages,
old_hidden_states = old_hidden_states,
n_chunks = self.args.unsloth_num_chunks,
loss_type = self.args.loss_type,
epsilon_low = self.epsilon_low,
epsilon_high = self.epsilon_high,
max_completion_length = self.args.max_completion_length,
delta = self.args.delta,
temperature = self.args.temperature,
logit_softcapping = logit_softcapping,
logit_scale_multiply = logit_scale_multiply,
logit_scale_divide = logit_scale_divide,
attention_mask = attention_mask,
)
else:
# to ensure backwards compatibility with trl 0.15.2 and maybe even 0.17
loss, completion_length, mean_kl = grpo_accumulated_loss(
self,
_input_ids,
logits_to_keep,
completion_mask,
advantages,
old_hidden_states,
trainer = self,
input_ids = _input_ids,
logits_to_keep = logits_to_keep,
completion_mask = completion_mask,
advantages = advantages,
old_hidden_states = old_hidden_states,
n_chunks = self.args.unsloth_num_chunks,
temperature = self.args.temperature,
logit_softcapping = logit_softcapping,
logit_scale_multiply = logit_scale_multiply,
logit_scale_divide = logit_scale_divide,
attention_mask = attention_mask,
)

pass
pass
# Log the metrics
# completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()

# mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
# self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())

if "train" in self._metrics:
mode = "eval" if self.control.should_evaluate else "train"
self._metrics[mode]["completion_length"].append(completion_length.item())
Expand Down
2 changes: 2 additions & 0 deletions unsloth/models/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,8 @@ def _for_inference(m):
embeddings = model.get_output_embeddings()
if hasattr(embeddings, "training"): embeddings.training = False
pass
# Must disable returning hidden states in the case for GRPO
os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "0"
return model
pass

Expand Down