From 9d1743c00d8b3de6ad62fd31790255687c3eb747 Mon Sep 17 00:00:00 2001 From: Tamoghno Kandar Date: Sun, 2 Nov 2025 12:03:04 -0800 Subject: [PATCH 1/7] Replace flash attention2 with kernels-coomunity/flash-attn2 --- docs/source/dpo_trainer.md | 2 +- docs/source/gkd_trainer.md | 2 +- docs/source/reducing_memory_usage.md | 4 ++-- tests/test_grpo_trainer.py | 2 +- tests/test_sft_trainer.py | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/source/dpo_trainer.md b/docs/source/dpo_trainer.md index 8e9f0ac41b5..9e524c6a940 100644 --- a/docs/source/dpo_trainer.md +++ b/docs/source/dpo_trainer.md @@ -253,7 +253,7 @@ model = AutoModelForCausalLM.from_pretrained( "mistralai/mixtral-8x7b-v0.1", load_in_4bit=True, quantization_config=bnb_config, - attn_implementation="flash_attention_2", + attn_implementation="kernels-community/flash-attn2", dtype=torch.bfloat16, device_map="auto", ) diff --git a/docs/source/gkd_trainer.md b/docs/source/gkd_trainer.md index 73be330c637..fd7f3c0c2eb 100644 --- a/docs/source/gkd_trainer.md +++ b/docs/source/gkd_trainer.md @@ -28,7 +28,7 @@ The [`GKDTrainer`] is a wrapper around the [`SFTTrainer`] class that takes in a The authors find that on-policy data (high `lmbda`) performs better and the optimal `beta` varied depending on the task and evaluation method. > [!WARNING] -> Make sure that `attn_implementation="flash_attention_2"` when training [Gemma models](https://huggingface.co/models?other=gemma2). Otherwise you will encounter NaNs in the logits due to the [soft capping technique](https://huggingface.co/blog/gemma2#soft-capping-and-attention-implementations) adopted by this architecture. +> Make sure that `attn_implementation="kernels-community/flash-attn2"` when training [Gemma models](https://huggingface.co/models?other=gemma2). Otherwise you will encounter NaNs in the logits due to the [soft capping technique](https://huggingface.co/blog/gemma2#soft-capping-and-attention-implementations) adopted by this architecture. The basic API is as follows: diff --git a/docs/source/reducing_memory_usage.md b/docs/source/reducing_memory_usage.md index 016c493124b..cfa07bc3c52 100644 --- a/docs/source/reducing_memory_usage.md +++ b/docs/source/reducing_memory_usage.md @@ -164,7 +164,7 @@ Padding-free batching is an alternative approach for reducing memory usage. In t ```python from trl import DPOConfig -training_args = DPOConfig(..., padding_free=True, model_init_kwargs={"attn_implementation": "flash_attention_2"}) +training_args = DPOConfig(..., padding_free=True, model_init_kwargs={"attn_implementation": "kernels-community/flash-attn2"}) ``` @@ -173,7 +173,7 @@ training_args = DPOConfig(..., padding_free=True, model_init_kwargs={"attn_imple ```python from trl import SFTConfig -training_args = SFTConfig(..., padding_free=True, model_init_kwargs={"attn_implementation": "flash_attention_2"}) +training_args = SFTConfig(..., padding_free=True, model_init_kwargs={"attn_implementation": "kernels-community/flash-attn2"}) ``` diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index 3002ce2b7fa..1266dc08da8 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -2033,7 +2033,7 @@ def data_gen(num_samples): ) model = AutoModelForImageTextToText.from_pretrained( model_name, - attn_implementation="flash_attention_2", + attn_implementation="kernels-community/flash-attn2", dtype="bfloat16", device_map=get_kbit_device_map(), quantization_config=quantization_config, diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index e3429809160..500dbef4f19 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -872,7 +872,7 @@ def test_train_padding_free(self): training_args = SFTConfig( output_dir=self.tmp_dir, padding_free=True, - model_init_kwargs={"attn_implementation": "flash_attention_2"}, + model_init_kwargs={"attn_implementation": "kernels-community/flash-attn2"}, bf16=True, # flash_attention_2 only supports bf16 and fp16 report_to="none", ) From 3376ef80f65cf022bed6ba1d021480bdc4e13161 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 4 Nov 2025 23:21:06 +0000 Subject: [PATCH 2/7] requires kernels --- tests/test_grpo_trainer.py | 4 ++-- tests/test_sft_trainer.py | 4 ++-- tests/testing_utils.py | 19 +------------------ trl/trainer/dpo_config.py | 6 ++++-- trl/trainer/grpo_config.py | 6 ++++-- trl/trainer/kto_config.py | 6 ++++-- 6 files changed, 17 insertions(+), 28 deletions(-) diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index 1266dc08da8..1cf8ce3085a 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -42,7 +42,7 @@ from .testing_utils import ( TrlTestCase, require_bitsandbytes, - require_flash_attn, + require_kernels, require_liger_kernel, require_peft, require_torch_accelerator, @@ -1980,7 +1980,7 @@ def test_training_with_transformers_paged(self, model_name): "HuggingFaceTB/SmolVLM-Instruct", # Only test the smaller model to avoid OOM ], ) - @require_flash_attn + @require_kernels @require_bitsandbytes @require_peft def test_vlm_training(self, model_name): diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 500dbef4f19..9ad35a6ed47 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -33,7 +33,7 @@ TrlTestCase, ignore_warnings, require_bitsandbytes, - require_flash_attn, + require_kernels, require_liger_kernel, require_peft, require_torch_accelerator, @@ -863,7 +863,7 @@ def test_train_with_iterable_dataset(self): new_param = trainer.model.get_parameter(n) assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" - @require_flash_attn + @require_kernels def test_train_padding_free(self): # Get the dataset dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") diff --git a/tests/testing_utils.py b/tests/testing_utils.py index 99a6e661f5c..c61a4edc3dd 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -24,7 +24,6 @@ from transformers import is_bitsandbytes_available, is_comet_available, is_sklearn_available, is_wandb_available from transformers.testing_utils import backend_device_count, torch_device from transformers.utils import ( - is_flash_attn_2_available, is_kernels_available, is_peft_available, is_rich_available, @@ -45,6 +44,7 @@ require_bitsandbytes = pytest.mark.skipif(not is_bitsandbytes_available(), reason="test requires bitsandbytes") require_comet = pytest.mark.skipif(not is_comet_available(), reason="test requires comet_ml") +require_kernels = pytest.mark.skipif(not is_kernels_available(), reason="test requires kernels") require_liger_kernel = pytest.mark.skipif(not is_liger_kernel_available(), reason="test requires liger-kernel") require_llm_blender = pytest.mark.skipif(not is_llm_blender_available(), reason="test requires llm-blender") require_math_latex = pytest.mark.skipif(not is_math_verify_available(), reason="test requires math_verify") @@ -85,23 +85,6 @@ def is_bitsandbytes_multi_backend_available() -> bool: ) -def is_flash_attn_available(): - flash_attn_available = is_flash_attn_2_available() - kernels_available = is_kernels_available() - try: - from kernels import get_kernel - - get_kernel("kernels-community/flash-attn") - except Exception: - kernels_available = False - - return kernels_available or flash_attn_available - - -# Function ported from transformers.testing_utils -require_flash_attn = pytest.mark.skipif(not is_flash_attn_available(), reason="test requires Flash Attention") - - class RandomBinaryJudge(BaseBinaryJudge): """ Random binary judge, for testing purposes. diff --git a/trl/trainer/dpo_config.py b/trl/trainer/dpo_config.py index bdc7e7c70b6..f803b0d4765 100644 --- a/trl/trainer/dpo_config.py +++ b/trl/trainer/dpo_config.py @@ -162,13 +162,15 @@ class DPOConfig(TrainingArguments): - Parameter `use_liger_loss` is deprecated and will be removed in version 0.28.0. Use `use_liger_kernel` instead. + Parameter `use_liger_loss` is deprecated and will be removed in version 0.28.0. Use `use_liger_kernel` + instead. base_model_attribute_name (`str`, *optional*, defaults to `"model"`): Name of the attribute in the model that contains the base model. This is used to get the base model from - the model when the model does not have a `get_decoder` method in the case when `use_liger_kernel` is `True`. + the model when the model does not have a `get_decoder` method in the case when `use_liger_kernel` is + `True`. beta (`float`, *optional*, defaults to `0.1`): Parameter controlling the deviation from the reference model. Higher β means less deviation from the reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index ea7e4d310b5..77223c02040 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -44,7 +44,8 @@ class GRPOConfig(TrainingArguments): cast_lm_head_to_fp32 (`bool`, *optional*, defaults to `False`): Whether to cast the language modeling head of the policy and reference models to float32. As recommended by the [ScaleRL](https://huggingface.co/papers/2510.13786) recipe. This flag is only supported when the model - has untied word embedding and language modeling head layers i.e. `tie_word_embeddings` in the model config is False. + has untied word embedding and language modeling head layers i.e. `tie_word_embeddings` in the model config + is False. > Parameters that control the data preprocessing remove_unused_columns (`bool`, *optional*, defaults to `False`): @@ -229,7 +230,8 @@ class GRPOConfig(TrainingArguments): - Parameter `use_liger_loss` is deprecated and will be removed in version 0.28.0. Use `use_liger_kernel` instead. + Parameter `use_liger_loss` is deprecated and will be removed in version 0.28.0. Use `use_liger_kernel` + instead. diff --git a/trl/trainer/kto_config.py b/trl/trainer/kto_config.py index 3eb79b6a342..3fbcbb3e284 100644 --- a/trl/trainer/kto_config.py +++ b/trl/trainer/kto_config.py @@ -86,13 +86,15 @@ class KTOConfig(TrainingArguments): - Parameter `use_liger_loss` is deprecated and will be removed in version 0.28.0. Use `use_liger_kernel` instead. + Parameter `use_liger_loss` is deprecated and will be removed in version 0.28.0. Use `use_liger_kernel` + instead. base_model_attribute_name (`str`, *optional*, defaults to `"model"`): Name of the attribute in the model that contains the base model. This is used to get the base model from - the model when the model does not have a `get_decoder` method in the case when `use_liger_kernel` is `True`. + the model when the model does not have a `get_decoder` method in the case when `use_liger_kernel` is + `True`. """ _VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs", "ref_model_init_kwargs"] From 689b20393da918c8ab4b008c50a05048ed6811bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 4 Nov 2025 23:22:44 +0000 Subject: [PATCH 3/7] kernels to dev deps --- pyproject.toml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index c1a94413f64..1ad31b645b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,9 @@ judges = [ "openai>=1.23.2", "llm-blender>=0.0.2" ] +kernels = [ + "kernels" +] liger = [ "liger-kernel>=0.6.2" ] @@ -100,6 +103,8 @@ dev = [ # judges "openai>=1.23.2", "llm-blender>=0.0.2", + # kernels + "kernels", # liger "liger-kernel>=0.6.2", # peft From cd0cfe3ea29f8a14629bb89b4929d286e9f5fd01 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 6 Nov 2025 18:28:33 +0000 Subject: [PATCH 4/7] issue with fa2, use fa --- docs/source/dpo_trainer.md | 2 +- docs/source/gkd_trainer.md | 2 +- docs/source/reducing_memory_usage.md | 4 ++-- tests/test_grpo_trainer.py | 2 +- tests/test_sft_trainer.py | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/source/dpo_trainer.md b/docs/source/dpo_trainer.md index 9e524c6a940..d210644931d 100644 --- a/docs/source/dpo_trainer.md +++ b/docs/source/dpo_trainer.md @@ -253,7 +253,7 @@ model = AutoModelForCausalLM.from_pretrained( "mistralai/mixtral-8x7b-v0.1", load_in_4bit=True, quantization_config=bnb_config, - attn_implementation="kernels-community/flash-attn2", + attn_implementation="kernels-community/flash-attn", dtype=torch.bfloat16, device_map="auto", ) diff --git a/docs/source/gkd_trainer.md b/docs/source/gkd_trainer.md index fd7f3c0c2eb..b6c65d3ca01 100644 --- a/docs/source/gkd_trainer.md +++ b/docs/source/gkd_trainer.md @@ -28,7 +28,7 @@ The [`GKDTrainer`] is a wrapper around the [`SFTTrainer`] class that takes in a The authors find that on-policy data (high `lmbda`) performs better and the optimal `beta` varied depending on the task and evaluation method. > [!WARNING] -> Make sure that `attn_implementation="kernels-community/flash-attn2"` when training [Gemma models](https://huggingface.co/models?other=gemma2). Otherwise you will encounter NaNs in the logits due to the [soft capping technique](https://huggingface.co/blog/gemma2#soft-capping-and-attention-implementations) adopted by this architecture. +> Make sure that `attn_implementation="kernels-community/flash-attn"` when training [Gemma models](https://huggingface.co/models?other=gemma2). Otherwise you will encounter NaNs in the logits due to the [soft capping technique](https://huggingface.co/blog/gemma2#soft-capping-and-attention-implementations) adopted by this architecture. The basic API is as follows: diff --git a/docs/source/reducing_memory_usage.md b/docs/source/reducing_memory_usage.md index ec8aa42c4fc..1727668c068 100644 --- a/docs/source/reducing_memory_usage.md +++ b/docs/source/reducing_memory_usage.md @@ -161,7 +161,7 @@ Padding-free batching is an alternative approach for reducing memory usage. In t ```python from trl import DPOConfig -training_args = DPOConfig(..., padding_free=True, model_init_kwargs={"attn_implementation": "kernels-community/flash-attn2"}) +training_args = DPOConfig(..., padding_free=True, model_init_kwargs={"attn_implementation": "kernels-community/flash-attn"}) ``` @@ -170,7 +170,7 @@ training_args = DPOConfig(..., padding_free=True, model_init_kwargs={"attn_imple ```python from trl import SFTConfig -training_args = SFTConfig(..., padding_free=True, model_init_kwargs={"attn_implementation": "kernels-community/flash-attn2"}) +training_args = SFTConfig(..., padding_free=True, model_init_kwargs={"attn_implementation": "kernels-community/flash-attn"}) ``` diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index d9f9c48e880..8ce75715841 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -2040,7 +2040,7 @@ def data_gen(num_samples): ) model = AutoModelForImageTextToText.from_pretrained( model_name, - attn_implementation="kernels-community/flash-attn2", + attn_implementation="kernels-community/flash-attn", dtype="bfloat16", device_map=get_kbit_device_map(), quantization_config=quantization_config, diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 9ad35a6ed47..dd9f8bdf5b9 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -872,7 +872,7 @@ def test_train_padding_free(self): training_args = SFTConfig( output_dir=self.tmp_dir, padding_free=True, - model_init_kwargs={"attn_implementation": "kernels-community/flash-attn2"}, + model_init_kwargs={"attn_implementation": "kernels-community/flash-attn"}, bf16=True, # flash_attention_2 only supports bf16 and fp16 report_to="none", ) From b21368eceb684632da7b951a3e57d7fa67da6e26 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Wed, 12 Nov 2025 20:07:10 -0700 Subject: [PATCH 5/7] Apply suggestion from @qgallouedec --- tests/test_sft_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 893d1224160..3e528cbc5b2 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -879,7 +879,7 @@ def test_train_padding_free(self): training_args = SFTConfig( output_dir=self.tmp_dir, padding_free=True, - model_init_kwargs={"attn_implementation": "kernels-community/flash-attn"}, + model_init_kwargs={"attn_implementation": "kernels-community/flash-attn2"}, bf16=True, # flash_attention_2 only supports bf16 and fp16 report_to="none", ) From b7bbef3502640741730e1258ce3ff998281eecc7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 13 Nov 2025 03:20:58 +0000 Subject: [PATCH 6/7] fa2 --- docs/source/dpo_trainer.md | 2 +- docs/source/gkd_trainer.md | 2 +- docs/source/reducing_memory_usage.md | 4 ++-- tests/test_grpo_trainer.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/source/dpo_trainer.md b/docs/source/dpo_trainer.md index d210644931d..9e524c6a940 100644 --- a/docs/source/dpo_trainer.md +++ b/docs/source/dpo_trainer.md @@ -253,7 +253,7 @@ model = AutoModelForCausalLM.from_pretrained( "mistralai/mixtral-8x7b-v0.1", load_in_4bit=True, quantization_config=bnb_config, - attn_implementation="kernels-community/flash-attn", + attn_implementation="kernels-community/flash-attn2", dtype=torch.bfloat16, device_map="auto", ) diff --git a/docs/source/gkd_trainer.md b/docs/source/gkd_trainer.md index b6c65d3ca01..fd7f3c0c2eb 100644 --- a/docs/source/gkd_trainer.md +++ b/docs/source/gkd_trainer.md @@ -28,7 +28,7 @@ The [`GKDTrainer`] is a wrapper around the [`SFTTrainer`] class that takes in a The authors find that on-policy data (high `lmbda`) performs better and the optimal `beta` varied depending on the task and evaluation method. > [!WARNING] -> Make sure that `attn_implementation="kernels-community/flash-attn"` when training [Gemma models](https://huggingface.co/models?other=gemma2). Otherwise you will encounter NaNs in the logits due to the [soft capping technique](https://huggingface.co/blog/gemma2#soft-capping-and-attention-implementations) adopted by this architecture. +> Make sure that `attn_implementation="kernels-community/flash-attn2"` when training [Gemma models](https://huggingface.co/models?other=gemma2). Otherwise you will encounter NaNs in the logits due to the [soft capping technique](https://huggingface.co/blog/gemma2#soft-capping-and-attention-implementations) adopted by this architecture. The basic API is as follows: diff --git a/docs/source/reducing_memory_usage.md b/docs/source/reducing_memory_usage.md index c16ad52df93..d9a30c336ea 100644 --- a/docs/source/reducing_memory_usage.md +++ b/docs/source/reducing_memory_usage.md @@ -188,7 +188,7 @@ Padding-free batching is an alternative approach for reducing memory usage. In t ```python from trl import DPOConfig -training_args = DPOConfig(..., padding_free=True, model_init_kwargs={"attn_implementation": "kernels-community/flash-attn"}) +training_args = DPOConfig(..., padding_free=True, model_init_kwargs={"attn_implementation": "kernels-community/flash-attn2"}) ``` @@ -197,7 +197,7 @@ training_args = DPOConfig(..., padding_free=True, model_init_kwargs={"attn_imple ```python from trl import SFTConfig -training_args = SFTConfig(..., padding_free=True, model_init_kwargs={"attn_implementation": "kernels-community/flash-attn"}) +training_args = SFTConfig(..., padding_free=True, model_init_kwargs={"attn_implementation": "kernels-community/flash-attn2"}) ``` diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index cc0fbf19fff..73ee7234306 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -2040,7 +2040,7 @@ def data_gen(num_samples): ) model = AutoModelForImageTextToText.from_pretrained( model_name, - attn_implementation="kernels-community/flash-attn", + attn_implementation="kernels-community/flash-attn2", dtype="bfloat16", device_map=get_kbit_device_map(), quantization_config=quantization_config, From fc42bd1eb0a705b71c34518a7a6c87c8f0809953 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 13 Nov 2025 03:47:30 +0000 Subject: [PATCH 7/7] require_ampere_or_newer --- tests/test_grpo_trainer.py | 2 ++ tests/test_sft_trainer.py | 2 ++ tests/testing_utils.py | 12 ++++++++++++ 3 files changed, 16 insertions(+) diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index 73ee7234306..b3844a399c1 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -41,6 +41,7 @@ from .testing_utils import ( TrlTestCase, + require_ampere_or_newer, require_bitsandbytes, require_kernels, require_liger_kernel, @@ -1988,6 +1989,7 @@ def test_training_with_transformers_paged(self, model_name): ], ) @require_kernels + @require_ampere_or_newer # Flash attention 2 requires Ampere or newer GPUs @require_bitsandbytes @require_peft def test_vlm_training(self, model_name): diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 3e528cbc5b2..874d5304f2f 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -33,6 +33,7 @@ from .testing_utils import ( TrlTestCase, ignore_warnings, + require_ampere_or_newer, require_bitsandbytes, require_kernels, require_liger_kernel, @@ -871,6 +872,7 @@ def test_train_with_iterable_dataset(self): assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" @require_kernels + @require_ampere_or_newer # Flash attention 2 requires Ampere or newer GPUs def test_train_padding_free(self): # Get the dataset dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") diff --git a/tests/testing_utils.py b/tests/testing_utils.py index 3bc26e54b0a..8f558bd5491 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -85,6 +85,18 @@ def is_bitsandbytes_multi_backend_available() -> bool: ) +def is_ampere_or_newer(device_index=0): + if not torch.cuda.is_available(): + return False + + major, minor = torch.cuda.get_device_capability(device_index) + # Ampere starts at compute capability 8.0 (e.g., A100 = 8.0, RTX 30xx = 8.6) + return (major, minor) >= (8, 0) + + +require_ampere_or_newer = pytest.mark.skipif(not is_ampere_or_newer(), reason="test requires Ampere or newer GPU") + + class RandomBinaryJudge(BaseBinaryJudge): """ Random binary judge, for testing purposes.