Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 38 additions & 40 deletions src/inmanta/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

Contact: code@inmanta.com
"""
from typing import Set, Dict, List, Optional
from configparser import RawConfigParser

from inmanta.const import ResourceState
Expand All @@ -33,9 +32,13 @@
import inmanta.db.versions

from inmanta.resources import Id
from inmanta import const
from inmanta import const, util
import asyncpg

from inmanta.types import JsonType
from typing import Dict, List, Union, Set, Optional, Any, Tuple


LOGGER = logging.getLogger(__name__)

DBLIMIT = 100000
Expand All @@ -44,6 +47,11 @@
# TODO: difference between None and not set


def json_encode(value: JsonType) -> str:
# see json_encode in tornado.escape
return json.dumps(value, default=util.custom_json_encoder)


class Field(object):

def __init__(self, field_type, required=False, unique=False, reference=False, part_of_primary_key=False, **kwargs):
Expand All @@ -67,32 +75,32 @@ def get_field_type(self):

field_type = property(get_field_type)

def is_required(self):
def is_required(self) -> bool:
return self._required

required = property(is_required)

def get_default(self):
def get_default(self) -> bool:
return self._default

default = property(get_default)

def get_default_value(self):
def get_default_value(self) -> Any:
return copy.copy(self._default_value)

default_value = property(get_default_value)

def is_unique(self):
def is_unique(self) -> bool:
return self._unique

unique = property(is_unique)

def is_reference(self):
def is_reference(self) -> bool:
return self._reference

reference = property(is_reference)

def is_part_of_primary_key(self):
def is_part_of_primary_key(self) -> bool:
return self._part_of_primary_key

part_of_primary_key = property(is_part_of_primary_key)
Expand All @@ -110,7 +118,7 @@ class DataDocument(object):
def __init__(self, **kwargs):
self._data = kwargs

def to_dict(self):
def to_dict(self) -> JsonType:
"""
Return a dict representation of this object.
"""
Expand Down Expand Up @@ -140,17 +148,17 @@ class BaseDocument(object, metaclass=DocumentMeta):
_connection_pool = None

@classmethod
def table_name(cls):
def table_name(cls) -> str:
"""
Return the name of the collection
"""
return cls.__name__.lower()

def __init__(self, from_postgres=False, **kwargs):
def __init__(self, from_postgres: bool=False, **kwargs: Any) -> None:
self.__fields = self._create_dict_wrapper(from_postgres, kwargs)

@classmethod
def _create_dict(cls, from_postgres, kwargs):
def _create_dict(cls, from_postgres: bool, kwargs: Dict[str, Any]) -> JsonType:
result = {}
fields = cls._fields.copy()

Expand Down Expand Up @@ -198,11 +206,11 @@ def _create_dict(cls, from_postgres, kwargs):
return result

@classmethod
def _get_names_of_primary_key_fields(cls):
def _get_names_of_primary_key_fields(cls) -> List[str]:
fields = cls._fields.copy()
return [name for name, value in fields.items() if value.is_part_of_primary_key()]

def _get_filter_on_primary_key_fields(self, offset=1):
def _get_filter_on_primary_key_fields(self, offset: int=1) -> Tuple[str, List[Any]]:
names_primary_key_fields = self._get_names_of_primary_key_fields()
query = {field_name: self.__getattribute__(field_name) for field_name in names_primary_key_fields}
return self._get_composed_filter(offset=offset, **query)
Expand Down Expand Up @@ -268,12 +276,12 @@ def __setattr__(self, name, value):
raise AttributeError(name)

@classmethod
def _convert_field_names_to_db_column_names(cls, field_dict):
def _convert_field_names_to_db_column_names(cls, field_dict: Dict[str, str]) -> Dict[str, str]:
return field_dict

def _get_column_names_and_values(self):
column_names = []
values = []
def _get_column_names_and_values(self) -> Tuple[List[str], List[str]]:
column_names: List[str] = []
values: List[str] = []
for name, typing in self._fields.items():
if self._fields[name].reference:
continue
Expand Down Expand Up @@ -399,7 +407,7 @@ async def update_fields(self, **kwargs):
await self._execute_query(query, *values)

@classmethod
async def get_by_id(cls, doc_id: uuid.UUID):
async def get_by_id(cls, doc_id: uuid.UUID) -> Optional["BaseDocument"]:
"""
Get a specific document based on its ID

Expand All @@ -408,6 +416,7 @@ async def get_by_id(cls, doc_id: uuid.UUID):
result = await cls.get_list(id=doc_id)
if len(result) > 0:
return result[0]
return None

@classmethod
async def get_one(cls, **query):
Expand Down Expand Up @@ -449,7 +458,7 @@ async def delete_all(cls, connection=None, **query):
return record_count

@classmethod
def _get_composed_filter(cls, offset=1, col_name_prefix=None, **query):
def _get_composed_filter(cls, offset: int=1, col_name_prefix: str=None, **query: Any) -> Tuple[str, List[Any]]:
filter_statements = []
values = []
index_count = max(1, offset)
Expand All @@ -463,7 +472,7 @@ def _get_composed_filter(cls, offset=1, col_name_prefix=None, **query):
return (filter_as_string, values)

@classmethod
def _get_filter(cls, name, value, index, col_name_prefix=None):
def _get_filter(cls, name: str, value: Any, index: int, col_name_prefix: str=None) -> Tuple[str, Any]:
if value is None:
return (name + " IS NULL", None)
filter_statement = name + "=$" + str(index)
Expand All @@ -473,12 +482,12 @@ def _get_filter(cls, name, value, index, col_name_prefix=None):
return (filter_statement, value)

@classmethod
def _get_value(cls, value):
def _get_value(cls, value: Any) -> Any:
if isinstance(value, dict):
return json.dumps(cls._get_value_of_dict(value))
return json_encode(value)

if isinstance(value, DataDocument) or issubclass(value.__class__, DataDocument):
return json.dumps(cls._get_value_of_dict(value.to_dict()))
return json_encode(value)

if isinstance(value, list):
return [cls._get_value(x) for x in value]
Expand All @@ -488,19 +497,8 @@ def _get_value(cls, value):

if isinstance(value, uuid.UUID):
return str(value)
return value

@classmethod
def _get_value_of_dict(cls, dct):
result = {}
for key, value in dct.items():
if isinstance(value, datetime.datetime):
result[key] = value.strftime("%Y-%m-%dT%H:%M:%S.%f")
elif isinstance(value, dict):
result[key] = cls._get_value_of_dict(value)
else:
result[key] = cls._get_value(value)
return result
return value

async def delete(self, connection=None):
"""
Expand All @@ -525,7 +523,7 @@ async def select_query(cls, query, values, no_obj=False):
result.append(cls(from_postgres=True, **record))
return result

def to_dict(self):
def to_dict(self) -> JsonType:
"""
Return a dict representing the document
"""
Expand Down Expand Up @@ -560,7 +558,7 @@ class Project(BaseDocument):
name = Field(field_type=str, required=True, unique=True)


def convert_boolean(value):
def convert_boolean(value: Any) -> bool:
if isinstance(value, bool):
return value

Expand All @@ -569,7 +567,7 @@ def convert_boolean(value):
return RawConfigParser.BOOLEAN_STATES[value.lower()]


def convert_int(value):
def convert_int(value: Any) -> Union[int, float]:
if isinstance(value, (int, float)):
return value

Expand All @@ -581,7 +579,7 @@ def convert_int(value):
return f_value


def convert_agent_map(value):
def convert_agent_map(value: Dict[str, str]) -> Dict[str, str]:
if not isinstance(value, dict):
raise ValueError("Agent map should be a dict")

Expand Down
27 changes: 5 additions & 22 deletions src/inmanta/protocol/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import inspect
import enum
import uuid
import datetime
import logging
import json
import gzip
Expand All @@ -32,7 +31,7 @@
from urllib import parse
from typing import Any, Dict, List, Optional, Union, Tuple, Set, Callable, Generator, cast, TYPE_CHECKING # noqa: F401

from inmanta import execute, const
from inmanta import execute, const, util
from inmanta import config as inmanta_config
from inmanta.types import JsonType
from . import exceptions
Expand Down Expand Up @@ -375,27 +374,11 @@ def custom_json_encoder(o: object) -> Union[Dict, str, List]:
"""
A custom json encoder that knows how to encode other types commonly used by Inmanta
"""
if isinstance(o, uuid.UUID):
return str(o)

if isinstance(o, datetime.datetime):
return o.isoformat(timespec='microseconds')

if hasattr(o, "to_dict"):
return o.to_dict()

if isinstance(o, enum.Enum):
return o.name

if isinstance(o, Exception):
# Logs can push exceptions through RPC. Return a string representation.
return str(o)

if isinstance(o, execute.util.Unknown):
return const.UNKNOWN_STRING

LOGGER.error("Unable to serialize %s", o)
raise TypeError(repr(o) + " is not JSON serializable")
# handle common python types
return util.custom_json_encoder(o)


def attach_warnings(code: int, value: JsonType, warnings: Optional[List[str]]) -> Tuple[int, JsonType]:
Expand All @@ -406,12 +389,12 @@ def attach_warnings(code: int, value: JsonType, warnings: Optional[List[str]]) -
return code, value


def json_encode(value: Dict[str, Any]) -> str:
def json_encode(value: JsonType) -> str:
# see json_encode in tornado.escape
return json.dumps(value, default=custom_json_encoder).replace("</", "<\\/")


def gzipped_json(value: Dict[str, Any]) -> Tuple[bool, Union[bytes, str]]:
def gzipped_json(value: JsonType) -> Tuple[bool, Union[bytes, str]]:
json_string = json_encode(value)
if len(json_string) < web.GZipContentEncoding.MIN_LENGTH:
return False, json_string
Expand Down
29 changes: 28 additions & 1 deletion src/inmanta/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,14 @@
import logging
import socket
import warnings
import uuid
import datetime
import enum

import pkg_resources
from pkg_resources import DistributionNotFound
from tornado.ioloop import IOLoop
from typing import Callable, Dict, Union, Tuple
from typing import Callable, Dict, Union, Tuple, List

from inmanta.types import JsonType

Expand Down Expand Up @@ -155,3 +158,27 @@ def get_free_tcp_port() -> str:
_addr, port = tcp.getsockname()
tcp.close()
return str(port)


def custom_json_encoder(o: object) -> Union[Dict, str, List]:
"""
A custom json encoder that knows how to encode other types commonly used by Inmanta from standard python libraries
"""
if isinstance(o, uuid.UUID):
return str(o)

if isinstance(o, datetime.datetime):
return o.isoformat(timespec='microseconds')

if hasattr(o, "to_dict"):
return o.to_dict()

if isinstance(o, enum.Enum):
return o.name

if isinstance(o, Exception):
# Logs can push exceptions through RPC. Return a string representation.
return str(o)

LOGGER.error("Unable to serialize %s", o)
raise TypeError(repr(o) + " is not JSON serializable")
23 changes: 23 additions & 0 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1863,3 +1863,26 @@ async def test_insert_many(init_dataclasses_and_load_schema, postgresql_client):

assert len(project_names_in_result) == 2
assert sorted(["proj1", "proj2"]) == sorted(project_names_in_result)


@pytest.mark.asyncio
async def test_resources_json(init_dataclasses_and_load_schema):
project = data.Project(name="test")
await project.insert()
env = data.Environment(name="dev", project=project.id, repo_url="", repo_branch="")
await env.insert()

version = 1
cm = data.ConfigurationModel(environment=env.id, version=version, date=datetime.datetime.now(), total=1,
version_info={}, released=True, deployed=True)
await cm.insert()

res1 = data.Resource.new(environment=env.id,
resource_version_id="std::File[agent1,path=/etc/file1],v=%s" % version,
status=const.ResourceState.deployed, last_deploy=datetime.datetime.now(),
attributes={"attr": [{"a": 1, "b": "c"}]})
await res1.insert()

res = await data.Resource.get_one(environment=res1.environment, resource_version_id=res1.resource_version_id)

assert res1.attributes == res.attributes