-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset.py
45 lines (33 loc) · 1.27 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# -*- coding: utf-8 -*-
"""dataset
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1Vpu4okr1eyEWYUIcuCXWKx1EF_fHeaj9
"""
import os
from PIL import Image
import torch
from torch.utils.data import Dataset
class VideoDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.video_list = os.listdir(root_dir)
def __len__(self):
return len(self.video_list)
def __getitem__(self, idx):
video_name = self.video_list[idx]
video_path = os.path.join(self.root_dir, video_name)
frames = []
for frame_name in sorted(os.listdir(video_path)):
frame_path = os.path.join(video_path, frame_name)
frame = Image.open(frame_path).convert('RGB')
if self.transform:
frame = self.transform(frame)
frames.append(frame)
video = torch.stack(frames, dim=0) # Stack frames along the first dimension
return video
# Example usage:
# transform = torchvision.transforms.ToTensor()
# dataset = VideoDataset(root_dir='path/to/your/data', transform=transform)
# dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)