Skip to content
Snippets Groups Projects
Commit a19efba9 authored by Menezes, Luke J (PG/T - Comp Sci & Elec Eng)'s avatar Menezes, Luke J (PG/T - Comp Sci & Elec Eng)
Browse files

Upload New File

parent 5f26cfb2
No related branches found
No related tags found
No related merge requests found
Pipeline #72745 canceled
%% Cell type:code id:3f7f7977-2d7f-45d9-a5ce-860e9f12b167 tags:
``` python
from flask import Flask, jsonify, request, render_template
import json
import os
import logging
import pandas as pd
import datasets, evaluate
from transformers import pipeline
import torch
from datetime import datetime
from functools import partial
import numpy as np
import seaborn as sns
import string
import nltk
import re
import time
from transformers import AutoTokenizer, AutoModelForTokenClassification, DataCollatorForTokenClassification, TrainingArguments, Trainer, EarlyStoppingCallback
```
%% Cell type:code id:59de7380-6213-4d2c-9c21-d2e51a242c98 tags:
``` python
from datasets import load_dataset
dataset = load_dataset("surrey-nlp/PLOD-filtered")
dataset=dataset['train']
```
%% Cell type:code id:e381b9eb-06ea-4333-adea-3aae055d6a56 tags:
``` python
max(len(example['tokens']) for example in dataset)
```
%% Cell type:code id:c9fccb03-9080-4d79-9323-2a7e29e31c47 tags:
``` python
def filter_long(data):
return len(data['tokens']) <= 500
dataset = dataset.filter(filter_long)
```
%% Cell type:code id:39c416a9-bdde-44c5-a88e-962ec39e6c51 tags:
``` python
from datasets import load_dataset
import random
def print_random_tokens(dataset):
# Get the total number of rows in the dataset
num_rows = len(dataset)
# Generate 1000 random unique indices from the dataset
random_indices = random.sample(range(num_rows), 1000)
# Retrieve the 'tokens' from these random indices
random_tokens = dataset.select(random_indices)['tokens']
# Print each list of tokens
for tokens in random_tokens:
print(tokens)
# Load the dataset
# dataset = load_dataset("surrey-nlp/PLOD-unfiltered")
# train_dataset = dataset["train"]
# # Example usage of the function
# print_random_tokens(train_dataset)
```
%% Cell type:code id:15fb65b5-33dd-43d4-92a7-de043b4334a9 tags:
``` python
app = Flask(__name__)
```
%% Cell type:code id:64969fe0-4820-457a-ac94-b737ee928727 tags:
``` python
@app.route('/use-pretrained', methods=['GET'])
def use_pretrained():
"""Endpoint to load and use a pre-trained model."""
try:
# Load the pre-trained model and tokenizer
global loaded_tokenizer, loaded_model
loaded_tokenizer = AutoTokenizer.from_pretrained("SciBERT-finetuned-NER")
loaded_model = AutoModelForTokenClassification.from_pretrained("SciBERT-finetuned-NER")
return jsonify(success="Pre-trained model loaded successfully")
except Exception as e:
return jsonify(error=str(e)), 500
```
%% Cell type:code id:3358edd7-ced6-4a2e-b838-ac354c9c0809 tags:
``` python
@app.route('/predict', methods=['POST'])
## Train must be run before this
## run from command line with: curl -s -H "Content-Type: application/json" -X POST -d '{"input": }' localhost:8080/predict
## examples:
## curl -s -H "Content-Type: application/json" -X POST -d '{"input": "For this purpose the Gothenburg Young Persons Empowerment Scale (GYPES) was developed."}' localhost:8080/predict
## curl -s -H "Content-Type: application/json" -X POST -d '{"input": "Recent work by us and others suggest that the host’s heat shock protein 90 (Hsp90) chaperone can modulate the evolutionary paths traversed by viruses [18, 19]."}' localhost:8080/predict
def predict():
inputs = request.get_json().get('input')
converted_inputs = split_string(inputs)
predictions = predict_tags(converted_inputs, loaded_tokenizer, loaded_model, label_encoding)
ner_tags = [i[1] for i in predictions]
save_results(converted_inputs, ner_tags)
return jsonify(predictions = str(predictions))
```
%% Cell type:code id:92347951-af52-4c1f-b81f-f89cd3272c3a tags:
``` python
@app.route('/test-model', methods=['GET'])
def test_model():
start_time = time.time()
"""Endpoint to test the pre-trained model on 1000 random dataset samples."""
dataset = load_dataset("surrey-nlp/PLOD-unfiltered", split='train')
def filter_long(data):
return len(data['tokens']) <= 400
dataset = dataset.filter(filter_long)
sample_indices = random.sample(range(len(dataset)), 20000)
sampled_data = dataset.select(sample_indices)
results = []
print("in test_model")
for item in sampled_data:
# Join tokens to form a single string as the model expects a sequence
input_text = " ".join(item['tokens'])
# Tokenize the text
inputs = loaded_tokenizer(input_text, return_tensors="pt")
# Get model predictions
with torch.no_grad():
outputs = loaded_model(**inputs)
predictions = torch.argmax(outputs.logits, dim=-1)
tokens = loaded_tokenizer.convert_ids_to_tokens(inputs['input_ids'].squeeze())
predicted_tags = [dataset.features['ner_tags'].feature.int2str(p) for p in predictions.squeeze().tolist()]
# Combine tokens and their predicted tags
token_predictions = list(zip(tokens, predicted_tags))
results.append({'text': input_text, 'predictions': token_predictions})
total_time = time.time() - start_time
print("Total time taken: " + str(total_time))
return jsonify(results)
```
%% Cell type:code id:e00a3561-f701-4482-866e-68d26cc8d3d8 tags:
``` python
if __name__ == '__main__':
# Entry point for running on the local machine
# host is localhost; port is 8080; this file is index (.py)
app.run(host='127.0.0.1', port=8080, debug=False)
```
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment