Skip to content
This repository has been archived by the owner on Dec 21, 2022. It is now read-only.

Commit

Permalink
Simplify index handling for GenomicDataset.
Browse files Browse the repository at this point in the history
  • Loading branch information
tshauck committed Jun 22, 2019
1 parent e233162 commit 3213a9e
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 54 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
- `Alphabet.encoded_end`
- `Alphabet.encoded_padding`
- Remove uniprot dataset creation.
- Simplify index handling for GenomicDataset.

## 0.6.1 (2019-06-10)

Expand Down
54 changes: 16 additions & 38 deletions gcgc/ml/pytorch_utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
"""Objects and methods for dealing with PyTorch data."""

from pathlib import Path
from typing import Dict, Sequence
from typing import Sequence

from Bio import File, SeqIO
from Bio import File
from Bio import SeqIO
import torch
import torch.utils.data

Expand All @@ -15,30 +16,10 @@
from gcgc.parser.gcgc_record import GCGCRecord


class _SequenceIndexer(object):
"""A helper object that is used to index multiple files at ones."""

def __init__(self):
"""Initialize the _SequenceIndexer object."""
self._record_index = {}
self._counter = -1

def __call__(self, sid):
"""Return the record index if it's known, otherwise generate a new one."""
try:
return self._record_index[sid]
except KeyError:
self._counter = self._counter + 1
self._record_index[sid] = self._counter
return self._record_index[sid]


class GenomicDataset(torch.utils.data.Dataset):
"""GenomicDataset can be used to load sequence information into a format aminable to PyTorch."""

def __init__(
self, file_index: Dict[Path, File._IndexedSeqFileDict], parser: TorchSequenceParser
):
def __init__(self, file_index: File._SQLiteManySeqFilesDict, parser: TorchSequenceParser):
"""Initialize the GenomicDataset object."""

self._file_index = file_index
Expand All @@ -63,32 +44,29 @@ def from_paths(
cls,
path_sequence: Sequence[Path],
parser: TorchSequenceParser,
file_format="fasta",
file_format: str = "fasta",
alphabet: EncodingAlphabet = ExtendedIUPACDNAEncoding(),
index_db: str = ":memory:",
**kwargs,
) -> "GenomicDataset":
"""Initialize the GenomicDataset from a pathlib.Path sequence."""

file_index = {}
si = _SequenceIndexer()

for f in sorted(path_sequence):
file_index[f] = SeqIO.index(str(f), file_format, key_function=si, alphabet=alphabet)

file_index = SeqIO.index_db(
index_db, [str(p) for p in path_sequence], file_format, alphabet=alphabet, **kwargs
)
return cls(file_index, parser)

def __len__(self) -> int:
"""Return the length of the dataset."""

return sum(len(v) for v in self._file_index.values())
return len(self._file_index)

def __getitem__(self, i: int):
"""Get the record from the index."""

for k, v in self._file_index.items():
try:
r = GCGCRecord(path=k, seq_record=v[i])
return self._parser.parse_record(r)
except KeyError:
pass
qry = "SELECT key, file_number FROM offset_data LIMIT 1 OFFSET ?;"
key, file_number = self._file_index._con.execute(qry, (i,)).fetchone()
file_name = Path(self._file_index._filenames[file_number])

raise RuntimeError(f"Exausted file index while looking for {i}.")
r = GCGCRecord(path=file_name, seq_record=self._file_index[key])
return self._parser.parse_record(r)
37 changes: 21 additions & 16 deletions gcgc/tests/third_party/pytorch_utils/test_pytorch_utils.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,37 @@
# (c) Copyright 2018 Trent Hauck
# All Rights Reserved

import unittest
import tempfile
import pathlib

from gcgc.alphabet.iupac import IUPACProteinEncoding
from gcgc.ml.pytorch_utils.data import GenomicDataset
from gcgc.parser import SequenceParser
from gcgc.tests.fixtures import ECOLI_PATH, P53_HUMAN
from gcgc.tests.fixtures import ECOLI_PATH
from gcgc.tests.fixtures import P53_HUMAN

SP = SequenceParser()

class TestPyTorchUtils(unittest.TestCase):
def setUp(self) -> None:
self.sp = SequenceParser()

def test_load_dataset(self):
def yielder():
yield P53_HUMAN
def test_load_dataset():
def yielder():
yield P53_HUMAN

test_dataset = GenomicDataset.from_paths(yielder(), self.sp, "fasta")
self.assertEqual(len(test_dataset), 1)
test_dataset = GenomicDataset.from_paths(yielder(), SP, "fasta")
assert len(test_dataset) == 1

def test_index_multiple_files(self):

glob = ECOLI_PATH.glob("*.fasta")
def test_index_multiple_files():

pe = IUPACProteinEncoding()
test_dataset = GenomicDataset.from_paths(glob, self.sp, "fasta", pe)
self.assertEqual(len(test_dataset), 25)
glob = ECOLI_PATH.glob("*.fasta")

pe = IUPACProteinEncoding()

with tempfile.TemporaryDirectory() as tmpdir:
db_path = pathlib.Path(tmpdir) / 'test.db'
test_dataset = GenomicDataset.from_paths(glob, SP, "fasta", pe, str(db_path))

assert len(test_dataset) == 25

test_sequences = {
0: "sp|C5A0C3|ARGE_ECOBW",
Expand All @@ -36,4 +41,4 @@ def test_index_multiple_files(self):

for idx, expected_id in test_sequences.items():
actual_record = test_dataset[idx]
self.assertEqual(actual_record["id"], expected_id)
assert actual_record["id"] == expected_id

0 comments on commit 3213a9e

Please sign in to comment.