Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 26 additions & 38 deletions examples/contrastive-image-text/clip_media_pipe.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import numpy as np
from torch.utils.data.sampler import BatchSampler

from optimum.habana.utils import check_habana_frameworks_version
from optimum.utils import logging


Expand All @@ -25,29 +24,37 @@

try:
from habana_frameworks.mediapipe import fn
from habana_frameworks.mediapipe.backend.nodes import opnode_tensor_info
from habana_frameworks.mediapipe.backend.operator_specs import schema
from habana_frameworks.mediapipe.media_types import dtype, ftype, imgtype, randomCropType, readerOutType
from habana_frameworks.mediapipe.mediapipe import MediaPipe
from habana_frameworks.mediapipe.operators.media_nodes import MediaReaderNode
from habana_frameworks.mediapipe.operators.reader_nodes.read_image_from_dir import get_max_file
from habana_frameworks.mediapipe.operators.reader_nodes.reader_nodes import (
media_ext_reader_op_impl,
media_ext_reader_op_tensor_info,
)
from habana_frameworks.torch.hpu import get_device_name
except ImportError:
pass

read_image_text_from_dataset_params = {
"label_dtype": dtype.UINT32,
"dataset": None,
}


class read_image_text_from_dataset(MediaReaderNode):
class read_image_text_from_dataset(media_ext_reader_op_impl):
"""
Class defining read image/text from directory node.
Class defining read image/text from clip dataset.

"""

def __init__(self, name, guid, device, inputs, params, cparams, node_attr):
super().__init__(name, guid, device, inputs, params, cparams, node_attr)
def __init__(self, params):
self.batch_size = 1
params = params["priv_params"]
self.meta_dtype = params["label_dtype"]
self.dataset = params["dataset"]
self.epoch = 0

self.batch_sampler_iter = None
self.iter_loc = 0
self.num_imgs_slice = len(ClipMediaPipe.batch_sampler.sampler)
self.num_batches_slice = len(ClipMediaPipe.batch_sampler)

Expand All @@ -63,13 +70,13 @@ def set_params(self, params):

def gen_output_info(self):
out_info = []
o = opnode_tensor_info(dtype.NDT, np.array([self.batch_size], dtype=np.uint32), "")
o = media_ext_reader_op_tensor_info(dtype.NDT, np.array([self.batch_size], dtype=np.uint32), "")
out_info.append(o)
o = opnode_tensor_info(
o = media_ext_reader_op_tensor_info(
self.meta_dtype, np.array([self.dataset.text_max_length, self.batch_size], dtype=np.uint32), ""
)
out_info.append(o)
o = opnode_tensor_info(
o = media_ext_reader_op_tensor_info(
self.meta_dtype, np.array([self.dataset.text_max_length, self.batch_size], dtype=np.uint32), ""
)
out_info.append(o)
Expand Down Expand Up @@ -113,30 +120,6 @@ def __next__(self):
return img_list, input_id_list, attention_mask_list


read_image_text_from_dataset_params = {
"label_dtype": dtype.UINT64,
"dataset": None,
}
schema.add_operator(
"ClipDataReader",
None,
0,
0,
[],
3,
read_image_text_from_dataset_params,
None,
read_image_text_from_dataset,
dtype.NDT,
)
if check_habana_frameworks_version("1.14.0"):
op_class = fn.operator_add("ClipDataReader")
else:
op_class = fn.operator_add("ClipDataReader", False)
op_class.__module__ = fn.__name__
setattr(fn, "ClipDataReader", op_class)


class ClipMediaPipe(MediaPipe):
"""
Class defining clip media pipe:
Expand Down Expand Up @@ -164,8 +147,13 @@ def __init__(self, dataset=None, sampler=None, batch_size=512, drop_last=False,
super(ClipMediaPipe, self).__init__(
device=self.device, batch_size=batch_size, prefetch_depth=queue_depth, pipe_name=pipe_name
)

self.input = fn.ClipDataReader(label_dtype=dtype.UINT32, dataset=self.dataset)
params = read_image_text_from_dataset_params.copy()
params["dataset"] = self.dataset
self.input = fn.MediaExtReaderOp(
impl=read_image_text_from_dataset,
num_outputs=3,
priv_params=params,
)
def_output_image_size = [self.image_size, self.image_size]
res_pp_filter = ftype.BICUBIC
self.decode = fn.ImageDecoder(
Expand Down