File tree Expand file tree Collapse file tree 1 file changed +34
-0
lines changed
torchvision/prototype/datasets/utils Expand file tree Collapse file tree 1 file changed +34
-0
lines changed Original file line number Diff line number Diff line change 2020 extract_archive ,
2121 _decompress ,
2222 download_file_from_google_drive ,
23+ _get_redirect_url ,
24+ _get_google_drive_file_id ,
2325)
2426
2527
@@ -134,9 +136,41 @@ def __init__(
134136 super ().__init__ (file_name = file_name or pathlib .Path (urlparse (url ).path ).name , ** kwargs )
135137 self .url = url
136138 self .mirrors = mirrors
139+ self ._resolved = False
140+
141+ def resolve (self ) -> OnlineResource :
142+ if self ._resolved :
143+ return self
144+
145+ redirect_url = _get_redirect_url (self .url )
146+ if redirect_url == self .url :
147+ self ._resolved = True
148+ return self
149+
150+ meta = {
151+ attr .lstrip ("_" ): getattr (self , attr )
152+ for attr in (
153+ "file_name" ,
154+ "sha256" ,
155+ "_preprocess" ,
156+ "_loader" ,
157+ )
158+ }
159+
160+ gdrive_id = _get_google_drive_file_id (redirect_url )
161+ if gdrive_id :
162+ return GDriveResource (gdrive_id , ** meta )
163+
164+ http_resource = HttpResource (redirect_url , ** meta )
165+ http_resource ._resolved = True
166+ return http_resource
137167
138168 def _download (self , root : pathlib .Path ) -> None :
169+ if not self ._resolved :
170+ return self .resolve ()._download (root )
171+
139172 for url in itertools .chain ((self .url ,), self .mirrors ):
173+
140174 try :
141175 download_url (url , str (root ), filename = self .file_name , md5 = None )
142176 # TODO: make this more precise
You can’t perform that action at this time.
0 commit comments