@@ -94,29 +94,47 @@ def test_naming_conventions(model_fn):
9494 + TM .get_models_from_module (models .video )
9595 + TM .get_models_from_module (models .optical_flow ),
9696)
97+ @run_if_test_with_prototype
9798def test_schema_meta_validation (model_fn ):
9899 classification_fields = ["size" , "categories" , "acc@1" , "acc@5" ]
99100 defaults = {
100- "all" : ["interpolation" , "recipe" ],
101+ "all" : ["task" , "architecture" , "publication_year" , " interpolation" , "recipe" , "num_params " ],
101102 "models" : classification_fields ,
102103 "detection" : ["categories" , "map" ],
103104 "quantization" : classification_fields + ["backend" , "quantization" , "unquantized" ],
104105 "segmentation" : ["categories" , "mIoU" , "acc" ],
105106 "video" : classification_fields ,
106107 "optical_flow" : [],
107108 }
109+ model_name = model_fn .__name__
108110 module_name = model_fn .__module__ .split ("." )[- 2 ]
109111 fields = set (defaults ["all" ] + defaults [module_name ])
110112
111113 weights_enum = _get_model_weights (model_fn )
114+ if len (weights_enum ) == 0 :
115+ pytest .skip (f"Model '{ model_name } ' doesn't have any pre-trained weights." )
112116
113117 problematic_weights = {}
118+ incorrect_params = []
114119 for w in weights_enum :
115120 missing_fields = fields - set (w .meta .keys ())
116121 if missing_fields :
117122 problematic_weights [w ] = missing_fields
123+ if w == weights_enum .default :
124+ if module_name == "quantization" :
125+ # parametes() cound doesn't work well with quantization, so we check against the non-quantized
126+ unquantized_w = w .meta .get ("unquantized" )
127+ if unquantized_w is not None and w .meta .get ("num_params" ) != unquantized_w .meta .get ("num_params" ):
128+ incorrect_params .append (w )
129+ else :
130+ if w .meta .get ("num_params" ) != sum (p .numel () for p in model_fn (weights = w ).parameters ()):
131+ incorrect_params .append (w )
132+ else :
133+ if w .meta .get ("num_params" ) != weights_enum .default .meta .get ("num_params" ):
134+ incorrect_params .append (w )
118135
119136 assert not problematic_weights
137+ assert not incorrect_params
120138
121139
122140@pytest .mark .parametrize ("model_fn" , TM .get_models_from_module (models ))
0 commit comments