diff --git a/sdk/contentsafety/azure-ai-contentsafety/_meta.json b/sdk/contentsafety/azure-ai-contentsafety/_meta.json new file mode 100644 index 000000000000..17d5598ea51f --- /dev/null +++ b/sdk/contentsafety/azure-ai-contentsafety/_meta.json @@ -0,0 +1,6 @@ +{ + "commit": "924a98b9c8d71496dccd02b92fecf3902d2ea025", + "repository_url": "https://github.com/Azure/azure-rest-api-specs", + "typespec_src": "specification/cognitiveservices/ContentSafety", + "@azure-tools/typespec-python": "0.13.5" +} \ No newline at end of file diff --git a/sdk/contentsafety/azure-ai-contentsafety/azure/ai/contentsafety/_configuration.py b/sdk/contentsafety/azure-ai-contentsafety/azure/ai/contentsafety/_configuration.py index 8e25e529d1c5..6f53a586ffa6 100644 --- a/sdk/contentsafety/azure-ai-contentsafety/azure/ai/contentsafety/_configuration.py +++ b/sdk/contentsafety/azure-ai-contentsafety/azure/ai/contentsafety/_configuration.py @@ -15,7 +15,7 @@ from ._version import VERSION -class ContentSafetyClientConfiguration(Configuration): # pylint: disable=too-many-instance-attributes +class ContentSafetyClientConfiguration(Configuration): # pylint: disable=too-many-instance-attributes,name-too-long """Configuration for ContentSafetyClient. Note that all parameters used to create this instance are saved as instance diff --git a/sdk/contentsafety/azure-ai-contentsafety/azure/ai/contentsafety/_model_base.py b/sdk/contentsafety/azure-ai-contentsafety/azure/ai/contentsafety/_model_base.py index a7ae06682f33..3c87287b19cb 100644 --- a/sdk/contentsafety/azure-ai-contentsafety/azure/ai/contentsafety/_model_base.py +++ b/sdk/contentsafety/azure-ai-contentsafety/azure/ai/contentsafety/_model_base.py @@ -7,6 +7,7 @@ # pylint: disable=protected-access, arguments-differ, signature-differs, broad-except # pyright: reportGeneralTypeIssues=false +import calendar import functools import sys import logging @@ -14,13 +15,14 @@ import re import copy import typing +import email from datetime import datetime, date, time, timedelta, timezone from json import JSONEncoder import isodate from azure.core.exceptions import DeserializationError from azure.core import CaseInsensitiveEnumMeta from azure.core.pipeline import PipelineResponse -from azure.core.serialization import _Null # pylint: disable=protected-access +from azure.core.serialization import _Null if sys.version_info >= (3, 9): from collections.abc import MutableMapping @@ -31,7 +33,6 @@ __all__ = ["AzureJSONEncoder", "Model", "rest_field", "rest_discriminator"] - TZ_UTC = timezone.utc @@ -59,69 +60,53 @@ def _timedelta_as_isostr(td: timedelta) -> str: if days: date_str = "%sD" % days - # Build time - time_str = "T" + if hours or minutes or seconds: + # Build time + time_str = "T" - # Hours - bigger_exists = date_str or hours - if bigger_exists: - time_str += "{:02}H".format(hours) + # Hours + bigger_exists = date_str or hours + if bigger_exists: + time_str += "{:02}H".format(hours) - # Minutes - bigger_exists = bigger_exists or minutes - if bigger_exists: - time_str += "{:02}M".format(minutes) + # Minutes + bigger_exists = bigger_exists or minutes + if bigger_exists: + time_str += "{:02}M".format(minutes) - # Seconds - try: - if seconds.is_integer(): - seconds_string = "{:02}".format(int(seconds)) - else: - # 9 chars long w/ leading 0, 6 digits after decimal - seconds_string = "%09.6f" % seconds - # Remove trailing zeros - seconds_string = seconds_string.rstrip("0") - except AttributeError: # int.is_integer() raises - seconds_string = "{:02}".format(seconds) + # Seconds + try: + if seconds.is_integer(): + seconds_string = "{:02}".format(int(seconds)) + else: + # 9 chars long w/ leading 0, 6 digits after decimal + seconds_string = "%09.6f" % seconds + # Remove trailing zeros + seconds_string = seconds_string.rstrip("0") + except AttributeError: # int.is_integer() raises + seconds_string = "{:02}".format(seconds) - time_str += "{}S".format(seconds_string) + time_str += "{}S".format(seconds_string) + else: + time_str = "" return "P" + date_str + time_str -def _datetime_as_isostr(dt: typing.Union[datetime, date, time, timedelta]) -> str: - """Converts a datetime.(datetime|date|time|timedelta) object into an ISO 8601 formatted string - - :param timedelta dt: The date object to convert - :rtype: str - :return: ISO8601 version of this datetime - """ - # First try datetime.datetime - if hasattr(dt, "year") and hasattr(dt, "hour"): - dt = typing.cast(datetime, dt) - # astimezone() fails for naive times in Python 2.7, so make make sure dt is aware (tzinfo is set) - if not dt.tzinfo: - iso_formatted = dt.replace(tzinfo=TZ_UTC).isoformat() - else: - iso_formatted = dt.astimezone(TZ_UTC).isoformat() - # Replace the trailing "+00:00" UTC offset with "Z" (RFC 3339: https://www.ietf.org/rfc/rfc3339.txt) - return iso_formatted.replace("+00:00", "Z") - # Next try datetime.date or datetime.time - try: - dt = typing.cast(typing.Union[date, time], dt) - return dt.isoformat() - # Last, try datetime.timedelta - except AttributeError: - dt = typing.cast(timedelta, dt) - return _timedelta_as_isostr(dt) - - -def _serialize_bytes(o) -> str: - return base64.b64encode(o).decode() +def _serialize_bytes(o, format: typing.Optional[str] = None) -> str: + encoded = base64.b64encode(o).decode() + if format == "base64url": + return encoded.strip("=").replace("+", "-").replace("/", "_") + return encoded -def _serialize_datetime(o): +def _serialize_datetime(o, format: typing.Optional[str] = None): if hasattr(o, "year") and hasattr(o, "hour"): + if format == "rfc7231": + return email.utils.format_datetime(o, usegmt=True) + if format == "unix-timestamp": + return int(calendar.timegm(o.utctimetuple())) + # astimezone() fails for naive times in Python 2.7, so make make sure o is aware (tzinfo is set) if not o.tzinfo: iso_formatted = o.replace(tzinfo=TZ_UTC).isoformat() @@ -135,7 +120,7 @@ def _serialize_datetime(o): def _is_readonly(p): try: - return p._readonly # pylint: disable=protected-access + return p._visibility == ["read"] # pylint: disable=protected-access except AttributeError: return False @@ -143,24 +128,27 @@ def _is_readonly(p): class AzureJSONEncoder(JSONEncoder): """A JSON encoder that's capable of serializing datetime objects and bytes.""" + def __init__(self, *args, exclude_readonly: bool = False, format: typing.Optional[str] = None, **kwargs): + super().__init__(*args, **kwargs) + self.exclude_readonly = exclude_readonly + self.format = format + def default(self, o): # pylint: disable=too-many-return-statements if _is_model(o): - readonly_props = [ - p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p) - ] # pylint: disable=protected-access - return {k: v for k, v in o.items() if k not in readonly_props} - if isinstance(o, (bytes, bytearray)): - return base64.b64encode(o).decode() - if isinstance(o, _Null): - return None + if self.exclude_readonly: + readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] + return {k: v for k, v in o.items() if k not in readonly_props} + return dict(o.items()) try: return super(AzureJSONEncoder, self).default(o) except TypeError: + if isinstance(o, _Null): + return None if isinstance(o, (bytes, bytearray)): - return _serialize_bytes(o) + return _serialize_bytes(o, self.format) try: # First try datetime.datetime - return _serialize_datetime(o) + return _serialize_datetime(o, self.format) except AttributeError: pass # Last, try datetime.timedelta @@ -173,6 +161,10 @@ def default(self, o): # pylint: disable=too-many-return-statements _VALID_DATE = re.compile(r"\d{4}[-]\d{2}[-]\d{2}T\d{2}:\d{2}:\d{2}" + r"\.?\d*Z?[-+]?[\d{2}]?:?[\d{2}]?") +_VALID_RFC7231 = re.compile( + r"(Mon|Tue|Wed|Thu|Fri|Sat|Sun),\s\d{2}\s" + r"(Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)\s\d{4}\s\d{2}:\d{2}:\d{2}\sGMT" +) def _deserialize_datetime(attr: typing.Union[str, datetime]) -> datetime: @@ -208,6 +200,36 @@ def _deserialize_datetime(attr: typing.Union[str, datetime]) -> datetime: return date_obj +def _deserialize_datetime_rfc7231(attr: typing.Union[str, datetime]) -> datetime: + """Deserialize RFC7231 formatted string into Datetime object. + + :param str attr: response string to be deserialized. + :rtype: ~datetime.datetime + :returns: The datetime object from that input + """ + if isinstance(attr, datetime): + # i'm already deserialized + return attr + match = _VALID_RFC7231.match(attr) + if not match: + raise ValueError("Invalid datetime string: " + attr) + + return email.utils.parsedate_to_datetime(attr) + + +def _deserialize_datetime_unix_timestamp(attr: typing.Union[float, datetime]) -> datetime: + """Deserialize unix timestamp into Datetime object. + + :param str attr: response string to be deserialized. + :rtype: ~datetime.datetime + :returns: The datetime object from that input + """ + if isinstance(attr, datetime): + # i'm already deserialized + return attr + return datetime.fromtimestamp(attr, TZ_UTC) + + def _deserialize_date(attr: typing.Union[str, date]) -> date: """Deserialize ISO-8601 formatted string into Date object. :param str attr: response string to be deserialized. @@ -232,13 +254,22 @@ def _deserialize_time(attr: typing.Union[str, time]) -> time: return isodate.parse_time(attr) -def deserialize_bytes(attr): +def _deserialize_bytes(attr): if isinstance(attr, (bytes, bytearray)): return attr return bytes(base64.b64decode(attr)) -def deserialize_duration(attr): +def _deserialize_bytes_base64(attr): + if isinstance(attr, (bytes, bytearray)): + return attr + padding = "=" * (3 - (len(attr) + 3) % 4) # type: ignore + attr = attr + padding # type: ignore + encoded = attr.replace("-", "+").replace("_", "/") + return bytes(base64.b64decode(encoded)) + + +def _deserialize_duration(attr): if isinstance(attr, timedelta): return attr return isodate.parse_duration(attr) @@ -248,17 +279,42 @@ def deserialize_duration(attr): datetime: _deserialize_datetime, date: _deserialize_date, time: _deserialize_time, - bytes: deserialize_bytes, - timedelta: deserialize_duration, + bytes: _deserialize_bytes, + bytearray: _deserialize_bytes, + timedelta: _deserialize_duration, typing.Any: lambda x: x, } +_DESERIALIZE_MAPPING_WITHFORMAT = { + "rfc3339": _deserialize_datetime, + "rfc7231": _deserialize_datetime_rfc7231, + "unix-timestamp": _deserialize_datetime_unix_timestamp, + "base64": _deserialize_bytes, + "base64url": _deserialize_bytes_base64, +} + + +def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = None): + if rf and rf._format: + return _DESERIALIZE_MAPPING_WITHFORMAT.get(rf._format) + return _DESERIALIZE_MAPPING.get(annotation) + + +def _get_type_alias_type(module_name: str, alias_name: str): + types = { + k: v + for k, v in sys.modules[module_name].__dict__.items() + if isinstance(v, typing._GenericAlias) # type: ignore + } + if alias_name not in types: + return alias_name + return types[alias_name] + def _get_model(module_name: str, model_name: str): models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)} module_end = module_name.rsplit(".", 1)[0] - module = sys.modules[module_end] - models.update({k: v for k, v in module.__dict__.items() if isinstance(v, type)}) + models.update({k: v for k, v in sys.modules[module_end].__dict__.items() if isinstance(v, type)}) if isinstance(model_name, str): model_name = model_name.split(".")[-1] if model_name not in models: @@ -359,12 +415,20 @@ def _is_model(obj: typing.Any) -> bool: return getattr(obj, "_is_model", False) -def _serialize(o): +def _serialize(o, format: typing.Optional[str] = None): # pylint: disable=too-many-return-statements + if isinstance(o, list): + return [_serialize(x, format) for x in o] + if isinstance(o, dict): + return {k: _serialize(v, format) for k, v in o.items()} + if isinstance(o, set): + return {_serialize(x, format) for x in o} + if isinstance(o, tuple): + return tuple(_serialize(x, format) for x in o) if isinstance(o, (bytes, bytearray)): - return _serialize_bytes(o) + return _serialize_bytes(o, format) try: # First try datetime.datetime - return _serialize_datetime(o) + return _serialize_datetime(o, format) except AttributeError: pass # Last, try datetime.timedelta @@ -386,7 +450,7 @@ def _get_rest_field( def _create_value(rf: typing.Optional["_RestField"], value: typing.Any) -> typing.Any: - return _deserialize(rf._type, value) if (rf and rf._is_model) else _serialize(value) + return _deserialize(rf._type, value) if (rf and rf._is_model) else _serialize(value, rf._format if rf else None) class Model(_MyMutableMapping): @@ -411,7 +475,11 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: # actual type errors only throw the first wrong keyword arg they see, so following that. raise TypeError(f"{class_name}.__init__() got an unexpected keyword argument '{non_attr_kwargs[0]}'") dict_to_pass.update( - {self._attr_to_rest_field[k]._rest_name: _serialize(v) for k, v in kwargs.items() if v is not None} + { + self._attr_to_rest_field[k]._rest_name: _create_value(self._attr_to_rest_field[k], v) + for k, v in kwargs.items() + if v is not None + } ) super().__init__(dict_to_pass) @@ -446,31 +514,77 @@ def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None: base.__mapping__[discriminator or cls.__name__] = cls # type: ignore # pylint: disable=no-member @classmethod - def _get_discriminator(cls) -> typing.Optional[str]: + def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]: for v in cls.__dict__.values(): - if isinstance(v, _RestField) and v._is_discriminator: # pylint: disable=protected-access + if ( + isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators + ): # pylint: disable=protected-access return v._rest_name # pylint: disable=protected-access return None @classmethod - def _deserialize(cls, data): + def _deserialize(cls, data, exist_discriminators): if not hasattr(cls, "__mapping__"): # pylint: disable=no-member return cls(data) - discriminator = cls._get_discriminator() + discriminator = cls._get_discriminator(exist_discriminators) + exist_discriminators.append(discriminator) mapped_cls = cls.__mapping__.get(data.get(discriminator), cls) # pylint: disable=no-member if mapped_cls == cls: return cls(data) - return mapped_cls._deserialize(data) # pylint: disable=protected-access + return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access + + def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]: + """Return a dict that can be JSONify using json.dump. + + :keyword bool exclude_readonly: Whether to remove the readonly properties. + :returns: A dict JSON compatible object + :rtype: dict + """ + + result = {} + if exclude_readonly: + readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)] + for k, v in self.items(): + if exclude_readonly and k in readonly_props: # pyright: reportUnboundVariable=false + continue + result[k] = Model._as_dict_value(v, exclude_readonly=exclude_readonly) + return result + + @staticmethod + def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any: + if v is None or isinstance(v, _Null): + return None + if isinstance(v, (list, tuple, set)): + return [Model._as_dict_value(x, exclude_readonly=exclude_readonly) for x in v] + if isinstance(v, dict): + return {dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()} + return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v -def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-statements - annotation: typing.Any, module: typing.Optional[str], rf: typing.Optional["_RestField"] = None +def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912 + annotation: typing.Any, + module: typing.Optional[str], + rf: typing.Optional["_RestField"] = None, ) -> typing.Optional[typing.Callable[[typing.Any], typing.Any]]: if not annotation or annotation in [int, float]: return None + # is it a type alias? + if isinstance(annotation, str): + if module is not None: + annotation = _get_type_alias_type(module, annotation) + + # is it a forward ref / in quotes? + if isinstance(annotation, (str, typing.ForwardRef)): + try: + model_name = annotation.__forward_arg__ # type: ignore + except AttributeError: + model_name = annotation + if module is not None: + annotation = _get_model(module, model_name) + try: - if module and _is_model(_get_model(module, annotation)): + if module and _is_model(annotation): if rf: rf._is_model = True @@ -479,14 +593,16 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj return obj return _deserialize(model_deserializer, obj) - return functools.partial(_deserialize_model, _get_model(module, annotation)) + return functools.partial(_deserialize_model, annotation) except Exception: pass # is it a literal? try: if sys.version_info >= (3, 8): - from typing import Literal # pylint: disable=no-name-in-module, ungrouped-imports + from typing import ( + Literal, + ) # pylint: disable=no-name-in-module, ungrouped-imports else: from typing_extensions import Literal # type: ignore # pylint: disable=ungrouped-imports @@ -495,24 +611,9 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj except AttributeError: pass - if getattr(annotation, "__origin__", None) is typing.Union: - - def _deserialize_with_union(union_annotation, obj): - for t in union_annotation.__args__: - try: - return _deserialize(t, obj, module) - except DeserializationError: - pass - raise DeserializationError() - - return functools.partial(_deserialize_with_union, annotation) - # is it optional? try: - # right now, assuming we don't have unions, since we're getting rid of the only - # union we used to have in msrest models, which was union of str and enum if any(a for a in annotation.__args__ if a == type(None)): - if_obj_deserializer = _get_deserialize_callable_from_annotation( next(a for a in annotation.__args__ if a != type(None)), module, rf ) @@ -526,14 +627,18 @@ def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Calla except AttributeError: pass - # is it a forward ref / in quotes? - if isinstance(annotation, (str, typing.ForwardRef)): - try: - model_name = annotation.__forward_arg__ # type: ignore - except AttributeError: - model_name = annotation - if module is not None: - annotation = _get_model(module, model_name) + if getattr(annotation, "__origin__", None) is typing.Union: + deserializers = [_get_deserialize_callable_from_annotation(arg, module, rf) for arg in annotation.__args__] + + def _deserialize_with_union(deserializers, obj): + for deserializer in deserializers: + try: + return _deserialize(deserializer, obj) + except DeserializationError: + pass + raise DeserializationError() + + return functools.partial(_deserialize_with_union, deserializers) try: if annotation._name == "Dict": @@ -564,7 +669,8 @@ def _deserialize_dict( if len(annotation.__args__) > 1: def _deserialize_multiple_sequence( - entry_deserializers: typing.List[typing.Optional[typing.Callable]], obj + entry_deserializers: typing.List[typing.Optional[typing.Callable]], + obj, ): if obj is None: return obj @@ -604,11 +710,12 @@ def _deserialize_default( pass return _deserialize_with_callable(deserializer_from_mapping, obj) - return functools.partial(_deserialize_default, annotation, _DESERIALIZE_MAPPING.get(annotation)) + return functools.partial(_deserialize_default, annotation, get_deserializer(annotation, rf)) def _deserialize_with_callable( - deserializer: typing.Optional[typing.Callable[[typing.Any], typing.Any]], value: typing.Any + deserializer: typing.Optional[typing.Callable[[typing.Any], typing.Any]], + value: typing.Any, ): try: if value is None: @@ -622,16 +729,21 @@ def _deserialize_with_callable( # for unknown value, return raw value return value if isinstance(deserializer, type) and issubclass(deserializer, Model): - return deserializer._deserialize(value) + return deserializer._deserialize(value, []) return typing.cast(typing.Callable[[typing.Any], typing.Any], deserializer)(value) except Exception as e: raise DeserializationError() from e -def _deserialize(deserializer: typing.Any, value: typing.Any, module: typing.Optional[str] = None) -> typing.Any: +def _deserialize( + deserializer: typing.Any, + value: typing.Any, + module: typing.Optional[str] = None, + rf: typing.Optional["_RestField"] = None, +) -> typing.Any: if isinstance(value, PipelineResponse): value = value.http_response.json() - deserializer = _get_deserialize_callable_from_annotation(deserializer, module) + deserializer = _get_deserialize_callable_from_annotation(deserializer, module, rf) return _deserialize_with_callable(deserializer, value) @@ -642,16 +754,18 @@ def __init__( name: typing.Optional[str] = None, type: typing.Optional[typing.Callable] = None, # pylint: disable=redefined-builtin is_discriminator: bool = False, - readonly: bool = False, + visibility: typing.Optional[typing.List[str]] = None, default: typing.Any = _UNSET, + format: typing.Optional[str] = None, ): self._type = type self._rest_name_input = name self._module: typing.Optional[str] = None self._is_discriminator = is_discriminator - self._readonly = readonly + self._visibility = visibility self._is_model = False self._default = default + self._format = format @property def _rest_name(self) -> str: @@ -665,7 +779,9 @@ def __get__(self, obj: Model, type=None): # pylint: disable=redefined-builtin item = obj.get(self._rest_name) if item is None: return item - return _deserialize(self._type, _serialize(item)) + if self._is_model: + return item + return _deserialize(self._type, _serialize(item, self._format), rf=self) def __set__(self, obj: Model, value) -> None: if value is None: @@ -675,9 +791,12 @@ def __set__(self, obj: Model, value) -> None: except KeyError: pass return - if self._is_model and not _is_model(value): - obj.__setitem__(self._rest_name, _deserialize(self._type, value)) - obj.__setitem__(self._rest_name, _serialize(value)) + if self._is_model: + if not _is_model(value): + value = _deserialize(self._type, value) + obj.__setitem__(self._rest_name, value) + return + obj.__setitem__(self._rest_name, _serialize(value, self._format)) def _get_deserialize_callable_from_annotation( self, annotation: typing.Any @@ -689,10 +808,11 @@ def rest_field( *, name: typing.Optional[str] = None, type: typing.Optional[typing.Callable] = None, # pylint: disable=redefined-builtin - readonly: bool = False, + visibility: typing.Optional[typing.List[str]] = None, default: typing.Any = _UNSET, + format: typing.Optional[str] = None, ) -> typing.Any: - return _RestField(name=name, type=type, readonly=readonly, default=default) + return _RestField(name=name, type=type, visibility=visibility, default=default, format=format) def rest_discriminator( diff --git a/sdk/contentsafety/azure-ai-contentsafety/azure/ai/contentsafety/_operations/_operations.py b/sdk/contentsafety/azure-ai-contentsafety/azure/ai/contentsafety/_operations/_operations.py index a018715ef713..0030de052392 100644 --- a/sdk/contentsafety/azure-ai-contentsafety/azure/ai/contentsafety/_operations/_operations.py +++ b/sdk/contentsafety/azure-ai-contentsafety/azure/ai/contentsafety/_operations/_operations.py @@ -22,15 +22,14 @@ ) from azure.core.paging import ItemPaged from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import HttpResponse -from azure.core.rest import HttpRequest +from azure.core.rest import HttpRequest, HttpResponse from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from .. import models as _models from .._model_base import AzureJSONEncoder, _deserialize from .._serialization import Serializer -from .._vendor import ContentSafetyClientMixinABC, _format_url_section +from .._vendor import ContentSafetyClientMixinABC if sys.version_info >= (3, 9): from collections.abc import MutableMapping @@ -88,7 +87,9 @@ def build_content_safety_analyze_image_request(**kwargs: Any) -> HttpRequest: # return HttpRequest(method="POST", url=_url, params=_params, headers=_headers, **kwargs) -def build_content_safety_get_text_blocklist_request(blocklist_name: str, **kwargs: Any) -> HttpRequest: # pylint: disable=name-too-long +def build_content_safety_get_text_blocklist_request( # pylint: disable=name-too-long + blocklist_name: str, **kwargs: Any +) -> HttpRequest: _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) @@ -101,7 +102,7 @@ def build_content_safety_get_text_blocklist_request(blocklist_name: str, **kwarg "blocklistName": _SERIALIZER.url("blocklist_name", blocklist_name, "str"), } - _url: str = _format_url_section(_url, **path_format_arguments) # type: ignore + _url: str = _url.format(**path_format_arguments) # type: ignore # Construct parameters _params["api-version"] = _SERIALIZER.query("api_version", api_version, "str") @@ -112,7 +113,9 @@ def build_content_safety_get_text_blocklist_request(blocklist_name: str, **kwarg return HttpRequest(method="GET", url=_url, params=_params, headers=_headers, **kwargs) -def build_content_safety_create_or_update_text_blocklist_request(blocklist_name: str, **kwargs: Any) -> HttpRequest: # pylint: disable=name-too-long +def build_content_safety_create_or_update_text_blocklist_request( # pylint: disable=name-too-long + blocklist_name: str, **kwargs: Any +) -> HttpRequest: _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) @@ -126,7 +129,7 @@ def build_content_safety_create_or_update_text_blocklist_request(blocklist_name: "blocklistName": _SERIALIZER.url("blocklist_name", blocklist_name, "str"), } - _url: str = _format_url_section(_url, **path_format_arguments) # type: ignore + _url: str = _url.format(**path_format_arguments) # type: ignore # Construct parameters _params["api-version"] = _SERIALIZER.query("api_version", api_version, "str") @@ -139,7 +142,9 @@ def build_content_safety_create_or_update_text_blocklist_request(blocklist_name: return HttpRequest(method="PATCH", url=_url, params=_params, headers=_headers, **kwargs) -def build_content_safety_delete_text_blocklist_request(blocklist_name: str, **kwargs: Any) -> HttpRequest: # pylint: disable=name-too-long +def build_content_safety_delete_text_blocklist_request( # pylint: disable=name-too-long + blocklist_name: str, **kwargs: Any +) -> HttpRequest: _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2023-04-30-preview")) @@ -149,7 +154,7 @@ def build_content_safety_delete_text_blocklist_request(blocklist_name: str, **kw "blocklistName": _SERIALIZER.url("blocklist_name", blocklist_name, "str"), } - _url: str = _format_url_section(_url, **path_format_arguments) # type: ignore + _url: str = _url.format(**path_format_arguments) # type: ignore # Construct parameters _params["api-version"] = _SERIALIZER.query("api_version", api_version, "str") @@ -176,7 +181,9 @@ def build_content_safety_list_text_blocklists_request(**kwargs: Any) -> HttpRequ return HttpRequest(method="GET", url=_url, params=_params, headers=_headers, **kwargs) -def build_content_safety_add_block_items_request(blocklist_name: str, **kwargs: Any) -> HttpRequest: # pylint: disable=name-too-long +def build_content_safety_add_block_items_request( # pylint: disable=name-too-long + blocklist_name: str, **kwargs: Any +) -> HttpRequest: _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) @@ -190,7 +197,7 @@ def build_content_safety_add_block_items_request(blocklist_name: str, **kwargs: "blocklistName": _SERIALIZER.url("blocklist_name", blocklist_name, "str"), } - _url: str = _format_url_section(_url, **path_format_arguments) # type: ignore + _url: str = _url.format(**path_format_arguments) # type: ignore # Construct parameters _params["api-version"] = _SERIALIZER.query("api_version", api_version, "str") @@ -203,7 +210,9 @@ def build_content_safety_add_block_items_request(blocklist_name: str, **kwargs: return HttpRequest(method="POST", url=_url, params=_params, headers=_headers, **kwargs) -def build_content_safety_remove_block_items_request(blocklist_name: str, **kwargs: Any) -> HttpRequest: # pylint: disable=name-too-long +def build_content_safety_remove_block_items_request( # pylint: disable=name-too-long + blocklist_name: str, **kwargs: Any +) -> HttpRequest: _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) @@ -215,7 +224,7 @@ def build_content_safety_remove_block_items_request(blocklist_name: str, **kwarg "blocklistName": _SERIALIZER.url("blocklist_name", blocklist_name, "str"), } - _url: str = _format_url_section(_url, **path_format_arguments) # type: ignore + _url: str = _url.format(**path_format_arguments) # type: ignore # Construct parameters _params["api-version"] = _SERIALIZER.query("api_version", api_version, "str") @@ -243,7 +252,7 @@ def build_content_safety_get_text_blocklist_item_request( # pylint: disable=nam "blockItemId": _SERIALIZER.url("block_item_id", block_item_id, "str"), } - _url: str = _format_url_section(_url, **path_format_arguments) # type: ignore + _url: str = _url.format(**path_format_arguments) # type: ignore # Construct parameters _params["api-version"] = _SERIALIZER.query("api_version", api_version, "str") @@ -255,7 +264,12 @@ def build_content_safety_get_text_blocklist_item_request( # pylint: disable=nam def build_content_safety_list_text_blocklist_items_request( # pylint: disable=name-too-long - blocklist_name: str, *, top: Optional[int] = None, skip: Optional[int] = None, **kwargs: Any + blocklist_name: str, + *, + top: Optional[int] = None, + skip: Optional[int] = None, + maxpagesize: Optional[int] = None, + **kwargs: Any ) -> HttpRequest: _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) @@ -269,7 +283,7 @@ def build_content_safety_list_text_blocklist_items_request( # pylint: disable=n "blocklistName": _SERIALIZER.url("blocklist_name", blocklist_name, "str"), } - _url: str = _format_url_section(_url, **path_format_arguments) # type: ignore + _url: str = _url.format(**path_format_arguments) # type: ignore # Construct parameters _params["api-version"] = _SERIALIZER.query("api_version", api_version, "str") @@ -277,6 +291,8 @@ def build_content_safety_list_text_blocklist_items_request( # pylint: disable=n _params["top"] = _SERIALIZER.query("top", top, "int") if skip is not None: _params["skip"] = _SERIALIZER.query("skip", skip, "int") + if maxpagesize is not None: + _params["maxpagesize"] = _SERIALIZER.query("maxpagesize", maxpagesize, "int") # Construct headers _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") @@ -388,7 +404,7 @@ def analyze_text( if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_content_safety_analyze_text_request( content_type=content_type, @@ -410,6 +426,8 @@ def analyze_text( response = pipeline_response.http_response if response.status_code not in [200]: + if _stream: + response.read() # Load the body in memory and close the socket map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) @@ -526,7 +544,7 @@ def analyze_image( if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_content_safety_analyze_image_request( content_type=content_type, @@ -548,6 +566,8 @@ def analyze_image( response = pipeline_response.http_response if response.status_code not in [200]: + if _stream: + response.read() # Load the body in memory and close the socket map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) @@ -607,6 +627,8 @@ def get_text_blocklist(self, blocklist_name: str, **kwargs: Any) -> _models.Text response = pipeline_response.http_response if response.status_code not in [200]: + if _stream: + response.read() # Load the body in memory and close the socket map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) @@ -731,7 +753,7 @@ def create_or_update_text_blocklist( if isinstance(resource, (IOBase, bytes)): _content = resource else: - _content = json.dumps(resource, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(resource, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_content_safety_create_or_update_text_blocklist_request( blocklist_name=blocklist_name, @@ -754,6 +776,8 @@ def create_or_update_text_blocklist( response = pipeline_response.http_response if response.status_code not in [200, 201]: + if _stream: + response.read() # Load the body in memory and close the socket map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) @@ -822,6 +846,8 @@ def delete_text_blocklist( # pylint: disable=inconsistent-return-statements response = pipeline_response.http_response if response.status_code not in [204]: + if _stream: + response.read() # Load the body in memory and close the socket map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) @@ -905,6 +931,8 @@ def get_next(next_link=None): response = pipeline_response.http_response if response.status_code not in [200]: + if _stream: + response.read() # Load the body in memory and close the socket map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) @@ -1023,7 +1051,7 @@ def add_block_items( if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_content_safety_add_block_items_request( blocklist_name=blocklist_name, @@ -1046,6 +1074,8 @@ def add_block_items( response = pipeline_response.http_response if response.status_code not in [200]: + if _stream: + response.read() # Load the body in memory and close the socket map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) @@ -1170,7 +1200,7 @@ def remove_block_items( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_content_safety_remove_block_items_request( blocklist_name=blocklist_name, @@ -1193,6 +1223,8 @@ def remove_block_items( # pylint: disable=inconsistent-return-statements response = pipeline_response.http_response if response.status_code not in [204]: + if _stream: + response.read() # Load the body in memory and close the socket map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) @@ -1248,6 +1280,8 @@ def get_text_blocklist_item(self, blocklist_name: str, block_item_id: str, **kwa response = pipeline_response.http_response if response.status_code not in [200]: + if _stream: + response.read() # Load the body in memory and close the socket map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) @@ -1282,6 +1316,7 @@ def list_text_blocklist_items( _headers = kwargs.pop("headers", {}) or {} _params = kwargs.pop("params", {}) or {} + maxpagesize = kwargs.pop("maxpagesize", None) cls: ClsType[List[_models.TextBlockItem]] = kwargs.pop("cls", None) error_map = { @@ -1299,6 +1334,7 @@ def prepare_request(next_link=None): blocklist_name=blocklist_name, top=top, skip=skip, + maxpagesize=maxpagesize, api_version=self._config.api_version, headers=_headers, params=_params, @@ -1349,6 +1385,8 @@ def get_next(next_link=None): response = pipeline_response.http_response if response.status_code not in [200]: + if _stream: + response.read() # Load the body in memory and close the socket map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) diff --git a/sdk/contentsafety/azure-ai-contentsafety/azure/ai/contentsafety/_serialization.py b/sdk/contentsafety/azure-ai-contentsafety/azure/ai/contentsafety/_serialization.py index 842ae727fbbc..9f3e29b11388 100644 --- a/sdk/contentsafety/azure-ai-contentsafety/azure/ai/contentsafety/_serialization.py +++ b/sdk/contentsafety/azure-ai-contentsafety/azure/ai/contentsafety/_serialization.py @@ -662,8 +662,9 @@ def _serialize(self, target_obj, data_type=None, **kwargs): _serialized.update(_new_attr) # type: ignore _new_attr = _new_attr[k] # type: ignore _serialized = _serialized[k] - except ValueError: - continue + except ValueError as err: + if isinstance(err, SerializationError): + raise except (AttributeError, KeyError, TypeError) as err: msg = "Attribute {} in object {} cannot be serialized.\n{}".format(attr_name, class_name, str(target_obj)) @@ -729,6 +730,8 @@ def url(self, name, data, data_type, **kwargs): if kwargs.get("skip_quote") is True: output = str(output) + # https://github.com/Azure/autorest.python/issues/2063 + output = output.replace("{", quote("{")).replace("}", quote("}")) else: output = quote(str(output), safe="") except SerializationError: @@ -741,6 +744,8 @@ def query(self, name, data, data_type, **kwargs): :param data: The data to be serialized. :param str data_type: The type to be serialized from. + :keyword bool skip_quote: Whether to skip quote the serialized result. + Defaults to False. :rtype: str :raises: TypeError if serialization fails. :raises: ValueError if data is None @@ -749,10 +754,8 @@ def query(self, name, data, data_type, **kwargs): # Treat the list aside, since we don't want to encode the div separator if data_type.startswith("["): internal_data_type = data_type[1:-1] - data = [self.serialize_data(d, internal_data_type, **kwargs) if d is not None else "" for d in data] - if not kwargs.get("skip_quote", False): - data = [quote(str(d), safe="") for d in data] - return str(self.serialize_iter(data, internal_data_type, **kwargs)) + do_quote = not kwargs.get("skip_quote", False) + return str(self.serialize_iter(data, internal_data_type, do_quote=do_quote, **kwargs)) # Not a list, regular serialization output = self.serialize_data(data, data_type, **kwargs) @@ -891,6 +894,8 @@ def serialize_iter(self, data, iter_type, div=None, **kwargs): not be None or empty. :param str div: If set, this str will be used to combine the elements in the iterable into a combined string. Default is 'None'. + :keyword bool do_quote: Whether to quote the serialized result of each iterable element. + Defaults to False. :rtype: list, str """ if isinstance(data, str): @@ -903,9 +908,14 @@ def serialize_iter(self, data, iter_type, div=None, **kwargs): for d in data: try: serialized.append(self.serialize_data(d, iter_type, **kwargs)) - except ValueError: + except ValueError as err: + if isinstance(err, SerializationError): + raise serialized.append(None) + if kwargs.get("do_quote", False): + serialized = ["" if s is None else quote(str(s), safe="") for s in serialized] + if div: serialized = ["" if s is None else str(s) for s in serialized] serialized = div.join(serialized) @@ -950,7 +960,9 @@ def serialize_dict(self, attr, dict_type, **kwargs): for key, value in attr.items(): try: serialized[self.serialize_unicode(key)] = self.serialize_data(value, dict_type, **kwargs) - except ValueError: + except ValueError as err: + if isinstance(err, SerializationError): + raise serialized[self.serialize_unicode(key)] = None if "xml" in serialization_ctxt: @@ -1900,7 +1912,7 @@ def deserialize_date(attr): if re.search(r"[^\W\d_]", attr, re.I + re.U): # type: ignore raise DeserializationError("Date must have only digits and -. Received: %s" % attr) # This must NOT use defaultmonth/defaultday. Using None ensure this raises an exception. - return isodate.parse_date(attr, defaultmonth=None, defaultday=None) + return isodate.parse_date(attr, defaultmonth=0, defaultday=0) @staticmethod def deserialize_time(attr): diff --git a/sdk/contentsafety/azure-ai-contentsafety/azure/ai/contentsafety/_vendor.py b/sdk/contentsafety/azure-ai-contentsafety/azure/ai/contentsafety/_vendor.py index 3b365ec56658..80b9ac2054b8 100644 --- a/sdk/contentsafety/azure-ai-contentsafety/azure/ai/contentsafety/_vendor.py +++ b/sdk/contentsafety/azure-ai-contentsafety/azure/ai/contentsafety/_vendor.py @@ -6,7 +6,7 @@ # -------------------------------------------------------------------------- from abc import ABC -from typing import List, TYPE_CHECKING, cast +from typing import TYPE_CHECKING from ._configuration import ContentSafetyClientConfiguration @@ -17,18 +17,6 @@ from ._serialization import Deserializer, Serializer -def _format_url_section(template, **kwargs): - components = template.split("/") - while components: - try: - return template.format(**kwargs) - except KeyError as key: - # Need the cast, as for some reasons "split" is typed as list[str | Any] - formatted_components = cast(List[str], template.split("/")) - components = [c for c in formatted_components if "{}".format(key.args[0]) not in c] - template = "/".join(components) - - class ContentSafetyClientMixinABC(ABC): """DO NOT use this class. It is for internal typing use only.""" diff --git a/sdk/contentsafety/azure-ai-contentsafety/azure/ai/contentsafety/_version.py b/sdk/contentsafety/azure-ai-contentsafety/azure/ai/contentsafety/_version.py index bbcd28b4aa67..be71c81bd282 100644 --- a/sdk/contentsafety/azure-ai-contentsafety/azure/ai/contentsafety/_version.py +++ b/sdk/contentsafety/azure-ai-contentsafety/azure/ai/contentsafety/_version.py @@ -6,4 +6,4 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -VERSION = "1.0.0b2" +VERSION = "1.0.0b1" diff --git a/sdk/contentsafety/azure-ai-contentsafety/azure/ai/contentsafety/aio/_configuration.py b/sdk/contentsafety/azure-ai-contentsafety/azure/ai/contentsafety/aio/_configuration.py index 975a0c7aeb73..6510d417bb09 100644 --- a/sdk/contentsafety/azure-ai-contentsafety/azure/ai/contentsafety/aio/_configuration.py +++ b/sdk/contentsafety/azure-ai-contentsafety/azure/ai/contentsafety/aio/_configuration.py @@ -15,7 +15,7 @@ from .._version import VERSION -class ContentSafetyClientConfiguration(Configuration): # pylint: disable=too-many-instance-attributes +class ContentSafetyClientConfiguration(Configuration): # pylint: disable=too-many-instance-attributes,name-too-long """Configuration for ContentSafetyClient. Note that all parameters used to create this instance are saved as instance diff --git a/sdk/contentsafety/azure-ai-contentsafety/azure/ai/contentsafety/aio/_operations/_operations.py b/sdk/contentsafety/azure-ai-contentsafety/azure/ai/contentsafety/aio/_operations/_operations.py index d819ac3172a8..da7d211ee033 100644 --- a/sdk/contentsafety/azure-ai-contentsafety/azure/ai/contentsafety/aio/_operations/_operations.py +++ b/sdk/contentsafety/azure-ai-contentsafety/azure/ai/contentsafety/aio/_operations/_operations.py @@ -22,8 +22,7 @@ map_error, ) from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import AsyncHttpResponse -from azure.core.rest import HttpRequest +from azure.core.rest import AsyncHttpResponse, HttpRequest from azure.core.tracing.decorator import distributed_trace from azure.core.tracing.decorator_async import distributed_trace_async from azure.core.utils import case_insensitive_dict @@ -157,7 +156,7 @@ async def analyze_text( if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_content_safety_analyze_text_request( content_type=content_type, @@ -179,6 +178,8 @@ async def analyze_text( response = pipeline_response.http_response if response.status_code not in [200]: + if _stream: + await response.read() # Load the body in memory and close the socket map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) @@ -295,7 +296,7 @@ async def analyze_image( if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_content_safety_analyze_image_request( content_type=content_type, @@ -317,6 +318,8 @@ async def analyze_image( response = pipeline_response.http_response if response.status_code not in [200]: + if _stream: + await response.read() # Load the body in memory and close the socket map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) @@ -376,6 +379,8 @@ async def get_text_blocklist(self, blocklist_name: str, **kwargs: Any) -> _model response = pipeline_response.http_response if response.status_code not in [200]: + if _stream: + await response.read() # Load the body in memory and close the socket map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) @@ -500,7 +505,7 @@ async def create_or_update_text_blocklist( if isinstance(resource, (IOBase, bytes)): _content = resource else: - _content = json.dumps(resource, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(resource, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_content_safety_create_or_update_text_blocklist_request( blocklist_name=blocklist_name, @@ -523,6 +528,8 @@ async def create_or_update_text_blocklist( response = pipeline_response.http_response if response.status_code not in [200, 201]: + if _stream: + await response.read() # Load the body in memory and close the socket map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) @@ -591,6 +598,8 @@ async def delete_text_blocklist( # pylint: disable=inconsistent-return-statemen response = pipeline_response.http_response if response.status_code not in [204]: + if _stream: + await response.read() # Load the body in memory and close the socket map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) @@ -674,6 +683,8 @@ async def get_next(next_link=None): response = pipeline_response.http_response if response.status_code not in [200]: + if _stream: + await response.read() # Load the body in memory and close the socket map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) @@ -792,7 +803,7 @@ async def add_block_items( if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_content_safety_add_block_items_request( blocklist_name=blocklist_name, @@ -815,6 +826,8 @@ async def add_block_items( response = pipeline_response.http_response if response.status_code not in [200]: + if _stream: + await response.read() # Load the body in memory and close the socket map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) @@ -939,7 +952,7 @@ async def remove_block_items( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_content_safety_remove_block_items_request( blocklist_name=blocklist_name, @@ -962,6 +975,8 @@ async def remove_block_items( # pylint: disable=inconsistent-return-statements response = pipeline_response.http_response if response.status_code not in [204]: + if _stream: + await response.read() # Load the body in memory and close the socket map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) @@ -1019,6 +1034,8 @@ async def get_text_blocklist_item( response = pipeline_response.http_response if response.status_code not in [200]: + if _stream: + await response.read() # Load the body in memory and close the socket map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) @@ -1053,6 +1070,7 @@ def list_text_blocklist_items( _headers = kwargs.pop("headers", {}) or {} _params = kwargs.pop("params", {}) or {} + maxpagesize = kwargs.pop("maxpagesize", None) cls: ClsType[List[_models.TextBlockItem]] = kwargs.pop("cls", None) error_map = { @@ -1070,6 +1088,7 @@ def prepare_request(next_link=None): blocklist_name=blocklist_name, top=top, skip=skip, + maxpagesize=maxpagesize, api_version=self._config.api_version, headers=_headers, params=_params, @@ -1120,6 +1139,8 @@ async def get_next(next_link=None): response = pipeline_response.http_response if response.status_code not in [200]: + if _stream: + await response.read() # Load the body in memory and close the socket map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) diff --git a/sdk/contentsafety/azure-ai-contentsafety/azure/ai/contentsafety/models/_models.py b/sdk/contentsafety/azure-ai-contentsafety/azure/ai/contentsafety/models/_models.py index 8aaae1049ae8..f73a30b99f56 100644 --- a/sdk/contentsafety/azure-ai-contentsafety/azure/ai/contentsafety/models/_models.py +++ b/sdk/contentsafety/azure-ai-contentsafety/azure/ai/contentsafety/models/_models.py @@ -316,7 +316,7 @@ class ImageData(_model_base.Model): :vartype blob_url: str """ - content: Optional[bytes] = rest_field() + content: Optional[bytes] = rest_field(format="base64") """Base64 encoding of image.""" blob_url: Optional[str] = rest_field(name="blobUrl") """The blob url of image.""" @@ -425,7 +425,7 @@ class TextBlockItem(_model_base.Model): :vartype text: str """ - block_item_id: str = rest_field(name="blockItemId") + block_item_id: str = rest_field(name="blockItemId", visibility=["read", "create", "query"]) """Block Item Id. It will be uuid. Required.""" description: Optional[str] = rest_field() """Block item description.""" @@ -500,7 +500,7 @@ class TextBlocklist(_model_base.Model): :vartype description: str """ - blocklist_name: str = rest_field(name="blocklistName") + blocklist_name: str = rest_field(name="blocklistName", visibility=["read", "create", "query"]) """Text blocklist name. Required.""" description: Optional[str] = rest_field() """Text blocklist description.""" diff --git a/sdk/contentsafety/azure-ai-contentsafety/samples/sample_analyze_image.py b/sdk/contentsafety/azure-ai-contentsafety/samples/sample_analyze_image.py index 2901b3a63d47..f3f52c8e3f81 100644 --- a/sdk/contentsafety/azure-ai-contentsafety/samples/sample_analyze_image.py +++ b/sdk/contentsafety/azure-ai-contentsafety/samples/sample_analyze_image.py @@ -6,9 +6,10 @@ # license information. # -------------------------------------------------------------------------- + def analyze_image(): # [START analyze_image] - + import os from azure.ai.contentsafety import ContentSafetyClient from azure.core.credentials import AzureKeyCredential @@ -51,4 +52,4 @@ def analyze_image(): if __name__ == "__main__": - analyze_image() \ No newline at end of file + analyze_image() diff --git a/sdk/contentsafety/azure-ai-contentsafety/samples/sample_analyze_image_async.py b/sdk/contentsafety/azure-ai-contentsafety/samples/sample_analyze_image_async.py index c53674fae41e..14c3ca7cc5e1 100644 --- a/sdk/contentsafety/azure-ai-contentsafety/samples/sample_analyze_image_async.py +++ b/sdk/contentsafety/azure-ai-contentsafety/samples/sample_analyze_image_async.py @@ -7,6 +7,7 @@ # -------------------------------------------------------------------------- import asyncio + async def analyze_image_async(): # [START analyze_image_async] @@ -51,9 +52,11 @@ async def analyze_image_async(): # [END analyze_image_async] + async def main(): await analyze_image_async() + if __name__ == "__main__": loop = asyncio.get_event_loop() - loop.run_until_complete(main()) \ No newline at end of file + loop.run_until_complete(main()) diff --git a/sdk/contentsafety/azure-ai-contentsafety/samples/sample_analyze_text.py b/sdk/contentsafety/azure-ai-contentsafety/samples/sample_analyze_text.py index bfc0717c4fac..1d1e38c37e78 100644 --- a/sdk/contentsafety/azure-ai-contentsafety/samples/sample_analyze_text.py +++ b/sdk/contentsafety/azure-ai-contentsafety/samples/sample_analyze_text.py @@ -48,5 +48,6 @@ def analyze_text(): # [END analyze_text] + if __name__ == "__main__": - analyze_text() \ No newline at end of file + analyze_text() diff --git a/sdk/contentsafety/azure-ai-contentsafety/samples/sample_analyze_text_async.py b/sdk/contentsafety/azure-ai-contentsafety/samples/sample_analyze_text_async.py index 445d9bde8308..c79d18dcb03d 100644 --- a/sdk/contentsafety/azure-ai-contentsafety/samples/sample_analyze_text_async.py +++ b/sdk/contentsafety/azure-ai-contentsafety/samples/sample_analyze_text_async.py @@ -7,6 +7,7 @@ # -------------------------------------------------------------------------- import asyncio + async def analyze_text_async(): # [START analyze_text_async] @@ -49,9 +50,11 @@ async def analyze_text_async(): # [END analyze_text_async] + async def main(): await analyze_text_async() + if __name__ == "__main__": loop = asyncio.get_event_loop() - loop.run_until_complete(main()) \ No newline at end of file + loop.run_until_complete(main()) diff --git a/sdk/contentsafety/azure-ai-contentsafety/samples/sample_manage_blocklist.py b/sdk/contentsafety/azure-ai-contentsafety/samples/sample_manage_blocklist.py index 1e8a108de823..2888cf16578c 100644 --- a/sdk/contentsafety/azure-ai-contentsafety/samples/sample_manage_blocklist.py +++ b/sdk/contentsafety/azure-ai-contentsafety/samples/sample_manage_blocklist.py @@ -6,6 +6,7 @@ # license information. # -------------------------------------------------------------------------- + def create_or_update_text_blocklist(): # [START create_or_update_text_blocklist] @@ -24,7 +25,9 @@ def create_or_update_text_blocklist(): blocklist_description = "Test blocklist management." try: - blocklist = client.create_or_update_text_blocklist(blocklist_name=blocklist_name, resource={"description": blocklist_description}) + blocklist = client.create_or_update_text_blocklist( + blocklist_name=blocklist_name, resource={"description": blocklist_description} + ) if blocklist: print("\nBlocklist created or updated: ") print(f"Name: {blocklist.blocklist_name}, Description: {blocklist.description}") @@ -39,16 +42,14 @@ def create_or_update_text_blocklist(): # [END create_or_update_text_blocklist] + def add_block_items(): # [START add_block_items] import os from azure.ai.contentsafety import ContentSafetyClient from azure.core.credentials import AzureKeyCredential - from azure.ai.contentsafety.models import ( - TextBlockItemInfo, - AddBlockItemsOptions - ) + from azure.ai.contentsafety.models import TextBlockItemInfo, AddBlockItemsOptions from azure.core.exceptions import HttpResponseError key = os.environ["CONTENT_SAFETY_KEY"] @@ -70,7 +71,9 @@ def add_block_items(): if result and result.value: print("\nBlock items added: ") for block_item in result.value: - print(f"BlockItemId: {block_item.block_item_id}, Text: {block_item.text}, Description: {block_item.description}") + print( + f"BlockItemId: {block_item.block_item_id}, Text: {block_item.text}, Description: {block_item.description}" + ) except HttpResponseError as e: print("\nAdd block items failed: ") if e.error: @@ -82,6 +85,7 @@ def add_block_items(): # [END add_block_items] + def analyze_text_with_blocklists(): # [START analyze_text_with_blocklists] @@ -102,12 +106,16 @@ def analyze_text_with_blocklists(): try: # After you edit your blocklist, it usually takes effect in 5 minutes, please wait some time before analyzing with blocklist after editing. - analysis_result = client.analyze_text(AnalyzeTextOptions(text=input_text, blocklist_names=[blocklist_name], break_by_blocklists=False)) + analysis_result = client.analyze_text( + AnalyzeTextOptions(text=input_text, blocklist_names=[blocklist_name], break_by_blocklists=False) + ) if analysis_result and analysis_result.blocklists_match_results: print("\nBlocklist match results: ") for match_result in analysis_result.blocklists_match_results: print(f"Block item was hit in text, Offset={match_result.offset}, Length={match_result.length}.") - print(f"BlocklistName: {match_result.blocklist_name}, BlockItemId: {match_result.block_item_id}, BlockItemText: {match_result.block_item_text}") + print( + f"BlocklistName: {match_result.blocklist_name}, BlockItemId: {match_result.block_item_id}, BlockItemText: {match_result.block_item_text}" + ) except HttpResponseError as e: print("\nAnalyze text failed: ") if e.error: @@ -119,6 +127,7 @@ def analyze_text_with_blocklists(): # [END analyze_text_with_blocklists] + def list_text_blocklists(): # [START list_text_blocklists] @@ -150,6 +159,7 @@ def list_text_blocklists(): # [END list_text_blocklists] + def get_text_blocklist(): # [START get_text_blocklist] @@ -182,6 +192,7 @@ def get_text_blocklist(): # [END get_text_blocklist] + def list_block_items(): # [START list_block_items] @@ -203,7 +214,9 @@ def list_block_items(): if block_items: print("\nList block items: ") for block_item in block_items: - print(f"BlockItemId: {block_item.block_item_id}, Text: {block_item.text}, Description: {block_item.description}") + print( + f"BlockItemId: {block_item.block_item_id}, Text: {block_item.text}, Description: {block_item.description}" + ) except HttpResponseError as e: print("\nList block items failed: ") if e.error: @@ -215,6 +228,7 @@ def list_block_items(): # [END list_block_items] + def get_block_item(): # [START get_block_item] @@ -244,12 +258,11 @@ def get_block_item(): block_item_id = add_result.value[0].block_item_id # Get this blockItem by blockItemId - block_item = client.get_text_blocklist_item( - blocklist_name=blocklist_name, - block_item_id= block_item_id - ) + block_item = client.get_text_blocklist_item(blocklist_name=blocklist_name, block_item_id=block_item_id) print("\nGet blockitem: ") - print(f"BlockItemId: {block_item.block_item_id}, Text: {block_item.text}, Description: {block_item.description}") + print( + f"BlockItemId: {block_item.block_item_id}, Text: {block_item.text}, Description: {block_item.description}" + ) except HttpResponseError as e: print("\nGet block item failed: ") if e.error: @@ -261,17 +274,14 @@ def get_block_item(): # [END get_block_item] + def remove_block_items(): # [START remove_block_items] import os from azure.ai.contentsafety import ContentSafetyClient from azure.core.credentials import AzureKeyCredential - from azure.ai.contentsafety.models import ( - TextBlockItemInfo, - AddBlockItemsOptions, - RemoveBlockItemsOptions - ) + from azure.ai.contentsafety.models import TextBlockItemInfo, AddBlockItemsOptions, RemoveBlockItemsOptions from azure.core.exceptions import HttpResponseError key = os.environ["CONTENT_SAFETY_KEY"] @@ -295,8 +305,7 @@ def remove_block_items(): # Remove this blockItem by blockItemId client.remove_block_items( - blocklist_name=blocklist_name, - body=RemoveBlockItemsOptions(block_item_ids=[block_item_id]) + blocklist_name=blocklist_name, body=RemoveBlockItemsOptions(block_item_ids=[block_item_id]) ) print(f"\nRemoved blockItem: {add_result.value[0].block_item_id}") except HttpResponseError as e: @@ -310,6 +319,7 @@ def remove_block_items(): # [END remove_block_items] + def delete_blocklist(): # [START delete_blocklist] @@ -340,6 +350,7 @@ def delete_blocklist(): # [END delete_blocklist] + if __name__ == "__main__": create_or_update_text_blocklist() add_block_items() @@ -349,4 +360,4 @@ def delete_blocklist(): list_block_items() get_block_item() remove_block_items() - delete_blocklist() \ No newline at end of file + delete_blocklist() diff --git a/sdk/contentsafety/azure-ai-contentsafety/sdk_packaging.toml b/sdk/contentsafety/azure-ai-contentsafety/sdk_packaging.toml new file mode 100644 index 000000000000..e7687fdae93b --- /dev/null +++ b/sdk/contentsafety/azure-ai-contentsafety/sdk_packaging.toml @@ -0,0 +1,2 @@ +[packaging] +auto_update = false \ No newline at end of file diff --git a/sdk/contentsafety/azure-ai-contentsafety/setup.py b/sdk/contentsafety/azure-ai-contentsafety/setup.py index c57e0fd20fae..9cb85d5c6503 100644 --- a/sdk/contentsafety/azure-ai-contentsafety/setup.py +++ b/sdk/contentsafety/azure-ai-contentsafety/setup.py @@ -6,66 +6,26 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- # coding: utf-8 - -import os -import re from setuptools import setup, find_packages PACKAGE_NAME = "azure-ai-contentsafety" -PACKAGE_PPRINT_NAME = "Azure AI Content Safety" - -# a-b-c => a/b/c -package_folder_path = PACKAGE_NAME.replace("-", "/") - -# Version extraction inspired from 'requests' -with open(os.path.join(package_folder_path, "_version.py"), "r") as fd: - version = re.search(r'^VERSION\s*=\s*[\'"]([^\'"]*)[\'"]', fd.read(), re.MULTILINE).group(1) - -if not version: - raise RuntimeError("Cannot find version information") - - +version = "1.0.0b1" setup( name=PACKAGE_NAME, version=version, - description="Microsoft {} Client Library for Python".format(PACKAGE_PPRINT_NAME), - long_description=open("README.md", "r").read(), - long_description_content_type="text/markdown", - license="MIT License", - author="Microsoft Corporation", - author_email="azpysdkhelp@microsoft.com", - url="https://github.com/Azure/azure-sdk-for-python/tree/main/sdk", + description="azure-ai-contentsafety", + author_email="", + url="", keywords="azure, azure sdk", - classifiers=[ - "Development Status :: 4 - Beta", - "Programming Language :: Python", - "Programming Language :: Python :: 3 :: Only", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "License :: OSI Approved :: MIT License", - ], - zip_safe=False, - packages=find_packages( - exclude=[ - "tests", - # Exclude packages that will be covered by PEP420 or nspkg - "azure", - "azure.ai", - ] - ), + packages=find_packages(), include_package_data=True, - package_data={ - "azure.ai.contentsafety": ["py.typed"], - }, install_requires=[ "isodate<1.0.0,>=0.6.1", - "azure-core<2.0.0,>=1.24.0", + "azure-core<2.0.0,>=1.28.0", "typing-extensions>=4.3.0; python_version<'3.8.0'", ], - python_requires=">=3.7", + long_description="""\ + Analyze harmful content. + """, ) diff --git a/sdk/contentsafety/azure-ai-contentsafety/tests/test_content_safety.py b/sdk/contentsafety/azure-ai-contentsafety/tests/test_content_safety.py index 9c6e83685afa..bb152ddc8c88 100644 --- a/sdk/contentsafety/azure-ai-contentsafety/tests/test_content_safety.py +++ b/sdk/contentsafety/azure-ai-contentsafety/tests/test_content_safety.py @@ -33,7 +33,9 @@ def test_analyze_text(self, content_safety_endpoint, content_safety_key): client = self.create_client(content_safety_endpoint, content_safety_key) assert client is not None - text_path = os.path.abspath(os.path.join(os.path.abspath(__file__), "..", "..", "./samples/sample_data/text.txt")) + text_path = os.path.abspath( + os.path.join(os.path.abspath(__file__), "..", "..", "./samples/sample_data/text.txt") + ) with open(text_path) as f: request = AnalyzeTextOptions(text=f.readline(), categories=[]) response = client.analyze_text(request) @@ -51,7 +53,9 @@ def test_analyze_image(self, content_safety_endpoint, content_safety_key): client = self.create_client(content_safety_endpoint, content_safety_key) assert client is not None - image_path = os.path.abspath(os.path.join(os.path.abspath(__file__), "..", "..", "./samples/sample_data/image.jpg")) + image_path = os.path.abspath( + os.path.join(os.path.abspath(__file__), "..", "..", "./samples/sample_data/image.jpg") + ) with open(image_path, "rb") as file: request = AnalyzeImageOptions(image=ImageData(content=file.read())) response = client.analyze_image(request) diff --git a/sdk/contentsafety/azure-ai-contentsafety/tsp-location.yaml b/sdk/contentsafety/azure-ai-contentsafety/tsp-location.yaml index 03b2e7251c2c..7418bbc4980e 100644 --- a/sdk/contentsafety/azure-ai-contentsafety/tsp-location.yaml +++ b/sdk/contentsafety/azure-ai-contentsafety/tsp-location.yaml @@ -1,3 +1,5 @@ -directory: specification/cognitiveservices/ContentSafety -commit: 17c41d0c4a96294bf563b009c9c72093963b529f +commit: 924a98b9c8d71496dccd02b92fecf3902d2ea025 repo: Azure/azure-rest-api-specs +directory: specification/cognitiveservices/ContentSafety +additionalDirectories: [] +