Skip to content

Commit

Permalink
Pushing change for lazy logging
Browse files Browse the repository at this point in the history
  • Loading branch information
shubham-s-agarwal committed Sep 4, 2024
1 parent c83780e commit 231cccb
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 8 deletions.
3 changes: 1 addition & 2 deletions medcat/meta_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,7 @@ def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None, data
# Make sure the config number of classes is the same as the one found in the data
if len(category_value2id) != self.config.model['nclasses']:
logger.warning(
"The number of classes set in the config is not the same as the one found in the data: {} vs {}".format(
self.config.model['nclasses'], len(category_value2id)))
"The number of classes set in the config is not the same as the one found in the data: %d vs %d" % (self.config.model['nclasses'], len(category_value2id)))
logger.warning("Auto-setting the nclasses value in config and rebuilding the model.")
self.config.model['nclasses'] = len(category_value2id)

Expand Down
6 changes: 3 additions & 3 deletions medcat/utils/meta_cat/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict
for k in keys_ls:
category_value2id_[k] = len(category_value2id_)

logger.warning("Labels found with 0 data; updates made\nFinal label encoding mapping:", category_value2id_)
logger.warning("Labels found with 0 data; updates made\nFinal label encoding mapping: %s" %category_value2id_)
category_value2id = category_value2id_

for c in category_values:
Expand All @@ -211,7 +211,7 @@ def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict
if data[i][2] in category_value2id.values():
label_data_[data[i][2]] = label_data_[data[i][2]] + 1

logger.info(f"Original label_data: {label_data_}")
logger.info("Original label_data: %s" %label_data_)
# Undersampling data
if category_undersample is None or category_undersample == '':
min_label = min(label_data_.values())
Expand All @@ -234,7 +234,7 @@ def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict
for i in range(len(data_undersampled)):
if data_undersampled[i][2] in category_value2id.values():
label_data[data_undersampled[i][2]] = label_data[data_undersampled[i][2]] + 1
logger.info(f"Updated label_data: {label_data}")
logger.info("Updated label_data: %s" %label_data)

return data, data_undersampled, category_value2id

Expand Down
4 changes: 2 additions & 2 deletions medcat/utils/meta_cat/ml_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def train_model(model: nn.Module, data: List, config: ConfigMetaCAT, save_dir_pa
y_ = [x[2] for x in train_data]
class_weights = compute_class_weight(class_weight="balanced", classes=np.unique(y_), y=y_)
config.train['class_weights'] = class_weights.tolist()
logger.info(f"Class weights computed: {class_weights}")
logger.info(f"Class weights computed: %s" %class_weights)

class_weights = torch.FloatTensor(class_weights).to(device)
if config.train['loss_funct'] == 'cross_entropy':
Expand Down Expand Up @@ -259,7 +259,7 @@ def initialize_model(classifier, data_, batch_size_, lr_, epochs=4):

# Total number of training steps
total_steps = int((len(data_) / batch_size_) * epochs)
logger.info('Total steps for optimizer: {}'.format(total_steps))
logger.info('Total steps for optimizer: %d' %total_steps)

# Set up the learning rate scheduler
scheduler_ = get_linear_schedule_with_warmup(optimizer_,
Expand Down
2 changes: 1 addition & 1 deletion medcat/utils/meta_cat/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __init__(self, config):
super(BertForMetaAnnotation, self).__init__()
_bertconfig = AutoConfig.from_pretrained(config.model.model_variant,num_hidden_layers=config.model['num_layers'])
if config.model['input_size'] != _bertconfig.hidden_size:
logger.warning(f"\nInput size for {config.model.model_variant} model should be {_bertconfig.hidden_size}, provided input size is {config.model['input_size']} Input size changed to {_bertconfig.hidden_size}")
logger.warning("Input size for %s model should be %d, provided input size is %d. Input size changed to %d" %(config.model.model_variant,_bertconfig.hidden_size,config.model['input_size'],_bertconfig.hidden_size))

bert = BertModel.from_pretrained(config.model.model_variant, config=_bertconfig)
self.config = config
Expand Down

0 comments on commit 231cccb

Please sign in to comment.