Skip to content

Commit

Permalink
fix search.py
Browse files Browse the repository at this point in the history
  • Loading branch information
dohyun-s committed Feb 20, 2023
1 parent b975e5f commit 7ad01b8
Showing 1 changed file with 47 additions and 3 deletions.
50 changes: 47 additions & 3 deletions colabfold/mmseqs/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,55 @@
from argparse import ArgumentParser
from pathlib import Path
from typing import List, Union
import os
import os, pandas

from colabfold.batch import get_queries, msa_to_str, get_queries_pairwise
from colabfold.inputs import get_queries, msa_to_str, parse_fasta
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

logger = logging.getLogger(__name__)

def get_queries_pairwise(
input_path: Union[str, Path], sort_queries_by: str = "length"
) -> Tuple[List[Tuple[str, str, Optional[List[str]]]], bool]:
"""Reads a directory of fasta files, a single fasta file or a csv file and returns a tuple
of job name, sequence and the optional a3m lines"""
input_path = Path(input_path)
if not input_path.exists():
raise OSError(f"{input_path} could not be found")
if input_path.is_file():
if input_path.suffix == ".csv" or input_path.suffix == ".tsv":
sep = "\t" if input_path.suffix == ".tsv" else ","
df = pandas.read_csv(input_path, sep=sep)
assert "id" in df.columns and "sequence" in df.columns
queries = [
(str(df["id"][0])+'&'+str(seq_id), [df["sequence"][0].upper(),sequence.upper()], None)
for i, (seq_id, sequence) in enumerate(df[["id", "sequence"]].itertuples(index=False)) if i!=0
]
for i in range(len(queries)):
if len(queries[i][1]) == 1:
queries[i] = (queries[i][0], queries[i][1][0], None)
elif input_path.suffix == ".a3m":
raise NotImplementedError()
elif input_path.suffix in [".fasta", ".faa", ".fa"]:
(sequences, headers) = parse_fasta(input_path.read_text())
queries = []
for i, (sequence, header) in enumerate(zip(sequences, headers)):
sequence = sequence.upper()
if sequence.count(":") == 0:
# Single sequence
if i==0:
continue
queries.append((headers[0]+'&'+header, [sequences[0],sequence], None))
else:
# Complex mode
queries.append((header, sequence.upper().split(":"), None))
else:
raise ValueError(f"Unknown file format {input_path.suffix}")
else:
raise NotImplementedError()

is_complex = True
return queries, is_complex

def run_mmseqs(mmseqs: Path, params: List[Union[str, Path]]):
params_log = " ".join(str(i) for i in params)
Expand Down Expand Up @@ -61,8 +104,9 @@ def mmseqs_search_monomer(
used_dbs.append(template_db)
if use_env:
used_dbs.append(metagenomic_db)

for db in used_dbs:
if str(db) == '.':
continue
if not dbbase.joinpath(f"{db}.dbtype").is_file():
raise FileNotFoundError(f"Database {db} does not exist")
if (
Expand Down

0 comments on commit 7ad01b8

Please sign in to comment.