From 6f4b60102ea113f5cc69ade88a8b1f1fd3f0d70b Mon Sep 17 00:00:00 2001 From: sanjaychelliah Date: Wed, 29 Nov 2023 12:58:39 +0530 Subject: [PATCH] removing internal_only params --- clarifai/constants/model.py | 4 ++-- clarifai/utils/model_train.py | 14 +++++++++----- tests/test_model_train.py | 6 +++--- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/clarifai/constants/model.py b/clarifai/constants/model.py index e7aac625..5419f6bd 100644 --- a/clarifai/constants/model.py +++ b/clarifai/constants/model.py @@ -1,4 +1,4 @@ TRAINABLE_MODEL_TYPES = [ - 'visual-classifier', 'visual-detector', 'visual-segmenter', 'visual-anomaly-heatmap', - 'visual-embedder', 'clusterer', 'text-classifier', 'embedding-classifier', 'text-to-text' + 'visual-classifier', 'visual-detector', 'visual-segmenter', 'visual-embedder', 'clusterer', + 'text-classifier', 'embedding-classifier', 'text-to-text' ] diff --git a/clarifai/utils/model_train.py b/clarifai/utils/model_train.py index db71dad0..0a637e04 100644 --- a/clarifai/utils/model_train.py +++ b/clarifai/utils/model_train.py @@ -24,8 +24,9 @@ def response_to_model_params(response: MultiModelTypeResponse, """Converts the response from the API to a dictionary of model params for the given model type id.""" dict_response = MessageToDict(response) params = {} - params["dataset_id"] = "" - params["dataset_version_id"] = "" + if model_type_id != "clusterer": + params["dataset_id"] = "" + params["dataset_version_id"] = "" if model_type_id not in ["clusterer", "text-to-text"]: params["concepts"] = [] params["train_params"] = dict() @@ -38,7 +39,7 @@ def response_to_model_params(response: MultiModelTypeResponse, #removing the fields which are not required if (_path[0] in ["'eval_info'"]) or (_path[1] in ["dataset", "data"]) or (_path[-1] in [ "dataset_id", "dataset_version_id" - ]): + ]) or ("internalOnly" in modeltypefield.keys()): continue #checking the template model type fields if _path[-1] != "template": @@ -66,6 +67,8 @@ def response_to_model_params(response: MultiModelTypeResponse, params['train_params']["template"] = modeltypeenum['id'] #iterate through the template fields for modeltypeenumfield in modeltypeenum['modelTypeFields']: + if "internalOnly" in modeltypeenumfield.keys(): + continue try: params["train_params"][modeltypeenumfield['path'].split('.')[ -1]] = modeltypeenumfield['defaultValue'] @@ -103,8 +106,9 @@ def params_parser(params_dict: dict) -> Dict[str, Any]: del params_dict['train_params']['base_embed_model'] train_dict["train_info"]['params'].update(params_dict["train_params"]) - train_dict["train_info"]['params']['dataset_id'] = params_dict['dataset_id'] - train_dict["train_info"]['params']['dataset_version_id'] = params_dict['dataset_version_id'] + if 'dataset_id' in params_dict.keys(): + train_dict["train_info"]['params']['dataset_id'] = params_dict['dataset_id'] + train_dict["train_info"]['params']['dataset_version_id'] = params_dict['dataset_version_id'] train_dict['train_info'] = resources_pb2.TrainInfo(**train_dict['train_info']) if 'concepts' in params_dict.keys(): diff --git a/tests/test_model_train.py b/tests/test_model_train.py index 561417d9..3f7c6e85 100644 --- a/tests/test_model_train.py +++ b/tests/test_model_train.py @@ -38,7 +38,7 @@ def test_model_templates(self): model_types = self.app.list_trainable_model_types() templates = self.visual_classifier_model.list_training_templates() assert self.visual_classifier_model.model_type_id == 'visual-classifier' #create model test - assert len(model_types) == 9 #list trainable model types test + assert len(model_types) == 8 #list trainable model types test assert len(templates) >= 11 #list training templates test def test_model_params(self): @@ -69,8 +69,8 @@ def test_model_train(self, caplog): concepts = [concept.id for concept in self.app.list_concepts()] self.text_classifier_model.get_params( template='HF_GPTNeo_125m_lora', save_to='tests/assets/model_params.yaml') - param_info = self.text_classifier_model.get_param_info(param='num_gpus') - assert param_info['param'] == 'num_gpus' #test get param info + param_info = self.text_classifier_model.get_param_info(param='tokenizer_config') + assert param_info['param'] == 'tokenizer_config' #test get param info assert len(concepts) == 2 #test data upload for training self.text_classifier_model.update_params( dataset_id=CREATE_DATASET_ID,