Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion unsloth_zoo/empty_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def create_empty_causal_lm(config, dtype = torch.float16):
# Suppress warning on uninited weights
old_warn = os.environ.get("UNSLOTH_WARN_UNINITIALIZED", "1")
os.environ["UNSLOTH_WARN_UNINITIALIZED"] = "0"
model_name = getattr(config, 'model_name')
model_name = getattr(config, 'model_name', None)
kwargs = {"torch_dtype" if HAS_TORCH_DTYPE else "dtype" : dtype_from_config(config)}
original_meta_model = None
error = None
Expand Down
3 changes: 1 addition & 2 deletions unsloth_zoo/rl_replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,9 +559,8 @@ def grpo_accumulated_loss(
image_grid_thw = kwargs.get('image_grid_thw',None)
pixel_attention_mask = kwargs.get('pixel_attention_mask',None)
image_sizes = kwargs.get('image_sizes',None)
sampling_per_token_logps = kwargs.get("sampling_per_token_logps", None)
#delete this from kwargs so less issues

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The refactoring to use kwargs.pop is a great improvement for both conciseness and correctness. The comment on this line is now redundant and can be removed as the code is self-explanatory.

del kwargs["sampling_per_token_logps"]
sampling_per_token_logps = kwargs.pop("sampling_per_token_logps", None)
kwargs["vllm_importance_sampling_cap"] = trainer.vllm_importance_sampling_cap if sampling_per_token_logps is not None else None
kwargs["use_vllm"] = trainer.use_vllm
# Find closest multiple
Expand Down
18 changes: 17 additions & 1 deletion unsloth_zoo/vllm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,8 +818,24 @@ def vllm_dynamic_quant_supported(
pass


@torch.inference_mode
def get_vllm_state_dict(llm, return_state_dict = False, config = None, is_vision_model = False):
# If the vllm state dict was quantized using torchao, we will run into
# the following error when calling ops like aten.t() in inference mode.
# This is a bug in PyTorch that affects all tensor subclasses.
#
# Cannot set version_counter for inference tensor
#
# For now, we work around this issue by using torch.no_grad in this case.
# See https://github.com/pytorch/pytorch/issues/164872 for more details
if get_quant_type(config) == "torchao":
ctx_manager = torch.no_grad()
else:
ctx_manager = torch.inference_mode()
with ctx_manager:
return _get_vllm_state_dict(llm, return_state_dict, config, is_vision_model)


def _get_vllm_state_dict(llm, return_state_dict = False, config = None, is_vision_model = False):
# All Unsloth Zoo code licensed under LGPLv3
# Unmerges vLLM modules and returns HF equivalent state_dict
# vllm_state_dict = {}
Expand Down