Skip to content

Commit

Permalink
simplify the transform method of the Haplotype and Haplotypes classes
Browse files Browse the repository at this point in the history
  • Loading branch information
aryarm committed Jun 16, 2022
1 parent a0ae5e1 commit f2a3017
Showing 1 changed file with 24 additions and 40 deletions.
64 changes: 24 additions & 40 deletions haptools/data/haplotypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def _fmt(self):
@property
# TODO: use @cached_property in py3.8
def varIDs(self):
return {var.id for var in self.variants}
return tuple(var.id for var in self.variants)

@classmethod
def from_hap_spec(
Expand Down Expand Up @@ -334,9 +334,7 @@ def extras_head(cls) -> tuple:
"""
return tuple(extra.to_hap_spec("H") for extra in cls._extras)

def transform(
self, genotypes: GenotypesRefAlt, samples: list[str] = None
) -> npt.NDArray[bool]:
def transform(self, genotypes: GenotypesRefAlt) -> npt.NDArray[bool]:
"""
Transform a genotypes matrix via the current haplotype
Expand All @@ -350,8 +348,6 @@ def transform(
If the genotypes have not been loaded into the Genotypes object yet, this
method will call Genotypes.read(), while loading only the needed variants
samples : list[str], optional
See documentation for :py:attr:`~.Genotypes.read`
Returns
-------
Expand All @@ -360,34 +356,27 @@ def transform(
denotes the presence of the haplotype in one chromosome of a sample
"""
var_IDs = self.varIDs
# check: have the genotypes been loaded yet?
# if not, we can load just the variants we need
if genotypes.unset():
start = min(var.start for var in self.variants)
end = max(var.end for var in self.variants)
region = f"{self.chrom}:{start}-{end}"
genotypes.read(region=region, samples=samples, variants=var_IDs)
genotypes.check_biallelic(discard_also=True)
genotypes.check_phase()
# create a dict where the variants are keyed by ID
var_dict = {
var["id"]: var["ref"] for var in genotypes.variants if var["id"] in var_IDs
}
var_idxs = [
idx for idx, var in enumerate(genotypes.variants) if var["id"] in var_IDs
]
missing_IDs = var_IDs - var_dict.keys()
if len(missing_IDs):
gts = genotypes.subset(variants=var_IDs)
# check: were any of the variants absent from the genotypes?
if len(gts.variants) < len(var_IDs):
missing_IDs = set(var_IDs) - set(gts.variants["id"])
raise ValueError(
f"Variants {missing_IDs} are present in haplotype '{self.id}' but "
"absent in the provided genotypes"
)
# create a np array denoting the alleles that we want
alleles = [int(var.allele != var_dict[var.id]) for var in self.variants]
allele_arr = np.array([[[al] for al in alleles]]) # shape: (1, n, 1)
# note: the excessive use of square-brackets gives us shape (1, n, 1)
allele_arr = np.array(
[
[
[int(var.allele != gts.variants[i]["ref"])]
for i, var in enumerate(self.variants)
]
]
)
# look for the presence of each allele in each chromosomal strand
# and then just AND them together
return np.all(allele_arr == genotypes.data[:, var_idxs], axis=1)
return np.all(allele_arr == gts.data, axis=1)


class Haplotypes(Data):
Expand Down Expand Up @@ -751,9 +740,7 @@ def write(self):
def transform(
self,
genotypes: GenotypesRefAlt,
hap_gts: GenotypesRefAlt,
samples: list[str] = None,
low_memory: bool = False,
hap_gts: GenotypesRefAlt = None,
) -> GenotypesRefAlt:
"""
Transform a genotypes matrix via the current haplotype
Expand All @@ -765,35 +752,32 @@ def transform(
----------
genotypes : GenotypesRefAlt
The genotypes which to transform using the current haplotype
If the genotypes have not been loaded into the Genotypes object yet, this
method will call Genotypes.read(), while loading only the needed variants
hap_gts: GenotypesRefAlt
An empty GenotypesRefAlt object into which the haplotype genotypes should
be stored
samples : list[str], optional
See documentation for :py:attr:`~.Genotypes.read`
low_memory : bool, optional
If True, each haplotype's genotypes will be loaded one at a time.
Returns
-------
GenotypesRefAlt
A Genotypes object composed of haplotypes instead of regular variants.
"""
# Initialize GenotypesRefAlt return value
if hap_gts is None:
hap_gts = GenotypesRefAlt(fname=None, log=self.log)
hap_gts.samples = genotypes.samples
hap_gts.variants = np.array(
[(hap.id, hap.chrom, hap.start, 0, "A", "T") for hap in self.data.values()],
dtype=hap_gts.variants.dtype,
)
# Obtain and merge the haplotype genotypes
self.log.info(
f"Transforming a set of genotypes from {len(genotypes.variants)} total "
f"variants with a list of {len(self.data)} haplotypes"
)
hap_gts.data = np.concatenate(
tuple(
hap.transform(genotypes, samples)[:, np.newaxis]
for hap in self.data.values()
hap.transform(genotypes)[:, np.newaxis] for hap in self.data.values()
),
axis=1,
).astype(np.uint8)
).astype(genotypes.data.dtype)
return hap_gts

0 comments on commit f2a3017

Please sign in to comment.