From 2a3e658ee4107232cc3a9e94c1b6ba68848793d0 Mon Sep 17 00:00:00 2001 From: "Hafiz, Aqib (PG/T - Comp Sci & Elec Eng)" <ah02821@surrey.ac.uk> Date: Thu, 23 May 2024 17:19:22 +0000 Subject: [PATCH] Upload train_model.py and remove .ipyn --- train_model.py | 176 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 176 insertions(+) create mode 100644 train_model.py diff --git a/train_model.py b/train_model.py new file mode 100644 index 0000000..2d539eb --- /dev/null +++ b/train_model.py @@ -0,0 +1,176 @@ +%pip install datasets +%pip install transformers +%pip install spacy +%pip install torch +%pip install spacy-transformers +%pip install transformers[torch] +%pip install seqeval +from datasets import load_dataset, load_metric +from transformers import AutoTokenizer, AutoModelForTokenClassification, TrainingArguments, Trainer, EarlyStoppingCallback +from transformers import DataCollatorForTokenClassification +import numpy as np +from seqeval.metrics import classification_report, f1_score, accuracy_score +import matplotlib.pyplot as plt +import seaborn as sns +import torch +from transformers import TrainingArguments, Trainer, EarlyStoppingCallback + +import numpy as np + +# Function to load the dataset with a fallback option +def load_dataset_with_fallback(dataset_name, fallback_name): + try: + dataset = load_dataset(dataset_name, download_mode="force_redownload") + except Exception as e: + print(f"Error loading dataset '{dataset_name}': {e}. Using '{fallback_name}' as a fallback.") + dataset = load_dataset(fallback_name, download_mode="force_redownload") + return dataset + +# Load the dataset with a fallback +dataset = load_dataset_with_fallback("surrey-nlp/PLOD-CW", "conll2003") + +# Load the DistilBERT tokenizer and model +tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") +model = AutoModelForTokenClassification.from_pretrained("distilbert-base-uncased", num_labels=4) + +# Extract the train, validation, and test datasets +short_dataset = dataset["train"] +val_dataset = dataset["validation"] +test_dataset = dataset["test"] + +# Tokenize the input data +tokenized_input = tokenizer(short_dataset["tokens"], is_split_into_words=True) + +# Example single sentence example +for token in tokenized_input["input_ids"]: + print(tokenizer.convert_ids_to_tokens(token)) + break + +# Define the label encoding +label_encoding = {"B-O": 0, "B-AC": 1, "B-LF": 2, "I-LF": 3} + +# Create the label lists for train, validation, and test datasets +label_list = [[label_encoding.get(tag, 0) for tag in sample] for sample in short_dataset["ner_tags"]] +val_label_list = [[label_encoding.get(tag, 0) for tag in sample] for sample in val_dataset["ner_tags"]] +test_label_list = [[label_encoding.get(tag, 0) for tag in sample] for sample in test_dataset["ner_tags"]] + +def tokenize_and_align_labels(short_dataset, list_name): + tokenized_inputs = tokenizer(short_dataset["tokens"], truncation=True, is_split_into_words=True) + + labels = [] + for i, label in enumerate(list_name): + word_ids = tokenized_inputs.word_ids(batch_index=i) + previous_word_idx = None + label_ids = [] + for word_idx in word_ids: + if word_idx is None: + label_ids.append(-100) + elif word_idx != previous_word_idx: + label_ids.append(label[word_idx]) + else: + label_ids.append(label[word_idx]) + previous_word_idx = word_idx + + labels.append(label_ids) + + tokenized_inputs["labels"] = labels + return tokenized_inputs + +# Tokenize and align labels for train, validation, and test datasets +tokenized_datasets = tokenize_and_align_labels(short_dataset, label_list) +tokenized_val_datasets = tokenize_and_align_labels(val_dataset, val_label_list) +tokenized_test_datasets = tokenize_and_align_labels(test_dataset, test_label_list) + +# Convert dictionary of lists into a list of dictionaries for training +def turn_dict_to_list_of_dict(d): + new_list = [] + for labels, inputs in zip(d["labels"], d["input_ids"]): + entry = {"input_ids": inputs, "labels": labels} + new_list.append(entry) + return new_list + +# Convert tokenized datasets +tokenised_train = turn_dict_to_list_of_dict(tokenized_datasets) +tokenised_val = turn_dict_to_list_of_dict(tokenized_val_datasets) +tokenised_test = turn_dict_to_list_of_dict(tokenized_test_datasets) + +# Load the DataCollator +data_collator = DataCollatorForTokenClassification(tokenizer) + +# Load the metric +metric = load_metric("seqeval") + +def compute_metrics(p): + predictions, labels = p + predictions = np.argmax(predictions, axis=2) + + true_predictions = [ + [label_list[p] for (p, l) in zip(prediction, label) if l != -100] + for prediction, label in zip(predictions, labels) + ] + true_labels = [ + [label_list[l] for (p, l) in zip(prediction, label) if l != -100] + for prediction, label in zip(predictions, labels) + ] + + results = metric.compute(predictions=true_predictions, references=true_labels) + return { + "precision": results["overall_precision"], + "recall": results["overall_recall"], + "f1": results["overall_f1"], + "accuracy": results["overall_accuracy"], + } + +# Define training arguments +model_name = "distilbert-base-uncased" +epochs = 6 +batch_size = 4 +learning_rate = 2e-5 + +args = TrainingArguments( + f"DistilALBERT-finetuned-NER", + evaluation_strategy="steps", + eval_steps=7000, + save_total_limit=3, + learning_rate=learning_rate, + per_device_train_batch_size=batch_size, + per_device_eval_batch_size=batch_size, + num_train_epochs=epochs, + weight_decay=0.001, + save_steps=35000, + metric_for_best_model='f1', + load_best_model_at_end=True +) + +# Create the Trainer +trainer = Trainer( + model, + args, + train_dataset=tokenised_train, + eval_dataset=tokenised_val, + data_collator=data_collator, + tokenizer=tokenizer, + compute_metrics=compute_metrics, + callbacks=[EarlyStoppingCallback(early_stopping_patience=3)] +) + +# Train the model +trainer.train() + +# Prepare the test data for evaluation +predictions, labels, _ = trainer.predict(tokenised_test) +predictions = np.argmax(predictions, axis=2) + +# Remove the predictions for the special tokens +true_predictions = [ + [label_list[p] for (p, l) in zip(prediction, label) if l != -100] + for prediction, label in zip(predictions, labels) +] +true_labels = [ + [label_list[l] for (p, l) in zip(prediction, label) if l != -100] + for prediction, label in zip(predictions, labels) +] + +# Compute the metrics on the test results +results = metric.compute(predictions=true_predictions, references=true_labels) +results -- GitLab