From 219315049443dcab6507f42b86f3e6deafce8e02 Mon Sep 17 00:00:00 2001 From: Wilson Yan Date: Sun, 4 Aug 2024 21:41:26 +0000 Subject: [PATCH] Add better video support for codecs and container types + Fix Loader.load --- granular/formats.py | 24 +++++++++++++----------- granular/loader.py | 5 +++-- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/granular/formats.py b/granular/formats.py index c7f000b..6a52f1c 100644 --- a/granular/formats.py +++ b/granular/formats.py @@ -67,12 +67,12 @@ def decode_image(buffer, *args): return np.asarray(Image.open(io.BytesIO(buffer))) -def encode_mp4(array, fps=20): +def encode_video(array, fps=20, format='mp4', codec='h264'): import av T, H, W = array.shape[:3] fp = io.BytesIO() - output = av.open(fp, mode='w', format='mp4') - stream = output.add_stream('mpeg4', rate=float(fps)) + output = av.open(fp, mode='w', format=format) + stream = output.add_stream(codec, rate=float(fps)) stream.width = W stream.height = H stream.pix_fmt = 'yuv420p' @@ -85,14 +85,14 @@ def encode_mp4(array, fps=20): return fp.getvalue() -def decode_mp4(buffer, *args): +def decode_video(buffer, *args): + import numpy as np import av container = av.open(io.BytesIO(buffer)) - stream = container.streams.video[0] - T, H, W = stream.frames, stream.height, stream.width - array = np.empty((T, H, W, 3), dtype=np.uint8) - for t, frame in enumerate(container.decode(video=0)): - array[t] = frame.to_ndarray(format='rgb24') + array = [] + for frame in container.decode(video=0): + array.append(frame.to_ndarray(format='rgb24')) + array = np.stack(array) container.close() return array @@ -106,7 +106,8 @@ def decode_mp4(buffer, *args): 'tree': encode_tree, 'jpg': bind(encode_image, format='jpg'), 'png': bind(encode_image, format='png'), - 'mp4': encode_mp4, + 'mp4': bind(encode_video, format='mp4', codec='h264'), + 'webm': bind(encode_video, format='webm', codec='vp9'), } @@ -119,5 +120,6 @@ def decode_mp4(buffer, *args): 'tree': decode_tree, 'jpg': decode_image, 'png': decode_image, - 'mp4': decode_mp4, + 'mp4': decode_video, + 'webm': decode_video, } diff --git a/granular/loader.py b/granular/loader.py index 2f0120f..6387126 100644 --- a/granular/loader.py +++ b/granular/loader.py @@ -89,8 +89,9 @@ def load(self, d): self._receive() self.consumed = self.step = d['step'] self.seed = d['seed'] - for _ in range(self.prefetch): - self._request() + if self.started: + for _ in range(self.prefetch): + self._request() def close(self): self.stop.set()