diff --git a/examples/contrastive-image-text/clip_media_pipe.py b/examples/contrastive-image-text/clip_media_pipe.py index 62c2a5651b..3ba1038fb8 100644 --- a/examples/contrastive-image-text/clip_media_pipe.py +++ b/examples/contrastive-image-text/clip_media_pipe.py @@ -24,28 +24,33 @@ try: from habana_frameworks.mediapipe import fn - from habana_frameworks.mediapipe.backend.nodes import opnode_tensor_info - from habana_frameworks.mediapipe.backend.operator_specs import schema from habana_frameworks.mediapipe.media_types import dtype, ftype, imgtype, randomCropType, readerOutType from habana_frameworks.mediapipe.mediapipe import MediaPipe - from habana_frameworks.mediapipe.operators.media_nodes import MediaReaderNode + from habana_frameworks.mediapipe.operators.reader_nodes.reader_nodes import media_ext_reader_op_impl from habana_frameworks.mediapipe.operators.reader_nodes.read_image_from_dir import get_max_file + from habana_frameworks.mediapipe.operators.reader_nodes.reader_nodes import media_ext_reader_op_tensor_info from habana_frameworks.torch.hpu import get_device_name except ImportError: pass +read_image_text_from_dataset_params = { + "label_dtype": dtype.UINT64, + "dataset": None, +} -class read_image_text_from_dataset(MediaReaderNode): +class read_clip_data(media_ext_reader_op_impl): """ - Class defining read image/text from directory node. + Class defining read image/text from clip dataset. """ - def __init__(self, name, guid, device, inputs, params, cparams, node_attr): - super().__init__(name, guid, device, inputs, params, cparams, node_attr) + def __init__(self, params): + self.batch_size = 1 + params = params['priv_params'] self.meta_dtype = params["label_dtype"] self.dataset = params["dataset"] self.epoch = 0 + self.batch_sampler_iter = None self.num_imgs_slice = len(ClipMediaPipe.batch_sampler.sampler) self.num_batches_slice = len(ClipMediaPipe.batch_sampler) @@ -57,18 +62,16 @@ def __init__(self, name, guid, device, inputs, params, cparams, node_attr): self.max_file = get_max_file([img["path"] for img in self.dataset["image"]]) logger.info(f"The largest file is {self.max_file}.") + self.iter_loc = 0 + def set_params(self, params): self.batch_size = params.batch_size def gen_output_info(self): out_info = [] - o = opnode_tensor_info(dtype.NDT, np.array([self.batch_size], dtype=np.uint32), "") - out_info.append(o) - o = opnode_tensor_info( - self.meta_dtype, np.array([self.dataset.text_max_length, self.batch_size], dtype=np.uint32), "" - ) + o = media_ext_reader_op_tensor_info(dtype.NDT, np.array([self.batch_size], dtype=np.uint32), "") out_info.append(o) - o = opnode_tensor_info( + o = media_ext_reader_op_tensor_info( self.meta_dtype, np.array([self.dataset.text_max_length, self.batch_size], dtype=np.uint32), "" ) out_info.append(o) @@ -111,28 +114,6 @@ def __next__(self): return img_list, input_id_list, attention_mask_list - -read_image_text_from_dataset_params = { - "label_dtype": dtype.UINT64, - "dataset": None, -} -schema.add_operator( - "ClipDataReader", - None, - 0, - 0, - [], - 3, - read_image_text_from_dataset_params, - None, - read_image_text_from_dataset, - dtype.NDT, -) -op_class = fn.operator_add("ClipDataReader") -op_class.__module__ = fn.__name__ -setattr(fn, "ClipDataReader", op_class) - - class ClipMediaPipe(MediaPipe): """ Class defining clip media pipe: @@ -160,8 +141,12 @@ def __init__(self, dataset=None, sampler=None, batch_size=512, drop_last=False, super(ClipMediaPipe, self).__init__( device=self.device, batch_size=batch_size, prefetch_depth=queue_depth, pipe_name=pipe_name ) - - self.input = fn.ClipDataReader(label_dtype=dtype.UINT32, dataset=self.dataset) + params = read_image_text_from_dataset_params.copy() + params['dataset'] = self.dataset + self.input = fn.MediaExtReaderOp(impl=read_clip_data, + num_outputs=3, + priv_params=params, + ) def_output_image_size = [self.image_size, self.image_size] res_pp_filter = ftype.BICUBIC self.decode = fn.ImageDecoder(