Skip to content

Commit 62928eb

Browse files
authored
change order of key mapping transform (#993)
* change order of key mapping * move transform to before renaming * fix lint
1 parent 42d6669 commit 62928eb

File tree

3 files changed

+12
-8
lines changed

3 files changed

+12
-8
lines changed

src/fairchem/core/datasets/ase_datasets.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -135,12 +135,12 @@ def __getitem__(self, idx):
135135
if self.a2g.r_energy is True and self.lin_ref is not None:
136136
data_object.energy -= sum(self.lin_ref[data_object.atomic_numbers.long()])
137137

138-
if self.key_mapping is not None:
139-
data_object = rename_data_object_keys(data_object, self.key_mapping)
140-
141138
# Transform data object
142139
data_object = self.transforms(data_object)
143140

141+
if self.key_mapping is not None:
142+
data_object = rename_data_object_keys(data_object, self.key_mapping)
143+
144144
if self.config.get("include_relaxed_energy", False):
145145
data_object.energy_relaxed = self.get_relaxed_energy(self.ids[idx])
146146

src/fairchem/core/datasets/lmdb_dataset.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,12 @@ def __getitem__(self, idx: int) -> T_co:
145145
)
146146
data_object = pyg2_data_transform(pickle.loads(datapoint_pickled))
147147

148+
data_object = self.transforms(data_object)
149+
148150
if self.key_mapping is not None:
149151
data_object = rename_data_object_keys(data_object, self.key_mapping)
150152

151-
return self.transforms(data_object)
153+
return data_object
152154

153155
def connect_db(self, lmdb_path: Path | None = None) -> lmdb.Environment:
154156
return lmdb.open(

src/fairchem/core/datasets/oc22_lmdb_dataset.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -192,9 +192,6 @@ def __getitem__(self, idx):
192192
lin_energy = sum(self.lin_ref[data_object.atomic_numbers.long()])
193193
data_object[attr] -= lin_energy
194194

195-
if self.key_mapping is not None:
196-
data_object = rename_data_object_keys(data_object, self.key_mapping)
197-
198195
# to jointly train on oc22+oc20, need to delete these oc20-only attributes
199196
# ensure otf_graph=1 in your model configuration
200197
if "edge_index" in data_object:
@@ -204,7 +201,12 @@ def __getitem__(self, idx):
204201
if "distances" in data_object:
205202
del data_object.distances
206203

207-
return self.transforms(data_object)
204+
data_object = self.transforms(data_object)
205+
206+
if self.key_mapping is not None:
207+
data_object = rename_data_object_keys(data_object, self.key_mapping)
208+
209+
return data_object
208210

209211
def connect_db(self, lmdb_path=None):
210212
return lmdb.open(

0 commit comments

Comments
 (0)