diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 75361c6010b..2aa5790e4e5 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -1466,12 +1466,12 @@ def test_train_vlm_prompt_completion_gemma(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated") + assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated" # Gemma 3n uses a timm encoder, making it difficult to create a smaller variant for testing. # To ensure coverage, we run tests on the full model but mark them as slow to exclude from default runs.