11import csv
2- import os
3- import os .path
2+ import pathlib
43from typing import Any , Callable , Optional , Tuple
54
65import torch
@@ -38,7 +37,17 @@ def __init__(
3837 self ._split = verify_str_arg (split , "split" , self ._RESOURCES .keys ())
3938 super ().__init__ (root , transform = transform , target_transform = target_transform )
4039
41- with open (self ._verify_integrity (), "r" , newline = "" ) as file :
40+ base_folder = pathlib .Path (self .root ) / "fer2013"
41+ file_name , md5 = self ._RESOURCES [self ._split ]
42+ data_file = base_folder / file_name
43+ if not check_integrity (str (data_file ), md5 = md5 ):
44+ raise RuntimeError (
45+ f"{ file_name } not found in { base_folder } or corrupted. "
46+ f"You can download it from "
47+ f"https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge"
48+ )
49+
50+ with open (data_file , "r" , newline = "" ) as file :
4251 self ._samples = [
4352 (
4453 torch .tensor ([int (idx ) for idx in row ["pixels" ].split ()], dtype = torch .uint8 ).reshape (48 , 48 ),
@@ -62,17 +71,5 @@ def __getitem__(self, idx: int) -> Tuple[Any, Any]:
6271
6372 return image , target
6473
65- def _verify_integrity (self ):
66- base_folder = os .path .join (self .root , "fer2013" )
67- file_name , md5 = self ._RESOURCES [self ._split ]
68- file = os .path .join (base_folder , file_name )
69- if not check_integrity (file , md5 = md5 ):
70- raise RuntimeError (
71- f"{ file_name } not found in { base_folder } or corrupted. "
72- f"You can download it from "
73- f"https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge"
74- )
75- return file
76-
7774 def extra_repr (self ) -> str :
7875 return f"split={ self ._split } "
0 commit comments