Skip to content

Fix: Move RoPE tensors to right devices#2862

Closed
Datta0 wants to merge 8 commits intounslothai:mainfrom
Datta0:multigpu_inference
Closed

Fix: Move RoPE tensors to right devices#2862
Datta0 wants to merge 8 commits intounslothai:mainfrom
Datta0:multigpu_inference

Conversation

@Datta0
Copy link
Copy Markdown
Collaborator

@Datta0 Datta0 commented Jul 2, 2025

Multi GPU inference fails as some tensors are explicitly set to GPU 0 and sometimes RoPE values end up on the wrong GPU. This PR fixes that.

@Datta0 Datta0 requested a review from danielhanchen July 2, 2025 15:45
@danielhanchen
Copy link
Copy Markdown
Contributor

Ok your solution is fine, but moving to CPU will make things slower. We might have to replicate cos / sin on each GPU and make a tuple indexer

@danielhanchen
Copy link
Copy Markdown
Contributor

danielhanchen commented Jul 3, 2025

Ie:

cos = coses[X.device.index]
sin = sines[X.device.index]

or something

@Datta0
Copy link
Copy Markdown
Collaborator Author

Datta0 commented Jul 3, 2025

Ok your solution is fine, but moving to CPU will make things slower. We might have to replicate cos / sin on each GPU and make a tuple indexer

I'm only moving position_ids to CPU cuz they are indices and can be on CPU. We can move them to the right GPU as well.

And for cos, sin, we're moving them directly between GPUs, no explicit CPU calls by us.

The reason why I didn't replicate across GPUs was I wanted to be lean on the memory

@danielhanchen
Copy link
Copy Markdown
Contributor

I think it's best to use a tuple and not a dict

next_decoder_cache = []
for idx, decoder_layer in enumerate(self.model.layers):
decoder_device = decoder_layer.self_attn.q_proj.weight.device
hidden_states, out_weight, position_ids = move_to_device(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait wait we shouldnt need to move these tensors right?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we're mostly using PP. Which means some layers are on GPU0 and some on GPU1
The inputs and/or these tensors don't seem to move to the 2nd GPU automatically. It has to be done explicitly.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also we do have a check inside the move_to_device to ensure that we aren't moving unnecessarily.

This means we can pass in a row of Q, but we need to
remember K and V, which are called the KV cache.
"""
if position_ids is not None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here

@Datta0 Datta0 force-pushed the multigpu_inference branch from 606c451 to 324b392 Compare July 10, 2025 05:18
@Datta0 Datta0 changed the title Fix: Move tensors to right devices Fix: Move RoPE tensors to right devices Jul 10, 2025
@Datta0
Copy link
Copy Markdown
Collaborator Author

Datta0 commented Jul 11, 2025

Closing this as this is handled in #2919

@Datta0 Datta0 closed this Jul 11, 2025
@Datta0 Datta0 deleted the multigpu_inference branch July 26, 2025 04:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants