Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
144 commits
Select commit Hold shift + click to select a range
c39f56f
Fix mamba
danielhanchen Aug 9, 2025
4bd35c5
Update loader.py
danielhanchen Aug 9, 2025
1f0a4c3
Update vision.py
danielhanchen Aug 9, 2025
3cb9719
Update loader.py
danielhanchen Aug 9, 2025
a4081af
Merge branch 'main' into nightly
danielhanchen Aug 13, 2025
1432eac
Filter vLLM standby logs (#3131)
Datta0 Aug 13, 2025
fd1124a
Update loader.py
danielhanchen Aug 13, 2025
b78189b
Add scaler
danielhanchen Aug 13, 2025
cd2e284
Update llama.py
danielhanchen Aug 13, 2025
5e976a5
Update _utils.py
danielhanchen Aug 13, 2025
f451adf
Versioning
danielhanchen Aug 13, 2025
dafc7b8
Merge branch 'main' into nightly
danielhanchen Aug 13, 2025
bf5c402
Merge branch 'main' into nightly
danielhanchen Aug 13, 2025
3b82c42
GPT OSS fix
danielhanchen Aug 14, 2025
61366ef
GPT OSS fix
danielhanchen Aug 14, 2025
de043d9
Update loader.py
danielhanchen Aug 14, 2025
c1ef6f1
Update vision.py
danielhanchen Aug 14, 2025
f18cd26
Update vision.py
danielhanchen Aug 14, 2025
0215224
Update loader.py
danielhanchen Aug 14, 2025
5ed4a46
Update vision.py
danielhanchen Aug 15, 2025
e2ebb99
Merge branch 'main' into nightly
danielhanchen Aug 15, 2025
a222558
Update vision.py
danielhanchen Aug 15, 2025
cdcfe7d
Merge branch 'main' into nightly
danielhanchen Aug 15, 2025
6cffb1c
Update llama.py
danielhanchen Aug 15, 2025
15d33a5
Update llama.py
danielhanchen Aug 15, 2025
95a4daf
Update llama.py
danielhanchen Aug 15, 2025
4104bba
Versioning
danielhanchen Aug 15, 2025
8cc1999
Update mapper.py
danielhanchen Aug 15, 2025
a5dffd7
Merge branch 'main' into nightly
danielhanchen Aug 16, 2025
ffda8a7
Update vision.py
danielhanchen Aug 16, 2025
cdf2e17
Update vision.py
danielhanchen Aug 16, 2025
941d1ae
Update vision.py
danielhanchen Aug 16, 2025
73fa72c
Upcast norms
danielhanchen Aug 16, 2025
e4bbeef
Update loader.py
danielhanchen Aug 16, 2025
c8d00be
Update vision.py
danielhanchen Aug 16, 2025
564b6f8
Upcast layernorms
danielhanchen Aug 17, 2025
b8a34b4
Update llama.py
danielhanchen Aug 17, 2025
509fcb5
Update llama.py
danielhanchen Aug 17, 2025
27f1a2e
Update llama.py
danielhanchen Aug 18, 2025
931851a
Update llama.py
danielhanchen Aug 18, 2025
3b9057b
Update llama.py
danielhanchen Aug 18, 2025
3dd87bb
Update llama.py
danielhanchen Aug 18, 2025
f3f2b51
Merge branch 'main' into nightly
danielhanchen Aug 18, 2025
b757faf
Update save.py
danielhanchen Aug 18, 2025
2e86333
Update rl.py
danielhanchen Aug 18, 2025
b01e948
Update pyproject.toml
danielhanchen Aug 18, 2025
b064255
Merge branch 'main' into nightly
danielhanchen Aug 18, 2025
a751fd7
Update rl.py
danielhanchen Aug 18, 2025
c5d22e1
Merge branch 'main' into nightly
danielhanchen Aug 18, 2025
3cb6eaf
Update rl_replacements.py
danielhanchen Aug 18, 2025
de77a26
Update rl.py
danielhanchen Aug 19, 2025
27ca531
Update rl.py
danielhanchen Aug 19, 2025
6514c8e
Update rl.py
danielhanchen Aug 19, 2025
3e29ae7
Update _utils.py
danielhanchen Aug 19, 2025
a42f624
Update __init__.py
danielhanchen Aug 19, 2025
9437f9e
Torch 2.8
danielhanchen Aug 19, 2025
1dd99a2
Update rl_replacements.py
danielhanchen Aug 19, 2025
5d5ece0
Merge branch 'main' into nightly
danielhanchen Aug 19, 2025
ecd8d38
Merge branch 'main' into nightly
danielhanchen Aug 19, 2025
89b5603
Merge branch 'main' into nightly
danielhanchen Aug 19, 2025
fa68976
Merge branch 'main' into nightly
danielhanchen Aug 20, 2025
5349cd0
Update loader.py
danielhanchen Aug 20, 2025
5a344c2
UNSLOTH_ENABLE_CCE
danielhanchen Aug 20, 2025
e56363c
Fix
danielhanchen Aug 20, 2025
c79aece
Update loader.py
danielhanchen Aug 20, 2025
c4b530c
Update loader.py
danielhanchen Aug 20, 2025
0913b58
Update __init__.py
danielhanchen Aug 20, 2025
374f703
Update __init__.py
danielhanchen Aug 20, 2025
c0efbec
Update __init__.py
danielhanchen Aug 20, 2025
761a445
Update __init__.py
danielhanchen Aug 20, 2025
30ea44c
Import fixes
danielhanchen Aug 20, 2025
c45467c
Update loader.py
danielhanchen Aug 20, 2025
55e4c78
Fix aimv2 issue
danielhanchen Aug 20, 2025
a160e42
Update loader.py
danielhanchen Aug 20, 2025
675c4ef
Update import_fixes.py
danielhanchen Aug 20, 2025
a99d6b2
Update import_fixes.py
danielhanchen Aug 20, 2025
7e82623
Update loader.py
danielhanchen Aug 20, 2025
0e678d6
Update loader.py
danielhanchen Aug 20, 2025
9b82317
Update loader.py
danielhanchen Aug 20, 2025
8a76fd3
Upgrade
danielhanchen Aug 20, 2025
94bcb28
Update loader.py
danielhanchen Aug 20, 2025
7d7a115
Update loader.py
danielhanchen Aug 20, 2025
031f5e1
Update loader.py
danielhanchen Aug 20, 2025
98bee64
Update loader.py
danielhanchen Aug 20, 2025
21fa9fd
Merge branch 'main' into nightly
danielhanchen Aug 20, 2025
2ba9008
Update vision.py
danielhanchen Aug 21, 2025
ea435e6
Update vision.py
danielhanchen Aug 21, 2025
5bebfa9
custom_datatype
danielhanchen Aug 21, 2025
356789a
recheck
danielhanchen Aug 21, 2025
d0f97a9
Float16
danielhanchen Aug 21, 2025
d83767f
Update vision.py
danielhanchen Aug 21, 2025
5b575d8
Update vision.py
danielhanchen Aug 21, 2025
66eee4d
Update vision.py
danielhanchen Aug 21, 2025
27d044e
Update vision.py
danielhanchen Aug 21, 2025
34d07d8
Update vision.py
danielhanchen Aug 21, 2025
3ad7561
Update loader.py
danielhanchen Aug 21, 2025
b757297
Update loader.py
danielhanchen Aug 21, 2025
ceeca86
Update loader.py
danielhanchen Aug 21, 2025
87758b9
Update loader.py
danielhanchen Aug 21, 2025
97d34d4
Update loader.py
danielhanchen Aug 21, 2025
43bf41f
Update loader.py
danielhanchen Aug 21, 2025
6e7ad52
Update loader.py
danielhanchen Aug 21, 2025
d605aa7
Update loader.py
danielhanchen Aug 21, 2025
f417dc8
Update loader.py
danielhanchen Aug 21, 2025
05fe3d1
Update loader.py
danielhanchen Aug 21, 2025
a79d6f6
Update loader.py
danielhanchen Aug 21, 2025
59702c4
Update loader.py
danielhanchen Aug 21, 2025
1b66aee
Update loader.py
danielhanchen Aug 21, 2025
a71fa05
Update loader.py
danielhanchen Aug 21, 2025
d3e8625
Update loader.py
danielhanchen Aug 21, 2025
fb112cf
Update loader.py
danielhanchen Aug 21, 2025
5dbdcc5
Update loader.py
danielhanchen Aug 21, 2025
fdaa007
Update loader.py
danielhanchen Aug 21, 2025
ba0eb04
Bug fix
danielhanchen Aug 21, 2025
3f98262
Update loader.py
danielhanchen Aug 21, 2025
3e6511b
Update loader.py
danielhanchen Aug 21, 2025
c9e7537
Update loader.py
danielhanchen Aug 21, 2025
2e38e8a
Update loader.py
danielhanchen Aug 22, 2025
8b3a8ba
Update loader.py
danielhanchen Aug 22, 2025
f706d20
torch_dtype
danielhanchen Aug 22, 2025
bf863a8
Merge branch 'main' into nightly
danielhanchen Aug 28, 2025
84ca61f
Merge branch 'main' into nightly
danielhanchen Aug 30, 2025
e82fd70
Merge branch 'main' into nightly
danielhanchen Sep 4, 2025
c61a21d
Merge branch 'main' into nightly
danielhanchen Sep 4, 2025
b56cc1b
Update rl.py
danielhanchen Sep 4, 2025
c47f936
Fix CE Loss
danielhanchen Sep 4, 2025
6093c4c
Merge branch 'main' into nightly
danielhanchen Sep 4, 2025
0b896c5
Versioning
danielhanchen Sep 4, 2025
327f517
Merge branch 'main' into nightly
danielhanchen Sep 4, 2025
5b0c47a
Merge branch 'main' into nightly
danielhanchen Sep 8, 2025
de5c3b5
Merge branch 'main' into nightly
danielhanchen Sep 9, 2025
7234a62
Update loader.py
danielhanchen Sep 9, 2025
68c1aba
Update loader.py
danielhanchen Sep 9, 2025
d07b819
Merge branch 'main' into nightly
danielhanchen Sep 9, 2025
05fc2f2
extract_model_type_from_config
danielhanchen Sep 9, 2025
99c7afb
Model types
danielhanchen Sep 10, 2025
fc5d91d
Update loader.py
danielhanchen Sep 10, 2025
702a9ea
get_transformers_model_type
danielhanchen Sep 10, 2025
8ece4a6
Update loader.py
danielhanchen Sep 10, 2025
f3ac0e3
Update loader.py
danielhanchen Sep 10, 2025
d2b0d41
Update loader.py
danielhanchen Sep 10, 2025
e5920fe
Update rl.py
danielhanchen Sep 10, 2025
bf0367e
Update pyproject.toml
danielhanchen Sep 10, 2025
d2c2cc1
Update loader.py
danielhanchen Sep 10, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ triton = [
]

huggingface = [
"unsloth_zoo>=2025.9.3",
"unsloth_zoo>=2025.9.4",
"packaging",
"tyro",
"transformers>=4.51.3,!=4.52.0,!=4.52.1,!=4.52.2,!=4.52.3,!=4.53.0,!=4.54.0,!=4.55.0,!=4.55.1",
Expand Down Expand Up @@ -453,7 +453,7 @@ colab-ampere-torch220 = [
"flash-attn>=2.6.3",
]
colab-new = [
"unsloth_zoo>=2025.9.3",
"unsloth_zoo>=2025.9.4",
"packaging",
"tyro",
"transformers>=4.51.3,!=4.47.0,!=4.52.0,!=4.52.1,!=4.52.2,!=4.52.3,!=4.53.0,!=4.54.0,!=4.55.0,!=4.55.1",
Expand Down
2 changes: 1 addition & 1 deletion unsloth/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = "2025.9.2"
__version__ = "2025.9.3"

__all__ = [
"SUPPORTS_BFLOAT16",
Expand Down
132 changes: 66 additions & 66 deletions unsloth/models/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
global FORCE_FLOAT32
FORCE_FLOAT32 = [
"gemma3",
"gemma3n",
"gpt_oss",
]

Expand Down Expand Up @@ -177,6 +178,8 @@ def from_pretrained(

autoconfig_error = None
peft_error = None
model_config = None
peft_config = None
try:
model_config = AutoConfig.from_pretrained(
model_name,
Expand All @@ -200,8 +203,12 @@ def from_pretrained(
peft_error = str(error)
is_peft = False
pass

# Both config.json and adapter_config.json should not exist!
model_types = get_transformers_model_type(model_config or peft_config)
if len(model_types) == 1:
model_type = model_types[0]
else:
# Leave as tuple if more than one arch
model_type = model_types

# Old transformers versions check
both_exist = (is_model and is_peft) and not SUPPORTS_LLAMA32
Expand Down Expand Up @@ -266,8 +273,6 @@ def from_pretrained(

if not was_disabled: enable_progress_bars()

model_type = model_config.model_type

if model_type == "llama":
scaling_type = None
if getattr(model_config, "rope_scaling", None) is not None:
Expand Down Expand Up @@ -493,10 +498,11 @@ def from_pretrained(
from transformers import AutoModelForVision2Seq
pass

# Must be alphabetically sorted for each entry
DISABLE_COMPILE_MODEL_NAMES = [
"aya-vision",
"aya_vision",
"modernbert",
"granite-vision",
"granite,llava_next", # Granite-vision 3
]


Expand Down Expand Up @@ -573,20 +579,61 @@ def from_pretrained(
if not use_exact_model_name:
model_name = get_model_name(model_name, load_in_4bit)

# Check modelscope
if USE_MODELSCOPE and not os.path.exists(model_name):
from modelscope import snapshot_download
model_name = snapshot_download(model_name)
pass

# First check if it's a normal model via AutoConfig
from huggingface_hub.utils import disable_progress_bars, enable_progress_bars, are_progress_bars_disabled
was_disabled = are_progress_bars_disabled()
disable_progress_bars()

autoconfig_error = None
peft_error = None
model_config = None
peft_config = None
try:
model_config = AutoConfig.from_pretrained(
model_name,
token = token,
revision = revision,
trust_remote_code = trust_remote_code,
)
is_model = True
except Exception as error:
autoconfig_error = str(error)
is_model = False
try:
peft_config = PeftConfig.from_pretrained(
model_name,
token = token,
revision = revision,
trust_remote_code = trust_remote_code,
)
is_peft = True
except Exception as error:
peft_error = str(error)
is_peft = False
pass
model_types = get_transformers_model_type(model_config or peft_config)
model_types_all = ",".join(model_types)

# Check versions
lowered_model_name = model_name.lower()
os.environ["UNSLOTH_MODEL_NAME"] = lowered_model_name
LATEST = '\nPlease use transformers via `pip install --no-deps git+https://github.com/huggingface/transformers.git`'
NIGHTLY = '\nPlease use nightly transformers via pip install --upgrade "transformers>=4.49.0"`'
# Pixtral
if "pixtral" in lowered_model_name and transformers_version < Version("4.49.0"):
if "pixtral" in model_types_all and transformers_version < Version("4.49.0"):
raise RuntimeError("Unsloth: Pixtral only works on transformers >= 4.49.0." + LATEST)
# Qwen 2.5
elif "qwen2.5" in lowered_model_name and transformers_version < Version("4.49.0"):
elif "qwen2_5" in model_types_all and transformers_version < Version("4.49.0"):
raise RuntimeError("Unsloth: Qwen 2.5 only works on transformers >= 4.49.0." + LATEST)
# Gemma 3
elif "gemma-3" in lowered_model_name:
if "gemma-3n" in lowered_model_name:
elif "gemma3" in model_types_all:
if "gemma3n" in model_types_all:
if transformers_version < Version("4.53.0"):
raise RuntimeError("Unsloth: Gemma 3N only works on transformers >= 4.53.0" + LATEST)
os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1"
Expand All @@ -604,33 +651,33 @@ def from_pretrained(
# common in both gemma-3 and gemma-3n
os.environ["UNSLOTH_HIGH_PRECISION_LAYERNORM"] = "1"
# Cohere
elif "c4ai-command-a-03-2025" in lowered_model_name and transformers_version < Version("4.50.0.dev0"):
elif "cohere2" in model_types_all and transformers_version < Version("4.50.0.dev0"):
raise RuntimeError("Unsloth: Cohere's Command model only works on transformers >= 4.50.0." + NIGHTLY)
# Sesame
elif "csm-1b" in lowered_model_name:
elif "csm" in model_types_all:
os.environ["UNSLOTH_COMPILE_DISABLE"] = "1" # Inference is too slow
os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1" # Sesame fails
os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] = \
"all;torch.float32;torch.float16;"\
"if name.endswith(('_proj', 'fc1', 'fc2', 'codebook', 'head')): module.to(torch.float16)"\
";"
# Granite 4
elif 'granite-4' in lowered_model_name:
elif 'granitemoehybrid' in model_types_all:
# Granite-4 rms norms are stored as 16 bit, but we upcast
os.environ["UNSLOTH_HIGH_PRECISION_LAYERNORM"] = "1"
os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1"
# Olmo 2
elif "olmo-2" in lowered_model_name and transformers_version < Version("4.50.0.dev0"):
elif "olmo2" in model_types_all and transformers_version < Version("4.50.0.dev0"):
raise RuntimeError("Unsloth: OLMo-2 only works on transformers >= 4.50.0." + NIGHTLY)
elif "falcon-h1" in lowered_model_name:
elif "falcon_h1" in model_types_all:
# Falcon must use float32 Triton ie TRITON_F32_DEFAULT = 'ieee'
# since Mamba kernels error out on using lower precision
os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] = \
"float16;torch.float32;torch.float16;"\
"if name.endswith(('q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj', 'head')): module.to(torch.float16)"\
";"\
"os.environ['TRITON_F32_DEFAULT'] = 'ieee'"
elif "gpt-oss" in lowered_model_name:
elif "gpt_oss" in model_types_all:
os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1"
if not load_in_4bit:
# Only upcast MoE biases for MXFP4, not BnB
Expand Down Expand Up @@ -675,44 +722,6 @@ def from_pretrained(
os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1"
pass

if USE_MODELSCOPE and not os.path.exists(model_name):
from modelscope import snapshot_download
model_name = snapshot_download(model_name)
pass

# First check if it's a normal model via AutoConfig
from huggingface_hub.utils import disable_progress_bars, enable_progress_bars, are_progress_bars_disabled
was_disabled = are_progress_bars_disabled()
disable_progress_bars()

autoconfig_error = None
peft_error = None
try:
model_config = AutoConfig.from_pretrained(
model_name,
token = token,
revision = revision,
trust_remote_code = trust_remote_code,
)
is_model = True
except Exception as error:
autoconfig_error = str(error)
is_model = False
try:
peft_config = PeftConfig.from_pretrained(
model_name,
token = token,
revision = revision,
trust_remote_code = trust_remote_code,
)
is_peft = True
except Exception as error:
peft_error = str(error)
is_peft = False
pass

# Both config.json and adapter_config.json should not exist!

# Old transformers versions check
both_exist = (is_model and is_peft) and not SUPPORTS_LLAMA32

Expand Down Expand Up @@ -782,24 +791,16 @@ def from_pretrained(
else:
redirector = contextlib.redirect_stdout(open(os.devnull, "w"))

# Get model types like Gemma3 etc
model_types = get_transformers_model_type(
model_name = model_name,
token = token,
revision = revision,
trust_remote_code = trust_remote_code,
)
model_types = ["siglip"] + model_types

# Set forced float32 env flag
os.environ["UNSLOTH_FORCE_FLOAT32"] = "0"
do_forced_float32 = False
for model_type_arch in model_types:
if model_type_arch != "siglip": break
global FORCE_FLOAT32
for disable_name in FORCE_FLOAT32:
if (disable_name.lower() == model_type_arch.lower().replace("-", "_") or \
disable_name.lower() in model_name.lower()) and \
if (disable_name.lower() == model_type_arch.lower().replace("-", "").replace("_", "") or \
disable_name.lower() in model_types_all) and \
((dtype == torch.float16) or not SUPPORTS_BFLOAT16):
os.environ["UNSLOTH_FORCE_FLOAT32"] = "1"
dtype = torch.bfloat16 # Change to bfloat16 loading
Expand All @@ -808,7 +809,6 @@ def from_pretrained(
# Patch gradient checkpointing
if use_gradient_checkpointing == "unsloth":
patch_unsloth_smart_gradient_checkpointing(dtype = dtype)

with redirector:
patch_loss_functions(torch_compile = False)
model_types, supports_sdpa = unsloth_compile_transformers(
Expand Down Expand Up @@ -845,7 +845,7 @@ def from_pretrained(
)
pass
# Fix SDPA
if "gemma-3n" in lowered_model_name:
if "gemma3n" in model_types_all:
supports_sdpa = False
pass

Expand Down
42 changes: 22 additions & 20 deletions unsloth/models/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
}

from trl import __version__ as trl_version
from unsloth_zoo.utils import Version
trl_version = Version(trl_version)

def vLLMSamplingParams(**kwargs):
from vllm import SamplingParams
Expand Down Expand Up @@ -804,7 +806,7 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import
" " * 12 + "if (getattr(args, 'use_vllm', False) == False):\n" + \
" " * 16 + "args.use_vllm = True\n"

if "grpo" in trainer_file and trl_version >= "0.18":
if "grpo" in trainer_file and trl_version >= Version("0.18.0"):
# If model has vllm_engine, then use vllm in colocate mode. Donot wait for server
vllm_setter += \
" " * 12 + "args.vllm_mode='colocate'\n"
Expand Down Expand Up @@ -850,26 +852,27 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import
sampling_params # Add spaces

# count the indentation of last line of sampling_params.
last_line = sampling_params.split("\n")[-1]
last_prev_line = sampling_params.split("\n")[-2]
last_prev_indentation = len(last_prev_line) - len(last_prev_line.lstrip())
last_indentation = len(last_line) - len(last_line.lstrip())


# Add extra arguments to SamplingParams
extra = "**getattr(getattr(args, 'vllm_sampling_params', vLLMSamplingParams()), '_set_kwargs', {})"
# Backwards replace
to_replace = ",\n" + " "*last_prev_indentation + extra + ",\n" + " "*last_indentation + ")"
sampling_params = to_replace.join(sampling_params.rsplit(")", 1))
# Strip multiple commas
sampling_params = re.sub(r"[\,][\s]{0,}\,", ",", sampling_params)

new_vllm_part = \
f"\n{' '*8}if {args}.use_vllm:\n{sampling_params}"\
f"\n{' '*8}else:\n"
splitted_sampling_params = sampling_params.split("\n")
if len(splitted_sampling_params) >= 2:
last_line = splitted_sampling_params[-1]
last_prev_line = splitted_sampling_params[-2]
last_prev_indentation = len(last_prev_line) - len(last_prev_line.lstrip())
last_indentation = len(last_line) - len(last_line.lstrip())

# Add extra arguments to SamplingParams
extra = "**getattr(getattr(args, 'vllm_sampling_params', vLLMSamplingParams()), '_set_kwargs', {})"
# Backwards replace
to_replace = ",\n" + " "*last_prev_indentation + extra + ",\n" + " "*last_indentation + ")"
sampling_params = to_replace.join(sampling_params.rsplit(")", 1))
# Strip multiple commas
sampling_params = re.sub(r"[\,][\s]{0,}\,", ",", sampling_params)

new_vllm_part = \
f"\n{' '*8}if {args}.use_vllm:\n{sampling_params}"\
f"\n{' '*8}else:\n"
pass

if trl_version >= "0.18":
if trl_version >= Version("0.18.0"):
# Replace LLM init with already existing vLLM engine for colocate mode
vllm_llm_init_pattern = r"self\.llm\s*=\s*LLM\(.*?\)*\)\s*?\n(?!,)"
vllm_llm_replacement = "self.llm = model.vllm_engine\n"
Expand All @@ -881,7 +884,6 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import
)

init = init.replace(vllm_part, new_vllm_part)

pass

# Search for vLLM calling in all child functions
Expand Down