diff --git a/haptools/data/genotypes.py b/haptools/data/genotypes.py index c9fff622..890f65da 100644 --- a/haptools/data/genotypes.py +++ b/haptools/data/genotypes.py @@ -92,14 +92,21 @@ def read(self, region: str = None, samples: list[str] = None): # load all info into memory vcf = VCF(str(self.fname), samples=samples) self.samples = tuple(vcf.samples) - variants = list(vcf(region)) + self.variants = [] + self.data = [] + for variant in vcf(region): + # save meta information about each variant + self.variants.append((variant.ID, variant.CHROM, variant.POS, variant.aaf)) + # extract the genotypes to a matrix of size n x p x 3 + # the last dimension has three items: + # 1) presence of REF in strand one + # 2) presence of REF in strand two + # 3) whether the genotype is phased + self.data.append(variant.genotypes) vcf.close() - # save meta information about each variant + # convert to np array for speedy operations later on self.variants = np.array( - [ - (variant.ID, variant.CHROM, variant.POS, variant.aaf) - for variant in variants - ], + self.variants, dtype=[ ("id", "U50"), ("chrom", "U10"), @@ -107,14 +114,7 @@ def read(self, region: str = None, samples: list[str] = None): ("aaf", np.float64), ], ) - # extract the genotypes to a np matrix of size n x p x 3 - # the last dimension has three items: - # 1) presence of REF in strand one - # 2) presence of REF in strand two - # 3) whether the genotype is phased - self.data = np.array( - [variant.genotypes for variant in variants], dtype=np.uint8 - ) + self.data = np.array(self.data, dtype=np.uint8) if self.data.shape == (0, 0, 0): raise ValueError( "Failed to load genotypes. If you specified a region, check that the"