From 06cc273fb063d73ec65071c3807c76bf63f0d448 Mon Sep 17 00:00:00 2001 From: Michael Lamkin Date: Fri, 12 Jan 2024 11:51:37 -0800 Subject: [PATCH] fix: convert `samples` argument in `Genotypes.read` into a set and fix `tr_harmonizer` bug arising when TRTools is also installed (#225) Co-authored-by: Arya Massarat <23412689+aryarm@users.noreply.github.com> --- haptools/__main__.py | 18 +++--- haptools/data/genotypes.py | 112 ++++++++++++++++++++++++--------- haptools/data/tr_harmonizer.py | 14 +---- haptools/ld.py | 5 +- haptools/sim_phenotype.py | 7 +-- haptools/transform.py | 16 ++--- tests/test_data.py | 12 ++-- 7 files changed, 113 insertions(+), 71 deletions(-) diff --git a/haptools/__main__.py b/haptools/__main__.py index 39cf3322..3d76d829 100755 --- a/haptools/__main__.py +++ b/haptools/__main__.py @@ -491,10 +491,10 @@ def simphenotype( ) if samples_file: with samples_file as samps_file: - samples = samps_file.read().splitlines() + samples = set(samps_file.read().splitlines()) elif samples: - # needs to be converted from tuple to list - samples = list(samples) + # needs to be converted from tuple to set + samples = set(samples) else: samples = None @@ -657,10 +657,10 @@ def transform( ) if samples_file: with samples_file as samps_file: - samples = samps_file.read().splitlines() + samples = set(samps_file.read().splitlines()) elif samples: - # needs to be converted from tuple to list - samples = list(samples) + # needs to be converted from tuple to set + samples = set(samples) else: samples = None @@ -828,10 +828,10 @@ def ld( ) if samples_file: with samples_file as samps_file: - samples = samps_file.read().splitlines() + samples = set(samps_file.read().splitlines()) elif samples: - # needs to be converted from tuple to list - samples = list(samples) + # needs to be converted from tuple to set + samples = set(samples) else: samples = None diff --git a/haptools/data/genotypes.py b/haptools/data/genotypes.py index caa297c5..b90bef98 100644 --- a/haptools/data/genotypes.py +++ b/haptools/data/genotypes.py @@ -3,15 +3,15 @@ import gc from csv import reader from pathlib import Path +from logging import Logger from typing import Iterator -from logging import getLogger, Logger from collections import namedtuple, Counter import pgenlib import numpy as np import numpy.typing as npt +from pysam import VariantFile from cyvcf2 import VCF, Variant -from pysam import VariantFile, TabixFile try: import trtools.utils.tr_harmonizer as trh @@ -77,7 +77,7 @@ def load( cls: Genotypes, fname: Path | str, region: str = None, - samples: list[str] = None, + samples: set[str] = None, variants: set[str] = None, ) -> Genotypes: """ @@ -91,7 +91,7 @@ def load( See documentation for :py:attr:`~.Data.fname` region : str, optional See documentation for :py:meth:`~.Genotypes.read` - samples : list[str], optional + samples : set[str], optional See documentation for :py:meth:`~.Genotypes.read` variants : set[str], optional See documentation for :py:meth:`~.Genotypes.read` @@ -112,7 +112,7 @@ def load( def read( self, region: str = None, - samples: list[str] = None, + samples: set[str] = None, variants: set[str] = None, max_variants: int = None, ): @@ -132,9 +132,11 @@ def read( For this to work, the VCF must be indexed and the seqname must match! Defaults to loading all genotypes - samples : list[str], optional + samples : set[str], optional A subset of the samples from which to extract genotypes + Note that they are loaded in the same order as in the file + Defaults to loading genotypes from all samples variants : set[str], optional A set of variant IDs for which to extract genotypes @@ -307,7 +309,7 @@ def _iterate(self, vcf: VCF, region: str = None, variants: set[str] = None): vcf.close() def __iter__( - self, region: str = None, samples: list[str] = None, variants: set[str] = None + self, region: str = None, samples: set[str] = None, variants: set[str] = None ) -> Iterator[namedtuple]: """ Read genotypes from a VCF line by line without storing anything @@ -316,7 +318,7 @@ def __iter__( ---------- region : str, optional See documentation for :py:meth:`~.Genotypes.read` - samples : list[str], optional + samples : set[str], optional See documentation for :py:meth:`~.Genotypes.read` variants : set[str], optional See documentation for :py:meth:`~.Genotypes.read` @@ -326,6 +328,13 @@ def __iter__( Iterator[namedtuple] See documentation for :py:meth:`~.Genotypes._iterate` """ + if samples is not None: + if not isinstance(samples, set): + self.log.warning( + "Samples cannot be loaded in a particular order. " + "Use subset() to reorder the samples after loading them." + ) + samples = list(samples) vcf = VCF(str(self.fname), samples=samples, lazy=True) self.samples = tuple(vcf.samples) # call another function to force the lines above to be run immediately @@ -797,6 +806,37 @@ def write(self): vcf.close() +class TRRecordHarmonizerRegion(trh.TRRecordHarmonizer): + """ + Parameters + ---------- + vcffile : VCF + vcftype : {'auto', 'gangstr', 'advntr', 'hipstr', 'eh', 'popstr'}, optional + Type of the VCF file. Default='auto'. + If vcftype=='auto', attempts to infer the type. + Attributes + ---------- + vcffile : VCF + vcfiter : VCF + Region to grab strs from within the VCF file. + vcftype : enum + Type of the VCF file. Must be included in VcfTypes + """ + + def __init__( + self, + vcffile: VCF, + vcfiter: object, + vcftype: str | trh.VcfTypes = "auto", + ): + super().__init__(vcffile, vcftype) + self.vcfiter = vcfiter + + def __next__(self) -> trh.TRRecord: + """Iterate over TRRecord produced from the underlying vcf.""" + return trh.HarmonizeRecord(self.vcftype, next(self.vcfiter)) + + class GenotypesTR(Genotypes): """ A class for processing TR genotypes from a file @@ -823,7 +863,12 @@ class GenotypesTR(Genotypes): {'auto', 'gangstr', 'advntr', 'hipstr', 'eh', 'popstr'} """ - def __init__(self, fname: Path | str, log: Logger = None, vcftype: str = "auto"): + def __init__( + self, + fname: Path | str, + log: Logger = None, + vcftype: str = "auto", + ): super().__init__(fname, log) self.vcftype = vcftype @@ -832,7 +877,7 @@ def load( cls: GenotypesTR, fname: Path | str, region: str = None, - samples: list[str] = None, + samples: set[str] = None, variants: set[str] = None, vcftype: str = "auto", ) -> Genotypes: @@ -847,7 +892,7 @@ def load( See documentation for :py:attr:`~.Data.fname` region : str, optional See documentation for :py:meth:`~.Genotypes.read` - samples : list[str], optional + samples : set[str], optional See documentation for :py:meth:`~.Genotypes.read` variants : set[str], optional See documentation for :py:meth:`~.Genotypes.read` @@ -878,8 +923,8 @@ def _vcf_iter(self, vcf: cyvcf2.VCF, region: str = None): tr_records: trh.TRRecord TRRecord objects yielded from TRRecordHarmonizer """ - for record in trh.TRRecordHarmonizer( - vcffile=vcf, vcfiter=vcf(region), region=region, vcftype=self.vcftype + for record in TRRecordHarmonizerRegion( + vcffile=vcf, vcfiter=vcf(region), vcftype=self.vcftype ): record.ID = record.record_id record.CHROM = record.chrom @@ -938,7 +983,7 @@ class GenotypesPLINK(GenotypesVCF): ---------- data : npt.NDArray See documentation for :py:attr:`~.GenotypesVCF.data` - samples : tuple + samples : tuple[str] See documentation for :py:attr:`~.GenotypesVCF.samples` variants : np.array See documentation for :py:attr:`~.GenotypesVCF.variants` @@ -956,11 +1001,16 @@ class GenotypesPLINK(GenotypesVCF): >>> genotypes = GenotypesPLINK.load('tests/data/simple.pgen') """ - def __init__(self, fname: Path | str, log: Logger = None, chunk_size: int = None): + def __init__( + self, + fname: Path | str, + log: Logger = None, + chunk_size: int = None, + ): super().__init__(fname, log) self.chunk_size = chunk_size - def read_samples(self, samples: list[str] = None): + def read_samples(self, samples: set[str] = None): """ Read sample IDs from a PSAM file into a list stored in :py:attr:`~.GenotypesPLINK.samples` @@ -969,7 +1019,7 @@ def read_samples(self, samples: list[str] = None): Parameters ---------- - samples : list[str], optional + samples : set[str], optional See documentation for :py:attr:`~.GenotypesVCF.read` Returns @@ -980,6 +1030,10 @@ def read_samples(self, samples: list[str] = None): if len(self.samples) != 0: self.log.warning("Sample data has already been loaded. Overriding.") if samples is not None and not isinstance(samples, set): + self.log.warning( + "Samples cannot be loaded in a particular order. " + "Use subset() to reorder the samples after loading them." + ) samples = set(samples) with self.hook_compressed(self.fname.with_suffix(".psam"), mode="r") as psam: psamples = reader(psam, delimiter="\t") @@ -1210,7 +1264,7 @@ def read_variants( def read( self, region: str = None, - samples: list[str] = None, + samples: set[str] = None, variants: set[str] = None, max_variants: int = None, ): @@ -1222,7 +1276,7 @@ def read( ---------- region : str, optional See documentation for :py:attr:`~.GenotypesVCF.read` - samples : list[str], optional + samples : set[str], optional See documentation for :py:attr:`~.GenotypesVCF.read` variants : set[str], optional See documentation for :py:attr:`~.GenotypesVCF.read` @@ -1366,7 +1420,7 @@ def _iterate( pgen.close() def __iter__( - self, region: str = None, samples: list[str] = None, variants: set[str] = None + self, region: str = None, samples: set[str] = None, variants: set[str] = None ) -> Iterator[namedtuple]: """ Read genotypes from a PGEN line by line without storing anything @@ -1375,7 +1429,7 @@ def __iter__( ---------- region : str, optional See documentation for :py:meth:`~.Genotypes.read` - samples : list[str], optional + samples : set[str], optional See documentation for :py:meth:`~.Genotypes.read` variants : set[str], optional See documentation for :py:meth:`~.Genotypes.read` @@ -1552,7 +1606,7 @@ class GenotypesPLINKTR(GenotypesPLINK): ---------- data : npt.NDArray See documentation for :py:attr:`~.GenotypesPLINK.data` - samples : tuple + samples : tuple[str] See documentation for :py:attr:`~.GenotypesPLINK.samples` variants : np.array See documentation for :py:attr:`~.GenotypesPLINK.variants` @@ -1562,7 +1616,6 @@ class GenotypesPLINKTR(GenotypesPLINK): See documentation for :py:attr:`~.GenotypesPLINK.chunk_size` vcftype: str, optional See documentation for :py:attr:`~.GenotypesTR.vcftype` - Examples -------- >>> genotypes = GenotypesPLINK.load('tests/data/simple.pgen') @@ -1583,7 +1636,7 @@ def load( cls: GenotypesPLINKTR, fname: Path | str, region: str = None, - samples: list[str] = None, + samples: set[str] = None, variants: set[str] = None, vcftype: str = "auto", ) -> Genotypes: @@ -1598,7 +1651,7 @@ def load( See documentation for :py:attr:`~.Data.fname` region : str, optional See documentation for :py:meth:`~.Genotypes.read` - samples : list[str], optional + samples : set[str], optional See documentation for :py:meth:`~.Genotypes.read` variants : set[str], optional See documentation for :py:meth:`~.Genotypes.read` @@ -1632,12 +1685,12 @@ def _iter_TRRecords(self, region: str = None, variants: set[str] = None): An iterator over each line of the PVAR file """ vcf = VCF(self.fname.with_suffix(".pvar")) - tr_records = trh.TRRecordHarmonizer( + tr_records = TRRecordHarmonizerRegion( vcffile=vcf, vcfiter=vcf(region), - region=region, vcftype=self.vcftype, ) + # filter out TRs that we didn't want if variants is not None: tr_records = filter(lambda rec: rec.record_id in variants, tr_records) @@ -1646,7 +1699,7 @@ def _iter_TRRecords(self, region: str = None, variants: set[str] = None): def read( self, region: str = None, - samples: list[str] = None, + samples: set[str] = None, variants: set[str] = None, max_variants: int = None, ): @@ -1658,7 +1711,7 @@ def read( ---------- region : str, optional See documentation for :py:attr:`~.GenotypesVCF.read` - samples : list[str], optional + samples : set[str], optional See documentation for :py:attr:`~.GenotypesVCF.read` variants : set[str], optional See documentation for :py:attr:`~.GenotypesVCF.read` @@ -1666,6 +1719,7 @@ def read( See documentation for :py:attr:`~.GenotypesVCF.read` """ super().read(region, samples, variants, max_variants) + num_variants = len(self.variants) # initialize a jagged array of allele lengths max_num_alleles = max(map(len, self.variants["alleles"])) diff --git a/haptools/data/tr_harmonizer.py b/haptools/data/tr_harmonizer.py index ee03fff2..9813d901 100644 --- a/haptools/data/tr_harmonizer.py +++ b/haptools/data/tr_harmonizer.py @@ -1633,8 +1633,6 @@ class TRRecordHarmonizer: Attributes ---------- vcffile : cyvcf2.VCF instance - region : str - Region to grab strs from within the VCF file. vcftype : enum Type of the VCF file. Must be included in VcfTypes Raises @@ -1644,17 +1642,9 @@ class TRRecordHarmonizer: See :py:meth:`InferVCFType` for more details. """ - def __init__( - self, - vcffile: cyvcf2.VCF, - vcfiter: object, - region: str, - vcftype: Union[str, VcfTypes] = "auto", - ): + def __init__(self, vcffile: cyvcf2.VCF, vcftype: Union[str, VcfTypes] = "auto"): self.vcffile = vcffile - self.vcfiter = vcfiter self.vcftype = InferVCFType(vcffile, vcftype) - self.region = region def MayHaveImpureRepeats(self) -> bool: """ @@ -1725,7 +1715,7 @@ def __iter__(self) -> Iterator[TRRecord]: def __next__(self) -> TRRecord: """Iterate over TRRecord produced from the underlying vcf.""" - return HarmonizeRecord(self.vcftype, next(self.vcfiter)) + return HarmonizeRecord(self.vcftype, next(self.vcffile)) # TODO check all users of this class for new options diff --git a/haptools/ld.py b/haptools/ld.py index be28ca14..47bed516 100644 --- a/haptools/ld.py +++ b/haptools/ld.py @@ -1,4 +1,5 @@ from __future__ import annotations +import logging from pathlib import Path from dataclasses import dataclass, field @@ -50,7 +51,7 @@ def calc_ld( genotypes: Path, haplotypes: Path, region: str = None, - samples: list[str] = None, + samples: set[str] = None, ids: tuple[str] = None, chunk_size: int = None, discard_missing: bool = False, @@ -72,7 +73,7 @@ def calc_ld( region : str, optional See documentation for :py:meth:`~.data.Genotypes.read` and :py:meth:`~.data.Haplotypes.read` - samples : list[str], optional + samples : set[str], optional See documentation for :py:meth:`~.data.Genotypes.read` ids: set[str], optional A subset of haplotype IDs to obtain from the .hap file. All others diff --git a/haptools/sim_phenotype.py b/haptools/sim_phenotype.py index df6909ec..7ae45998 100644 --- a/haptools/sim_phenotype.py +++ b/haptools/sim_phenotype.py @@ -303,7 +303,7 @@ def simulate_pt( prevalence: float = None, normalize: bool = True, region: str = None, - samples: list[str] = None, + samples: set[str] = None, haplotype_ids: set[str] = None, chunk_size: int = None, repeats: Path = None, @@ -352,13 +352,10 @@ def simulate_pt( match! Defaults to loading all haplotypes - sample : tuple[str], optional + samples : set[str], optional A subset of the samples from which to extract genotypes Defaults to loading genotypes from all samples - samples_file : Path, optional - A single column txt file containing a list of the samples (one per line) to - subset from the genotypes file haplotype_ids: set[str], optional A list of haplotype IDs to obtain from the .hap file. All others are ignored. diff --git a/haptools/transform.py b/haptools/transform.py index df70f6cb..11738730 100644 --- a/haptools/transform.py +++ b/haptools/transform.py @@ -5,8 +5,8 @@ from dataclasses import dataclass, field import numpy as np +from cyvcf2 import VCF import numpy.typing as npt -from cyvcf2 import VCF, Variant from pysam import VariantFile from . import data @@ -28,7 +28,7 @@ class HaplotypeAncestry(data.Haplotype): default=(data.Extra("ancestry", "s", "Local ancestry"),), ) - def transform(self, genotypes: data.GenotypesVCF) -> npt.NDArray[bool]: + def transform(self, genotypes: data.GenotypesVCF) -> npt.NDArray: """ Transform a genotypes matrix via the current haplotype and its ancestral population @@ -80,7 +80,7 @@ def __init__( fname: Path | str, haplotype: type[HaplotypeAncestry] = HaplotypeAncestry, variant: type[data.Variant] = data.Variant, - log: Logger = None, + log: logging.Logger = None, ): """ Contrasting with the base Haplotypes class: this class uses HaplotypeAncestry @@ -171,11 +171,11 @@ class GenotypesAncestry(data.GenotypesVCF): ancestry : np.array The ancestral population of each allele in each sample of :py:attr:`~.GenotypesAncestry.data` - log: Logger + log: logging.Logger See documentation for :py:attr:`~.Genotypes.log` """ - def __init__(self, fname: Path | str, log: Logger = None): + def __init__(self, fname: Path | str, log: logging.Logger = None): super().__init__(fname, log) self.ancestry = None self.valid_labels = None @@ -227,7 +227,7 @@ def _iterate(self, vcf: VCF, region: str = None, variants: set[str] = None): def read( self, region: str = None, - samples: list[str] = None, + samples: set[str] = None, variants: set[str] = None, max_variants: int = None, ): @@ -532,7 +532,7 @@ def transform_haps( genotypes: Path, haplotypes: Path, region: str = None, - samples: list[str] = None, + samples: set[str] = None, haplotype_ids: set[str] = None, chunk_size: int = None, discard_missing: bool = False, @@ -552,7 +552,7 @@ def transform_haps( region : str, optional See documentation for :py:meth:`~.data.Genotypes.read` and :py:meth:`~.data.Haplotypes.read` - samples : list[str], optional + samples : set[str], optional See documentation for :py:meth:`~.data.Genotypes.read` haplotype_ids: set[str], optional A set of haplotype IDs to obtain from the .hap file. All others are ignored. diff --git a/tests/test_data.py b/tests/test_data.py index 28ce228a..dd385bd2 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -20,7 +20,6 @@ Breakpoints, GenotypesTR, GenotypesVCF, - GenotypesTR, GenotypesPLINK, GenotypesPLINKTR, ) @@ -195,7 +194,8 @@ def test_load_genotypes_subset(self): gts = Genotypes(DATADIR / "simple.vcf.gz") samples = ["HG00097", "HG00100"] - gts.read(region="1:10115-10117", samples=samples) + samples_set = set(samples) + gts.read(region="1:10115-10117", samples=samples_set) np.testing.assert_allclose(gts.data, expected) assert gts.samples == tuple(samples) @@ -203,9 +203,8 @@ def test_load_genotypes_subset(self): expected = expected[:, [1]] gts = Genotypes(DATADIR / "simple.vcf.gz") - samples = ["HG00097", "HG00100"] variants = {"1:10117:C:A"} - gts.read(region="1:10115-10117", samples=samples, variants=variants) + gts.read(region="1:10115-10117", samples=samples_set, variants=variants) np.testing.assert_allclose(gts.data, expected) assert gts.samples == tuple(samples) @@ -501,7 +500,8 @@ def test_load_genotypes_subset(self): gts = GenotypesPLINK(DATADIR / "simple.pgen") samples = [expected.samples[1], expected.samples[3]] - gts.read(region="1:10115-10117", samples=samples) + samples_set = set(samples) + gts.read(region="1:10115-10117", samples=samples_set) gts.check_phase() np.testing.assert_allclose(gts.data, expected_data) assert gts.samples == tuple(samples) @@ -511,7 +511,7 @@ def test_load_genotypes_subset(self): gts = GenotypesPLINK(DATADIR / "simple.pgen") variants = {"1:10117:C:A"} - gts.read(region="1:10115-10117", samples=samples, variants=variants) + gts.read(region="1:10115-10117", samples=samples_set, variants=variants) gts.check_phase() np.testing.assert_allclose(gts.data, expected_data) assert gts.samples == tuple(samples)