Skip to content

Commit

Permalink
refactor: allowing multi task subtasks to reuse shared embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
laserkelvin committed Jul 1, 2024
1 parent b888ee8 commit c182d90
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions matsciml/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2689,8 +2689,10 @@ def forward(
if self.is_multidata:
for key, data in batch.items():
data["embeddings"] = self.encoder(data)
embeddings = data["embeddings"]
else:
batch["embeddings"] = self.encoder(batch)
embeddings = batch["embeddings"]
# for single dataset usage, we assume the nested structure isn't used
if self.is_multidata:
for key, data in batch.items():
Expand All @@ -2699,13 +2701,13 @@ def forward(
results[key] = {}
# finally call the task with the data
for task_type, subtask in subtasks.items():
results[key][task_type] = subtask(data)
results[key][task_type] = subtask.process_embedding(embeddings)
else:
# in the single dataset case, we can skip the outer loop
# and just pass the batch into the subtask
tasks = list(self.task_map.values()).pop(0)
for task_type, subtask in tasks.items():
results[task_type] = subtask(batch)
results[task_type] = subtask.process_embedding(embeddings)
return results

def predict(self, batch: BatchDict) -> dict[str, dict[str, torch.Tensor]]:
Expand Down

0 comments on commit c182d90

Please sign in to comment.