We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent f55e37a commit 3e4af0aCopy full SHA for 3e4af0a
tests/test_models.py
@@ -50,6 +50,14 @@ def test_encoder_timm():
50
input_tensor = torch.rand(b, c, h, w)
51
backend = "timm"
52
53
+ encoder = VisualEncoder(
54
+ model_name="off", in_chans=c, d_model=features, backend=backend
55
+ )
56
+ output = encoder(input_tensor)
57
+
58
+ assert output.shape == (b, features)
59
+ assert not torch.is_nonzero.any()
60
61
encoder = VisualEncoder(
62
model_name="resnet18", in_chans=c, d_model=features, backend=backend
63
)
@@ -93,6 +101,14 @@ def test_encoder_torch():
93
101
94
102
backend = "torch"
95
103
104
105
106
107
108
109
110
111
96
112
97
113
98
114
0 commit comments