Skip to content

Commit

Permalink
xlm-roberta-large
Browse files Browse the repository at this point in the history
  • Loading branch information
Ismat-Samadov committed Nov 5, 2024
1 parent b4d80f1 commit 206e574
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 74 deletions.
208 changes: 138 additions & 70 deletions models/XLM-RoBERTa.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -112,18 +112,31 @@
},
"outputs": [],
"source": [
"import os\n",
"import warnings\n",
"import numpy as np\n",
"import torch\n",
"import ast\n",
"from datasets import load_dataset\n",
"# Standard library imports\n",
"import os # Provides functions for interacting with the operating system\n",
"import warnings # Used to handle or suppress warnings\n",
"import numpy as np # Essential for numerical operations and array manipulation\n",
"import torch # PyTorch library for tensor computations and model handling\n",
"import ast # Used for safe evaluation of strings to Python objects (e.g., parsing tokens)\n",
"\n",
"# Hugging Face and Transformers imports\n",
"from datasets import load_dataset # Loads datasets for model training and evaluation\n",
"from transformers import (\n",
" AutoTokenizer, DataCollatorForTokenClassification, TrainingArguments, Trainer,\n",
" AutoModelForTokenClassification, get_linear_schedule_with_warmup, EarlyStoppingCallback\n",
" AutoTokenizer, # Initializes a tokenizer from a pre-trained model\n",
" DataCollatorForTokenClassification, # Handles padding and formatting of token classification data\n",
" TrainingArguments, # Defines training parameters like batch size and learning rate\n",
" Trainer, # High-level API for managing training and evaluation\n",
" AutoModelForTokenClassification, # Loads a pre-trained model for token classification tasks\n",
" get_linear_schedule_with_warmup, # Learning rate scheduler for gradual warm-up and linear decay\n",
" EarlyStoppingCallback # Callback to stop training if validation performance plateaus\n",
")\n",
"from huggingface_hub import login\n",
"from seqeval.metrics import precision_score, recall_score, f1_score, classification_report\n"
"\n",
"# Hugging Face Hub\n",
"from huggingface_hub import login # Allows logging in to Hugging Face Hub to upload models\n",
"\n",
"# seqeval metrics for NER evaluation\n",
"from seqeval.metrics import precision_score, recall_score, f1_score, classification_report\n",
"# Provides precision, recall, F1-score, and classification report for evaluating NER model performance\n"
]
},
{
Expand Down Expand Up @@ -160,7 +173,7 @@
],
"source": [
"# Log in to Hugging Face Hub\n",
"login(token=\"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\")\n"
"login(token=\"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\")\n"
]
},
{
Expand All @@ -171,8 +184,10 @@
},
"outputs": [],
"source": [
"# Disable unwanted logs and warnings\n",
"# Disable WandB (Weights & Biases) logging to avoid unwanted log outputs during training\n",
"os.environ[\"WANDB_DISABLED\"] = \"true\"\n",
"\n",
"# Suppress warning messages to keep output clean, especially during training and evaluation\n",
"warnings.filterwarnings(\"ignore\")\n"
]
},
Expand Down Expand Up @@ -337,22 +352,24 @@
}
],
"source": [
"# Load the dataset\n",
"# Load the Azerbaijani NER dataset from Hugging Face\n",
"dataset = load_dataset(\"LocalDoc/azerbaijani-ner-dataset\")\n",
"print(dataset)\n",
"print(dataset) # Display dataset structure (e.g., train/validation splits)\n",
"\n",
"# Preprocessing function for tokens and NER tags\n",
"# Preprocessing function to format tokens and NER tags correctly\n",
"def preprocess_example(example):\n",
" try:\n",
" # Convert string of tokens to a list and parse NER tags to integers\n",
" example[\"tokens\"] = ast.literal_eval(example[\"tokens\"])\n",
" example[\"ner_tags\"] = list(map(int, ast.literal_eval(example[\"ner_tags\"])))\n",
" except (ValueError, SyntaxError) as e:\n",
" # Skip and log malformed examples, ensuring error resilience\n",
" print(f\"Skipping malformed example: {example['index']} due to error: {e}\")\n",
" example[\"tokens\"] = []\n",
" example[\"ner_tags\"] = []\n",
" return example\n",
"\n",
"# Apply preprocessing\n",
"# Apply preprocessing to each dataset entry, ensuring consistent formatting\n",
"dataset = dataset.map(preprocess_example)\n"
]
},
Expand Down Expand Up @@ -507,32 +524,39 @@
}
],
"source": [
"# Initialize tokenizer\n",
"# Initialize the tokenizer for multilingual NER using XLM-RoBERTa\n",
"tokenizer = AutoTokenizer.from_pretrained(\"xlm-roberta-base\")\n",
"\n",
"# Function to tokenize input and align labels with tokenized words\n",
"def tokenize_and_align_labels(example):\n",
" # Tokenize the sentence while preserving word boundaries for correct NER tag alignment\n",
" tokenized_inputs = tokenizer(\n",
" example[\"tokens\"],\n",
" truncation=True,\n",
" is_split_into_words=True,\n",
" padding=\"max_length\",\n",
" max_length=128,\n",
" example[\"tokens\"], # List of words (tokens) in the sentence\n",
" truncation=True, # Truncate sentences longer than max_length\n",
" is_split_into_words=True, # Specify that input is a list of words\n",
" padding=\"max_length\", # Pad to maximum sequence length\n",
" max_length=128, # Set the maximum sequence length to 128 tokens\n",
" )\n",
" labels = []\n",
" word_ids = tokenized_inputs.word_ids()\n",
" previous_word_idx = None\n",
"\n",
" labels = [] # List to store aligned NER labels\n",
" word_ids = tokenized_inputs.word_ids() # Get word IDs for each token\n",
" previous_word_idx = None # Initialize previous word index for tracking\n",
"\n",
" # Loop through word indices to align NER tags with subword tokens\n",
" for word_idx in word_ids:\n",
" if word_idx is None:\n",
" labels.append(-100)\n",
" labels.append(-100) # Set padding token labels to -100 (ignored in loss)\n",
" elif word_idx != previous_word_idx:\n",
" # Assign the label from example's NER tags if word index matches\n",
" labels.append(example[\"ner_tags\"][word_idx] if word_idx < len(example[\"ner_tags\"]) else -100)\n",
" else:\n",
" labels.append(-100)\n",
" previous_word_idx = word_idx\n",
" tokenized_inputs[\"labels\"] = labels\n",
" labels.append(-100) # Label subword tokens with -100 to avoid redundant labels\n",
" previous_word_idx = word_idx # Update previous word index\n",
"\n",
" tokenized_inputs[\"labels\"] = labels # Add labels to tokenized inputs\n",
" return tokenized_inputs\n",
"\n",
"# Apply tokenization and alignment\n",
"# Apply tokenization and label alignment function to the dataset\n",
"tokenized_datasets = dataset.map(tokenize_and_align_labels, batched=False)\n"
]
},
Expand Down Expand Up @@ -575,9 +599,9 @@
}
],
"source": [
"# Create a 90-10 split for training and validation\n",
"# Create a 90-10 split of the dataset for training and validation\n",
"tokenized_datasets = tokenized_datasets[\"train\"].train_test_split(test_size=0.1)\n",
"print(tokenized_datasets)\n"
"print(tokenized_datasets) # Output structure of split datasets"
]
},
{
Expand All @@ -588,18 +612,33 @@
},
"outputs": [],
"source": [
"# Define a list of entity labels for NER tagging with B- (beginning) and I- (inside) markers\n",
"label_list = [\n",
" \"O\", \"B-PERSON\", \"I-PERSON\", \"B-LOCATION\", \"I-LOCATION\",\n",
" \"B-ORGANISATION\", \"I-ORGANISATION\", \"B-DATE\", \"I-DATE\",\n",
" \"B-TIME\", \"I-TIME\", \"B-MONEY\", \"I-MONEY\", \"B-PERCENTAGE\",\n",
" \"I-PERCENTAGE\", \"B-FACILITY\", \"I-FACILITY\", \"B-PRODUCT\",\n",
" \"I-PRODUCT\", \"B-EVENT\", \"I-EVENT\", \"B-ART\", \"I-ART\",\n",
" \"B-LAW\", \"I-LAW\", \"B-LANGUAGE\", \"I-LANGUAGE\", \"B-GPE\",\n",
" \"I-GPE\", \"B-NORP\", \"I-NORP\", \"B-ORDINAL\", \"I-ORDINAL\",\n",
" \"B-CARDINAL\", \"I-CARDINAL\", \"B-DISEASE\", \"I-DISEASE\",\n",
" \"B-CONTACT\", \"I-CONTACT\", \"B-ADAGE\", \"I-ADAGE\",\n",
" \"B-QUANTITY\", \"I-QUANTITY\", \"B-MISCELLANEOUS\", \"I-MISCELLANEOUS\",\n",
" \"B-POSITION\", \"I-POSITION\", \"B-PROJECT\", \"I-PROJECT\"\n",
" \"O\", # Outside of a named entity\n",
" \"B-PERSON\", \"I-PERSON\", # Person name (e.g., \"John\" in \"John Doe\")\n",
" \"B-LOCATION\", \"I-LOCATION\", # Geographical location (e.g., \"Paris\")\n",
" \"B-ORGANISATION\", \"I-ORGANISATION\", # Organization name (e.g., \"UNICEF\")\n",
" \"B-DATE\", \"I-DATE\", # Date entity (e.g., \"2024-11-05\")\n",
" \"B-TIME\", \"I-TIME\", # Time (e.g., \"12:00 PM\")\n",
" \"B-MONEY\", \"I-MONEY\", # Monetary values (e.g., \"$20\")\n",
" \"B-PERCENTAGE\", \"I-PERCENTAGE\", # Percentage values (e.g., \"20%\")\n",
" \"B-FACILITY\", \"I-FACILITY\", # Physical facilities (e.g., \"Airport\")\n",
" \"B-PRODUCT\", \"I-PRODUCT\", # Product names (e.g., \"iPhone\")\n",
" \"B-EVENT\", \"I-EVENT\", # Named events (e.g., \"Olympics\")\n",
" \"B-ART\", \"I-ART\", # Works of art (e.g., \"Mona Lisa\")\n",
" \"B-LAW\", \"I-LAW\", # Laws and legal documents (e.g., \"Article 50\")\n",
" \"B-LANGUAGE\", \"I-LANGUAGE\", # Languages (e.g., \"Azerbaijani\")\n",
" \"B-GPE\", \"I-GPE\", # Geopolitical entities (e.g., \"Europe\")\n",
" \"B-NORP\", \"I-NORP\", # Nationalities, religious groups, political groups\n",
" \"B-ORDINAL\", \"I-ORDINAL\", # Ordinal indicators (e.g., \"first\", \"second\")\n",
" \"B-CARDINAL\", \"I-CARDINAL\", # Cardinal numbers (e.g., \"three\")\n",
" \"B-DISEASE\", \"I-DISEASE\", # Diseases (e.g., \"COVID-19\")\n",
" \"B-CONTACT\", \"I-CONTACT\", # Contact info (e.g., email or phone number)\n",
" \"B-ADAGE\", \"I-ADAGE\", # Common sayings or adages\n",
" \"B-QUANTITY\", \"I-QUANTITY\", # Quantities (e.g., \"5 km\")\n",
" \"B-MISCELLANEOUS\", \"I-MISCELLANEOUS\", # Miscellaneous entities not fitting other categories\n",
" \"B-POSITION\", \"I-POSITION\", # Job titles or positions (e.g., \"CEO\")\n",
" \"B-PROJECT\", \"I-PROJECT\" # Project names (e.g., \"Project Apollo\")\n",
"]\n"
]
},
Expand Down Expand Up @@ -662,13 +701,13 @@
}
],
"source": [
"# Initialize the data collator\n",
"# Initialize a data collator to handle padding and formatting for token classification\n",
"data_collator = DataCollatorForTokenClassification(tokenizer)\n",
"\n",
"# Load the model\n",
"# Load a pre-trained model for token classification, adapted for NER tasks\n",
"model = AutoModelForTokenClassification.from_pretrained(\n",
" \"xlm-roberta-base\",\n",
" num_labels=len(label_list)\n",
" \"xlm-roberta-large\", # Base model (multilingual XLM-RoBERTa) for NER\n",
" num_labels=len(label_list) # Set the number of output labels to match NER categories\n",
")\n"
]
},
Expand All @@ -680,18 +719,35 @@
},
"outputs": [],
"source": [
"# Define a function to compute evaluation metrics for the model's predictions\n",
"def compute_metrics(p):\n",
" predictions, labels = p\n",
" predictions, labels = p # Unpack predictions and true labels from the input\n",
"\n",
" # Convert logits to predicted label indices by taking the argmax along the last axis\n",
" predictions = np.argmax(predictions, axis=2)\n",
"\n",
" # Filter out special padding labels (-100) and convert indices to label names\n",
" true_labels = [[label_list[l] for l in label if l != -100] for label in labels]\n",
" true_predictions = [\n",
" [label_list[p] for (p, l) in zip(prediction, label) if l != -100]\n",
" for prediction, label in zip(predictions, labels)\n",
" ]\n",
"\n",
" # Print a detailed classification report for each label category\n",
" print(classification_report(true_labels, true_predictions))\n",
"\n",
" # Calculate and return key evaluation metrics\n",
" return {\n",
" # Precision measures the accuracy of predicted positive instances\n",
" # Important in NER to ensure entity predictions are correct and reduce false positives.\n",
" \"precision\": precision_score(true_labels, true_predictions),\n",
"\n",
" # Recall measures the model's ability to capture all relevant entities\n",
" # Essential in NER to ensure the model captures all entities, reducing false negatives.\n",
" \"recall\": recall_score(true_labels, true_predictions),\n",
"\n",
" # F1-score is the harmonic mean of precision and recall, balancing both metrics\n",
" # Useful in NER for providing an overall performance measure, especially when precision and recall are both important.\n",
" \"f1\": f1_score(true_labels, true_predictions),\n",
" }\n"
]
Expand All @@ -704,21 +760,22 @@
},
"outputs": [],
"source": [
"# Set up training arguments for model training, defining essential training configurations\n",
"training_args = TrainingArguments(\n",
" output_dir=\"./results\",\n",
" evaluation_strategy=\"epoch\", # Evaluate at the end of each epoch\n",
" save_strategy=\"epoch\", # Save the model at the end of each epoch\n",
" learning_rate=1e-5,\n",
" per_device_train_batch_size=64,\n",
" per_device_eval_batch_size=64,\n",
" num_train_epochs=8,\n",
" weight_decay=0.01,\n",
" fp16=True,\n",
" logging_dir='./logs',\n",
" save_total_limit=2,\n",
" load_best_model_at_end=True, # Load the best model at the end of training\n",
" metric_for_best_model=\"f1\",\n",
" report_to=\"none\"\n",
" output_dir=\"./results\", # Directory to save model checkpoints and final outputs\n",
" evaluation_strategy=\"epoch\", # Evaluate model on the validation set at the end of each epoch\n",
" save_strategy=\"epoch\", # Save model checkpoints at the end of each epoch\n",
" learning_rate=2e-5, # Set a low learning rate to ensure stable training for fine-tuning\n",
" per_device_train_batch_size=128, # Number of examples per batch during training, balancing speed and memory\n",
" per_device_eval_batch_size=128, # Number of examples per batch during evaluation\n",
" num_train_epochs=12, # Number of full training passes over the dataset\n",
" weight_decay=0.005, # Regularization term to prevent overfitting by penalizing large weights\n",
" fp16=True, # Use 16-bit floating point for faster and memory-efficient training\n",
" logging_dir='./logs', # Directory to store training logs\n",
" save_total_limit=2, # Keep only the 2 latest model checkpoints to save storage space\n",
" load_best_model_at_end=True, # Load the best model based on metrics at the end of training\n",
" metric_for_best_model=\"f1\", # Use F1-score to determine the best model checkpoint\n",
" report_to=\"none\" # Disable reporting to external services (useful in local runs)\n",
")\n"
]
},
Expand All @@ -730,15 +787,16 @@
},
"outputs": [],
"source": [
"# Initialize the Trainer class to manage the training loop with all necessary components\n",
"trainer = Trainer(\n",
" model=model,\n",
" args=training_args,\n",
" train_dataset=tokenized_datasets[\"train\"],\n",
" eval_dataset=tokenized_datasets[\"test\"],\n",
" tokenizer=tokenizer,\n",
" data_collator=data_collator,\n",
" compute_metrics=compute_metrics,\n",
" callbacks=[EarlyStoppingCallback(early_stopping_patience=2)]\n",
" model=model, # The pre-trained model to be fine-tuned\n",
" args=training_args, # Training configuration parameters defined in TrainingArguments\n",
" train_dataset=tokenized_datasets[\"train\"], # Tokenized training dataset\n",
" eval_dataset=tokenized_datasets[\"test\"], # Tokenized validation dataset\n",
" tokenizer=tokenizer, # Tokenizer used for processing input text\n",
" data_collator=data_collator, # Data collator for padding and batching during training\n",
" compute_metrics=compute_metrics, # Function to calculate evaluation metrics like precision, recall, F1\n",
" callbacks=[EarlyStoppingCallback(early_stopping_patience=5)] # Stop training early if validation metrics don't improve for 2 epochs\n",
")\n"
]
},
Expand Down Expand Up @@ -1036,8 +1094,13 @@
}
],
"source": [
"# Begin the training process and capture the training metrics\n",
"training_metrics = trainer.train()\n",
"\n",
"# Evaluate the model on the validation set after training\n",
"eval_results = trainer.evaluate()\n",
"\n",
"# Print evaluation results, including precision, recall, and F1-score\n",
"print(eval_results)\n"
]
},
Expand Down Expand Up @@ -1078,8 +1141,13 @@
}
],
"source": [
"# Define the directory where the trained model and tokenizer will be saved\n",
"save_directory = \"./XLM-RoBERTa\"\n",
"\n",
"# Save the trained model to the specified directory\n",
"model.save_pretrained(save_directory)\n",
"\n",
"# Save the tokenizer to the same directory for compatibility with the model\n",
"tokenizer.save_pretrained(save_directory)\n"
]
},
Expand Down
4 changes: 2 additions & 2 deletions models/push_to_HF.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
login(token=hf_token)

# Define your repository ID
repo_id = "IsmatS/mbert-az-ner"
repo_id = "IsmatS/xlm_roberta_large_az_ner"

# Initialize HfApi and upload the model folder
api = HfApi()
api.upload_folder(folder_path="./mbert-azerbaijani-ner", path_in_repo="", repo_id=repo_id)
api.upload_folder(folder_path="./xlm-roberta-large", path_in_repo="", repo_id=repo_id)
4 changes: 2 additions & 2 deletions models/xlm_roberta_large.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2450,7 +2450,7 @@
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "d97ab405-ad70-4c27-a1aa-50690f429ea3"
"outputId": "d8184694-0ab9-44e4-9b4e-859cd2ea6188"
},
"execution_count": null,
"outputs": [
Expand All @@ -2466,7 +2466,7 @@
]
},
"metadata": {},
"execution_count": 18
"execution_count": 19
}
]
},
Expand Down

0 comments on commit 206e574

Please sign in to comment.