Skip to content

Commit 300acc1

Browse files
committed
minor fix to remove runtime error
1 parent c0079cb commit 300acc1

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

karlo/modules/diffusion/respace.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,9 @@ def condition_score(self, cond_fn, *args, **kwargs):
104104

105105
def _wrap_model(self, model):
106106
def wrapped(x, ts, **kwargs):
107+
ts_cpu = ts.detach().to("cpu")
107108
return model(
108-
x, self.timestep_map[ts].to(device=ts.device, dtype=ts.dtype), **kwargs
109+
x, self.timestep_map[ts_cpu].to(device=ts.device, dtype=ts.dtype), **kwargs
109110
)
110111

111112
return wrapped

0 commit comments

Comments
 (0)