You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Today I was trying to train a model using jax backend with multiple losses, while using the sample_weight argument, which seem to expect something like:
However, I got a KeyError referencing a section of method call in trainers/compile_utils.py. It was trying to look for key 0 but failing to find it in line
By looking at it, it seems to me like resolve_path was written with self._flat_losses in mind, taking for granted that path is something accessible. However, in my case, sample_weight was not preprocessed in a way for it to have a path attribute, the structure was as I had passed it to fit (in my case, a vanilla dict with the traced stuff inside).
I suppose a good way to fix it would be adding the path to sample_weight in an appropriate spot, but for now I resulted in editing the library code to do something like:
for (path, loss_fn, loss_weight, loss_name), metric in zip(
self._flat_losses, metrics
):
y_t, y_p = resolve_path(path, y_true), resolve_path(path, y_pred)
if sample_weight is not None and tree.is_nested(sample_weight):
_sample_weight = sample_weight[loss_name]
else:
_sample_weight = sample_weight
Which is pretty coarse but works for me.
Just wanted to give a heads-up on it if it's a real issue, otherwise if I'm doing something comically wrong I apologize for the time wasted!
python 3.11.11
keras 3.9
jax 0.5.2
Have a great day!
The text was updated successfully, but these errors were encountered:
Hello!
Today I was trying to train a model using jax backend with multiple losses, while using the sample_weight argument, which seem to expect something like:
However, I got a KeyError referencing a section of method call in trainers/compile_utils.py. It was trying to look for key 0 but failing to find it in line
_sample_weight = resolve_path(path, sample_weight)
By looking at it, it seems to me like resolve_path was written with self._flat_losses in mind, taking for granted that path is something accessible. However, in my case, sample_weight was not preprocessed in a way for it to have a path attribute, the structure was as I had passed it to fit (in my case, a vanilla dict with the traced stuff inside).
I suppose a good way to fix it would be adding the path to sample_weight in an appropriate spot, but for now I resulted in editing the library code to do something like:
Which is pretty coarse but works for me.
Just wanted to give a heads-up on it if it's a real issue, otherwise if I'm doing something comically wrong I apologize for the time wasted!
python 3.11.11
keras 3.9
jax 0.5.2
Have a great day!
The text was updated successfully, but these errors were encountered: