Skip to content

Commit 89274f9

Browse files
committed
loading aggregaged osm labels from osm_aggregate.parquet
1 parent 6d02544 commit 89274f9

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

Diff for: src/earthtext/datamodules/components/chipmultilabel.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def __init__(
120120
files = pd.Series([f.removesuffix('.npy') for f in os.listdir(folder) if os.path.isfile(os.path.join(folder, f))])
121121
self.metadata = self.metadata[self.metadata['original_chip_id'].isin(files)]
122122
self.metadata.set_index('original_chip_id', inplace=True)
123-
123+
self.df_osm_agg = pd.read_parquet(f"{self.folder_neighbors}/osm_aggregate.parquet")
124124

125125
if max_items is not None:
126126
self.metadata = self.metadata.iloc[np.random.permutation(len(self.metadata))[:max_items]]
@@ -164,24 +164,25 @@ def __getitem__(self, idx):
164164

165165
if self.neighbor_embeddings_folder is not None:
166166
# Aggregate neighbor OSM data
167-
item_neighbors = pd.read_parquet(f"{self.folder_neighbors}/{item.name}.parquet")['chipid']
168-
osm_aggregate = self.metadata.loc[item_neighbors[
169-
item_neighbors.isin(self.metadata.index)], ['onehot_count', 'onehot_area', 'onehot_length']].sum()
167+
# item_neighbors = pd.read_parquet(f"{self.folder_neighbors}/{item.name}.parquet")['chipid']
168+
# osm_aggregate = self.metadata.loc[item_neighbors[
169+
# item_neighbors.isin(self.metadata.index)], ['onehot_count', 'onehot_area', 'onehot_length']].sum()
170+
osm_aggregate = self.df_osm_agg.loc[item.name]
170171
if self.multilabel_threshold_osm_ohecount is not None:
171-
multilabel = osm_neighbors['onehot_count'].astype(int)
172+
multilabel = osm_aggregate['onehot_count'].astype(int)
172173
multilabel = (multilabel >= self.multilabel_threshold_osm_ohecount).astype(int)
173174
if self.multilabel_threshold_osm_ohearea is not None:
174175
# either area or a bit less than squared length
175176
min_ohe_length = np.sqrt(self.multilabel_threshold_osm_ohearea)*4 / 1.5
176-
multilabel = (osm_aggregate.onehot_area > self.multilabel_threshold_osm_ohearea) | (osm_aggregate.onehot_length > min_ohe_length)
177+
multilabel = (osm_aggregate['onehot_area'] > self.multilabel_threshold_osm_ohearea) | (osm_aggregate['onehot_length'] > min_ohe_length)
177178
else:
178179
if self.multilabel_threshold_osm_ohecount is not None:
179-
multilabel = item.onehot_count.astype(int)
180+
multilabel = item['onehot_count'].astype(int)
180181
multilabel = (multilabel >= self.multilabel_threshold_osm_ohecount).astype(int)
181182
if self.multilabel_threshold_osm_ohearea is not None:
182183
# either area or a bit less than squared length
183184
min_ohe_length = np.sqrt(self.multilabel_threshold_osm_ohearea)*4 / 1.5
184-
multilabel = (item.onehot_area > self.multilabel_threshold_osm_ohearea) | (item.onehot_length > min_ohe_length)
185+
multilabel = (item['onehot_area'] > self.multilabel_threshold_osm_ohearea) | (item['onehot_length'] > min_ohe_length)
185186

186187
r['multilabel'] = torch.tensor(multilabel).type(torch.int8)
187188

0 commit comments

Comments
 (0)