Skip to content

Commit 74f278d

Browse files
patrickvonplatenJimmy
authored andcommitted
[Sigmas] Keep sigmas on CPU (huggingface#6173)
* correct * Apply suggestions from code review * make style
1 parent b4c0d52 commit 74f278d

13 files changed

+26
-0
lines changed

src/diffusers/schedulers/scheduling_consistency_models.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def __init__(
9898
self.custom_timesteps = False
9999
self.is_scale_input_called = False
100100
self._step_index = None
101+
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
101102

102103
def index_for_timestep(self, timestep, schedule_timesteps=None):
103104
if schedule_timesteps is None:
@@ -230,6 +231,7 @@ def set_timesteps(
230231
self.timesteps = torch.from_numpy(timesteps).to(device=device)
231232

232233
self._step_index = None
234+
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
233235

234236
# Modified _convert_to_karras implementation that takes in ramp as argument
235237
def _convert_to_karras(self, ramp):

src/diffusers/schedulers/scheduling_deis_multistep.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ def __init__(
187187
self.model_outputs = [None] * solver_order
188188
self.lower_order_nums = 0
189189
self._step_index = None
190+
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
190191

191192
@property
192193
def step_index(self):
@@ -254,6 +255,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
254255

255256
# add an index counter for schedulers that allow duplicated timesteps
256257
self._step_index = None
258+
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
257259

258260
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
259261
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:

src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ def __init__(
214214
self.model_outputs = [None] * solver_order
215215
self.lower_order_nums = 0
216216
self._step_index = None
217+
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
217218

218219
@property
219220
def step_index(self):
@@ -290,6 +291,7 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
290291

291292
# add an index counter for schedulers that allow duplicated timesteps
292293
self._step_index = None
294+
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
293295

294296
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
295297
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:

src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ def __init__(
209209
self.model_outputs = [None] * solver_order
210210
self.lower_order_nums = 0
211211
self._step_index = None
212+
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
212213
self.use_karras_sigmas = use_karras_sigmas
213214

214215
@property
@@ -289,6 +290,7 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
289290

290291
# add an index counter for schedulers that allow duplicated timesteps
291292
self._step_index = None
293+
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
292294

293295
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
294296
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:

src/diffusers/schedulers/scheduling_dpmsolver_sde.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ def __init__(
198198
self.noise_sampler = None
199199
self.noise_sampler_seed = noise_sampler_seed
200200
self._step_index = None
201+
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
201202

202203
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
203204
def index_for_timestep(self, timestep, schedule_timesteps=None):
@@ -347,6 +348,7 @@ def set_timesteps(
347348
self.mid_point_sigma = None
348349

349350
self._step_index = None
351+
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
350352
self.noise_sampler = None
351353

352354
# for exp beta schedules, such as the one for `pipeline_shap_e.py`

src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ def __init__(
197197
self.sample = None
198198
self.order_list = self.get_order_list(num_train_timesteps)
199199
self._step_index = None
200+
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
200201

201202
def get_order_list(self, num_inference_steps: int) -> List[int]:
202203
"""
@@ -288,6 +289,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
288289

289290
# add an index counter for schedulers that allow duplicated timesteps
290291
self._step_index = None
292+
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
291293

292294
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
293295
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:

src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ def __init__(
166166
self.is_scale_input_called = False
167167

168168
self._step_index = None
169+
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
169170

170171
@property
171172
def init_noise_sigma(self):
@@ -249,6 +250,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
249250

250251
self.timesteps = torch.from_numpy(timesteps).to(device=device)
251252
self._step_index = None
253+
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
252254

253255
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
254256
def _init_step_index(self, timestep):

src/diffusers/schedulers/scheduling_euler_discrete.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ def __init__(
237237
self.use_karras_sigmas = use_karras_sigmas
238238

239239
self._step_index = None
240+
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
240241

241242
@property
242243
def init_noise_sigma(self):
@@ -341,6 +342,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
341342

342343
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
343344
self._step_index = None
345+
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
344346

345347
def _sigma_to_t(self, sigma, log_sigmas):
346348
# get log sigma

src/diffusers/schedulers/scheduling_heun_discrete.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ def __init__(
148148
self.use_karras_sigmas = use_karras_sigmas
149149

150150
self._step_index = None
151+
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
151152

152153
def index_for_timestep(self, timestep, schedule_timesteps=None):
153154
if schedule_timesteps is None:
@@ -269,6 +270,7 @@ def set_timesteps(
269270
self.dt = None
270271

271272
self._step_index = None
273+
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
272274

273275
# (YiYi Notes: keep this for now since we are keeping add_noise function which use index_for_timestep)
274276
# for exp beta schedules, such as the one for `pipeline_shap_e.py`

src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def __init__(
140140
# set all values
141141
self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
142142
self._step_index = None
143+
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
143144

144145
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
145146
def index_for_timestep(self, timestep, schedule_timesteps=None):
@@ -295,6 +296,7 @@ def set_timesteps(
295296
self._index_counter = defaultdict(int)
296297

297298
self._step_index = None
299+
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
298300

299301
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
300302
def _sigma_to_t(self, sigma, log_sigmas):

0 commit comments

Comments
 (0)