Skip to content

Commit

Permalink
Refactor all *.from_db() routines to use from_db_json()
Browse files Browse the repository at this point in the history
This wrapper calls json.loads() but also handles None (returning None),
which enables the code at many call sites to be simplified.

Removed some callers' `if isinstance(field, str): ...` code, which has
the effect of newly disallowing field values that are already dicts.
However we've verified that all *.from_db() calls have raw database
outputs as their arguments, so such fields will be always be strings
and IMHO giving a dict to from_db_json() is really a logic error that
should be detected.

In SequencingGroupInternal.from_db() added `pop(..., None)` so that
a missing meta field is now accepted. The previous code suggests that
having pop() produce KeyError here was unintended.

The expected argument types for from_db_json() are listed in the
definition, but we don't list its return type. The best we could
say in general is `object` but most call sites expect `dict[str, str]`
(or occasionally `list[str]`) due to the shape of their expected JSON.
Specifying `object` would lead to mypy errors at these call sites.
  • Loading branch information
jmarshall committed Jun 13, 2024
1 parent 71dddbe commit 9eecc05
Show file tree
Hide file tree
Showing 12 changed files with 37 additions and 66 deletions.
15 changes: 7 additions & 8 deletions db/python/layers/web.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# pylint: disable=too-many-locals, too-many-instance-attributes
import asyncio
import itertools
import json
import re
from collections import defaultdict
from datetime import date
Expand All @@ -15,7 +14,7 @@
from db.python.tables.base import DbBase
from db.python.tables.project import ProjectPermissionsTable
from db.python.tables.sequencing_group import SequencingGroupTable
from db.python.utils import escape_like_term
from db.python.utils import escape_like_term, from_db_json
from models.models import (
AssayInternal,
FamilySimpleInternal,
Expand Down Expand Up @@ -114,7 +113,7 @@ def _project_summary_process_assay_rows_by_sample_id(
AssayInternal(
id=seq['id'],
type=seq['type'],
meta=json.loads(seq['meta']),
meta=from_db_json(seq['meta']),
sample_id=seq['sample_id'],
)
for seq in assay_rows
Expand Down Expand Up @@ -153,7 +152,7 @@ def _project_summary_process_sequencing_group_rows_by_sample_id(
sg_id_to_sample_id[sg_id] = row['sample_id']
sg_by_id[sg_id] = NestedSequencingGroupInternal(
id=sg_id,
meta=json.loads(row['meta']),
meta=from_db_json(row['meta']),
type=row['type'],
technology=row['technology'],
platform=row['platform'],
Expand Down Expand Up @@ -189,9 +188,9 @@ def _project_summary_process_sample_rows(
smodels = [
NestedSampleInternal(
id=s['id'],
external_ids=json.loads(s['external_ids']),
external_ids=from_db_json(s['external_ids']),
type=s['type'],
meta=json.loads(s['meta']) or {},
meta=from_db_json(s['meta']) or {},
created_date=str(sample_id_start_times.get(s['id'], '')),
sequencing_groups=sg_models_by_sample_id.get(s['id'], []),
non_sequencing_assays=filtered_assay_models_by_sid.get(s['id'], []),
Expand Down Expand Up @@ -450,8 +449,8 @@ async def get_project_summary(
pmodels.append(
NestedParticipantInternal(
id=p['id'],
external_ids=json.loads(p['external_ids']),
meta=json.loads(p['meta']),
external_ids=from_db_json(p['external_ids']),
meta=from_db_json(p['meta']),
families=pid_to_families.get(p['id'], []),
samples=list(smodels_by_pid.get(p['id'])),
reported_sex=p['reported_sex'],
Expand Down
8 changes: 4 additions & 4 deletions db/python/tables/participant_phenotype.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import json
from collections import defaultdict
from typing import Any, Dict, List, Tuple

from db.python.tables.base import DbBase
from db.python.utils import from_db_json, to_db_json


class ParticipantPhenotypeTable(DbBase):
Expand Down Expand Up @@ -32,7 +32,7 @@ async def add_key_value_rows(self, rows: List[Tuple[int, str, Any]]) -> None:
{
'participant_id': r[0],
'description': r[1],
'value': json.dumps(r[2]),
'value': to_db_json(r[2]),
'audit_log_id': audit_log_id,
}
for r in rows
Expand Down Expand Up @@ -67,7 +67,7 @@ async def get_key_value_rows_for_participant_ids(
pid = row['participant_id']
key = row['description']
value = row['value']
formed_key_value_pairs[pid][key] = json.loads(value)
formed_key_value_pairs[pid][key] = from_db_json(value)

return formed_key_value_pairs

Expand All @@ -91,6 +91,6 @@ async def get_key_value_rows_for_all_participants(
pid = row['participant_id']
key = row['description']
value = row['value']
formed_key_value_pairs[pid][key] = json.loads(value)
formed_key_value_pairs[pid][key] = from_db_json(value)

return formed_key_value_pairs
6 changes: 4 additions & 2 deletions db/python/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,8 +431,10 @@ def get_logger():
return _logger


def from_db_json(text):
"""Convert DB's JSON text to Python object"""
def from_db_json(text: str | bytes | None):
"""Convert database's JSON text to Python object"""
if text is None:
return None
return json.loads(text)


Expand Down
8 changes: 2 additions & 6 deletions models/models/analysis.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import enum
import json
from datetime import date, datetime
from typing import Any

from pydantic import BaseModel

from db.python.utils import from_db_json
from models.base import SMBase
from models.enums import AnalysisStatus
from models.utils.cohort_id_format import (
Expand Down Expand Up @@ -40,10 +40,6 @@ def from_db(**kwargs):
analysis_type = kwargs.pop('type', None)
status = kwargs.pop('status', None)
timestamp_completed = kwargs.pop('timestamp_completed', None)
meta = kwargs.get('meta')

if meta and isinstance(meta, str):
meta = json.loads(meta)

if timestamp_completed and isinstance(timestamp_completed, str):
timestamp_completed = datetime.fromisoformat(timestamp_completed)
Expand All @@ -65,7 +61,7 @@ def from_db(**kwargs):
output=kwargs.pop('output', []),
timestamp_completed=timestamp_completed,
project=kwargs.get('project'),
meta=meta,
meta=from_db_json(kwargs.get('meta')),
active=bool(kwargs.get('active')),
author=kwargs.get('author'),
)
Expand Down
8 changes: 2 additions & 6 deletions models/models/analysis_runner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import datetime
import json

from db.python.utils import from_db_json
from models.base import SMBase
from models.models.project import ProjectId

Expand Down Expand Up @@ -34,10 +34,6 @@ class AnalysisRunnerInternal(SMBase):
@staticmethod
def from_db(**kwargs):
"""Convert from db Record"""
meta = kwargs.pop('meta')
if meta:
meta = json.loads(meta)

_timestamp = kwargs.pop('timestamp')
# if _timestamp:
# _timestamp = datetime.datetime.fromisoformat(_timestamp)
Expand All @@ -58,7 +54,7 @@ def from_db(**kwargs):
hail_version=kwargs.pop('hail_version'),
batch_url=kwargs.pop('batch_url'),
submitting_user=kwargs.pop('submitting_user'),
meta=meta,
meta=from_db_json(kwargs.pop('meta')),
audit_log_id=kwargs.pop('audit_log_id'),
output_path=kwargs.pop('output_path'),
)
Expand Down
10 changes: 2 additions & 8 deletions models/models/assay.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from typing import Any

from db.python.utils import from_db_json
from models.base import OpenApiGenNoneType, SMBase
from models.utils.sample_id_format import sample_id_format, sample_id_transform_to_raw

Expand All @@ -26,13 +26,7 @@ def __eq__(self, other):
def from_db(d: dict):
"""Take DB mapping object, and return SampleSequencing"""
meta = d.pop('meta', None)

if meta:
if isinstance(meta, bytes):
meta = meta.decode()
if isinstance(meta, str):
meta = json.loads(meta)
return AssayInternal(meta=meta, **d)
return AssayInternal(meta=from_db_json(meta), **d)

def to_external(self):
"""Convert to transport model"""
Expand Down
9 changes: 3 additions & 6 deletions models/models/audit_log.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import datetime
import json

from db.python.utils import from_db_json
from models.base import SMBase
from models.models.project import ProjectId

Expand All @@ -24,8 +24,5 @@ class AuditLogInternal(SMBase):
@staticmethod
def from_db(d: dict):
"""Take DB mapping object, and return SampleSequencing"""
meta = {}
if 'meta' in d:
meta = json.loads(d.pop('meta'))

return AuditLogInternal(meta=meta, **d)
meta = d.pop('meta', None)
return AuditLogInternal(meta=from_db_json(meta) or {}, **d)
7 changes: 2 additions & 5 deletions models/models/cohort.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import json

from db.python.utils import from_db_json
from models.base import SMBase
from models.models.project import ProjectId
from models.utils.cohort_id_format import cohort_id_format
Expand Down Expand Up @@ -92,9 +91,7 @@ def from_db(d: dict):
_id = d.pop('id', None)
name = d.pop('name', None)
description = d.pop('description', None)
criteria = d.pop('criteria', None)
if criteria and isinstance(criteria, str):
criteria = json.loads(criteria)
criteria = from_db_json(d.pop('criteria', None))
project = d.pop('project', None)

return CohortTemplateInternal(
Expand Down
7 changes: 3 additions & 4 deletions models/models/participant.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import json

from db.python.utils import from_db_json
from models.base import OpenApiGenNoneType, SMBase
from models.models.family import FamilySimple, FamilySimpleInternal
from models.models.project import ProjectId
Expand Down Expand Up @@ -28,8 +27,8 @@ class ParticipantInternal(SMBase):
def from_db(cls, data: dict):
"""Convert from db keys, mainly converting JSON-encoded fields"""
for key in ['external_ids', 'meta']:
if key in data and isinstance(data[key], str):
data[key] = json.loads(data[key])
if key in data:
data[key] = from_db_json(data[key])

return ParticipantInternal(**data)

Expand Down
4 changes: 2 additions & 2 deletions models/models/project.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from typing import Optional

from db.python.utils import from_db_json
from models.base import SMBase

ProjectId = int
Expand All @@ -20,5 +20,5 @@ class Project(SMBase):
def from_db(kwargs):
"""From DB row, with db keys"""
kwargs = dict(kwargs)
kwargs['meta'] = json.loads(kwargs['meta']) if kwargs.get('meta') else {}
kwargs['meta'] = from_db_json(kwargs.get('meta')) or {}
return Project(**kwargs)
15 changes: 4 additions & 11 deletions models/models/sample.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import json

from db.python.utils import from_db_json
from models.base import OpenApiGenNoneType, SMBase, parse_sql_bool
from models.models.assay import Assay, AssayInternal, AssayUpsert, AssayUpsertInternal
from models.models.sequencing_group import (
Expand Down Expand Up @@ -30,17 +29,11 @@ def from_db(d: dict):
"""
_id = d.pop('id', None)
type_ = d.pop('type', None)
meta = d.pop('meta', None)
meta = from_db_json(d.pop('meta', None))
active = parse_sql_bool(d.pop('active', None))

if meta:
if isinstance(meta, bytes):
meta = meta.decode()
if isinstance(meta, str):
meta = json.loads(meta)

if 'external_ids' in d and isinstance(d['external_ids'], str):
d['external_ids'] = json.loads(d['external_ids'])
if 'external_ids' in d:
d['external_ids'] = from_db_json(d['external_ids'])

return SampleInternal(id=_id, type=str(type_), meta=meta, active=active, **d)

Expand Down
6 changes: 2 additions & 4 deletions models/models/sequencing_group.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from typing import Any

from db.python.utils import from_db_json
from models.base import OpenApiGenNoneType, SMBase
from models.models.assay import Assay, AssayInternal, AssayUpsert, AssayUpsertInternal
from models.utils.sample_id_format import sample_id_format, sample_id_transform_to_raw
Expand Down Expand Up @@ -47,9 +47,7 @@ class SequencingGroupInternal(SMBase):
@classmethod
def from_db(cls, **kwargs):
"""From database model"""
meta = kwargs.pop('meta')
if meta and isinstance(meta, str):
meta = json.loads(meta)
meta = from_db_json(kwargs.pop('meta', None))

_archived = kwargs.pop('archived', None)
if _archived is not None:
Expand Down

0 comments on commit 9eecc05

Please sign in to comment.