diff --git a/peptdeep/hla/hla_class1.py b/peptdeep/hla/hla_class1.py index 780b95b8..467a8321 100644 --- a/peptdeep/hla/hla_class1.py +++ b/peptdeep/hla/hla_class1.py @@ -20,16 +20,6 @@ nonspecific_digest_cat_proteins, ) - -_model_zip_name = global_settings['local_hla_model_zip_name'] -_model_url = global_settings['hla_model_url'] -_model_zip = os.path.join( - pretrain_dir, _model_zip_name -) - -if not os.path.exists(_model_zip): - download_models(url=_model_url, target_path=_model_zip) - class HLA_Class_I_LSTM(torch.nn.Module): """ HLA-I-binding peptide prediction model using LSTM. @@ -132,6 +122,13 @@ class HLA1_Binding_Classifier(ModelInterface): """ Class to predict HLA-binding probabilities of peptides. """ + + _model_zip_name = global_settings['local_hla_model_zip_name'] + _model_url = global_settings['hla_model_url'] + _model_zip = os.path.join( + pretrain_dir, _model_zip_name + ) + def __init__(self, dropout:float=0.1, model_class:type=HLA_Class_I_LSTM, # model defined above @@ -398,11 +395,16 @@ def predict_from_proteins(self, ) return peptide_df + def _download_pretrained_hla_model(self): + download_models(url=self._model_url, target_path=self._model_zip) + def load_pretrained_hla_model(self): """ Load pretrained `HLA1_IEDB.pt` model. """ + if not os.path.exists(self._model_zip): + self._download_pretrained_hla_model() self.load( - model_file=_model_zip, + model_file=self._model_zip, model_path_in_zip="HLA1_IEDB.pt" ) diff --git a/peptdeep/hla/hla_utils.py b/peptdeep/hla/hla_utils.py index 208f71f6..8a876719 100644 --- a/peptdeep/hla/hla_utils.py +++ b/peptdeep/hla/hla_utils.py @@ -11,14 +11,14 @@ from alphabase.protein.fasta import load_all_proteins def load_prot_df( - protein_data:Union[str,list,dict], + protein_data:Union[str,list,tuple,set,dict], )->pd.DataFrame: """ Load protein dataframe from input protein_data. Parameters ---------- - protein_data : Union[str,list,dict] + protein_data : Union[str,list,tuple,set,dict] str: fasta file list (tuple, or set): a list of fasta files dict: protein dict @@ -27,17 +27,23 @@ def load_prot_df( ------- pd.DataFrame protein dataframe + + Raises + ------ + TypeError + protein_data type is not one of str, list, tuple, set, or dict. """ if isinstance(protein_data, str): protein_dict = load_all_proteins([protein_data]) elif isinstance(protein_data, (list,tuple,set)): protein_dict = load_all_proteins(protein_data) - elif isinstance(protein_data, str): - protein_dict = load_all_proteins([protein_data]) elif isinstance(protein_data, dict): protein_dict = protein_data else: - return pd.DataFrame() + raise TypeError( + "`protein_data` must be str, list, tuple, set or dict, " + f"`{type(protein_data)}` is given." + ) prot_df = pd.DataFrame().from_dict(protein_dict, orient='index') prot_df['nAA'] = prot_df.sequence.str.len() return prot_df diff --git a/peptdeep/pretrained_models.py b/peptdeep/pretrained_models.py index 03e36677..4dc1380d 100644 --- a/peptdeep/pretrained_models.py +++ b/peptdeep/pretrained_models.py @@ -83,11 +83,11 @@ def download_models( ---------- url : str, optional Remote or local path. - Defaults to `peptdeep.pretrained_models.model_url` + Defaults to :data:`peptdeep.pretrained_models.model_url` target_path : str, optional Target file path after download. - Defaults to `peptdeep.pretrained_models.model_zip` + Defaults to :data:`peptdeep.pretrained_models.model_zip` overwrite : bool, optional overwirte old model files. diff --git a/requirements/requirements_hla.txt b/requirements/requirements_hla.txt index 5658a041..f6012f1f 100644 --- a/requirements/requirements_hla.txt +++ b/requirements/requirements_hla.txt @@ -1 +1 @@ -pydivsufsort +pydivsufsort # used by alphabase.protein.lcp_digest