Skip to content

Commit

Permalink
Pushing changed tests and removing empty change
Browse files Browse the repository at this point in the history
  • Loading branch information
shubham-s-agarwal committed May 8, 2024
1 parent 563c3d4 commit decfbfb
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 17 deletions.
1 change: 0 additions & 1 deletion medcat/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ def save(self, save_path: str) -> None:
save_path(str): Where to save the created json file
"""
# We want to save the dict here, not the whole class

json_string = jsonpickle.encode(
{field: getattr(self, field) for field in self.fields()})

Expand Down
9 changes: 6 additions & 3 deletions medcat/meta_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,13 @@ def get_model(self, embeddings: Optional[Tensor]) -> nn.Module:
The module
"""
config = self.config
from medcat.utils.meta_cat.models import LSTM
from medcat.utils.meta_cat.models import BertForMetaAnnotation
if config.model['model_name'] == 'lstm':
model: Union[LSTM, BertForMetaAnnotation] = LSTM(embeddings, config)
from medcat.utils.meta_cat.models import LSTM
model: nn.Module = LSTM(embeddings, config)
logger.info("LSTM model used for classification")

elif config.model['model_name'] == 'bert':
from medcat.utils.meta_cat.models import BertForMetaAnnotation
model = BertForMetaAnnotation(config)

if not config.model.model_freeze_layers:
Expand Down Expand Up @@ -289,6 +289,9 @@ def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None, data
data = full_data
if self.config.model.phase_number == 1:
data = data_undersampled
if not t_config['auto_save_model']:
logger.info("For phase 1, model state has to be saved. Saving model...")
t_config['auto_save_model'] = True

report = train_model(self.model, data=data, config=self.config, save_dir_path=save_dir_path)

Expand Down
1 change: 0 additions & 1 deletion medcat/utils/meta_cat/ml_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ def print_report(epoch: int, running_loss: List, all_logits: List, y: Any, name:
name (str): The name of the report. Defaults to Train.
"""
if all_logits:
# print(classification_report(y, np.argmax(np.concatenate(all_logits, axis=0), axis=1)))
logger.info('Epoch: %d %s %s', epoch, "*" * 50, name)
logger.info(classification_report(y, np.argmax(np.concatenate(all_logits, axis=0), axis=1)))

Expand Down
24 changes: 12 additions & 12 deletions tests/test_meta_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,10 @@ def test_predict_spangroup(self):

n_meta_cat.config.general.span_group = None

def test_z_bert_meta_cat(self):

class MetaCATBertTest(MetaCATTests):
@classmethod
def setUpClass(cls) -> None:
tokenizer = TokenizerWrapperBERT(AutoTokenizer.from_pretrained('prajjwal1/bert-tiny'))
config = ConfigMetaCAT()
config.general['category_name'] = 'Status'
Expand All @@ -102,20 +105,17 @@ def test_z_bert_meta_cat(self):
config.train['batch_size'] = 64
config.model['model_name'] = 'bert'

self.meta_cat = MetaCAT(tokenizer=tokenizer, embeddings=None, config=config)
self.tmp_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "tmp")
cls.meta_cat: MetaCAT = MetaCAT(tokenizer=tokenizer, embeddings=None, config=config)
cls.tmp_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "tmp")
os.makedirs(cls.tmp_dir, exist_ok=True)

def test_two_phase(self):
self.meta_cat.config.model['phase_number'] = 1
self.test_train()
self.meta_cat.config.model['phase_number'] = 2
self.test_train()
self.test_save_load()
self.test_predict_spangroup()

def _test_two_phase():
self.meta_cat.config.model['phase_number'] = 1
self.test_train()
self.meta_cat.config.model['phase_number'] = 2
self.test_train()

_test_two_phase()
self.meta_cat.config.model['phase_number'] = 0


if __name__ == '__main__':
Expand Down

0 comments on commit decfbfb

Please sign in to comment.