@@ -120,7 +120,7 @@ def __init__(
120
120
files = pd .Series ([f .removesuffix ('.npy' ) for f in os .listdir (folder ) if os .path .isfile (os .path .join (folder , f ))])
121
121
self .metadata = self .metadata [self .metadata ['original_chip_id' ].isin (files )]
122
122
self .metadata .set_index ('original_chip_id' , inplace = True )
123
-
123
+ self . df_osm_agg = pd . read_parquet ( f" { self . folder_neighbors } /osm_aggregate.parquet" )
124
124
125
125
if max_items is not None :
126
126
self .metadata = self .metadata .iloc [np .random .permutation (len (self .metadata ))[:max_items ]]
@@ -164,24 +164,25 @@ def __getitem__(self, idx):
164
164
165
165
if self .neighbor_embeddings_folder is not None :
166
166
# Aggregate neighbor OSM data
167
- item_neighbors = pd .read_parquet (f"{ self .folder_neighbors } /{ item .name } .parquet" )['chipid' ]
168
- osm_aggregate = self .metadata .loc [item_neighbors [
169
- item_neighbors .isin (self .metadata .index )], ['onehot_count' , 'onehot_area' , 'onehot_length' ]].sum ()
167
+ # item_neighbors = pd.read_parquet(f"{self.folder_neighbors}/{item.name}.parquet")['chipid']
168
+ # osm_aggregate = self.metadata.loc[item_neighbors[
169
+ # item_neighbors.isin(self.metadata.index)], ['onehot_count', 'onehot_area', 'onehot_length']].sum()
170
+ osm_aggregate = self .df_osm_agg .loc [item .name ]
170
171
if self .multilabel_threshold_osm_ohecount is not None :
171
- multilabel = osm_neighbors ['onehot_count' ].astype (int )
172
+ multilabel = osm_aggregate ['onehot_count' ].astype (int )
172
173
multilabel = (multilabel >= self .multilabel_threshold_osm_ohecount ).astype (int )
173
174
if self .multilabel_threshold_osm_ohearea is not None :
174
175
# either area or a bit less than squared length
175
176
min_ohe_length = np .sqrt (self .multilabel_threshold_osm_ohearea )* 4 / 1.5
176
- multilabel = (osm_aggregate . onehot_area > self .multilabel_threshold_osm_ohearea ) | (osm_aggregate . onehot_length > min_ohe_length )
177
+ multilabel = (osm_aggregate [ ' onehot_area' ] > self .multilabel_threshold_osm_ohearea ) | (osm_aggregate [ ' onehot_length' ] > min_ohe_length )
177
178
else :
178
179
if self .multilabel_threshold_osm_ohecount is not None :
179
- multilabel = item . onehot_count .astype (int )
180
+ multilabel = item [ ' onehot_count' ] .astype (int )
180
181
multilabel = (multilabel >= self .multilabel_threshold_osm_ohecount ).astype (int )
181
182
if self .multilabel_threshold_osm_ohearea is not None :
182
183
# either area or a bit less than squared length
183
184
min_ohe_length = np .sqrt (self .multilabel_threshold_osm_ohearea )* 4 / 1.5
184
- multilabel = (item . onehot_area > self .multilabel_threshold_osm_ohearea ) | (item . onehot_length > min_ohe_length )
185
+ multilabel = (item [ ' onehot_area' ] > self .multilabel_threshold_osm_ohearea ) | (item [ ' onehot_length' ] > min_ohe_length )
185
186
186
187
r ['multilabel' ] = torch .tensor (multilabel ).type (torch .int8 )
187
188
0 commit comments