Skip to content

Patch decoder layers to have field for device#198

Merged
danielhanchen merged 3 commits into
unslothai:mainfrom
Datta0:multigpu_inputs
Jul 10, 2025
Merged

Patch decoder layers to have field for device#198
danielhanchen merged 3 commits into
unslothai:mainfrom
Datta0:multigpu_inputs

Conversation

@Datta0
Copy link
Copy Markdown
Collaborator

@Datta0 Datta0 commented Jul 10, 2025

To patch PP, we need have info of which layer is on what device.

Needed for: unslothai/unsloth#2919

Comment thread unsloth_zoo/patching_utils.py Outdated
"""
Verify that all parameters of a module are on the same device.
"""
set_of_devices = set([x.device for x in module.parameters()])
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this be a generator set() and not set([])

def get_model(model):
found_layers = False
x = model
while True:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we instead find torch.nn.ModuleList or something with layers as the name maybe via model.named_modules() maybe

@danielhanchen danielhanchen merged commit e635503 into unslothai:main Jul 10, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants