diff --git a/tests/test_litellm/llms/azure/test_azure_common_utils.py b/tests/test_litellm/llms/azure/test_azure_common_utils.py index 5a0680d66bf..b8c38fb9099 100644 --- a/tests/test_litellm/llms/azure/test_azure_common_utils.py +++ b/tests/test_litellm/llms/azure/test_azure_common_utils.py @@ -483,6 +483,7 @@ async def test_ensure_initialize_azure_sdk_client_always_used(call_type): "input_file_id": "123", }, "aretrieve_batch": {"batch_id": "123"}, + "acancel_batch": {"batch_id": "123"}, "aget_assistants": {"custom_llm_provider": "azure"}, "acreate_assistants": {"custom_llm_provider": "azure"}, "adelete_assistant": {"custom_llm_provider": "azure", "assistant_id": "123"}, @@ -537,7 +538,7 @@ async def test_ensure_initialize_azure_sdk_client_always_used(call_type): patch_target = ( "litellm.rerank_api.main.azure_rerank.initialize_azure_sdk_client" ) - elif call_type == CallTypes.acreate_batch or call_type == CallTypes.aretrieve_batch: + elif call_type == CallTypes.acreate_batch or call_type == CallTypes.aretrieve_batch or call_type == CallTypes.acancel_batch: patch_target = ( "litellm.batches.main.azure_batches_instance.initialize_azure_sdk_client" )