Skip to content

Commit

Permalink
open video file only at init
Browse files Browse the repository at this point in the history
  • Loading branch information
aaprasad committed Jul 23, 2024
1 parent e5fe078 commit e60afb5
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 8 deletions.
24 changes: 18 additions & 6 deletions dreem/datasets/microscopy_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(
seed,
)

self.videos = videos
self.vid_files = videos
self.tracks = tracks
self.chunk = chunk
self.clip_length = clip_length
Expand Down Expand Up @@ -92,13 +92,19 @@ def __init__(
parser(self.tracks[video_idx]) for video_idx in range(len(self.tracks))
]

self.videos = []
for vid_file in self.vid_files:
if not isinstance(vid_file, list):
self.videos.append(data_utils.LazyTiffStack(vid_file))
else:
self.videos.append([Image.open(frame_file) for frame_file in vid_file])
self.frame_idx = [
(
torch.arange(Image.open(video).n_frames)
if isinstance(video, str)
else torch.arange(len(video))
)
for video in self.videos
for video in self.vid_files
]

# Method in BaseDataset. Creates label_idx and chunked_frame_idx to be
Expand Down Expand Up @@ -128,17 +134,14 @@ def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Fram

video = self.videos[label_idx]

if not isinstance(video, list):
video = data_utils.LazyTiffStack(self.videos[label_idx])

frames = []
for frame_id in frame_idx:
instances, gt_track_ids, centroids = [], [], []

img = (
video.get_section(frame_id)
if not isinstance(video, list)
else np.array(Image.open(video[frame_id]))
else np.array(video[frame_id])
)

lf = labels[labels["FRAME"].astype(int) == frame_id.item()]
Expand Down Expand Up @@ -202,3 +205,12 @@ def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Fram
)

return frames

def __del__(self):
"""Handle file closing before deletion."""
for vid_reader in self.videos:
if not isinstance(vid_reader, list):
vid_reader.close()
else:
for frame_reader in vid_reader:
frame_reader.close()
9 changes: 7 additions & 2 deletions dreem/datasets/sleap_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def __init__(
# if self.seed is not None:
# np.random.seed(self.seed)
self.labels = [sio.load_slp(slp_file) for slp_file in self.slp_files]

self.videos = [imageio.get_reader(vid_file) for vid_file in self.vid_files]
# do we need this? would need to update with sleap-io

# for label in self.labels:
Expand Down Expand Up @@ -140,7 +140,7 @@ def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Fram

video_name = self.video_files[label_idx]

vid_reader = imageio.get_reader(video_name, "ffmpeg")
vid_reader = self.videos[label_idx]

img = vid_reader.get_data(0)

Expand Down Expand Up @@ -370,3 +370,8 @@ def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Fram
frames.append(frame)

return frames

def __del__(self):
"""Handle file closing before garbage collection."""
for reader in self.videos:
reader.close()

0 comments on commit e60afb5

Please sign in to comment.