You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi,
I have a classification model trained on the shape batch_size x N (=46) x channel x height x width. How can I adapt TRAK to use for that ? I get an error in the featurize function itself. Do I have to modify the in_dims?
/opt/conda/lib/python3.8/site-packages/trak/gradient_computers.py in compute_per_sample_grad(self, batch)
148
149 # map over batch dimensions (hence 0 for each batch dimension, and None for model params)
--> 150 grads = torch.func.vmap(
151 grads_loss,
152 in_dims=(None, None, None, *([0] * len(batch))),
/opt/conda/lib/python3.8/site-packages/torch/_functorch/vmap.py in wrapped(*args, **kwargs)
432
433 # If chunk_size is not specified.
--> 434 return _flat_vmap(
435 func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs
436 )
/opt/conda/lib/python3.8/site-packages/torch/_functorch/vmap.py in fn(*args, **kwargs)
37 def fn(*args, **kwargs):
38 with torch.autograd.graph.disable_saved_tensors_hooks(message):
---> 39 return f(*args, **kwargs)
40 return fn
41
Hi,
I have a classification model trained on the shape batch_size x N (=46) x channel x height x width. How can I adapt TRAK to use for that ? I get an error in the featurize function itself. Do I have to modify the in_dims?
/opt/conda/lib/python3.8/site-packages/trak/gradient_computers.py in compute_per_sample_grad(self, batch)
148
149 # map over batch dimensions (hence 0 for each batch dimension, and None for model params)
--> 150 grads = torch.func.vmap(
151 grads_loss,
152 in_dims=(None, None, None, *([0] * len(batch))),
/opt/conda/lib/python3.8/site-packages/torch/_functorch/vmap.py in wrapped(*args, **kwargs)
432
433 # If chunk_size is not specified.
--> 434 return _flat_vmap(
435 func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs
436 )
/opt/conda/lib/python3.8/site-packages/torch/_functorch/vmap.py in fn(*args, **kwargs)
37 def fn(*args, **kwargs):
38 with torch.autograd.graph.disable_saved_tensors_hooks(message):
---> 39 return f(*args, **kwargs)
40 return fn
41
/opt/conda/lib/python3.8/site-packages/torch/_functorch/vmap.py in _flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs)
617 try:
618 batched_inputs = _create_batched_inputs(flat_in_dims, flat_args, vmap_level, args_spec)
--> 619 batched_outputs = func(*batched_inputs, **kwargs)
620 return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)
621 finally:
/opt/conda/lib/python3.8/site-packages/torch/functorch/eager_transforms.py in wrapper(*args, **kwargs)
1378 @wraps(func)
1379 def wrapper(*args, **kwargs):
-> 1380 results = grad_and_value(func, argnums, has_aux=has_aux)(*args, **kwargs)
1381 if has_aux:
1382 grad, (, aux) = results
/opt/conda/lib/python3.8/site-packages/torch/_functorch/vmap.py in fn(*args, **kwargs)
37 def fn(*args, **kwargs):
38 with torch.autograd.graph.disable_saved_tensors_hooks(message):
---> 39 return f(*args, **kwargs)
40 return fn
41
/opt/conda/lib/python3.8/site-packages/torch/functorch/eager_transforms.py in wrapper(*args, **kwargs)
1243 tree_map(partial(_create_differentiable, level=level), diff_args)
1244
-> 1245 output = func(*args, **kwargs)
1246 if has_aux:
1247 if not (isinstance(output, tuple) and len(output) == 2):
/opt/conda/lib/python3.8/site-packages/trak/modelout_functions.py in get_output(model, weights, buffers, image, label)
138 """
139 logits = ch.func.functional_call(model, (weights, buffers), image.unsqueeze(0))
--> 140 bindex = ch.arange(logits.shape[0]).to(logits.device, non_blocking=False)
141 logits_correct = logits[bindex, label.unsqueeze(0)]
142
AttributeError: 'tuple' object has no attribute 'shape'
The text was updated successfully, but these errors were encountered: