diff --git a/examples/contrastive-image-text/run_bridgetower.py b/examples/contrastive-image-text/run_bridgetower.py index 42b9e8a468..b54ca8e7c0 100644 --- a/examples/contrastive-image-text/run_bridgetower.py +++ b/examples/contrastive-image-text/run_bridgetower.py @@ -17,6 +17,7 @@ Training BridgeTower with a contrastive text-image loss. """ +import io import logging import os import sys @@ -28,6 +29,7 @@ import transformers from datasets import load_dataset from habana_dataloader_trainer import HabanaDataloaderTrainer +from PIL import Image from torchvision.io import ImageReadMode, read_image from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize from torchvision.transforms.functional import InterpolationMode, to_grayscale, to_tensor @@ -468,18 +470,29 @@ def get_image(image_or_path): if isinstance(image_or_path, str): # If the argument is a path to an image file, read it return read_image(image_or_path, mode=ImageReadMode.RGB) - elif isinstance(image_or_path, dict): - # Manage the case where images are a dictionary with keys 'bytes' and 'path' - return - else: - # If the argument is already an image, convert it into a tensor + elif isinstance(image_or_path, Image.Image): if len(image_or_path.getbands()) == 1: image_or_path = to_grayscale(image_or_path, num_output_channels=3) return to_tensor(image_or_path) + return None + def transform_images(examples): - images = [get_image(image_file) for image_file in examples[image_column]] + images = [] + + for item in examples[image_column]: + # Manage the case where images are a dictionary with keys 'bytes' and 'path' + if isinstance(item, dict): + encoding = "ISO-8859-1" + s = item["bytes"].decode(encoding) + b = bytearray(s, encoding) + image = Image.open(io.BytesIO(b)).convert("RGB") + images.append(to_tensor(image)) + else: + images.append(get_image(item)) + examples["pixel_values"] = [image_transformations(image) for image in images] + return examples if training_args.do_train: diff --git a/optimum/habana/transformers/gradient_checkpointing.py b/optimum/habana/transformers/gradient_checkpointing.py index 1279de3f56..f2983bab2c 100644 --- a/optimum/habana/transformers/gradient_checkpointing.py +++ b/optimum/habana/transformers/gradient_checkpointing.py @@ -181,8 +181,10 @@ def backward(ctx, *args): set_device_states(ctx.fwd_devices, ctx.fwd_device_states) detached_inputs = detach_variable(tuple(inputs)) - with torch.enable_grad(), torch.autocast(**ctx.hpu_autocast_kwargs), torch.amp.autocast( - "cpu", **ctx.cpu_autocast_kwargs + with ( + torch.enable_grad(), + torch.autocast(**ctx.hpu_autocast_kwargs), + torch.amp.autocast("cpu", **ctx.cpu_autocast_kwargs), ): outputs = ctx.run_function(*detached_inputs) diff --git a/optimum/habana/transformers/models/falcon/modeling_falcon.py b/optimum/habana/transformers/models/falcon/modeling_falcon.py index f4c9d454ab..84ccf01095 100644 --- a/optimum/habana/transformers/models/falcon/modeling_falcon.py +++ b/optimum/habana/transformers/models/falcon/modeling_falcon.py @@ -453,9 +453,11 @@ def pre_attn_forward( ) else: # TODO very similar to the fp8 case above, could be merged. - with sdp_kernel( - enable_recompute=flash_attention_recompute - ) if SDPContext else contextlib.nullcontext(): + with ( + sdp_kernel(enable_recompute=flash_attention_recompute) + if SDPContext + else contextlib.nullcontext() + ): attn_output = FusedSDPA.apply( query_layer, key_layer, diff --git a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py index 5dfc29625a..edcaf2f631 100644 --- a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py +++ b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py @@ -362,9 +362,9 @@ def forward( ) htcore.mark_step() else: - with sdp_kernel( - enable_recompute=flash_attention_recompute - ) if SDPContext else contextlib.nullcontext(): + with ( + sdp_kernel(enable_recompute=flash_attention_recompute) if SDPContext else contextlib.nullcontext() + ): attn_output = FusedSDPA.apply( query_states, key_states, value_states, attention_mask, 0.0, False, None ) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index f94ebe6a4f..200f2a78a2 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -87,9 +87,11 @@ def test_text_to_speech(self, model, expected_sample_rate): generate_kwargs = {"lazy_mode": True, "ignore_eos": False, "hpu_graphs": True} generator.model = wrap_in_hpu_graph(generator.model) - with torch.autocast( - "hpu", torch.bfloat16, enabled=(model_dtype == torch.bfloat16) - ), torch.no_grad(), torch.inference_mode(): + with ( + torch.autocast("hpu", torch.bfloat16, enabled=(model_dtype == torch.bfloat16)), + torch.no_grad(), + torch.inference_mode(), + ): for i in range(3): output = generator(text, forward_params=forward_params, generate_kwargs=generate_kwargs) assert isinstance(output["audio"], np.ndarray) diff --git a/tests/transformers/tests/models/wav2vec2/test_modeling_wav2vec2.py b/tests/transformers/tests/models/wav2vec2/test_modeling_wav2vec2.py index cabd9e9785..4d37712e7b 100644 --- a/tests/transformers/tests/models/wav2vec2/test_modeling_wav2vec2.py +++ b/tests/transformers/tests/models/wav2vec2/test_modeling_wav2vec2.py @@ -1639,9 +1639,10 @@ def test_wav2vec2_with_lm_pool(self): self.assertEqual(transcription[0], "habitan aguas poco profundas y rocosas") # user-managed pool + num_processes should trigger a warning - with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl, multiprocessing.get_context("fork").Pool( - 2 - ) as pool: + with ( + CaptureLogger(processing_wav2vec2_with_lm.logger) as cl, + multiprocessing.get_context("fork").Pool(2) as pool, + ): transcription = processor.batch_decode(logits.cpu().numpy(), pool, num_processes=2).text self.assertIn("num_process", cl.out)