Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ doc: build_doc_docker_image

clean:
find . -name "habana_log.livealloc.log_*" -type f -delete
find . -name "hl-smi_log*" -type f -delete
find . -name .lock -type f -delete
find . -name .graph_dumps -type d -exec rm -r {} +
find . -name save-hpu.pdb -type f -delete
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def check_optimum_habana_min_version(*a, **b):
logger = logging.getLogger(__name__)

# Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks.
check_min_version("4.37.0")
check_min_version("4.38.0")
check_optimum_habana_min_version("1.10.0")

require_version("datasets>=1.14.0", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt")
Expand Down
6 changes: 5 additions & 1 deletion examples/contrastive-image-text/clip_media_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import numpy as np
from torch.utils.data.sampler import BatchSampler

from optimum.habana.utils import check_habana_frameworks_version
from optimum.utils import logging


Expand Down Expand Up @@ -128,7 +129,10 @@ def __next__(self):
read_image_text_from_dataset,
dtype.NDT,
)
op_class = fn.operator_add("ClipDataReader")
if check_habana_frameworks_version("1.14.0"):
op_class = fn.operator_add("ClipDataReader")
else:
op_class = fn.operator_add("ClipDataReader", False)
op_class.__module__ = fn.__name__
setattr(fn, "ClipDataReader", op_class)

Expand Down
2 changes: 1 addition & 1 deletion examples/contrastive-image-text/run_bridgetower.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def check_optimum_habana_min_version(*a, **b):
logger = logging.getLogger(__name__)

# Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks.
check_min_version("4.37.0")
check_min_version("4.38.0")
check_optimum_habana_min_version("1.10.0")

require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/contrastive-image-text/requirements.txt")
Expand Down
2 changes: 1 addition & 1 deletion examples/contrastive-image-text/run_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def check_optimum_habana_min_version(*a, **b):
logger = logging.getLogger(__name__)

# Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks.
check_min_version("4.37.0")
check_min_version("4.38.0")
check_optimum_habana_min_version("1.10.0")

require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/contrastive-image-text/requirements.txt")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def check_optimum_habana_min_version(*a, **b):
logger = logging.getLogger(__name__)

# Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks.
check_min_version("4.37.0")
check_min_version("4.38.0")
check_optimum_habana_min_version("1.10.0")

require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")
Expand Down
2 changes: 1 addition & 1 deletion examples/language-modeling/run_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def check_optimum_habana_min_version(*a, **b):
logger = logging.getLogger(__name__)

# Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks.
check_min_version("4.37.0")
check_min_version("4.38.0")
check_optimum_habana_min_version("1.10.0")

require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
Expand Down
2 changes: 1 addition & 1 deletion examples/language-modeling/run_mlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def check_optimum_habana_min_version(*a, **b):
logger = logging.getLogger(__name__)

# Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks.
check_min_version("4.37.0")
check_min_version("4.38.0")
check_optimum_habana_min_version("1.10.0")

require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
Expand Down
2 changes: 1 addition & 1 deletion examples/question-answering/run_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def check_optimum_habana_min_version(*a, **b):
logger = logging.getLogger(__name__)

# Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks.
check_min_version("4.37.0")
check_min_version("4.38.0")
check_optimum_habana_min_version("1.10.0")

require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
Expand Down
2 changes: 1 addition & 1 deletion examples/question-answering/run_seq2seq_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def check_optimum_habana_min_version(*a, **b):
logger = logging.getLogger(__name__)

# Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks.
check_min_version("4.37.0")
check_min_version("4.38.0")
check_optimum_habana_min_version("1.10.0")

require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def check_optimum_habana_min_version(*a, **b):
logger = logging.getLogger(__name__)

# Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks.
check_min_version("4.37.0")
check_min_version("4.38.0")
check_optimum_habana_min_version("1.10.0")

require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def check_optimum_habana_min_version(*a, **b):


# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.37.0")
check_min_version("4.38.0")
check_optimum_habana_min_version("1.10.0")

require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
Expand Down Expand Up @@ -469,6 +469,9 @@ def main():
if data_args.language is not None:
# We only need to set the task id when the language is specified (i.e. in a multilingual setting)
tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task)
model.generation_config.task = data_args.task
model.generation_config.language = data_args.language
model.generation_config.forced_decoder_ids = None

# 6. Resample speech dataset if necessary
dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate
Expand Down
2 changes: 1 addition & 1 deletion examples/summarization/run_summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def check_optimum_habana_min_version(*a, **b):
logger = logging.getLogger(__name__)

# Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks.
check_min_version("4.37.0")
check_min_version("4.38.0")
check_optimum_habana_min_version("1.10.0")

require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
Expand Down
2 changes: 1 addition & 1 deletion examples/text-classification/run_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def check_optimum_habana_min_version(*a, **b):
logger = logging.getLogger(__name__)

# Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks.
check_min_version("4.37.0")
check_min_version("4.38.0")
check_optimum_habana_min_version("1.10.0")

require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
Expand Down
2 changes: 1 addition & 1 deletion examples/translation/run_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def check_optimum_habana_min_version(*a, **b):
logger = logging.getLogger(__name__)

# Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks.
check_min_version("4.37.0")
check_min_version("4.38.0")
check_optimum_habana_min_version("1.10.0")

require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
Expand Down
1 change: 0 additions & 1 deletion optimum/habana/transformers/generation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,5 @@
from .stopping_criteria import (
gaudi_MaxLengthCriteria_call,
gaudi_MaxNewTokensCriteria_call,
gaudi_StoppingCriteriaList_call,
)
from .utils import MODELS_OPTIMIZED_WITH_STATIC_SHAPES, GaudiGenerationMixin
4 changes: 0 additions & 4 deletions optimum/habana/transformers/generation/stopping_criteria.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,3 @@ def gaudi_MaxNewTokensCriteria_call(self, input_ids: torch.LongTensor, scores: t
return token_idx >= self.max_length
else:
return input_ids.shape[-1] >= self.max_length


def gaudi_StoppingCriteriaList_call(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return any(criteria(input_ids, scores, **kwargs) for criteria in self)
Comment thread
libinta marked this conversation as resolved.
Loading