Skip to content

Commit c8b02d6

Browse files
committed
possiility of using by on periodic sampler
1 parent 44a80cf commit c8b02d6

File tree

1 file changed

+35
-16
lines changed

1 file changed

+35
-16
lines changed

ChildProject/pipelines/samplers.py

+35-16
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ def __init__(
224224
period: int,
225225
offset: int = 0,
226226
profile: str = None,
227+
by: str = "recording_filename",
227228
recordings: Union[str, List[str], pd.DataFrame] = None,
228229
exclude: Union[str, pd.DataFrame] = None,
229230
):
@@ -233,6 +234,7 @@ def __init__(
233234
self.period = int(period)
234235
self.offset = int(offset)
235236
self.profile = profile
237+
self.by = by
236238

237239
def _sample(self):
238240
recordings = self.project.get_recordings_from_list(self.recordings)
@@ -250,23 +252,34 @@ def _sample(self):
250252

251253
recordings["duration"].fillna(0, inplace=True)
252254

253-
self.segments = recordings[["recording_filename", "duration"]].copy()
254-
self.segments["segment_onset"] = self.segments.apply(
255-
lambda row: np.arange(
256-
self.offset,
257-
row["duration"] - self.length + 1e-4,
255+
recordings = recordings.copy()
256+
recordings['start_ts'] = recordings.apply(
257+
lambda row: int(pd.Timestamp(str(row['date_iso']) + 'T' + str(row['start_time'])).timestamp()),
258+
axis=1)
259+
recordings['end_ts'] = recordings['start_ts'] + recordings['duration']
260+
segments = []
261+
# work by groups (by argument), create a singular timeline from those groups and choose periodic segments from there
262+
# this means that recordings following each other will maintain continuity in sampling period
263+
# it also means concurrent recordings in the same session will have the same samples kept time/date wise regardless of shifts in start
264+
for i, gdf in recordings.groupby(self.by):
265+
all_segments = pd.DataFrame({'segment_onset': np.arange(
266+
gdf['start_ts'].min() + self.offset,
267+
gdf['end_ts'].max() - self.length,
258268
self.period + self.length,
259-
),
260-
axis=1,
261-
)
262-
self.segments = self.segments.explode("segment_onset")
263-
# discard recordings that can't include segments (they are NA here bc explode keeps empty lists)
264-
self.segments = self.segments.dropna(subset=['segment_onset'])
265-
self.segments["segment_onset"] = self.segments["segment_onset"].astype(int)
266-
self.segments["segment_offset"] = self.segments["segment_onset"] + self.length
267-
self.segments.rename(
268-
columns={"recording_filename": "recording_filename"}, inplace=True
269-
)
269+
)})
270+
all_segments['segment_offset'] = all_segments['segment_onset'] + self.length
271+
rec_segments = []
272+
for recording in gdf.to_dict(orient='records'):
273+
tmp_segs = all_segments[(all_segments['segment_offset'] > recording['start_ts']) & (all_segments['segment_onset'] < recording['end_ts'])].copy()
274+
# cut down overflowing stamps
275+
tmp_segs['segment_onset'] = tmp_segs['segment_onset'].apply(lambda x: max(x, recording['start_ts']))
276+
tmp_segs['segment_offset'] = tmp_segs['segment_offset'].apply(lambda x: min(x, recording['end_ts']))
277+
tmp_segs['segment_onset'] = tmp_segs['segment_onset'] - recording['start_ts']
278+
tmp_segs['segment_offset'] = tmp_segs['segment_offset'] - recording['start_ts']
279+
tmp_segs['recording_filename'] = recording['recording_filename']
280+
rec_segments.append(tmp_segs)
281+
segments.append(pd.concat(rec_segments))
282+
self.segments = pd.concat(segments)
270283

271284
return self.segments
272285

@@ -297,6 +310,12 @@ def add_parser(subparsers, subcommand):
297310
default="",
298311
type=str,
299312
)
313+
parser.add_argument(
314+
"--by",
315+
help="units to sample from (default behavior is to sample by recording)",
316+
choices=["recording_filename", "session_id", "child_id"],
317+
default="recording_filename",
318+
)
300319

301320

302321
class RandomVocalizationSampler(Sampler):

0 commit comments

Comments
 (0)