Skip to content

Commit

Permalink
removing internal_only params
Browse files Browse the repository at this point in the history
  • Loading branch information
sanjaychelliah committed Nov 29, 2023
1 parent 4c75dd6 commit 6f4b601
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 10 deletions.
4 changes: 2 additions & 2 deletions clarifai/constants/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
TRAINABLE_MODEL_TYPES = [
'visual-classifier', 'visual-detector', 'visual-segmenter', 'visual-anomaly-heatmap',
'visual-embedder', 'clusterer', 'text-classifier', 'embedding-classifier', 'text-to-text'
'visual-classifier', 'visual-detector', 'visual-segmenter', 'visual-embedder', 'clusterer',
'text-classifier', 'embedding-classifier', 'text-to-text'
]
14 changes: 9 additions & 5 deletions clarifai/utils/model_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ def response_to_model_params(response: MultiModelTypeResponse,
"""Converts the response from the API to a dictionary of model params for the given model type id."""
dict_response = MessageToDict(response)
params = {}
params["dataset_id"] = ""
params["dataset_version_id"] = ""
if model_type_id != "clusterer":
params["dataset_id"] = ""
params["dataset_version_id"] = ""
if model_type_id not in ["clusterer", "text-to-text"]:
params["concepts"] = []
params["train_params"] = dict()
Expand All @@ -38,7 +39,7 @@ def response_to_model_params(response: MultiModelTypeResponse,
#removing the fields which are not required
if (_path[0] in ["'eval_info'"]) or (_path[1] in ["dataset", "data"]) or (_path[-1] in [
"dataset_id", "dataset_version_id"
]):
]) or ("internalOnly" in modeltypefield.keys()):
continue
#checking the template model type fields
if _path[-1] != "template":
Expand Down Expand Up @@ -66,6 +67,8 @@ def response_to_model_params(response: MultiModelTypeResponse,
params['train_params']["template"] = modeltypeenum['id']
#iterate through the template fields
for modeltypeenumfield in modeltypeenum['modelTypeFields']:
if "internalOnly" in modeltypeenumfield.keys():
continue
try:
params["train_params"][modeltypeenumfield['path'].split('.')[
-1]] = modeltypeenumfield['defaultValue']
Expand Down Expand Up @@ -103,8 +106,9 @@ def params_parser(params_dict: dict) -> Dict[str, Any]:
del params_dict['train_params']['base_embed_model']

train_dict["train_info"]['params'].update(params_dict["train_params"])
train_dict["train_info"]['params']['dataset_id'] = params_dict['dataset_id']
train_dict["train_info"]['params']['dataset_version_id'] = params_dict['dataset_version_id']
if 'dataset_id' in params_dict.keys():
train_dict["train_info"]['params']['dataset_id'] = params_dict['dataset_id']
train_dict["train_info"]['params']['dataset_version_id'] = params_dict['dataset_version_id']
train_dict['train_info'] = resources_pb2.TrainInfo(**train_dict['train_info'])

if 'concepts' in params_dict.keys():
Expand Down
6 changes: 3 additions & 3 deletions tests/test_model_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_model_templates(self):
model_types = self.app.list_trainable_model_types()
templates = self.visual_classifier_model.list_training_templates()
assert self.visual_classifier_model.model_type_id == 'visual-classifier' #create model test
assert len(model_types) == 9 #list trainable model types test
assert len(model_types) == 8 #list trainable model types test
assert len(templates) >= 11 #list training templates test

def test_model_params(self):
Expand Down Expand Up @@ -69,8 +69,8 @@ def test_model_train(self, caplog):
concepts = [concept.id for concept in self.app.list_concepts()]
self.text_classifier_model.get_params(
template='HF_GPTNeo_125m_lora', save_to='tests/assets/model_params.yaml')
param_info = self.text_classifier_model.get_param_info(param='num_gpus')
assert param_info['param'] == 'num_gpus' #test get param info
param_info = self.text_classifier_model.get_param_info(param='tokenizer_config')
assert param_info['param'] == 'tokenizer_config' #test get param info
assert len(concepts) == 2 #test data upload for training
self.text_classifier_model.update_params(
dataset_id=CREATE_DATASET_ID,
Expand Down

0 comments on commit 6f4b601

Please sign in to comment.