-
Notifications
You must be signed in to change notification settings - Fork 5.7k
/
Copy pathtest_pipelines_common.py
2192 lines (1784 loc) · 95.9 KB
/
test_pipelines_common.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import gc
import inspect
import json
import os
import tempfile
import unittest
import uuid
from typing import Any, Callable, Dict, Union
import numpy as np
import PIL.Image
import torch
import torch.nn as nn
from huggingface_hub import ModelCard, delete_repo
from huggingface_hub.utils import is_jinja_available
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
import diffusers
from diffusers import (
AsymmetricAutoencoderKL,
AutoencoderKL,
AutoencoderTiny,
ConsistencyDecoderVAE,
DDIMScheduler,
DiffusionPipeline,
KolorsPipeline,
StableDiffusionPipeline,
StableDiffusionXLPipeline,
UNet2DConditionModel,
)
from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import IPAdapterMixin
from diffusers.models.attention_processor import AttnProcessor
from diffusers.models.controlnets.controlnet_xs import UNetControlNetXSModel
from diffusers.models.unets.unet_3d_condition import UNet3DConditionModel
from diffusers.models.unets.unet_i2vgen_xl import I2VGenXLUNet
from diffusers.models.unets.unet_motion_model import UNetMotionModel
from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import logging
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import (
CaptureLogger,
require_accelerate_version_greater,
require_accelerator,
require_torch,
skip_mps,
torch_device,
)
from ..models.autoencoders.vae import (
get_asym_autoencoder_kl_config,
get_autoencoder_kl_config,
get_autoencoder_tiny_config,
get_consistency_vae_config,
)
from ..models.unets.test_models_unet_2d_condition import (
create_ip_adapter_faceid_state_dict,
create_ip_adapter_state_dict,
)
from ..others.test_utils import TOKEN, USER, is_staging_test
def to_np(tensor):
if isinstance(tensor, torch.Tensor):
tensor = tensor.detach().cpu().numpy()
return tensor
def check_same_shape(tensor_list):
shapes = [tensor.shape for tensor in tensor_list]
return all(shape == shapes[0] for shape in shapes[1:])
def check_qkv_fusion_matches_attn_procs_length(model, original_attn_processors):
current_attn_processors = model.attn_processors
return len(current_attn_processors) == len(original_attn_processors)
def check_qkv_fusion_processors_exist(model):
current_attn_processors = model.attn_processors
proc_names = [v.__class__.__name__ for _, v in current_attn_processors.items()]
return all(p.startswith("Fused") for p in proc_names)
class SDFunctionTesterMixin:
"""
This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes.
It provides a set of common tests for PyTorch pipeline that inherit from StableDiffusionMixin, e.g. vae_slicing, vae_tiling, freeu, etc.
"""
def test_vae_slicing(self, image_count=4):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
# components["scheduler"] = LMSDiscreteScheduler.from_config(components["scheduler"].config)
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["prompt"] = [inputs["prompt"]] * image_count
if "image" in inputs: # fix batch size mismatch in I2V_Gen pipeline
inputs["image"] = [inputs["image"]] * image_count
output_1 = pipe(**inputs)
# make sure sliced vae decode yields the same result
pipe.enable_vae_slicing()
inputs = self.get_dummy_inputs(device)
inputs["prompt"] = [inputs["prompt"]] * image_count
if "image" in inputs:
inputs["image"] = [inputs["image"]] * image_count
inputs["return_dict"] = False
output_2 = pipe(**inputs)
assert np.abs(output_2[0].flatten() - output_1[0].flatten()).max() < 1e-2
def test_vae_tiling(self):
components = self.get_dummy_components()
# make sure here that pndm scheduler skips prk
if "safety_checker" in components:
components["safety_checker"] = None
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
inputs["return_dict"] = False
# Test that tiled decode at 512x512 yields the same result as the non-tiled decode
output_1 = pipe(**inputs)[0]
# make sure tiled vae decode yields the same result
pipe.enable_vae_tiling()
inputs = self.get_dummy_inputs(torch_device)
inputs["return_dict"] = False
output_2 = pipe(**inputs)[0]
assert np.abs(to_np(output_2) - to_np(output_1)).max() < 5e-1
# test that tiled decode works with various shapes
shapes = [(1, 4, 73, 97), (1, 4, 65, 49)]
with torch.no_grad():
for shape in shapes:
zeros = torch.zeros(shape).to(torch_device)
pipe.vae.decode(zeros)
# MPS currently doesn't support ComplexFloats, which are required for FreeU - see https://github.com/huggingface/diffusers/issues/7569.
@skip_mps
def test_freeu(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
# Normal inference
inputs = self.get_dummy_inputs(torch_device)
inputs["return_dict"] = False
inputs["output_type"] = "np"
output = pipe(**inputs)[0]
# FreeU-enabled inference
pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4)
inputs = self.get_dummy_inputs(torch_device)
inputs["return_dict"] = False
inputs["output_type"] = "np"
output_freeu = pipe(**inputs)[0]
# FreeU-disabled inference
pipe.disable_freeu()
freeu_keys = {"s1", "s2", "b1", "b2"}
for upsample_block in pipe.unet.up_blocks:
for key in freeu_keys:
assert getattr(upsample_block, key) is None, f"Disabling of FreeU should have set {key} to None."
inputs = self.get_dummy_inputs(torch_device)
inputs["return_dict"] = False
inputs["output_type"] = "np"
output_no_freeu = pipe(**inputs)[0]
assert not np.allclose(
output[0, -3:, -3:, -1], output_freeu[0, -3:, -3:, -1]
), "Enabling of FreeU should lead to different results."
assert np.allclose(
output, output_no_freeu, atol=1e-2
), f"Disabling of FreeU should lead to results similar to the default pipeline results but Max Abs Error={np.abs(output_no_freeu - output).max()}."
def test_fused_qkv_projections(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["return_dict"] = False
image = pipe(**inputs)[0]
original_image_slice = image[0, -3:, -3:, -1]
pipe.fuse_qkv_projections()
for _, component in pipe.components.items():
if (
isinstance(component, nn.Module)
and hasattr(component, "original_attn_processors")
and component.original_attn_processors is not None
):
assert check_qkv_fusion_processors_exist(
component
), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
assert check_qkv_fusion_matches_attn_procs_length(
component, component.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
inputs = self.get_dummy_inputs(device)
inputs["return_dict"] = False
image_fused = pipe(**inputs)[0]
image_slice_fused = image_fused[0, -3:, -3:, -1]
pipe.unfuse_qkv_projections()
inputs = self.get_dummy_inputs(device)
inputs["return_dict"] = False
image_disabled = pipe(**inputs)[0]
image_slice_disabled = image_disabled[0, -3:, -3:, -1]
assert np.allclose(
original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2
), "Fusion of QKV projections shouldn't affect the outputs."
assert np.allclose(
image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
assert np.allclose(
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
), "Original outputs should match when fused QKV projections are disabled."
class IPAdapterTesterMixin:
"""
This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes.
It provides a set of common tests for pipelines that support IP Adapters.
"""
def test_pipeline_signature(self):
parameters = inspect.signature(self.pipeline_class.__call__).parameters
assert issubclass(self.pipeline_class, IPAdapterMixin)
self.assertIn(
"ip_adapter_image",
parameters,
"`ip_adapter_image` argument must be supported by the `__call__` method",
)
self.assertIn(
"ip_adapter_image_embeds",
parameters,
"`ip_adapter_image_embeds` argument must be supported by the `__call__` method",
)
def _get_dummy_image_embeds(self, cross_attention_dim: int = 32):
return torch.randn((2, 1, cross_attention_dim), device=torch_device)
def _get_dummy_faceid_image_embeds(self, cross_attention_dim: int = 32):
return torch.randn((2, 1, 1, cross_attention_dim), device=torch_device)
def _get_dummy_masks(self, input_size: int = 64):
_masks = torch.zeros((1, 1, input_size, input_size), device=torch_device)
_masks[0, :, :, : int(input_size / 2)] = 1
return _masks
def _modify_inputs_for_ip_adapter_test(self, inputs: Dict[str, Any]):
parameters = inspect.signature(self.pipeline_class.__call__).parameters
if "image" in parameters.keys() and "strength" in parameters.keys():
inputs["num_inference_steps"] = 4
inputs["output_type"] = "np"
inputs["return_dict"] = False
return inputs
def test_ip_adapter(self, expected_max_diff: float = 1e-4, expected_pipe_slice=None):
r"""Tests for IP-Adapter.
The following scenarios are tested:
- Single IP-Adapter with scale=0 should produce same output as no IP-Adapter.
- Multi IP-Adapter with scale=0 should produce same output as no IP-Adapter.
- Single IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter.
- Multi IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter.
"""
# Raising the tolerance for this test when it's run on a CPU because we
# compare against static slices and that can be shaky (with a VVVV low probability).
expected_max_diff = 9e-4 if torch_device == "cpu" else expected_max_diff
components = self.get_dummy_components()
pipe = self.pipeline_class(**components).to(torch_device)
pipe.set_progress_bar_config(disable=None)
cross_attention_dim = pipe.unet.config.get("cross_attention_dim", 32)
# forward pass without ip adapter
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
if expected_pipe_slice is None:
output_without_adapter = pipe(**inputs)[0]
else:
output_without_adapter = expected_pipe_slice
# 1. Single IP-Adapter test cases
adapter_state_dict = create_ip_adapter_state_dict(pipe.unet)
pipe.unet._load_ip_adapter_weights(adapter_state_dict)
# forward pass with single ip adapter, but scale=0 which should have no effect
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
pipe.set_ip_adapter_scale(0.0)
output_without_adapter_scale = pipe(**inputs)[0]
if expected_pipe_slice is not None:
output_without_adapter_scale = output_without_adapter_scale[0, -3:, -3:, -1].flatten()
# forward pass with single ip adapter, but with scale of adapter weights
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
pipe.set_ip_adapter_scale(42.0)
output_with_adapter_scale = pipe(**inputs)[0]
if expected_pipe_slice is not None:
output_with_adapter_scale = output_with_adapter_scale[0, -3:, -3:, -1].flatten()
max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max()
max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max()
self.assertLess(
max_diff_without_adapter_scale,
expected_max_diff,
"Output without ip-adapter must be same as normal inference",
)
self.assertGreater(
max_diff_with_adapter_scale, 1e-2, "Output with ip-adapter must be different from normal inference"
)
# 2. Multi IP-Adapter test cases
adapter_state_dict_1 = create_ip_adapter_state_dict(pipe.unet)
adapter_state_dict_2 = create_ip_adapter_state_dict(pipe.unet)
pipe.unet._load_ip_adapter_weights([adapter_state_dict_1, adapter_state_dict_2])
# forward pass with multi ip adapter, but scale=0 which should have no effect
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
pipe.set_ip_adapter_scale([0.0, 0.0])
output_without_multi_adapter_scale = pipe(**inputs)[0]
if expected_pipe_slice is not None:
output_without_multi_adapter_scale = output_without_multi_adapter_scale[0, -3:, -3:, -1].flatten()
# forward pass with multi ip adapter, but with scale of adapter weights
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
pipe.set_ip_adapter_scale([42.0, 42.0])
output_with_multi_adapter_scale = pipe(**inputs)[0]
if expected_pipe_slice is not None:
output_with_multi_adapter_scale = output_with_multi_adapter_scale[0, -3:, -3:, -1].flatten()
max_diff_without_multi_adapter_scale = np.abs(
output_without_multi_adapter_scale - output_without_adapter
).max()
max_diff_with_multi_adapter_scale = np.abs(output_with_multi_adapter_scale - output_without_adapter).max()
self.assertLess(
max_diff_without_multi_adapter_scale,
expected_max_diff,
"Output without multi-ip-adapter must be same as normal inference",
)
self.assertGreater(
max_diff_with_multi_adapter_scale,
1e-2,
"Output with multi-ip-adapter scale must be different from normal inference",
)
def test_ip_adapter_cfg(self, expected_max_diff: float = 1e-4):
parameters = inspect.signature(self.pipeline_class.__call__).parameters
if "guidance_scale" not in parameters:
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components).to(torch_device)
pipe.set_progress_bar_config(disable=None)
cross_attention_dim = pipe.unet.config.get("cross_attention_dim", 32)
adapter_state_dict = create_ip_adapter_state_dict(pipe.unet)
pipe.unet._load_ip_adapter_weights(adapter_state_dict)
pipe.set_ip_adapter_scale(1.0)
# forward pass with CFG not applied
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)[0].unsqueeze(0)]
inputs["guidance_scale"] = 1.0
out_no_cfg = pipe(**inputs)[0]
# forward pass with CFG applied
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
inputs["guidance_scale"] = 7.5
out_cfg = pipe(**inputs)[0]
assert out_cfg.shape == out_no_cfg.shape
def test_ip_adapter_masks(self, expected_max_diff: float = 1e-4):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components).to(torch_device)
pipe.set_progress_bar_config(disable=None)
cross_attention_dim = pipe.unet.config.get("cross_attention_dim", 32)
sample_size = pipe.unet.config.get("sample_size", 32)
block_out_channels = pipe.vae.config.get("block_out_channels", [128, 256, 512, 512])
input_size = sample_size * (2 ** (len(block_out_channels) - 1))
# forward pass without ip adapter
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
output_without_adapter = pipe(**inputs)[0]
output_without_adapter = output_without_adapter[0, -3:, -3:, -1].flatten()
adapter_state_dict = create_ip_adapter_state_dict(pipe.unet)
pipe.unet._load_ip_adapter_weights(adapter_state_dict)
# forward pass with single ip adapter and masks, but scale=0 which should have no effect
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
inputs["cross_attention_kwargs"] = {"ip_adapter_masks": [self._get_dummy_masks(input_size)]}
pipe.set_ip_adapter_scale(0.0)
output_without_adapter_scale = pipe(**inputs)[0]
output_without_adapter_scale = output_without_adapter_scale[0, -3:, -3:, -1].flatten()
# forward pass with single ip adapter and masks, but with scale of adapter weights
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
inputs["cross_attention_kwargs"] = {"ip_adapter_masks": [self._get_dummy_masks(input_size)]}
pipe.set_ip_adapter_scale(42.0)
output_with_adapter_scale = pipe(**inputs)[0]
output_with_adapter_scale = output_with_adapter_scale[0, -3:, -3:, -1].flatten()
max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max()
max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max()
self.assertLess(
max_diff_without_adapter_scale,
expected_max_diff,
"Output without ip-adapter must be same as normal inference",
)
self.assertGreater(
max_diff_with_adapter_scale, 1e-3, "Output with ip-adapter must be different from normal inference"
)
def test_ip_adapter_faceid(self, expected_max_diff: float = 1e-4):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components).to(torch_device)
pipe.set_progress_bar_config(disable=None)
cross_attention_dim = pipe.unet.config.get("cross_attention_dim", 32)
# forward pass without ip adapter
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
output_without_adapter = pipe(**inputs)[0]
output_without_adapter = output_without_adapter[0, -3:, -3:, -1].flatten()
adapter_state_dict = create_ip_adapter_faceid_state_dict(pipe.unet)
pipe.unet._load_ip_adapter_weights(adapter_state_dict)
# forward pass with single ip adapter, but scale=0 which should have no effect
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_image_embeds"] = [self._get_dummy_faceid_image_embeds(cross_attention_dim)]
pipe.set_ip_adapter_scale(0.0)
output_without_adapter_scale = pipe(**inputs)[0]
output_without_adapter_scale = output_without_adapter_scale[0, -3:, -3:, -1].flatten()
# forward pass with single ip adapter, but with scale of adapter weights
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_image_embeds"] = [self._get_dummy_faceid_image_embeds(cross_attention_dim)]
pipe.set_ip_adapter_scale(42.0)
output_with_adapter_scale = pipe(**inputs)[0]
output_with_adapter_scale = output_with_adapter_scale[0, -3:, -3:, -1].flatten()
max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max()
max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max()
self.assertLess(
max_diff_without_adapter_scale,
expected_max_diff,
"Output without ip-adapter must be same as normal inference",
)
self.assertGreater(
max_diff_with_adapter_scale, 1e-3, "Output with ip-adapter must be different from normal inference"
)
class PipelineLatentTesterMixin:
"""
This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes.
It provides a set of common tests for PyTorch pipeline that has vae, e.g.
equivalence of different input and output types, etc.
"""
@property
def image_params(self) -> frozenset:
raise NotImplementedError(
"You need to set the attribute `image_params` in the child test class. "
"`image_params` are tested for if all accepted input image types (i.e. `pt`,`pil`,`np`) are producing same results"
)
@property
def image_latents_params(self) -> frozenset:
raise NotImplementedError(
"You need to set the attribute `image_latents_params` in the child test class. "
"`image_latents_params` are tested for if passing latents directly are producing same results"
)
def get_dummy_inputs_by_type(self, device, seed=0, input_image_type="pt", output_type="np"):
inputs = self.get_dummy_inputs(device, seed)
def convert_to_pt(image):
if isinstance(image, torch.Tensor):
input_image = image
elif isinstance(image, np.ndarray):
input_image = VaeImageProcessor.numpy_to_pt(image)
elif isinstance(image, PIL.Image.Image):
input_image = VaeImageProcessor.pil_to_numpy(image)
input_image = VaeImageProcessor.numpy_to_pt(input_image)
else:
raise ValueError(f"unsupported input_image_type {type(image)}")
return input_image
def convert_pt_to_type(image, input_image_type):
if input_image_type == "pt":
input_image = image
elif input_image_type == "np":
input_image = VaeImageProcessor.pt_to_numpy(image)
elif input_image_type == "pil":
input_image = VaeImageProcessor.pt_to_numpy(image)
input_image = VaeImageProcessor.numpy_to_pil(input_image)
else:
raise ValueError(f"unsupported input_image_type {input_image_type}.")
return input_image
for image_param in self.image_params:
if image_param in inputs.keys():
inputs[image_param] = convert_pt_to_type(
convert_to_pt(inputs[image_param]).to(device), input_image_type
)
inputs["output_type"] = output_type
return inputs
def test_pt_np_pil_outputs_equivalent(self, expected_max_diff=1e-4):
self._test_pt_np_pil_outputs_equivalent(expected_max_diff=expected_max_diff)
def _test_pt_np_pil_outputs_equivalent(self, expected_max_diff=1e-4, input_image_type="pt"):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
output_pt = pipe(
**self.get_dummy_inputs_by_type(torch_device, input_image_type=input_image_type, output_type="pt")
)[0]
output_np = pipe(
**self.get_dummy_inputs_by_type(torch_device, input_image_type=input_image_type, output_type="np")
)[0]
output_pil = pipe(
**self.get_dummy_inputs_by_type(torch_device, input_image_type=input_image_type, output_type="pil")
)[0]
max_diff = np.abs(output_pt.cpu().numpy().transpose(0, 2, 3, 1) - output_np).max()
self.assertLess(
max_diff, expected_max_diff, "`output_type=='pt'` generate different results from `output_type=='np'`"
)
max_diff = np.abs(np.array(output_pil[0]) - (output_np * 255).round()).max()
self.assertLess(max_diff, 2.0, "`output_type=='pil'` generate different results from `output_type=='np'`")
def test_pt_np_pil_inputs_equivalent(self):
if len(self.image_params) == 0:
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
out_input_pt = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="pt"))[0]
out_input_np = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0]
out_input_pil = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="pil"))[0]
max_diff = np.abs(out_input_pt - out_input_np).max()
self.assertLess(max_diff, 1e-4, "`input_type=='pt'` generate different result from `input_type=='np'`")
max_diff = np.abs(out_input_pil - out_input_np).max()
self.assertLess(max_diff, 1e-2, "`input_type=='pt'` generate different result from `input_type=='np'`")
def test_latents_input(self):
if len(self.image_latents_params) == 0:
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.image_processor = VaeImageProcessor(do_resize=False, do_normalize=False)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
out = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="pt"))[0]
vae = components["vae"]
inputs = self.get_dummy_inputs_by_type(torch_device, input_image_type="pt")
generator = inputs["generator"]
for image_param in self.image_latents_params:
if image_param in inputs.keys():
inputs[image_param] = (
vae.encode(inputs[image_param]).latent_dist.sample(generator) * vae.config.scaling_factor
)
out_latents_inputs = pipe(**inputs)[0]
max_diff = np.abs(out - out_latents_inputs).max()
self.assertLess(max_diff, 1e-4, "passing latents as image input generate different result from passing image")
def test_multi_vae(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
block_out_channels = pipe.vae.config.block_out_channels
norm_num_groups = pipe.vae.config.norm_num_groups
vae_classes = [AutoencoderKL, AsymmetricAutoencoderKL, ConsistencyDecoderVAE, AutoencoderTiny]
configs = [
get_autoencoder_kl_config(block_out_channels, norm_num_groups),
get_asym_autoencoder_kl_config(block_out_channels, norm_num_groups),
get_consistency_vae_config(block_out_channels, norm_num_groups),
get_autoencoder_tiny_config(block_out_channels),
]
out_np = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0]
for vae_cls, config in zip(vae_classes, configs):
vae = vae_cls(**config)
vae = vae.to(torch_device)
components["vae"] = vae
vae_pipe = self.pipeline_class(**components)
out_vae_np = vae_pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0]
assert out_vae_np.shape == out_np.shape
@require_torch
class PipelineFromPipeTesterMixin:
@property
def original_pipeline_class(self):
if "xl" in self.pipeline_class.__name__.lower():
original_pipeline_class = StableDiffusionXLPipeline
elif "kolors" in self.pipeline_class.__name__.lower():
original_pipeline_class = KolorsPipeline
else:
original_pipeline_class = StableDiffusionPipeline
return original_pipeline_class
def get_dummy_inputs_pipe(self, device, seed=0):
inputs = self.get_dummy_inputs(device, seed=seed)
inputs["output_type"] = "np"
inputs["return_dict"] = False
return inputs
def get_dummy_inputs_for_pipe_original(self, device, seed=0):
inputs = {}
for k, v in self.get_dummy_inputs_pipe(device, seed=seed).items():
if k in set(inspect.signature(self.original_pipeline_class.__call__).parameters.keys()):
inputs[k] = v
return inputs
def test_from_pipe_consistent_config(self):
if self.original_pipeline_class == StableDiffusionPipeline:
original_repo = "hf-internal-testing/tiny-stable-diffusion-pipe"
original_kwargs = {"requires_safety_checker": False}
elif self.original_pipeline_class == StableDiffusionXLPipeline:
original_repo = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
original_kwargs = {"requires_aesthetics_score": True, "force_zeros_for_empty_prompt": False}
elif self.original_pipeline_class == KolorsPipeline:
original_repo = "hf-internal-testing/tiny-kolors-pipe"
original_kwargs = {"force_zeros_for_empty_prompt": False}
else:
raise ValueError(
"original_pipeline_class must be either StableDiffusionPipeline or StableDiffusionXLPipeline"
)
# create original_pipeline_class(sd/sdxl)
pipe_original = self.original_pipeline_class.from_pretrained(original_repo, **original_kwargs)
# original_pipeline_class(sd/sdxl) -> pipeline_class
pipe_components = self.get_dummy_components()
pipe_additional_components = {}
for name, component in pipe_components.items():
if name not in pipe_original.components:
pipe_additional_components[name] = component
pipe = self.pipeline_class.from_pipe(pipe_original, **pipe_additional_components)
# pipeline_class -> original_pipeline_class(sd/sdxl)
original_pipe_additional_components = {}
for name, component in pipe_original.components.items():
if name not in pipe.components or not isinstance(component, pipe.components[name].__class__):
original_pipe_additional_components[name] = component
pipe_original_2 = self.original_pipeline_class.from_pipe(pipe, **original_pipe_additional_components)
# compare the config
original_config = {k: v for k, v in pipe_original.config.items() if not k.startswith("_")}
original_config_2 = {k: v for k, v in pipe_original_2.config.items() if not k.startswith("_")}
assert original_config_2 == original_config
def test_from_pipe_consistent_forward_pass(self, expected_max_diff=1e-3):
components = self.get_dummy_components()
original_expected_modules, _ = self.original_pipeline_class._get_signature_keys(self.original_pipeline_class)
# pipeline components that are also expected to be in the original pipeline
original_pipe_components = {}
# additional components that are not in the pipeline, but expected in the original pipeline
original_pipe_additional_components = {}
# additional components that are in the pipeline, but not expected in the original pipeline
current_pipe_additional_components = {}
for name, component in components.items():
if name in original_expected_modules:
original_pipe_components[name] = component
else:
current_pipe_additional_components[name] = component
for name in original_expected_modules:
if name not in original_pipe_components:
if name in self.original_pipeline_class._optional_components:
original_pipe_additional_components[name] = None
else:
raise ValueError(f"missing required module for {self.original_pipeline_class.__class__}: {name}")
pipe_original = self.original_pipeline_class(**original_pipe_components, **original_pipe_additional_components)
for component in pipe_original.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe_original.to(torch_device)
pipe_original.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs_for_pipe_original(torch_device)
output_original = pipe_original(**inputs)[0]
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs_pipe(torch_device)
output = pipe(**inputs)[0]
pipe_from_original = self.pipeline_class.from_pipe(pipe_original, **current_pipe_additional_components)
pipe_from_original.to(torch_device)
pipe_from_original.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs_pipe(torch_device)
output_from_original = pipe_from_original(**inputs)[0]
max_diff = np.abs(to_np(output) - to_np(output_from_original)).max()
self.assertLess(
max_diff,
expected_max_diff,
"The outputs of the pipelines created with `from_pipe` and `__init__` are different.",
)
inputs = self.get_dummy_inputs_for_pipe_original(torch_device)
output_original_2 = pipe_original(**inputs)[0]
max_diff = np.abs(to_np(output_original) - to_np(output_original_2)).max()
self.assertLess(max_diff, expected_max_diff, "`from_pipe` should not change the output of original pipeline.")
for component in pipe_original.components.values():
if hasattr(component, "attn_processors"):
assert all(
type(proc) == AttnProcessor for proc in component.attn_processors.values()
), "`from_pipe` changed the attention processor in original pipeline."
@require_accelerator
@require_accelerate_version_greater("0.14.0")
def test_from_pipe_consistent_forward_pass_cpu_offload(self, expected_max_diff=1e-3):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs_pipe(torch_device)
output = pipe(**inputs)[0]
original_expected_modules, _ = self.original_pipeline_class._get_signature_keys(self.original_pipeline_class)
# pipeline components that are also expected to be in the original pipeline
original_pipe_components = {}
# additional components that are not in the pipeline, but expected in the original pipeline
original_pipe_additional_components = {}
# additional components that are in the pipeline, but not expected in the original pipeline
current_pipe_additional_components = {}
for name, component in components.items():
if name in original_expected_modules:
original_pipe_components[name] = component
else:
current_pipe_additional_components[name] = component
for name in original_expected_modules:
if name not in original_pipe_components:
if name in self.original_pipeline_class._optional_components:
original_pipe_additional_components[name] = None
else:
raise ValueError(f"missing required module for {self.original_pipeline_class.__class__}: {name}")
pipe_original = self.original_pipeline_class(**original_pipe_components, **original_pipe_additional_components)
for component in pipe_original.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe_original.set_progress_bar_config(disable=None)
pipe_from_original = self.pipeline_class.from_pipe(pipe_original, **current_pipe_additional_components)
for component in pipe_from_original.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe_from_original.enable_model_cpu_offload(device=torch_device)
pipe_from_original.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs_pipe(torch_device)
output_from_original = pipe_from_original(**inputs)[0]
max_diff = np.abs(to_np(output) - to_np(output_from_original)).max()
self.assertLess(
max_diff,
expected_max_diff,
"The outputs of the pipelines created with `from_pipe` and `__init__` are different.",
)
@require_torch
class PipelineKarrasSchedulerTesterMixin:
"""
This mixin is designed to be used with unittest.TestCase classes.
It provides a set of common tests for each PyTorch pipeline that makes use of KarrasDiffusionSchedulers
equivalence of dict and tuple outputs, etc.
"""
def test_karras_schedulers_shape(
self, num_inference_steps_for_strength=4, num_inference_steps_for_strength_for_iterations=5
):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
# make sure that PNDM does not need warm-up
pipe.scheduler.register_to_config(skip_prk_steps=True)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = 2
if "strength" in inputs:
inputs["num_inference_steps"] = num_inference_steps_for_strength
inputs["strength"] = 0.5
outputs = []
for scheduler_enum in KarrasDiffusionSchedulers:
if "KDPM2" in scheduler_enum.name:
inputs["num_inference_steps"] = num_inference_steps_for_strength_for_iterations
scheduler_cls = getattr(diffusers, scheduler_enum.name)
pipe.scheduler = scheduler_cls.from_config(pipe.scheduler.config)
output = pipe(**inputs)[0]
outputs.append(output)
if "KDPM2" in scheduler_enum.name:
inputs["num_inference_steps"] = 2
assert check_same_shape(outputs)
@require_torch
class PipelineTesterMixin:
"""
This mixin is designed to be used with unittest.TestCase classes.
It provides a set of common tests for each PyTorch pipeline, e.g. saving and loading the pipeline,
equivalence of dict and tuple outputs, etc.
"""
# Canonical parameters that are passed to `__call__` regardless
# of the type of pipeline. They are always optional and have common
# sense default values.
required_optional_params = frozenset(
[
"num_inference_steps",
"num_images_per_prompt",
"generator",
"latents",
"output_type",
"return_dict",
]
)
# set these parameters to False in the child class if the pipeline does not support the corresponding functionality
test_attention_slicing = True
test_xformers_attention = True
def get_generator(self, seed):
device = torch_device if torch_device != "mps" else "cpu"
generator = torch.Generator(device).manual_seed(seed)
return generator
@property
def pipeline_class(self) -> Union[Callable, DiffusionPipeline]:
raise NotImplementedError(
"You need to set the attribute `pipeline_class = ClassNameOfPipeline` in the child test class. "
"See existing pipeline tests for reference."
)
def get_dummy_components(self):
raise NotImplementedError(
"You need to implement `get_dummy_components(self)` in the child test class. "
"See existing pipeline tests for reference."
)
def get_dummy_inputs(self, device, seed=0):
raise NotImplementedError(
"You need to implement `get_dummy_inputs(self, device, seed)` in the child test class. "
"See existing pipeline tests for reference."
)
@property
def params(self) -> frozenset:
raise NotImplementedError(
"You need to set the attribute `params` in the child test class. "
"`params` are checked for if all values are present in `__call__`'s signature."
" You can set `params` using one of the common set of parameters defined in `pipeline_params.py`"
" e.g., `TEXT_TO_IMAGE_PARAMS` defines the common parameters used in text to "
"image pipelines, including prompts and prompt embedding overrides."
"If your pipeline's set of arguments has minor changes from one of the common sets of arguments, "
"do not make modifications to the existing common sets of arguments. I.e. a text to image pipeline "
"with non-configurable height and width arguments should set the attribute as "
"`params = TEXT_TO_IMAGE_PARAMS - {'height', 'width'}`. "
"See existing pipeline tests for reference."
)
@property
def batch_params(self) -> frozenset:
raise NotImplementedError(
"You need to set the attribute `batch_params` in the child test class. "
"`batch_params` are the parameters required to be batched when passed to the pipeline's "
"`__call__` method. `pipeline_params.py` provides some common sets of parameters such as "
"`TEXT_TO_IMAGE_BATCH_PARAMS`, `IMAGE_VARIATION_BATCH_PARAMS`, etc... If your pipeline's "
"set of batch arguments has minor changes from one of the common sets of batch arguments, "
"do not make modifications to the existing common sets of batch arguments. I.e. a text to "
"image pipeline `negative_prompt` is not batched should set the attribute as "
"`batch_params = TEXT_TO_IMAGE_BATCH_PARAMS - {'negative_prompt'}`. "
"See existing pipeline tests for reference."
)
@property
def callback_cfg_params(self) -> frozenset:
raise NotImplementedError(
"You need to set the attribute `callback_cfg_params` in the child test class that requires to run test_callback_cfg. "
"`callback_cfg_params` are the parameters that needs to be passed to the pipeline's callback "
"function when dynamically adjusting `guidance_scale`. They are variables that require special"
"treatment when `do_classifier_free_guidance` is `True`. `pipeline_params.py` provides some common"
" sets of parameters such as `TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS`. If your pipeline's "
"set of cfg arguments has minor changes from one of the common sets of cfg arguments, "
"do not make modifications to the existing common sets of cfg arguments. I.e. for inpaint pipeline, you "
" need to adjust batch size of `mask` and `masked_image_latents` so should set the attribute as"
"`callback_cfg_params = TEXT_TO_IMAGE_CFG_PARAMS.union({'mask', 'masked_image_latents'})`"
)
def setUp(self):
# clean up the VRAM before each test
super().setUp()
gc.collect()
torch.cuda.empty_cache()
def tearDown(self):
# clean up the VRAM after each test in case of CUDA runtime errors
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_save_load_local(self, expected_max_difference=5e-4):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
output = pipe(**inputs)[0]
logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
logger.setLevel(diffusers.logging.INFO)
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir, safe_serialization=False)
with CaptureLogger(logger) as cap_logger:
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)