Skip to content

Commit 7c21e6e

Browse files
committed
Update test.py to fix issue #11
1 parent 6238eed commit 7c21e6e

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

test.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from pipeline.stablevsr_pipeline import StableVSRPipeline
2-
from diffusers import DDPMScheduler
2+
from diffusers import DDPMScheduler, ControlNetModel
33
from accelerate.utils import set_seed
44
from PIL import Image
55
import os
@@ -21,6 +21,7 @@ def center_crop(im, size=128):
2121
parser.add_argument("--out_path", default='./StableVSR_results/', type=str, help="Path to output folder.")
2222
parser.add_argument("--in_path", type=str, required=True, help="Path to input folder (containing sets of LR images).")
2323
parser.add_argument("--num_inference_steps", type=int, default=50, help="Number of sampling steps")
24+
parser.add_argument("--controlnet_ckpt", type=str, default=None, help="Path to your folder with the controlnet checkpoint.")
2425
args = parser.parse_args()
2526

2627
print("Run with arguments:")
@@ -31,7 +32,8 @@ def center_crop(im, size=128):
3132
set_seed(42)
3233
device = torch.device('cuda')
3334
model_id = 'claudiom4sir/StableVSR'
34-
pipeline = StableVSRPipeline.from_pretrained(model_id)
35+
controlnet_model = ControlNetModel.from_pretrained(args.controlnet_ckpt if args.controlnet_ckpt is not None else model_id, subfolder='controlnet') # your own controlnet model
36+
pipeline = StableVSRPipeline.from_pretrained(model_id, controlnet=controlnet_model)
3537
scheduler = DDPMScheduler.from_pretrained(model_id, subfolder='scheduler')
3638
pipeline.scheduler = scheduler
3739
pipeline = pipeline.to(device)

0 commit comments

Comments
 (0)