Skip to content

Commit

Permalink
feat: Added ability to read tandem repeats with GenotypesTR (#204)
Browse files Browse the repository at this point in the history
  • Loading branch information
mlamkin7 authored Apr 7, 2023
1 parent c5594d9 commit 6257264
Show file tree
Hide file tree
Showing 6 changed files with 2,065 additions and 1 deletion.
15 changes: 15 additions & 0 deletions docs/api/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,21 @@ All of the other methods in the :class:`Genotypes` class are inherited, but the
.. _api-data-genotypesplink:

GenotypesTR
++++++++++++
The :class:`GenotypesTR` class *extends* :class:`Genotypes` class. The :class:`GenotypesTR` class follows the same structure of :class:`GenotypesVCF`, but can now load repeat count of tandem repeats as the alleles.

All of the other methods in the :class:`Genotypes` class are inherited, but the :class:`GenotypesTR` class' ``load()`` function is unique to loading tandem repeat variants.

.. code-block:: python
genotypes = data.GenotypesTR.load('tests/data/simple_tr.vcf')
# make the first sample have 4 and 7 repeats for the alleles of the fourth variant
genotypes.data[0, 3] = (4, 7)
genotypes.write()
.. _api-data-genotypestr:

GenotypesPLINK
++++++++++++++
The :class:`GenotypesPLINK` class offers experimental support for reading and writing PLINK2 PGEN, PVAR, and PSAM files. We are able to read genotypes from PLINK2 PGEN files in a fraction of the time of VCFs. Reading from VCFs is :math:`O(n*p)`, while reading from PGEN files is approximately :math:`O(1)`.
Expand Down
2 changes: 1 addition & 1 deletion haptools/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
from .covariates import Covariates
from .breakpoints import Breakpoints, HapBlock
from .haplotypes import Extra, Variant, Haplotype, Haplotypes
from .genotypes import Genotypes, GenotypesVCF, GenotypesPLINK
from .genotypes import Genotypes, GenotypesVCF, GenotypesTR, GenotypesPLINK
271 changes: 271 additions & 0 deletions haptools/data/genotypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import numpy.typing as npt
from cyvcf2 import VCF, Variant
from pysam import VariantFile, TabixFile
from . import tr_harmonizer as trh

from .data import Data

Expand Down Expand Up @@ -695,6 +696,276 @@ def write(self):
vcf.close()


class GenotypesTR(Genotypes):
"""
A class for processing TR genotypes from a file
Unlike the base Genotypes class, this class genotypes will be repeat number
in the variants array
Attributes
----------
data : np.array
See documentation for :py:attr:`~.Genotypes.data`
fname : Path | str
See documentation for :py:attr:`~.Genotypes.fname`
samples : tuple[str]
See documentation for :py:attr:`~.Genotypes.samples`
variants : np.array
Variant-level meta information:
1. ID
2. CHROM
3. POS
4. [REF, ALT1, ALT2, ...]
log: Logger
See documentation for :py:attr:`~.Genotypes.log`
"""

def __init__(self, fname: Path | str, log: Logger = None):
super().__init__(fname, log)
dtype = {k: v[0] for k, v in self.variants.dtype.fields.items()}
self.variants = np.array([], dtype=list(dtype.items()) + [("alleles", object)])

def _variant_arr(self, record: Variant):
"""
See documentation for :py:meth:`~.Genotypes._variant_arr`
"""
return np.array(
(
record.record_id,
record.chrom,
record.pos,
(record.ref_allele, *record.alt_alleles),
),
dtype=self.variants.dtype,
)

@classmethod
def load(
cls: GenotypesTR,
fname: Path | str,
region: str = None,
samples: list[str] = None,
variants: set[str] = None,
) -> Genotypes:
"""
Load STR genotypes from a VCF file
Read the file contents, check the genotype phase, and create the MAC matrix
Parameters
----------
fname
See documentation for :py:attr:`~.Data.fname`
region : str, optional
See documentation for :py:meth:`~.Genotypes.read`
samples : list[str], optional
See documentation for :py:meth:`~.Genotypes.read`
variants : set[str], optional
See documentation for :py:meth:`~.Genotypes.read`
Returns
-------
Genotypes
A Genotypes object with the data loaded into its properties
"""
genotypes = cls(fname)
genotypes.read(region, samples, variants)
genotypes.check_phase()
return genotypes

def read(
self,
region: str = None,
samples: list[str] = None,
variants: set[str] = None,
max_variants: int = None,
):
"""
Read genotypes from a VCF into a numpy matrix stored in :py:attr:`~.Genotypes.data`
Raises
------
ValueError
If the genotypes array is empty
Parameters
----------
region : str, optional
The region from which to extract genotypes; ex: 'chr1:1234-34566' or 'chr7'
For this to work, the VCF must be indexed and the seqname must match!
Defaults to loading all genotypes
samples : list[str], optional
A subset of the samples from which to extract genotypes
Defaults to loading genotypes from all samples
variants : set[str], optional
A set of variant IDs for which to extract genotypes
All other variants will be ignored. This may be useful if you're running
out of memory.
max_variants : int, optional
The maximum mumber of variants to load from the file. Setting this value
helps preallocate the arrays, making the process faster and less memory
intensive. You should use this option if your processes are frequently
"Killed" from memory overuse.
If you don't know how many variants there are, set this to a large number
greater than what you would except. The np array will be resized
appropriately. You can also use the bcftools "counts" plugin to obtain the
number of expected sites within a region.
Note that this value is ignored if the variants argument is provided.
"""
super().read()
records = self.__iter__(region=region, samples=samples, variants=variants)
if variants is not None:
max_variants = len(variants)
# check whether we can preallocate memory instead of making copies
if max_variants is None:
self.log.warning(
"The max_variants parameter was not specified. We have no choice but to"
" append to an ever-growing array, which can lead to memory overuse!"
)
variants_arr = []
data_arr = []
for rec in records:
variants_arr.append(rec.variants)
data_arr.append(rec.data)
self.log.info(f"Copying {len(variants_arr)} variants into np arrays.")
# convert to np array for speedy operations later on
self.variants = np.array(variants_arr, dtype=self.variants.dtype)
self.data = np.array(data_arr, dtype=np.uint8)
else:
# preallocate arrays! this will save us lots of memory and speed b/c
# appends can sometimes make copies
self.variants = np.empty((max_variants,), dtype=self.variants.dtype)
# in order to check_phase() later, we must store the phase info, as well
self.data = np.empty(
(max_variants, len(self.samples), (2 + (not self._prephased))),
dtype=np.uint8,
)
num_seen = 0
for rec in records:
if num_seen >= max_variants:
break
self.variants[num_seen] = rec.variants
self.data[num_seen] = rec.data
num_seen += 1
if max_variants > num_seen:
self.log.info(
f"Removing {max_variants-num_seen} unneeded variant records that "
"were preallocated b/c max_variants was specified."
)
self.variants = self.variants[:num_seen]
self.data = self.data[:num_seen]
if 0 in self.data.shape:
self.log.warning(
"Failed to load genotypes. If you specified a region, check that the"
" contig name matches! For example, double-check the 'chr' prefix."
)
# transpose the GT matrix so that samples are rows and variants are columns
self.log.info(f"Transposing genotype matrix of size {self.data.shape}.")
self.data = self.data.transpose((1, 0, 2))

def __iter__(
self, region: str = None, samples: list[str] = None, variants: set[str] = None
) -> Iterator[namedtuple]:
"""
Read genotypes from a VCF line by line without storing anything
Parameters
----------
region : str, optional
See documentation for :py:meth:`~.Genotypes.read`
samples : list[str], optional
See documentation for :py:meth:`~.Genotypes.read`
variants : set[str], optional
See documentation for :py:meth:`~.Genotypes.read`
Returns
-------
Iterator[namedtuple]
See documentation for :py:meth:`~.Genotypes._iterate`
"""
vcf = VCF(str(self.fname), samples=samples, lazy=True)
self.samples = tuple(vcf.samples)
# call another function to force the lines above to be run immediately
# see https://stackoverflow.com/a/36726497
return self._iterate(vcf, region, variants)

def _iterate(self, vcf: VCF, region: str = None, variants: set[str] = None):
"""
A generator over the lines of a VCF
This is a helper function for :py:meth:`~.Genotypes.__iter__`
Parameters
----------
vcf: VCF
The cyvcf2.VCF object from which to fetch variant records
region : str, optional
See documentation for :py:meth:`~.Genotypes.read`
variants : set[str], optional
See documentation for :py:meth:`~.Genotypes.read`
Yields
------
Iterator[namedtuple]
An iterator over each line in the file, where each line is encoded as a
namedtuple containing each of the class properties
"""
self.log.info(f"Loading genotypes from {len(self.samples)} samples")
Record = namedtuple("Record", "data variants")
# iterable used to collect records
vcfiter = vcf(region)
tr_records = trh.TRRecordHarmonizer(vcffile=vcf, vcfiter=vcfiter, region=region)
num_seen = 0
# iterate over each line in the VCF
# note, this can take a lot of time if there are many samples
for variant in tr_records:
if variants is not None and variant.record_id not in variants:
if num_seen >= len(variants):
# exit early if we've already found all the variants
break
continue
# save meta information about each variant
variant_arr = self._variant_arr(variant)
# extract the genotypes to a matrix of size n x 3
# the last dimension has three items:
# 1) presence of REF in strand one
# 2) presence of REF in strand two
# 3) whether the genotype is phased (if self._prephased is False)
# Check
try:
data = np.array(variant.vcfrecord.genotypes, dtype=np.uint8)

except ValueError:
self.log.warning(
"The current variant in the VCF contains genotypes that do not have"
" 2 alleles. "
+ "This will result in a significant slowdown due to iterating"
" over "
+ "all GTs and fixing the shape issue. Please update the VCF by "
+ "adding another allele to each GT with only one allele to fix the"
" slowdown."
)
data = []
for gt_sample in variant.vcfrecord.genotypes:
if len(gt_sample) == 2:
new_gt_sample = [gt_sample[0], -1, gt_sample[1]]
else:
new_gt_sample = gt_sample
data.append(new_gt_sample)
data = np.array(data, dtype=np.uint8)

data = data[:, : (2 + (not self._prephased))]
yield Record(data, variant_arr)
num_seen += 1
vcf.close()


class GenotypesPLINK(GenotypesVCF):
"""
A class for processing genotypes from a PLINK ``.pgen`` file
Expand Down
Loading

0 comments on commit 6257264

Please sign in to comment.