Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TRAK on 5D shapes #74

Open
Bas-2k opened this issue Sep 24, 2024 · 0 comments
Open

TRAK on 5D shapes #74

Bas-2k opened this issue Sep 24, 2024 · 0 comments

Comments

@Bas-2k
Copy link

Bas-2k commented Sep 24, 2024

Hi,
I have a classification model trained on the shape batch_size x N (=46) x channel x height x width. How can I adapt TRAK to use for that ? I get an error in the featurize function itself. Do I have to modify the in_dims?

/opt/conda/lib/python3.8/site-packages/trak/gradient_computers.py in compute_per_sample_grad(self, batch)
148
149 # map over batch dimensions (hence 0 for each batch dimension, and None for model params)
--> 150 grads = torch.func.vmap(
151 grads_loss,
152 in_dims=(None, None, None, *([0] * len(batch))),

/opt/conda/lib/python3.8/site-packages/torch/_functorch/vmap.py in wrapped(*args, **kwargs)
432
433 # If chunk_size is not specified.
--> 434 return _flat_vmap(
435 func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs
436 )

/opt/conda/lib/python3.8/site-packages/torch/_functorch/vmap.py in fn(*args, **kwargs)
37 def fn(*args, **kwargs):
38 with torch.autograd.graph.disable_saved_tensors_hooks(message):
---> 39 return f(*args, **kwargs)
40 return fn
41

/opt/conda/lib/python3.8/site-packages/torch/_functorch/vmap.py in _flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs)
617 try:
618 batched_inputs = _create_batched_inputs(flat_in_dims, flat_args, vmap_level, args_spec)
--> 619 batched_outputs = func(*batched_inputs, **kwargs)
620 return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)
621 finally:

/opt/conda/lib/python3.8/site-packages/torch/functorch/eager_transforms.py in wrapper(*args, **kwargs)
1378 @wraps(func)
1379 def wrapper(*args, **kwargs):
-> 1380 results = grad_and_value(func, argnums, has_aux=has_aux)(*args, **kwargs)
1381 if has_aux:
1382 grad, (
, aux) = results

/opt/conda/lib/python3.8/site-packages/torch/_functorch/vmap.py in fn(*args, **kwargs)
37 def fn(*args, **kwargs):
38 with torch.autograd.graph.disable_saved_tensors_hooks(message):
---> 39 return f(*args, **kwargs)
40 return fn
41

/opt/conda/lib/python3.8/site-packages/torch/functorch/eager_transforms.py in wrapper(*args, **kwargs)
1243 tree_map
(partial(_create_differentiable, level=level), diff_args)
1244
-> 1245 output = func(*args, **kwargs)
1246 if has_aux:
1247 if not (isinstance(output, tuple) and len(output) == 2):

/opt/conda/lib/python3.8/site-packages/trak/modelout_functions.py in get_output(model, weights, buffers, image, label)
138 """
139 logits = ch.func.functional_call(model, (weights, buffers), image.unsqueeze(0))
--> 140 bindex = ch.arange(logits.shape[0]).to(logits.device, non_blocking=False)
141 logits_correct = logits[bindex, label.unsqueeze(0)]
142

AttributeError: 'tuple' object has no attribute 'shape'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant