import json import urllib.request from collections import Counter from collections.abc import Iterable from dataclasses import dataclass from datetime import date from enum import Enum from typing import NewType, Optional from pydantic import BaseModel from pathogen_properties import TaxID from tree import Tree MGS_REPO_DEFAULTS = { "user": "naobservatory", "repo": "mgs-pipeline", "ref": "data-2023-07-21", } BioProject = NewType("BioProject", str) Sample = NewType("Sample", str) target_bioprojects = { "crits_christoph": [BioProject("PRJNA661613")], "rothman": [BioProject("PRJNA729801")], "spurbeck": [BioProject("PRJNA924011")], "brinch": [BioProject("PRJEB13832"), BioProject("PRJEB34633")], } @dataclass class GitHubRepo: user: str repo: str ref: str def get_file(self, path: str) -> str: file_url = ( f"https://raw.githubusercontent.com/" f"{self.user}/{self.repo}/{self.ref}/{path}" ) with urllib.request.urlopen(file_url) as response: if response.status == 200: return response.read() else: raise ValueError( f"Failed to download {file_url}. " f"Response status code: {response.status}" ) def load_bioprojects(repo: GitHubRepo) -> dict[BioProject, list[Sample]]: data = json.loads(repo.get_file("dashboard/metadata_bioprojects.json")) return { BioProject(bp): [Sample(s) for s in samples] for bp, samples in data.items() } class Enrichment(Enum): VIRAL = "viral" PANEL = "panel" class SampleAttributes(BaseModel): country: str state: Optional[str] = None county: Optional[str] = None location: str fine_location: Optional[str] = None # Fixme: Not all the dates are real dates date: date | str reads: int enrichment: Optional[Enrichment] = None method: Optional[str] = None def load_sample_attributes(repo: GitHubRepo) -> dict[Sample, SampleAttributes]: data = json.loads(repo.get_file("dashboard/metadata_samples.json")) return { Sample(s): SampleAttributes(**attribs) for s, attribs in data.items() } SampleCounts = dict[TaxID, dict[Sample, int]] def load_sample_counts(repo: GitHubRepo) -> SampleCounts: data: dict[str, dict[str, int]] = json.loads( repo.get_file("dashboard/human_virus_sample_counts.json") ) return { TaxID(int(taxid)): {Sample(sample): n for sample, n in counts.items()} for taxid, counts in data.items() } def load_tax_tree(repo: GitHubRepo) -> Tree[TaxID]: data = json.loads(repo.get_file("dashboard/human_virus_tree.json")) return Tree.tree_from_list(data).map(lambda x: TaxID(int(x))) def make_count_tree( taxtree: Tree[TaxID], sample_counts: SampleCounts ) -> Tree[tuple[TaxID, Counter[Sample]]]: return taxtree.map( lambda taxid: ( (taxid, Counter(sample_counts[taxid])) if taxid in sample_counts else (taxid, Counter()) ), ) def count_reads( taxtree: Tree[TaxID] | None, sample_counts: SampleCounts ) -> Counter[Sample]: if taxtree is None: return Counter() count_tree = make_count_tree(taxtree, sample_counts) return sum( (elem.data[1] for elem in count_tree), start=Counter(), ) @dataclass class MGSData: bioprojects: dict[BioProject, list[Sample]] sample_attrs: dict[Sample, SampleAttributes] read_counts: SampleCounts tax_tree: Tree[TaxID] @staticmethod def from_repo( user=MGS_REPO_DEFAULTS["user"], repo=MGS_REPO_DEFAULTS["repo"], ref=MGS_REPO_DEFAULTS["ref"], ): repo = GitHubRepo(user, repo, ref) return MGSData( bioprojects=load_bioprojects(repo), sample_attrs=load_sample_attributes(repo), read_counts=load_sample_counts(repo), tax_tree=load_tax_tree(repo), ) def sample_attributes( self, bioproject: BioProject, enrichment: Optional[Enrichment] = None ) -> dict[Sample, SampleAttributes]: samples = { s: self.sample_attrs[s] for s in self.bioprojects[bioproject] } if enrichment: return { s: attrs for s, attrs in samples.items() if attrs.enrichment == enrichment } else: return samples def total_reads(self, bioproject: BioProject) -> dict[Sample, int]: return { s: self.sample_attrs[s].reads for s in self.bioprojects[bioproject] } def viral_reads( self, bioproject: BioProject, taxids: Iterable[TaxID] ) -> dict[Sample, int]: viral_counts_by_taxid = { taxid: count_reads(self.tax_tree[taxid], self.read_counts) for taxid in taxids } return { s: sum(viral_counts_by_taxid[taxid][s] for taxid in taxids) for s in self.bioprojects[bioproject] }