diff --git a/.vscode/cspell.json b/.vscode/cspell.json index 405956c66227..62a0ca562da0 100644 --- a/.vscode/cspell.json +++ b/.vscode/cspell.json @@ -915,7 +915,7 @@ }, { "filename": "sdk/cognitivelanguage/azure-ai-language-questionanswering/**", - "words": [ "qnas", "qnamaker", "ADTO", "tfidf", "ngram" ], + "words": [ "qnas", "qnamaker", "ADTO", "tfidf", "ngram", "kwoa" ], "caseSensitive": false }, { diff --git a/sdk/cognitivelanguage/azure-ai-language-questionanswering/azure/ai/language/questionanswering/_utils/model_base.py b/sdk/cognitivelanguage/azure-ai-language-questionanswering/azure/ai/language/questionanswering/_utils/model_base.py index 12926fa98dcf..097f8197cfd9 100644 --- a/sdk/cognitivelanguage/azure-ai-language-questionanswering/azure/ai/language/questionanswering/_utils/model_base.py +++ b/sdk/cognitivelanguage/azure-ai-language-questionanswering/azure/ai/language/questionanswering/_utils/model_base.py @@ -37,6 +37,7 @@ TZ_UTC = timezone.utc _T = typing.TypeVar("_T") +_NONE_TYPE = type(None) def _timedelta_as_isostr(td: timedelta) -> str: @@ -171,6 +172,21 @@ def default(self, o): # pylint: disable=too-many-return-statements r"(Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)\s\d{4}\s\d{2}:\d{2}:\d{2}\sGMT" ) +_ARRAY_ENCODE_MAPPING = { + "pipeDelimited": "|", + "spaceDelimited": " ", + "commaDelimited": ",", + "newlineDelimited": "\n", +} + + +def _deserialize_array_encoded(delimit: str, attr): + if isinstance(attr, str): + if attr == "": + return [] + return attr.split(delimit) + return attr + def _deserialize_datetime(attr: typing.Union[str, datetime]) -> datetime: """Deserialize ISO-8601 formatted string into Datetime object. @@ -202,7 +218,7 @@ def _deserialize_datetime(attr: typing.Union[str, datetime]) -> datetime: test_utc = date_obj.utctimetuple() if test_utc.tm_year > 9999 or test_utc.tm_year < 1: raise OverflowError("Hit max or min date") - return date_obj + return date_obj # type: ignore[no-any-return] def _deserialize_datetime_rfc7231(attr: typing.Union[str, datetime]) -> datetime: @@ -256,7 +272,7 @@ def _deserialize_time(attr: typing.Union[str, time]) -> time: """ if isinstance(attr, time): return attr - return isodate.parse_time(attr) + return isodate.parse_time(attr) # type: ignore[no-any-return] def _deserialize_bytes(attr): @@ -315,6 +331,8 @@ def _deserialize_int_as_str(attr): def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = None): if annotation is int and rf and rf._format == "str": return _deserialize_int_as_str + if annotation is str and rf and rf._format in _ARRAY_ENCODE_MAPPING: + return functools.partial(_deserialize_array_encoded, _ARRAY_ENCODE_MAPPING[rf._format]) if rf and rf._format: return _DESERIALIZE_MAPPING_WITHFORMAT.get(rf._format) return _DESERIALIZE_MAPPING.get(annotation) # pyright: ignore @@ -353,9 +371,39 @@ def __contains__(self, key: typing.Any) -> bool: return key in self._data def __getitem__(self, key: str) -> typing.Any: + # If this key has been deserialized (for mutable types), we need to handle serialization + if hasattr(self, "_attr_to_rest_field"): + cache_attr = f"_deserialized_{key}" + if hasattr(self, cache_attr): + rf = _get_rest_field(getattr(self, "_attr_to_rest_field"), key) + if rf: + value = self._data.get(key) + if isinstance(value, (dict, list, set)): + # For mutable types, serialize and return + # But also update _data with serialized form and clear flag + # so mutations via this returned value affect _data + serialized = _serialize(value, rf._format) + # If serialized form is same type (no transformation needed), + # return _data directly so mutations work + if isinstance(serialized, type(value)) and serialized == value: + return self._data.get(key) + # Otherwise return serialized copy and clear flag + try: + object.__delattr__(self, cache_attr) + except AttributeError: + pass + # Store serialized form back + self._data[key] = serialized + return serialized return self._data.__getitem__(key) def __setitem__(self, key: str, value: typing.Any) -> None: + # Clear any cached deserialized value when setting through dictionary access + cache_attr = f"_deserialized_{key}" + try: + object.__delattr__(self, cache_attr) + except AttributeError: + pass self._data.__setitem__(key, value) def __delitem__(self, key: str) -> None: @@ -483,6 +531,8 @@ def _is_model(obj: typing.Any) -> bool: def _serialize(o, format: typing.Optional[str] = None): # pylint: disable=too-many-return-statements if isinstance(o, list): + if format in _ARRAY_ENCODE_MAPPING and all(isinstance(x, str) for x in o): + return _ARRAY_ENCODE_MAPPING[format].join(o) return [_serialize(x, format) for x in o] if isinstance(o, dict): return {k: _serialize(v, format) for k, v in o.items()} @@ -638,6 +688,10 @@ def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> Self: if not rf._rest_name_input: rf._rest_name_input = attr cls._attr_to_rest_field: dict[str, _RestField] = dict(attr_to_rest_field.items()) + cls._backcompat_attr_to_rest_field: dict[str, _RestField] = { + Model._get_backcompat_attribute_name(cls._attr_to_rest_field, attr): rf + for attr, rf in cls._attr_to_rest_field.items() + } cls._calculated.add(f"{cls.__module__}.{cls.__qualname__}") return super().__new__(cls) @@ -647,6 +701,16 @@ def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None: if hasattr(base, "__mapping__"): base.__mapping__[discriminator or cls.__name__] = cls # type: ignore + @classmethod + def _get_backcompat_attribute_name(cls, attr_to_rest_field: dict[str, "_RestField"], attr_name: str) -> str: + rest_field_obj = attr_to_rest_field.get(attr_name) # pylint: disable=protected-access + if rest_field_obj is None: + return attr_name + original_tsp_name = getattr(rest_field_obj, "_original_tsp_name", None) # pylint: disable=protected-access + if original_tsp_name: + return original_tsp_name + return attr_name + @classmethod def _get_discriminator(cls, exist_discriminators) -> typing.Optional["_RestField"]: for v in cls.__dict__.values(): @@ -767,6 +831,17 @@ def _deserialize_sequence( return obj if isinstance(obj, ET.Element): obj = list(obj) + try: + if ( + isinstance(obj, str) + and isinstance(deserializer, functools.partial) + and isinstance(deserializer.args[0], functools.partial) + and deserializer.args[0].func == _deserialize_array_encoded # pylint: disable=comparison-with-callable + ): + # encoded string may be deserialized to sequence + return deserializer(obj) + except: # pylint: disable=bare-except + pass return type(obj)(_deserialize(deserializer, entry, module) for entry in obj) @@ -817,16 +892,16 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur # is it optional? try: - if any(a for a in annotation.__args__ if a == type(None)): # pyright: ignore + if any(a is _NONE_TYPE for a in annotation.__args__): # pyright: ignore if len(annotation.__args__) <= 2: # pyright: ignore if_obj_deserializer = _get_deserialize_callable_from_annotation( - next(a for a in annotation.__args__ if a != type(None)), module, rf # pyright: ignore + next(a for a in annotation.__args__ if a is not _NONE_TYPE), module, rf # pyright: ignore ) return functools.partial(_deserialize_with_optional, if_obj_deserializer) # the type is Optional[Union[...]], we need to remove the None type from the Union annotation_copy = copy.copy(annotation) - annotation_copy.__args__ = [a for a in annotation_copy.__args__ if a != type(None)] # pyright: ignore + annotation_copy.__args__ = [a for a in annotation_copy.__args__ if a is not _NONE_TYPE] # pyright: ignore return _get_deserialize_callable_from_annotation(annotation_copy, module, rf) except AttributeError: pass @@ -972,6 +1047,7 @@ def _failsafe_deserialize_xml( return None +# pylint: disable=too-many-instance-attributes class _RestField: def __init__( self, @@ -984,6 +1060,7 @@ def __init__( format: typing.Optional[str] = None, is_multipart_file_input: bool = False, xml: typing.Optional[dict[str, typing.Any]] = None, + original_tsp_name: typing.Optional[str] = None, ): self._type = type self._rest_name_input = name @@ -995,10 +1072,15 @@ def __init__( self._format = format self._is_multipart_file_input = is_multipart_file_input self._xml = xml if xml is not None else {} + self._original_tsp_name = original_tsp_name @property def _class_type(self) -> typing.Any: - return getattr(self._type, "args", [None])[0] + result = getattr(self._type, "args", [None])[0] + # type may be wrapped by nested functools.partial so we need to check for that + if isinstance(result, functools.partial): + return getattr(result, "args", [None])[0] + return result @property def _rest_name(self) -> str: @@ -1009,14 +1091,37 @@ def _rest_name(self) -> str: def __get__(self, obj: Model, type=None): # pylint: disable=redefined-builtin # by this point, type and rest_name will have a value bc we default # them in __new__ of the Model class - item = obj.get(self._rest_name) + # Use _data.get() directly to avoid triggering __getitem__ which clears the cache + item = obj._data.get(self._rest_name) if item is None: return item if self._is_model: return item - return _deserialize(self._type, _serialize(item, self._format), rf=self) + + # For mutable types, we want mutations to directly affect _data + # Check if we've already deserialized this value + cache_attr = f"_deserialized_{self._rest_name}" + if hasattr(obj, cache_attr): + # Return the value from _data directly (it's been deserialized in place) + return obj._data.get(self._rest_name) + + deserialized = _deserialize(self._type, _serialize(item, self._format), rf=self) + + # For mutable types, store the deserialized value back in _data + # so mutations directly affect _data + if isinstance(deserialized, (dict, list, set)): + obj._data[self._rest_name] = deserialized + object.__setattr__(obj, cache_attr, True) # Mark as deserialized + return deserialized + + return deserialized def __set__(self, obj: Model, value) -> None: + # Clear the cached deserialized object when setting a new value + cache_attr = f"_deserialized_{self._rest_name}" + if hasattr(obj, cache_attr): + object.__delattr__(obj, cache_attr) + if value is None: # we want to wipe out entries if users set attr to None try: @@ -1046,6 +1151,7 @@ def rest_field( format: typing.Optional[str] = None, is_multipart_file_input: bool = False, xml: typing.Optional[dict[str, typing.Any]] = None, + original_tsp_name: typing.Optional[str] = None, ) -> typing.Any: return _RestField( name=name, @@ -1055,6 +1161,7 @@ def rest_field( format=format, is_multipart_file_input=is_multipart_file_input, xml=xml, + original_tsp_name=original_tsp_name, ) @@ -1184,7 +1291,7 @@ def _get_wrapped_element( _get_element(v, exclude_readonly, meta, wrapped_element) else: wrapped_element.text = _get_primitive_type_value(v) - return wrapped_element + return wrapped_element # type: ignore[no-any-return] def _get_primitive_type_value(v) -> str: @@ -1197,7 +1304,9 @@ def _get_primitive_type_value(v) -> str: return str(v) -def _create_xml_element(tag, prefix=None, ns=None): +def _create_xml_element( + tag: typing.Any, prefix: typing.Optional[str] = None, ns: typing.Optional[str] = None +) -> ET.Element: if prefix and ns: ET.register_namespace(prefix, ns) if ns: diff --git a/sdk/cognitivelanguage/azure-ai-language-questionanswering/azure/ai/language/questionanswering/_utils/serialization.py b/sdk/cognitivelanguage/azure-ai-language-questionanswering/azure/ai/language/questionanswering/_utils/serialization.py index e81921cbb011..81ec1de5922b 100644 --- a/sdk/cognitivelanguage/azure-ai-language-questionanswering/azure/ai/language/questionanswering/_utils/serialization.py +++ b/sdk/cognitivelanguage/azure-ai-language-questionanswering/azure/ai/language/questionanswering/_utils/serialization.py @@ -787,7 +787,7 @@ def serialize_data(self, data, data_type, **kwargs): # If dependencies is empty, try with current data class # It has to be a subclass of Enum anyway - enum_type = self.dependencies.get(data_type, data.__class__) + enum_type = self.dependencies.get(data_type, cast(type, data.__class__)) if issubclass(enum_type, Enum): return Serializer.serialize_enum(data, enum_obj=enum_type) @@ -821,13 +821,20 @@ def serialize_basic(cls, data, data_type, **kwargs): :param str data_type: Type of object in the iterable. :rtype: str, int, float, bool :return: serialized object + :raises TypeError: raise if data_type is not one of str, int, float, bool. """ custom_serializer = cls._get_custom_serializers(data_type, **kwargs) if custom_serializer: return custom_serializer(data) if data_type == "str": return cls.serialize_unicode(data) - return eval(data_type)(data) # nosec # pylint: disable=eval-used + if data_type == "int": + return int(data) + if data_type == "float": + return float(data) + if data_type == "bool": + return bool(data) + raise TypeError("Unknown basic data type: {}".format(data_type)) @classmethod def serialize_unicode(cls, data): @@ -1757,7 +1764,7 @@ def deserialize_basic(self, attr, data_type): # pylint: disable=too-many-return :param str data_type: deserialization data type. :return: Deserialized basic type. :rtype: str, int, float or bool - :raises TypeError: if string format is not valid. + :raises TypeError: if string format is not valid or data_type is not one of str, int, float, bool. """ # If we're here, data is supposed to be a basic type. # If it's still an XML node, take the text @@ -1783,7 +1790,11 @@ def deserialize_basic(self, attr, data_type): # pylint: disable=too-many-return if data_type == "str": return self.deserialize_unicode(attr) - return eval(data_type)(attr) # nosec # pylint: disable=eval-used + if data_type == "int": + return int(attr) + if data_type == "float": + return float(attr) + raise TypeError("Unknown basic data type: {}".format(data_type)) @staticmethod def deserialize_unicode(data): diff --git a/sdk/cognitivelanguage/azure-ai-language-questionanswering/tests/test_query_knowledgebase.py b/sdk/cognitivelanguage/azure-ai-language-questionanswering/tests/test_query_knowledgebase.py index 6ac76a9a16d7..62fbcad45a80 100644 --- a/sdk/cognitivelanguage/azure-ai-language-questionanswering/tests/test_query_knowledgebase.py +++ b/sdk/cognitivelanguage/azure-ai-language-questionanswering/tests/test_query_knowledgebase.py @@ -5,6 +5,7 @@ # ------------------------------------------------------------------------- import pytest +from testcase import QuestionAnsweringTestCase from azure.ai.language.questionanswering import QuestionAnsweringClient from azure.ai.language.questionanswering.models import ( AnswersOptions, @@ -16,11 +17,9 @@ ) from azure.core.credentials import AzureKeyCredential -from testcase import QuestionAnsweringTestCase - class TestQnAKnowledgeBase(QuestionAnsweringTestCase): - def test_query_knowledgebase(self, recorded_test, qna_creds): # standard model usage + def test_query_knowledgebase(self, recorded_test, qna_creds): # standard model usage # pylint: disable=unused-argument client = QuestionAnsweringClient(qna_creds["qna_endpoint"], AzureKeyCredential(qna_creds["qna_key"])) query_params = AnswersOptions( question="Ports and connectors", @@ -36,7 +35,7 @@ def test_query_knowledgebase(self, recorded_test, qna_creds): # standard model assert answer.qna_id is not None assert answer.source - def test_query_knowledgebase_with_answerspan(self, recorded_test, qna_creds): + def test_query_knowledgebase_with_answerspan(self, recorded_test, qna_creds): # pylint: disable=unused-argument client = QuestionAnsweringClient(qna_creds["qna_endpoint"], AzureKeyCredential(qna_creds["qna_key"])) query_params = AnswersOptions( question="Ports and connectors", @@ -53,7 +52,7 @@ def test_query_knowledgebase_with_answerspan(self, recorded_test, qna_creds): assert answer.short_answer.text assert answer.short_answer.confidence is not None - def test_query_knowledgebase_filter(self, recorded_test, qna_creds): + def test_query_knowledgebase_filter(self, recorded_test, qna_creds): # pylint: disable=unused-argument filters = QueryFilters( metadata_filter=MetadataFilter( metadata=[ @@ -76,7 +75,7 @@ def test_query_knowledgebase_filter(self, recorded_test, qna_creds): ) assert response.answers - def test_query_knowledgebase_only_id(self, recorded_test, qna_creds): + def test_query_knowledgebase_only_id(self, recorded_test, qna_creds): # pylint: disable=unused-argument client = QuestionAnsweringClient(qna_creds["qna_endpoint"], AzureKeyCredential(qna_creds["qna_key"])) with client: query_params = AnswersOptions(qna_id=19) @@ -88,11 +87,11 @@ def test_query_knowledgebase_overload_errors(self): # negative tests independen with QuestionAnsweringClient("http://fake.com", AzureKeyCredential("123")) as client: # These calls intentionally violate the method signature to ensure TypeError is raised. with pytest.raises(TypeError): - client.get_answers("positional_one", "positional_two") # type: ignore + client.get_answers("positional_one", "positional_two") # type: ignore # pylint: disable=too-many-function-args, missing-kwoa with pytest.raises(TypeError): - client.get_answers("positional_options_bag", options="options bag by name") # type: ignore + client.get_answers("positional_options_bag", options="options bag by name") # type: ignore # pylint: disable=missing-kwoa with pytest.raises(TypeError): - client.get_answers(options={"qnaId": 15}, project_name="hello", deployment_name="test") # type: ignore + client.get_answers(options={"qnaId": 15}, project_name="hello", deployment_name="test") # type: ignore # pylint: disable=no-value-for-parameter with pytest.raises(TypeError): client.get_answers({"qnaId": 15}, question="Why?", project_name="hello", deployment_name="test") # type: ignore @@ -102,4 +101,4 @@ def test_query_knowledgebase_question_or_qna_id(self): with pytest.raises(TypeError): client.get_answers(options, project_name="hello", deployment_name="test") with pytest.raises(TypeError): - client.get_answers(project_name="hello", deployment_name="test") + client.get_answers(project_name="hello", deployment_name="test") # pylint: disable=no-value-for-parameter diff --git a/sdk/cognitivelanguage/azure-ai-language-questionanswering/tests/test_query_knowledgebase_async.py b/sdk/cognitivelanguage/azure-ai-language-questionanswering/tests/test_query_knowledgebase_async.py index 72b11cddd0e8..b8e38dd7134f 100644 --- a/sdk/cognitivelanguage/azure-ai-language-questionanswering/tests/test_query_knowledgebase_async.py +++ b/sdk/cognitivelanguage/azure-ai-language-questionanswering/tests/test_query_knowledgebase_async.py @@ -5,6 +5,7 @@ # ------------------------------------------------------------------------- import pytest +from testcase import QuestionAnsweringTestCase from azure.ai.language.questionanswering.aio import QuestionAnsweringClient from azure.ai.language.questionanswering.models import ( AnswersOptions, @@ -16,12 +17,10 @@ ) from azure.core.credentials import AzureKeyCredential -from testcase import QuestionAnsweringTestCase - class TestQueryKnowledgeBaseAsync(QuestionAnsweringTestCase): @pytest.mark.asyncio - async def test_query_knowledgebase_basic(self, recorded_test, qna_creds): + async def test_query_knowledgebase_basic(self, recorded_test, qna_creds): # pylint: disable=unused-argument client = QuestionAnsweringClient(qna_creds["qna_endpoint"], AzureKeyCredential(qna_creds["qna_key"])) params = AnswersOptions( question="Ports and connectors", @@ -39,7 +38,7 @@ async def test_query_knowledgebase_basic(self, recorded_test, qna_creds): assert answer.metadata is not None @pytest.mark.asyncio - async def test_query_knowledgebase_with_short_answer(self, recorded_test, qna_creds): + async def test_query_knowledgebase_with_short_answer(self, recorded_test, qna_creds): # pylint: disable=unused-argument client = QuestionAnsweringClient(qna_creds["qna_endpoint"], AzureKeyCredential(qna_creds["qna_key"])) params = AnswersOptions( question="Ports and connectors", @@ -56,7 +55,7 @@ async def test_query_knowledgebase_with_short_answer(self, recorded_test, qna_cr assert answer.short_answer.confidence is not None @pytest.mark.asyncio - async def test_query_knowledgebase_filter(self, recorded_test, qna_creds): + async def test_query_knowledgebase_filter(self, recorded_test, qna_creds): # pylint: disable=unused-argument filters = QueryFilters( metadata_filter=MetadataFilter( metadata=[ @@ -80,14 +79,14 @@ async def test_query_knowledgebase_filter(self, recorded_test, qna_creds): deployment_name="production", ) assert response.answers - assert any( + assert any( # pylint: disable=use-a-generator [ a for a in response.answers if (a.metadata or {}).get("explicitlytaggedheading") == "check the battery level" ] ) - assert any( + assert any( # pylint: disable=use-a-generator [ a for a in response.answers diff --git a/sdk/cognitivelanguage/azure-ai-language-questionanswering/tests/test_query_text.py b/sdk/cognitivelanguage/azure-ai-language-questionanswering/tests/test_query_text.py index 8a9a55f98e8c..872f90610de2 100644 --- a/sdk/cognitivelanguage/azure-ai-language-questionanswering/tests/test_query_text.py +++ b/sdk/cognitivelanguage/azure-ai-language-questionanswering/tests/test_query_text.py @@ -4,16 +4,15 @@ # Runtime tests: text records querying (authoring removed) # ------------------------------------------------------------------------- import pytest +from testcase import QuestionAnsweringTestCase from azure.ai.language.questionanswering import QuestionAnsweringClient from azure.ai.language.questionanswering.models import AnswersFromTextOptions, TextDocument from azure.core.credentials import AzureKeyCredential -from testcase import QuestionAnsweringTestCase - class TestQueryText(QuestionAnsweringTestCase): - def test_query_text_basic(self, recorded_test, qna_creds): + def test_query_text_basic(self, recorded_test, qna_creds): # pylint: disable=unused-argument client = QuestionAnsweringClient(qna_creds["qna_endpoint"], AzureKeyCredential(qna_creds["qna_key"])) params = AnswersFromTextOptions( question="What is the meaning of life?", @@ -33,7 +32,7 @@ def test_query_text_basic(self, recorded_test, qna_creds): if answer.short_answer: assert answer.short_answer.text - def test_query_text_with_str_records(self, recorded_test, qna_creds): + def test_query_text_with_str_records(self, recorded_test, qna_creds): # pylint: disable=unused-argument client = QuestionAnsweringClient(qna_creds["qna_endpoint"], AzureKeyCredential(qna_creds["qna_key"])) params = { "question": "How long it takes to charge surface?", @@ -57,7 +56,7 @@ def test_query_text_with_str_records(self, recorded_test, qna_creds): def test_query_text_overload_errors(self): # negative client-side parameter validation with QuestionAnsweringClient("http://fake.com", AzureKeyCredential("123")) as client: with pytest.raises(TypeError): - client.get_answers_from_text("positional_one", "positional_two") # type: ignore[arg-type] + client.get_answers_from_text("positional_one", "positional_two") # type: ignore[arg-type] # pylint: disable=too-many-function-args with pytest.raises(TypeError): client.get_answers_from_text("positional_options_bag", options="options bag by name") # type: ignore[arg-type] params = AnswersFromTextOptions( @@ -65,13 +64,13 @@ def test_query_text_overload_errors(self): # negative client-side parameter val text_documents=[TextDocument(text="foo", id="doc1"), TextDocument(text="bar", id="doc2")], ) with pytest.raises(TypeError): - client.get_answers_from_text(options=params) # type: ignore[arg-type] + client.get_answers_from_text(options=params) # type: ignore[arg-type] # pylint: disable=no-value-for-parameter with pytest.raises(TypeError): - client.get_answers_from_text(question="why?", text_documents=["foo", "bar"], options=params) # type: ignore[arg-type] + client.get_answers_from_text(question="why?", text_documents=["foo", "bar"], options=params) # type: ignore[arg-type] # pylint: disable=no-value-for-parameter with pytest.raises(TypeError): client.get_answers_from_text(params, question="Why?") # type: ignore[arg-type] - def test_query_text_default_lang_override(self, recorded_test, qna_creds): + def test_query_text_default_lang_override(self, recorded_test, qna_creds): # pylint: disable=unused-argument client = QuestionAnsweringClient( qna_creds["qna_endpoint"], AzureKeyCredential(qna_creds["qna_key"]), default_language="es" ) diff --git a/sdk/cognitivelanguage/azure-ai-language-questionanswering/tests/test_query_text_async.py b/sdk/cognitivelanguage/azure-ai-language-questionanswering/tests/test_query_text_async.py index 7eec7b126fe8..bc297504f87d 100644 --- a/sdk/cognitivelanguage/azure-ai-language-questionanswering/tests/test_query_text_async.py +++ b/sdk/cognitivelanguage/azure-ai-language-questionanswering/tests/test_query_text_async.py @@ -4,17 +4,16 @@ # Inference async tests: text records querying (authoring removed) # ------------------------------------------------------------------------- import pytest +from testcase import QuestionAnsweringTestCase from azure.ai.language.questionanswering.aio import QuestionAnsweringClient from azure.ai.language.questionanswering.models import AnswersFromTextOptions, TextDocument from azure.core.credentials import AzureKeyCredential -from testcase import QuestionAnsweringTestCase - class TestQueryTextAsync(QuestionAnsweringTestCase): @pytest.mark.asyncio - async def test_query_text_basic(self, recorded_test, qna_creds): + async def test_query_text_basic(self, recorded_test, qna_creds): # pylint: disable=unused-argument client = QuestionAnsweringClient(qna_creds["qna_endpoint"], AzureKeyCredential(qna_creds["qna_key"])) params = AnswersFromTextOptions( question="What is the meaning of life?", @@ -33,7 +32,7 @@ async def test_query_text_basic(self, recorded_test, qna_creds): assert answer.id is not None @pytest.mark.asyncio - async def test_query_text_with_str_records(self, recorded_test, qna_creds): + async def test_query_text_with_str_records(self, recorded_test, qna_creds): # pylint: disable=unused-argument client = QuestionAnsweringClient(qna_creds["qna_endpoint"], AzureKeyCredential(qna_creds["qna_key"])) params = { "question": "How long it takes to charge surface?", @@ -73,7 +72,7 @@ async def test_query_text_overload_errors(self): # negative parameter validatio await client.get_answers_from_text(params, question="Why?") # type: ignore[arg-type] @pytest.mark.asyncio - async def test_query_text_default_lang_override(self, recorded_test, qna_creds): + async def test_query_text_default_lang_override(self, recorded_test, qna_creds): # pylint: disable=unused-argument client = QuestionAnsweringClient( qna_creds["qna_endpoint"], AzureKeyCredential(qna_creds["qna_key"]), default_language="es" ) diff --git a/sdk/cognitivelanguage/azure-ai-language-questionanswering/tests/testcase.py b/sdk/cognitivelanguage/azure-ai-language-questionanswering/tests/testcase.py index 965f9e721089..dd87faa210c9 100644 --- a/sdk/cognitivelanguage/azure-ai-language-questionanswering/tests/testcase.py +++ b/sdk/cognitivelanguage/azure-ai-language-questionanswering/tests/testcase.py @@ -8,6 +8,6 @@ class QuestionAnsweringTestCase(AzureRecordedTestCase): @property def kwargs_for_polling(self): - if self.is_playback: + if self.is_playback: # pylint: disable=using-constant-test return {"polling_interval": 0} return {}