Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 22 additions & 37 deletions examples/contrastive-image-text/clip_media_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,28 +24,33 @@

try:
from habana_frameworks.mediapipe import fn
from habana_frameworks.mediapipe.backend.nodes import opnode_tensor_info
from habana_frameworks.mediapipe.backend.operator_specs import schema
from habana_frameworks.mediapipe.media_types import dtype, ftype, imgtype, randomCropType, readerOutType
from habana_frameworks.mediapipe.mediapipe import MediaPipe
from habana_frameworks.mediapipe.operators.media_nodes import MediaReaderNode
from habana_frameworks.mediapipe.operators.reader_nodes.reader_nodes import media_ext_reader_op_impl
from habana_frameworks.mediapipe.operators.reader_nodes.read_image_from_dir import get_max_file
from habana_frameworks.mediapipe.operators.reader_nodes.reader_nodes import media_ext_reader_op_tensor_info
from habana_frameworks.torch.hpu import get_device_name
except ImportError:
pass

read_image_text_from_dataset_params = {
"label_dtype": dtype.UINT64,
"dataset": None,
}

class read_image_text_from_dataset(MediaReaderNode):
class read_clip_data(media_ext_reader_op_impl):
"""
Class defining read image/text from directory node.
Class defining read image/text from clip dataset.

"""

def __init__(self, name, guid, device, inputs, params, cparams, node_attr):
super().__init__(name, guid, device, inputs, params, cparams, node_attr)
def __init__(self, params):
self.batch_size = 1
params = params['priv_params']
self.meta_dtype = params["label_dtype"]
self.dataset = params["dataset"]
self.epoch = 0
self.batch_sampler_iter = None

self.num_imgs_slice = len(ClipMediaPipe.batch_sampler.sampler)
self.num_batches_slice = len(ClipMediaPipe.batch_sampler)
Expand All @@ -57,18 +62,16 @@ def __init__(self, name, guid, device, inputs, params, cparams, node_attr):
self.max_file = get_max_file([img["path"] for img in self.dataset["image"]])
logger.info(f"The largest file is {self.max_file}.")

self.iter_loc = 0

def set_params(self, params):
self.batch_size = params.batch_size

def gen_output_info(self):
out_info = []
o = opnode_tensor_info(dtype.NDT, np.array([self.batch_size], dtype=np.uint32), "")
out_info.append(o)
o = opnode_tensor_info(
self.meta_dtype, np.array([self.dataset.text_max_length, self.batch_size], dtype=np.uint32), ""
)
o = media_ext_reader_op_tensor_info(dtype.NDT, np.array([self.batch_size], dtype=np.uint32), "")
out_info.append(o)
o = opnode_tensor_info(
o = media_ext_reader_op_tensor_info(
self.meta_dtype, np.array([self.dataset.text_max_length, self.batch_size], dtype=np.uint32), ""
)
out_info.append(o)
Expand Down Expand Up @@ -111,28 +114,6 @@ def __next__(self):

return img_list, input_id_list, attention_mask_list


read_image_text_from_dataset_params = {
"label_dtype": dtype.UINT64,
"dataset": None,
}
schema.add_operator(
"ClipDataReader",
None,
0,
0,
[],
3,
read_image_text_from_dataset_params,
None,
read_image_text_from_dataset,
dtype.NDT,
)
op_class = fn.operator_add("ClipDataReader")
op_class.__module__ = fn.__name__
setattr(fn, "ClipDataReader", op_class)


class ClipMediaPipe(MediaPipe):
"""
Class defining clip media pipe:
Expand Down Expand Up @@ -160,8 +141,12 @@ def __init__(self, dataset=None, sampler=None, batch_size=512, drop_last=False,
super(ClipMediaPipe, self).__init__(
device=self.device, batch_size=batch_size, prefetch_depth=queue_depth, pipe_name=pipe_name
)

self.input = fn.ClipDataReader(label_dtype=dtype.UINT32, dataset=self.dataset)
params = read_image_text_from_dataset_params.copy()
params['dataset'] = self.dataset
self.input = fn.MediaExtReaderOp(impl=read_clip_data,
num_outputs=3,
priv_params=params,
)
def_output_image_size = [self.image_size, self.image_size]
res_pp_filter = ftype.BICUBIC
self.decode = fn.ImageDecoder(
Expand Down