Skip to content

Commit

Permalink
fix: convert samples argument in Genotypes.read into a set and fi…
Browse files Browse the repository at this point in the history
…x `tr_harmonizer` bug arising when TRTools is also installed (#225)

Co-authored-by: Arya Massarat <[email protected]>
  • Loading branch information
mlamkin7 and aryarm authored Jan 12, 2024
1 parent 8e01ed4 commit 06cc273
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 71 deletions.
18 changes: 9 additions & 9 deletions haptools/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,10 +491,10 @@ def simphenotype(
)
if samples_file:
with samples_file as samps_file:
samples = samps_file.read().splitlines()
samples = set(samps_file.read().splitlines())
elif samples:
# needs to be converted from tuple to list
samples = list(samples)
# needs to be converted from tuple to set
samples = set(samples)
else:
samples = None

Expand Down Expand Up @@ -657,10 +657,10 @@ def transform(
)
if samples_file:
with samples_file as samps_file:
samples = samps_file.read().splitlines()
samples = set(samps_file.read().splitlines())
elif samples:
# needs to be converted from tuple to list
samples = list(samples)
# needs to be converted from tuple to set
samples = set(samples)
else:
samples = None

Expand Down Expand Up @@ -828,10 +828,10 @@ def ld(
)
if samples_file:
with samples_file as samps_file:
samples = samps_file.read().splitlines()
samples = set(samps_file.read().splitlines())
elif samples:
# needs to be converted from tuple to list
samples = list(samples)
# needs to be converted from tuple to set
samples = set(samples)
else:
samples = None

Expand Down
112 changes: 83 additions & 29 deletions haptools/data/genotypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
import gc
from csv import reader
from pathlib import Path
from logging import Logger
from typing import Iterator
from logging import getLogger, Logger
from collections import namedtuple, Counter

import pgenlib
import numpy as np
import numpy.typing as npt
from pysam import VariantFile
from cyvcf2 import VCF, Variant
from pysam import VariantFile, TabixFile

try:
import trtools.utils.tr_harmonizer as trh
Expand Down Expand Up @@ -77,7 +77,7 @@ def load(
cls: Genotypes,
fname: Path | str,
region: str = None,
samples: list[str] = None,
samples: set[str] = None,
variants: set[str] = None,
) -> Genotypes:
"""
Expand All @@ -91,7 +91,7 @@ def load(
See documentation for :py:attr:`~.Data.fname`
region : str, optional
See documentation for :py:meth:`~.Genotypes.read`
samples : list[str], optional
samples : set[str], optional
See documentation for :py:meth:`~.Genotypes.read`
variants : set[str], optional
See documentation for :py:meth:`~.Genotypes.read`
Expand All @@ -112,7 +112,7 @@ def load(
def read(
self,
region: str = None,
samples: list[str] = None,
samples: set[str] = None,
variants: set[str] = None,
max_variants: int = None,
):
Expand All @@ -132,9 +132,11 @@ def read(
For this to work, the VCF must be indexed and the seqname must match!
Defaults to loading all genotypes
samples : list[str], optional
samples : set[str], optional
A subset of the samples from which to extract genotypes
Note that they are loaded in the same order as in the file
Defaults to loading genotypes from all samples
variants : set[str], optional
A set of variant IDs for which to extract genotypes
Expand Down Expand Up @@ -307,7 +309,7 @@ def _iterate(self, vcf: VCF, region: str = None, variants: set[str] = None):
vcf.close()

def __iter__(
self, region: str = None, samples: list[str] = None, variants: set[str] = None
self, region: str = None, samples: set[str] = None, variants: set[str] = None
) -> Iterator[namedtuple]:
"""
Read genotypes from a VCF line by line without storing anything
Expand All @@ -316,7 +318,7 @@ def __iter__(
----------
region : str, optional
See documentation for :py:meth:`~.Genotypes.read`
samples : list[str], optional
samples : set[str], optional
See documentation for :py:meth:`~.Genotypes.read`
variants : set[str], optional
See documentation for :py:meth:`~.Genotypes.read`
Expand All @@ -326,6 +328,13 @@ def __iter__(
Iterator[namedtuple]
See documentation for :py:meth:`~.Genotypes._iterate`
"""
if samples is not None:
if not isinstance(samples, set):
self.log.warning(
"Samples cannot be loaded in a particular order. "
"Use subset() to reorder the samples after loading them."
)
samples = list(samples)
vcf = VCF(str(self.fname), samples=samples, lazy=True)
self.samples = tuple(vcf.samples)
# call another function to force the lines above to be run immediately
Expand Down Expand Up @@ -797,6 +806,37 @@ def write(self):
vcf.close()


class TRRecordHarmonizerRegion(trh.TRRecordHarmonizer):
"""
Parameters
----------
vcffile : VCF
vcftype : {'auto', 'gangstr', 'advntr', 'hipstr', 'eh', 'popstr'}, optional
Type of the VCF file. Default='auto'.
If vcftype=='auto', attempts to infer the type.
Attributes
----------
vcffile : VCF
vcfiter : VCF
Region to grab strs from within the VCF file.
vcftype : enum
Type of the VCF file. Must be included in VcfTypes
"""

def __init__(
self,
vcffile: VCF,
vcfiter: object,
vcftype: str | trh.VcfTypes = "auto",
):
super().__init__(vcffile, vcftype)
self.vcfiter = vcfiter

def __next__(self) -> trh.TRRecord:
"""Iterate over TRRecord produced from the underlying vcf."""
return trh.HarmonizeRecord(self.vcftype, next(self.vcfiter))


class GenotypesTR(Genotypes):
"""
A class for processing TR genotypes from a file
Expand All @@ -823,7 +863,12 @@ class GenotypesTR(Genotypes):
{'auto', 'gangstr', 'advntr', 'hipstr', 'eh', 'popstr'}
"""

def __init__(self, fname: Path | str, log: Logger = None, vcftype: str = "auto"):
def __init__(
self,
fname: Path | str,
log: Logger = None,
vcftype: str = "auto",
):
super().__init__(fname, log)
self.vcftype = vcftype

Expand All @@ -832,7 +877,7 @@ def load(
cls: GenotypesTR,
fname: Path | str,
region: str = None,
samples: list[str] = None,
samples: set[str] = None,
variants: set[str] = None,
vcftype: str = "auto",
) -> Genotypes:
Expand All @@ -847,7 +892,7 @@ def load(
See documentation for :py:attr:`~.Data.fname`
region : str, optional
See documentation for :py:meth:`~.Genotypes.read`
samples : list[str], optional
samples : set[str], optional
See documentation for :py:meth:`~.Genotypes.read`
variants : set[str], optional
See documentation for :py:meth:`~.Genotypes.read`
Expand Down Expand Up @@ -878,8 +923,8 @@ def _vcf_iter(self, vcf: cyvcf2.VCF, region: str = None):
tr_records: trh.TRRecord
TRRecord objects yielded from TRRecordHarmonizer
"""
for record in trh.TRRecordHarmonizer(
vcffile=vcf, vcfiter=vcf(region), region=region, vcftype=self.vcftype
for record in TRRecordHarmonizerRegion(
vcffile=vcf, vcfiter=vcf(region), vcftype=self.vcftype
):
record.ID = record.record_id
record.CHROM = record.chrom
Expand Down Expand Up @@ -938,7 +983,7 @@ class GenotypesPLINK(GenotypesVCF):
----------
data : npt.NDArray
See documentation for :py:attr:`~.GenotypesVCF.data`
samples : tuple
samples : tuple[str]
See documentation for :py:attr:`~.GenotypesVCF.samples`
variants : np.array
See documentation for :py:attr:`~.GenotypesVCF.variants`
Expand All @@ -956,11 +1001,16 @@ class GenotypesPLINK(GenotypesVCF):
>>> genotypes = GenotypesPLINK.load('tests/data/simple.pgen')
"""

def __init__(self, fname: Path | str, log: Logger = None, chunk_size: int = None):
def __init__(
self,
fname: Path | str,
log: Logger = None,
chunk_size: int = None,
):
super().__init__(fname, log)
self.chunk_size = chunk_size

def read_samples(self, samples: list[str] = None):
def read_samples(self, samples: set[str] = None):
"""
Read sample IDs from a PSAM file into a list stored in
:py:attr:`~.GenotypesPLINK.samples`
Expand All @@ -969,7 +1019,7 @@ def read_samples(self, samples: list[str] = None):
Parameters
----------
samples : list[str], optional
samples : set[str], optional
See documentation for :py:attr:`~.GenotypesVCF.read`
Returns
Expand All @@ -980,6 +1030,10 @@ def read_samples(self, samples: list[str] = None):
if len(self.samples) != 0:
self.log.warning("Sample data has already been loaded. Overriding.")
if samples is not None and not isinstance(samples, set):
self.log.warning(
"Samples cannot be loaded in a particular order. "
"Use subset() to reorder the samples after loading them."
)
samples = set(samples)
with self.hook_compressed(self.fname.with_suffix(".psam"), mode="r") as psam:
psamples = reader(psam, delimiter="\t")
Expand Down Expand Up @@ -1210,7 +1264,7 @@ def read_variants(
def read(
self,
region: str = None,
samples: list[str] = None,
samples: set[str] = None,
variants: set[str] = None,
max_variants: int = None,
):
Expand All @@ -1222,7 +1276,7 @@ def read(
----------
region : str, optional
See documentation for :py:attr:`~.GenotypesVCF.read`
samples : list[str], optional
samples : set[str], optional
See documentation for :py:attr:`~.GenotypesVCF.read`
variants : set[str], optional
See documentation for :py:attr:`~.GenotypesVCF.read`
Expand Down Expand Up @@ -1366,7 +1420,7 @@ def _iterate(
pgen.close()

def __iter__(
self, region: str = None, samples: list[str] = None, variants: set[str] = None
self, region: str = None, samples: set[str] = None, variants: set[str] = None
) -> Iterator[namedtuple]:
"""
Read genotypes from a PGEN line by line without storing anything
Expand All @@ -1375,7 +1429,7 @@ def __iter__(
----------
region : str, optional
See documentation for :py:meth:`~.Genotypes.read`
samples : list[str], optional
samples : set[str], optional
See documentation for :py:meth:`~.Genotypes.read`
variants : set[str], optional
See documentation for :py:meth:`~.Genotypes.read`
Expand Down Expand Up @@ -1552,7 +1606,7 @@ class GenotypesPLINKTR(GenotypesPLINK):
----------
data : npt.NDArray
See documentation for :py:attr:`~.GenotypesPLINK.data`
samples : tuple
samples : tuple[str]
See documentation for :py:attr:`~.GenotypesPLINK.samples`
variants : np.array
See documentation for :py:attr:`~.GenotypesPLINK.variants`
Expand All @@ -1562,7 +1616,6 @@ class GenotypesPLINKTR(GenotypesPLINK):
See documentation for :py:attr:`~.GenotypesPLINK.chunk_size`
vcftype: str, optional
See documentation for :py:attr:`~.GenotypesTR.vcftype`
Examples
--------
>>> genotypes = GenotypesPLINK.load('tests/data/simple.pgen')
Expand All @@ -1583,7 +1636,7 @@ def load(
cls: GenotypesPLINKTR,
fname: Path | str,
region: str = None,
samples: list[str] = None,
samples: set[str] = None,
variants: set[str] = None,
vcftype: str = "auto",
) -> Genotypes:
Expand All @@ -1598,7 +1651,7 @@ def load(
See documentation for :py:attr:`~.Data.fname`
region : str, optional
See documentation for :py:meth:`~.Genotypes.read`
samples : list[str], optional
samples : set[str], optional
See documentation for :py:meth:`~.Genotypes.read`
variants : set[str], optional
See documentation for :py:meth:`~.Genotypes.read`
Expand Down Expand Up @@ -1632,12 +1685,12 @@ def _iter_TRRecords(self, region: str = None, variants: set[str] = None):
An iterator over each line of the PVAR file
"""
vcf = VCF(self.fname.with_suffix(".pvar"))
tr_records = trh.TRRecordHarmonizer(
tr_records = TRRecordHarmonizerRegion(
vcffile=vcf,
vcfiter=vcf(region),
region=region,
vcftype=self.vcftype,
)

# filter out TRs that we didn't want
if variants is not None:
tr_records = filter(lambda rec: rec.record_id in variants, tr_records)
Expand All @@ -1646,7 +1699,7 @@ def _iter_TRRecords(self, region: str = None, variants: set[str] = None):
def read(
self,
region: str = None,
samples: list[str] = None,
samples: set[str] = None,
variants: set[str] = None,
max_variants: int = None,
):
Expand All @@ -1658,14 +1711,15 @@ def read(
----------
region : str, optional
See documentation for :py:attr:`~.GenotypesVCF.read`
samples : list[str], optional
samples : set[str], optional
See documentation for :py:attr:`~.GenotypesVCF.read`
variants : set[str], optional
See documentation for :py:attr:`~.GenotypesVCF.read`
max_variants : int, optional
See documentation for :py:attr:`~.GenotypesVCF.read`
"""
super().read(region, samples, variants, max_variants)

num_variants = len(self.variants)
# initialize a jagged array of allele lengths
max_num_alleles = max(map(len, self.variants["alleles"]))
Expand Down
Loading

0 comments on commit 06cc273

Please sign in to comment.