Skip to content

Commit 3c68a50

Browse files
author
saehoon sam kim
authored
Merge pull request #14 from jmkim0309/hotfix/ts_cpu
Hotfix/ts cpu
2 parents d5e5fc7 + 64291f8 commit 3c68a50

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

demo/product_demo.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def __init__(
8989
def default_parser():
9090
parser = argparse.ArgumentParser()
9191
parser.add_argument("--root-dir", type=str, default=None)
92-
parser.add_argument("--max_bsz", type=int, default=4)
92+
parser.add_argument("--max_bsz", type=int, default=1)
9393
parser.add_argument(
9494
"--progressive", type=str, default="loop", choices=("loop", "stage", "final")
9595
)

karlo/modules/diffusion/respace.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,9 @@ def condition_score(self, cond_fn, *args, **kwargs):
104104

105105
def _wrap_model(self, model):
106106
def wrapped(x, ts, **kwargs):
107+
ts_cpu = ts.detach().to("cpu")
107108
return model(
108-
x, self.timestep_map[ts].to(device=ts.device, dtype=ts.dtype), **kwargs
109+
x, self.timestep_map[ts_cpu].to(device=ts.device, dtype=ts.dtype), **kwargs
109110
)
110111

111112
return wrapped

0 commit comments

Comments
 (0)