Skip to content

Commit

Permalink
update example
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Nov 8, 2023
1 parent 1f85bf2 commit 21cd5fd
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion examples/callback/custom_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,17 @@
from benchmarl.experiment import Experiment, ExperimentConfig
from benchmarl.experiment.callback import Callback
from benchmarl.models.mlp import MlpConfig
from tensordict import TensorDictBase
from tensordict import TensorDict, TensorDictBase


class MyCallbackA(Callback):
def on_batch_collected(self, batch: TensorDictBase):
print(f"Callback A is doing something with the sampling batch {batch}")

def on_train_step(self, batch: TensorDictBase, group: str) -> TensorDictBase:
print(f"Callback A is computing a loss with the training tensordict {batch}")
return TensorDict({}, [])

def on_train_end(self, training_td: TensorDictBase, group: str):
print(
f"Callback A is doing something with the training tensordict {training_td}"
Expand Down

0 comments on commit 21cd5fd

Please sign in to comment.