diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..1c0fef14e695222015d684871fb24d8aadc8d8ec --- /dev/null +++ b/app.py @@ -0,0 +1,138 @@ +import streamlit as st +from datasets import load_dataset, load_metric +from transformers import AutoTokenizer, AutoModelForTokenClassification, RobertaTokenizerFast +import torch +import numpy as np +import matplotlib.pyplot as plt + +# Load your custom model and tokenizer +@st.cache_resource +def load_model_and_tokenizer(): + tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base", add_prefix_space=True) + model = AutoModelForTokenClassification.from_pretrained("roberta-base", num_labels=4) + return model, tokenizer + +model, tokenizer = load_model_and_tokenizer() + +# Load dataset +@st.cache_resource +def load_datasets(): + dataset = load_dataset("surrey-nlp/PLOD-CW") + return dataset["train"], dataset["validation"], dataset["test"] + +train_dataset, val_dataset, test_dataset = load_datasets() + +# Label encoding +label_encoding = {"B-O": 0, "B-AC": 1, "B-LF": 2, "I-LF": 3} +label_decoding = {v: k for k, v in label_encoding.items()} + +# Tokenize and align labels +def tokenize_and_align_labels(dataset, label_list): + tokenized_inputs = tokenizer(dataset["tokens"], truncation=True, is_split_into_words=True) + labels = [] + for i, label in enumerate(label_list): + word_ids = tokenized_inputs.word_ids(batch_index=i) + previous_word_idx = None + label_ids = [] + for word_idx in word_ids: + if word_idx is None: + label_ids.append(-100) + elif word_idx != previous_word_idx: + label_ids.append(label[word_idx]) + else: + label_ids.append(label[word_idx]) + previous_word_idx = word_idx + labels.append(label_ids) + tokenized_inputs["labels"] = labels + return tokenized_inputs + +train_labels = [[label_encoding[tag] for tag in sample] for sample in train_dataset["ner_tags"]] +val_labels = [[label_encoding[tag] for tag in sample] for sample in val_dataset["ner_tags"]] +test_labels = [[label_encoding[tag] for tag in sample] for sample in test_dataset["ner_tags"]] + +tokenized_train = tokenize_and_align_labels(train_dataset, train_labels) +tokenized_val = tokenize_and_align_labels(val_dataset, val_labels) +tokenized_test = tokenize_and_align_labels(test_dataset, test_labels) + +# Convert to list of dicts +def turn_dict_to_list_of_dict(d): + return [{"input_ids": inputs, "labels": labels} for inputs, labels in zip(d["input_ids"], d["labels"])] + +tokenized_train = turn_dict_to_list_of_dict(tokenized_train) +tokenized_val = turn_dict_to_list_of_dict(tokenized_val) +tokenized_test = turn_dict_to_list_of_dict(tokenized_test) + +# Inference function +def predict(text, model, tokenizer): + tokens = tokenizer(text, return_tensors="pt") + with torch.no_grad(): + output = model(**tokens) + predictions = torch.argmax(output.logits, dim=-1).numpy()[0] + labels = [list(label_encoding.keys())[label] for label in predictions] + return labels + +# Compute metrics +def compute_metrics(predictions, labels): + metric = load_metric("seqeval") + predictions = np.argmax(predictions, axis=2) + true_predictions = [ + [label_decoding[p] for (p, l) in zip(prediction, label) if l != -100] + for prediction, label in zip(predictions, labels) + ] + true_labels = [ + [label_decoding[l] for (p, l) in zip(prediction, label) if l != -100] + for prediction, label in zip(predictions, labels) + ] + results = metric.compute(predictions=true_predictions, references=true_labels) + return results + +# Streamlit app +st.title("Custom Model Integration with Streamlit") +st.write("This app uses a custom fine-tuned RoBERTa model for NER.") + +input_text = st.text_area("Enter your text:", "") + +if st.button("Predict"): + if input_text: + labels = predict(input_text, model, tokenizer) + st.write("Predicted Labels:") + st.write(labels) + +# Visualization example (optional) +def visualize_predictions(predictions): + fig, ax = plt.subplots() + bins = range(len(label_encoding) + 1) + ax.hist(predictions, bins=bins, align='left', rwidth=0.8) + ax.set_xticks(range(len(label_encoding))) + ax.set_xticklabels([label_decoding[i] for i in range(len(label_encoding))]) + ax.set_xlabel('Labels') + ax.set_ylabel('Frequency') + ax.set_title('Prediction Distribution') + st.pyplot(fig) + +if st.checkbox("Show Predictions Distribution"): + if input_text: + labels = predict(input_text, model, tokenizer) + label_indices = [label_encoding[label] for label in labels] + visualize_predictions(label_indices) + +# Adding file upload functionality +uploaded_file = st.file_uploader("Choose a file", type=["txt"]) + +if uploaded_file is not None: + content = uploaded_file.read().decode("utf-8") + lines = content.split('\n') + if st.button("Predict File Content"): + results = [predict(line, model, tokenizer) for line in lines] + st.write("Batch Predictions:") + for line, result in zip(lines, results): + st.write(f"Text: {line}") + st.write(f"Prediction: {result}") + +# Evaluate model +if st.checkbox("Evaluate Model"): + predictions, labels, _ = model.predict(tokenized_test) + metrics = compute_metrics(predictions, labels) + st.write("Evaluation Metrics:") + st.write(metrics) +