-
Notifications
You must be signed in to change notification settings - Fork 37
/
Copy pathtest_hubconf.py
31 lines (22 loc) · 1.46 KB
/
test_hubconf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import sys
import torch
from hubconf import radio_model
if __name__ == "__main__":
model = radio_model(version=sys.argv[1] if len(sys.argv) > 1 else '').cuda()
x = torch.rand(1, 3, 224, 224, device='cuda')
with torch.no_grad():
y = model(x)
op, (int0,) = model.forward_intermediates(x, indices=[-1], output_fmt='NLC', aggregation='sparse')
diff = (op.features - int0).norm()
print(f'Output diff: {diff.item():.8f}')
y_int1 = model.forward_intermediates(x, indices=[1, 5, 7], output_fmt='NCHW')
y_int2 = model.forward_intermediates(x, indices=[2, 4, 6], output_fmt='NLC')
y_int3 = model.forward_intermediates(x, indices=[3, 5, 7], return_prefix_tokens=True, output_fmt='NCHW', aggregation='dense', intermediates_only=True)
y_int4 = model.forward_intermediates(x, indices=[3, 5, 7], return_prefix_tokens=True, output_fmt='NCHW', aggregation='dense', intermediates_only=True, norm_alpha_scheme='pre-alpha')
y_int5 = model.forward_intermediates(x, indices=[3, 5, 7], return_prefix_tokens=True, output_fmt='NCHW', aggregation='dense', intermediates_only=True, norm_alpha_scheme='none')
pass
model = radio_model(version=sys.argv[1] if len(sys.argv) > 1 else '', adaptor_names=['clip', 'dino_v2']).cuda()
with torch.no_grad():
y = model(x, feature_fmt='NCHW')
assert y['backbone'].features.ndim == 4 and y['clip'].features.ndim == 4 and y['dino_v2'].features.ndim == 4
pass