diff --git a/unsloth/models/gemma.py b/unsloth/models/gemma.py index cf543ae094..ce635bed21 100644 --- a/unsloth/models/gemma.py +++ b/unsloth/models/gemma.py @@ -300,16 +300,16 @@ def __init__( for device in range(DEVICE_COUNT): self._set_cos_sin_cache( seq_len = self.current_rope_size, - device = torch.device(device), + device = torch.device(DEVICE_TYPE_TORCH, device), dtype = torch.get_default_dtype(), ) # dummy so that patch_utils doesn't fail for now self.cos_cached = torch.empty( - 1, device = torch.cuda.current_device(), dtype = torch.get_default_dtype() + 1, device = get_current_device(), dtype = torch.get_default_dtype() ) self.sin_cached = torch.empty( - 1, device = torch.cuda.current_device(), dtype = torch.get_default_dtype() + 1, device = get_current_device(), dtype = torch.get_default_dtype() ) def _set_cos_sin_cache(self, seq_len, device, dtype): @@ -350,7 +350,7 @@ def forward(self, x, position_ids = None, seq_len = None): def get_cached(self, seq_len = None, device_index = None): if device_index is None: - device_index = torch.cuda.current_device() + device_index = get_current_device() return self.multi_gpu_cos_cached[device_index], self.multi_gpu_sin_cached[ device_index ] @@ -362,7 +362,9 @@ def extend_rope_embedding(self, x, seq_len): self.current_rope_size = math.ceil(seq_len / 8192) * 8192 for device in range(DEVICE_COUNT): self._set_cos_sin_cache( - self.current_rope_size, device = torch.device(device), dtype = x.dtype + self.current_rope_size, + device = torch.device(DEVICE_TYPE_TORCH, device), + dtype = x.dtype, ) @@ -489,5 +491,5 @@ def post_patch(model, tokenizer, correct_dtype = None): for _ in range(3): gc.collect() - torch.cuda.empty_cache() + clean_gpu_cache() return model, tokenizer