Skip to content

Commit 28904ed

Browse files
woshiyyyaJimmy
authored andcommitted
[DDPMScheduler] Load alpha_cumprod to device to avoid redundant data movement. (huggingface#6704)
* load cumprod tensor to device Signed-off-by: woshiyyya <[email protected]> * fixing ci Signed-off-by: woshiyyya <[email protected]> * make fix-copies Signed-off-by: woshiyyya <[email protected]> --------- Signed-off-by: woshiyyya <[email protected]>
1 parent fc36e14 commit 28904ed

File tree

8 files changed

+42
-13
lines changed

8 files changed

+42
-13
lines changed

src/diffusers/schedulers/scheduling_ddim.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,10 @@ def add_noise(
477477
timesteps: torch.IntTensor,
478478
) -> torch.FloatTensor:
479479
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
480-
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
480+
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
481+
# for the subsequent add_noise calls
482+
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
483+
alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)
481484
timesteps = timesteps.to(original_samples.device)
482485

483486
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
@@ -498,7 +501,8 @@ def get_velocity(
498501
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
499502
) -> torch.FloatTensor:
500503
# Make sure alphas_cumprod and timestep have same device and dtype as sample
501-
alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
504+
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
505+
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
502506
timesteps = timesteps.to(sample.device)
503507

504508
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5

src/diffusers/schedulers/scheduling_ddim_parallel.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -602,7 +602,10 @@ def add_noise(
602602
timesteps: torch.IntTensor,
603603
) -> torch.FloatTensor:
604604
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
605-
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
605+
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
606+
# for the subsequent add_noise calls
607+
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
608+
alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)
606609
timesteps = timesteps.to(original_samples.device)
607610

608611
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
@@ -623,7 +626,8 @@ def get_velocity(
623626
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
624627
) -> torch.FloatTensor:
625628
# Make sure alphas_cumprod and timestep have same device and dtype as sample
626-
alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
629+
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
630+
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
627631
timesteps = timesteps.to(sample.device)
628632

629633
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5

src/diffusers/schedulers/scheduling_ddpm.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,10 @@ def add_noise(
503503
timesteps: torch.IntTensor,
504504
) -> torch.FloatTensor:
505505
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
506-
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
506+
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
507+
# for the subsequent add_noise calls
508+
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
509+
alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)
507510
timesteps = timesteps.to(original_samples.device)
508511

509512
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
@@ -523,7 +526,8 @@ def get_velocity(
523526
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
524527
) -> torch.FloatTensor:
525528
# Make sure alphas_cumprod and timestep have same device and dtype as sample
526-
alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
529+
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
530+
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
527531
timesteps = timesteps.to(sample.device)
528532

529533
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5

src/diffusers/schedulers/scheduling_ddpm_parallel.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,10 @@ def add_noise(
594594
timesteps: torch.IntTensor,
595595
) -> torch.FloatTensor:
596596
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
597-
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
597+
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
598+
# for the subsequent add_noise calls
599+
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
600+
alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)
598601
timesteps = timesteps.to(original_samples.device)
599602

600603
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
@@ -615,7 +618,8 @@ def get_velocity(
615618
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
616619
) -> torch.FloatTensor:
617620
# Make sure alphas_cumprod and timestep have same device and dtype as sample
618-
alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
621+
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
622+
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
619623
timesteps = timesteps.to(sample.device)
620624

621625
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5

src/diffusers/schedulers/scheduling_lcm.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -575,7 +575,10 @@ def add_noise(
575575
timesteps: torch.IntTensor,
576576
) -> torch.FloatTensor:
577577
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
578-
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
578+
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
579+
# for the subsequent add_noise calls
580+
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
581+
alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)
579582
timesteps = timesteps.to(original_samples.device)
580583

581584
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
@@ -596,7 +599,8 @@ def get_velocity(
596599
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
597600
) -> torch.FloatTensor:
598601
# Make sure alphas_cumprod and timestep have same device and dtype as sample
599-
alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
602+
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
603+
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
600604
timesteps = timesteps.to(sample.device)
601605

602606
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5

src/diffusers/schedulers/scheduling_pndm.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,10 @@ def add_noise(
455455
timesteps: torch.IntTensor,
456456
) -> torch.FloatTensor:
457457
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
458-
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
458+
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
459+
# for the subsequent add_noise calls
460+
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
461+
alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)
459462
timesteps = timesteps.to(original_samples.device)
460463

461464
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5

src/diffusers/schedulers/scheduling_sasolver.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1069,7 +1069,10 @@ def add_noise(
10691069
timesteps: torch.IntTensor,
10701070
) -> torch.FloatTensor:
10711071
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
1072-
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
1072+
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
1073+
# for the subsequent add_noise calls
1074+
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
1075+
alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)
10731076
timesteps = timesteps.to(original_samples.device)
10741077

10751078
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5

src/diffusers/schedulers/scheduling_unclip.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,10 @@ def add_noise(
332332
timesteps: torch.IntTensor,
333333
) -> torch.FloatTensor:
334334
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
335-
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
335+
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
336+
# for the subsequent add_noise calls
337+
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
338+
alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)
336339
timesteps = timesteps.to(original_samples.device)
337340

338341
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5

0 commit comments

Comments
 (0)