diff --git a/examples/contrastive-image-text/clip_media_pipe.py b/examples/contrastive-image-text/clip_media_pipe.py old mode 100644 new mode 100755 index 48811d4c08..a4248959c7 --- a/examples/contrastive-image-text/clip_media_pipe.py +++ b/examples/contrastive-image-text/clip_media_pipe.py @@ -16,7 +16,6 @@ import numpy as np from torch.utils.data.sampler import BatchSampler -from optimum.habana.utils import check_habana_frameworks_version from optimum.utils import logging @@ -25,29 +24,37 @@ try: from habana_frameworks.mediapipe import fn - from habana_frameworks.mediapipe.backend.nodes import opnode_tensor_info - from habana_frameworks.mediapipe.backend.operator_specs import schema from habana_frameworks.mediapipe.media_types import dtype, ftype, imgtype, randomCropType, readerOutType from habana_frameworks.mediapipe.mediapipe import MediaPipe - from habana_frameworks.mediapipe.operators.media_nodes import MediaReaderNode from habana_frameworks.mediapipe.operators.reader_nodes.read_image_from_dir import get_max_file + from habana_frameworks.mediapipe.operators.reader_nodes.reader_nodes import ( + media_ext_reader_op_impl, + media_ext_reader_op_tensor_info, + ) from habana_frameworks.torch.hpu import get_device_name except ImportError: pass +read_image_text_from_dataset_params = { + "label_dtype": dtype.UINT32, + "dataset": None, +} + -class read_image_text_from_dataset(MediaReaderNode): +class read_image_text_from_dataset(media_ext_reader_op_impl): """ - Class defining read image/text from directory node. + Class defining read image/text from clip dataset. """ - def __init__(self, name, guid, device, inputs, params, cparams, node_attr): - super().__init__(name, guid, device, inputs, params, cparams, node_attr) + def __init__(self, params): + self.batch_size = 1 + params = params["priv_params"] self.meta_dtype = params["label_dtype"] self.dataset = params["dataset"] self.epoch = 0 - + self.batch_sampler_iter = None + self.iter_loc = 0 self.num_imgs_slice = len(ClipMediaPipe.batch_sampler.sampler) self.num_batches_slice = len(ClipMediaPipe.batch_sampler) @@ -63,13 +70,13 @@ def set_params(self, params): def gen_output_info(self): out_info = [] - o = opnode_tensor_info(dtype.NDT, np.array([self.batch_size], dtype=np.uint32), "") + o = media_ext_reader_op_tensor_info(dtype.NDT, np.array([self.batch_size], dtype=np.uint32), "") out_info.append(o) - o = opnode_tensor_info( + o = media_ext_reader_op_tensor_info( self.meta_dtype, np.array([self.dataset.text_max_length, self.batch_size], dtype=np.uint32), "" ) out_info.append(o) - o = opnode_tensor_info( + o = media_ext_reader_op_tensor_info( self.meta_dtype, np.array([self.dataset.text_max_length, self.batch_size], dtype=np.uint32), "" ) out_info.append(o) @@ -113,30 +120,6 @@ def __next__(self): return img_list, input_id_list, attention_mask_list -read_image_text_from_dataset_params = { - "label_dtype": dtype.UINT64, - "dataset": None, -} -schema.add_operator( - "ClipDataReader", - None, - 0, - 0, - [], - 3, - read_image_text_from_dataset_params, - None, - read_image_text_from_dataset, - dtype.NDT, -) -if check_habana_frameworks_version("1.14.0"): - op_class = fn.operator_add("ClipDataReader") -else: - op_class = fn.operator_add("ClipDataReader", False) -op_class.__module__ = fn.__name__ -setattr(fn, "ClipDataReader", op_class) - - class ClipMediaPipe(MediaPipe): """ Class defining clip media pipe: @@ -164,8 +147,13 @@ def __init__(self, dataset=None, sampler=None, batch_size=512, drop_last=False, super(ClipMediaPipe, self).__init__( device=self.device, batch_size=batch_size, prefetch_depth=queue_depth, pipe_name=pipe_name ) - - self.input = fn.ClipDataReader(label_dtype=dtype.UINT32, dataset=self.dataset) + params = read_image_text_from_dataset_params.copy() + params["dataset"] = self.dataset + self.input = fn.MediaExtReaderOp( + impl=read_image_text_from_dataset, + num_outputs=3, + priv_params=params, + ) def_output_image_size = [self.image_size, self.image_size] res_pp_filter = ftype.BICUBIC self.decode = fn.ImageDecoder(