diff --git a/testing.ipynb b/testing.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..79d241f225d53014ca9f858378e91dbf728944d9 --- /dev/null +++ b/testing.ipynb @@ -0,0 +1,345 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "3f7f7977-2d7f-45d9-a5ce-860e9f12b167", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/meenusathyanarayanan/anaconda3/lib/python3.11/site-packages/transformers/utils/generic.py:260: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n", + " torch.utils._pytree._register_pytree_node(\n", + "/Users/meenusathyanarayanan/anaconda3/lib/python3.11/site-packages/transformers/utils/generic.py:260: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n", + " torch.utils._pytree._register_pytree_node(\n" + ] + } + ], + "source": [ + "from flask import Flask, jsonify, request, render_template\n", + "import json\n", + "import os\n", + "import logging\n", + "import pandas as pd\n", + "import datasets, evaluate\n", + "from transformers import pipeline\n", + "import torch\n", + "from datetime import datetime\n", + "from functools import partial\n", + "import numpy as np\n", + "import seaborn as sns\n", + "import string\n", + "import nltk\n", + "import re\n", + "import time\n", + "\n", + "from transformers import AutoTokenizer, AutoModelForTokenClassification, DataCollatorForTokenClassification, TrainingArguments, Trainer, EarlyStoppingCallback" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "59de7380-6213-4d2c-9c21-d2e51a242c98", + "metadata": {}, + "outputs": [], + "source": [ + "from datasets import load_dataset\n", + " \n", + "dataset = load_dataset(\"surrey-nlp/PLOD-filtered\")\n", + "dataset=dataset['train']" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "e381b9eb-06ea-4333-adea-3aae055d6a56", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1247" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "max(len(example['tokens']) for example in dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "c9fccb03-9080-4d79-9323-2a7e29e31c47", + "metadata": {}, + "outputs": [], + "source": [ + "def filter_long(data):\n", + " return len(data['tokens']) <= 500\n", + "dataset = dataset.filter(filter_long)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "39c416a9-bdde-44c5-a88e-962ec39e6c51", + "metadata": {}, + "outputs": [], + "source": [ + "from datasets import load_dataset\n", + "import random\n", + "\n", + "def print_random_tokens(dataset):\n", + " # Get the total number of rows in the dataset\n", + " num_rows = len(dataset)\n", + " \n", + " # Generate 1000 random unique indices from the dataset\n", + " random_indices = random.sample(range(num_rows), 1000)\n", + " \n", + " # Retrieve the 'tokens' from these random indices\n", + " random_tokens = dataset.select(random_indices)['tokens']\n", + " \n", + " # Print each list of tokens\n", + " for tokens in random_tokens:\n", + " print(tokens)\n", + "\n", + "# Load the dataset\n", + "# dataset = load_dataset(\"surrey-nlp/PLOD-unfiltered\")\n", + "# train_dataset = dataset[\"train\"]\n", + "\n", + "# # Example usage of the function\n", + "# print_random_tokens(train_dataset)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "15fb65b5-33dd-43d4-92a7-de043b4334a9", + "metadata": {}, + "outputs": [], + "source": [ + "app = Flask(__name__)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "64969fe0-4820-457a-ac94-b737ee928727", + "metadata": {}, + "outputs": [], + "source": [ + "@app.route('/use-pretrained', methods=['GET'])\n", + "def use_pretrained():\n", + " \"\"\"Endpoint to load and use a pre-trained model.\"\"\"\n", + " try:\n", + " # Load the pre-trained model and tokenizer\n", + " global loaded_tokenizer, loaded_model\n", + " loaded_tokenizer = AutoTokenizer.from_pretrained(\"SciBERT-finetuned-NER\")\n", + " loaded_model = AutoModelForTokenClassification.from_pretrained(\"SciBERT-finetuned-NER\")\n", + " return jsonify(success=\"Pre-trained model loaded successfully\")\n", + " except Exception as e:\n", + " return jsonify(error=str(e)), 500" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "3358edd7-ced6-4a2e-b838-ac354c9c0809", + "metadata": {}, + "outputs": [], + "source": [ + "@app.route('/predict', methods=['POST'])\n", + "## Train must be run before this\n", + "## run from command line with: curl -s -H \"Content-Type: application/json\" -X POST -d '{\"input\": }' localhost:8080/predict\n", + "## examples:\n", + "## curl -s -H \"Content-Type: application/json\" -X POST -d '{\"input\": \"For this purpose the Gothenburg Young Persons Empowerment Scale (GYPES) was developed.\"}' localhost:8080/predict\n", + "## curl -s -H \"Content-Type: application/json\" -X POST -d '{\"input\": \"Recent work by us and others suggest that the host’s heat shock protein 90 (Hsp90) chaperone can modulate the evolutionary paths traversed by viruses [18, 19].\"}' localhost:8080/predict\n", + "def predict():\n", + " inputs = request.get_json().get('input')\n", + " converted_inputs = split_string(inputs)\n", + " predictions = predict_tags(converted_inputs, loaded_tokenizer, loaded_model, label_encoding)\n", + "\n", + " ner_tags = [i[1] for i in predictions]\n", + " save_results(converted_inputs, ner_tags)\n", + " return jsonify(predictions = str(predictions))" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "92347951-af52-4c1f-b81f-f89cd3272c3a", + "metadata": {}, + "outputs": [], + "source": [ + "@app.route('/test-model', methods=['GET'])\n", + "def test_model():\n", + " start_time = time.time()\n", + " \"\"\"Endpoint to test the pre-trained model on 1000 random dataset samples.\"\"\"\n", + " dataset = load_dataset(\"surrey-nlp/PLOD-unfiltered\", split='train')\n", + "\n", + " def filter_long(data):\n", + " return len(data['tokens']) <= 400\n", + " dataset = dataset.filter(filter_long)\n", + " \n", + " sample_indices = random.sample(range(len(dataset)), 20000)\n", + " sampled_data = dataset.select(sample_indices)\n", + "\n", + " results = []\n", + " print(\"in test_model\")\n", + " for item in sampled_data:\n", + " # Join tokens to form a single string as the model expects a sequence\n", + " input_text = \" \".join(item['tokens'])\n", + " # Tokenize the text\n", + " inputs = loaded_tokenizer(input_text, return_tensors=\"pt\")\n", + " # Get model predictions\n", + " with torch.no_grad():\n", + " outputs = loaded_model(**inputs)\n", + " predictions = torch.argmax(outputs.logits, dim=-1)\n", + " tokens = loaded_tokenizer.convert_ids_to_tokens(inputs['input_ids'].squeeze())\n", + " predicted_tags = [dataset.features['ner_tags'].feature.int2str(p) for p in predictions.squeeze().tolist()]\n", + "\n", + " # Combine tokens and their predicted tags\n", + " token_predictions = list(zip(tokens, predicted_tags))\n", + " results.append({'text': input_text, 'predictions': token_predictions})\n", + " total_time = time.time() - start_time\n", + " print(\"Total time taken: \" + str(total_time))\n", + "\n", + " return jsonify(results)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "e00a3561-f701-4482-866e-68d26cc8d3d8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " * Serving Flask app '__main__'\n", + " * Debug mode: off\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING: This is a development server. Do not use it in a production deployment. Use a production WSGI server instead.\n", + " * Running on http://127.0.0.1:8080\n", + "Press CTRL+C to quit\n", + "/Users/meenusathyanarayanan/anaconda3/lib/python3.11/site-packages/transformers/utils/generic.py:260: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n", + " torch.utils._pytree._register_pytree_node(\n", + "127.0.0.1 - - [24/May/2024 11:21:35] \"GET /use-pretrained HTTP/1.1\" 200 -\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "in test_model\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[2024-05-24 11:28:18,175] ERROR in app: Exception on /test-model [GET]\n", + "Traceback (most recent call last):\n", + " File \"/Users/meenusathyanarayanan/.local/lib/python3.11/site-packages/flask/app.py\", line 1455, in wsgi_app\n", + " response = self.full_dispatch_request()\n", + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", + " File \"/Users/meenusathyanarayanan/.local/lib/python3.11/site-packages/flask/app.py\", line 869, in full_dispatch_request\n", + " rv = self.handle_user_exception(e)\n", + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", + " File \"/Users/meenusathyanarayanan/.local/lib/python3.11/site-packages/flask/app.py\", line 867, in full_dispatch_request\n", + " rv = self.dispatch_request()\n", + " ^^^^^^^^^^^^^^^^^^^^^^^\n", + " File \"/Users/meenusathyanarayanan/.local/lib/python3.11/site-packages/flask/app.py\", line 852, in dispatch_request\n", + " return self.ensure_sync(self.view_functions[rule.endpoint])(**view_args)\n", + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", + " File \"/var/folders/3c/_rc81dfd35755sd95j1m_6zc0000gn/T/ipykernel_50926/1002882848.py\", line 23, in test_model\n", + " outputs = loaded_model(**inputs)\n", + " ^^^^^^^^^^^^^^^^^^^^^^\n", + " File \"/Users/meenusathyanarayanan/anaconda3/lib/python3.11/site-packages/torch/nn/modules/module.py\", line 1511, in _wrapped_call_impl\n", + " return self._call_impl(*args, **kwargs)\n", + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", + " File \"/Users/meenusathyanarayanan/anaconda3/lib/python3.11/site-packages/torch/nn/modules/module.py\", line 1520, in _call_impl\n", + " return forward_call(*args, **kwargs)\n", + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", + " File \"/Users/meenusathyanarayanan/anaconda3/lib/python3.11/site-packages/transformers/models/bert/modeling_bert.py\", line 1756, in forward\n", + " outputs = self.bert(\n", + " ^^^^^^^^^^\n", + " File \"/Users/meenusathyanarayanan/anaconda3/lib/python3.11/site-packages/torch/nn/modules/module.py\", line 1511, in _wrapped_call_impl\n", + " return self._call_impl(*args, **kwargs)\n", + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", + " File \"/Users/meenusathyanarayanan/anaconda3/lib/python3.11/site-packages/torch/nn/modules/module.py\", line 1520, in _call_impl\n", + " return forward_call(*args, **kwargs)\n", + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", + " File \"/Users/meenusathyanarayanan/anaconda3/lib/python3.11/site-packages/transformers/models/bert/modeling_bert.py\", line 1015, in forward\n", + " embedding_output = self.embeddings(\n", + " ^^^^^^^^^^^^^^^^\n", + " File \"/Users/meenusathyanarayanan/anaconda3/lib/python3.11/site-packages/torch/nn/modules/module.py\", line 1511, in _wrapped_call_impl\n", + " return self._call_impl(*args, **kwargs)\n", + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", + " File \"/Users/meenusathyanarayanan/anaconda3/lib/python3.11/site-packages/torch/nn/modules/module.py\", line 1520, in _call_impl\n", + " return forward_call(*args, **kwargs)\n", + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", + " File \"/Users/meenusathyanarayanan/anaconda3/lib/python3.11/site-packages/transformers/models/bert/modeling_bert.py\", line 238, in forward\n", + " embeddings += position_embeddings\n", + "RuntimeError: The size of tensor a (534) must match the size of tensor b (512) at non-singleton dimension 1\n", + "127.0.0.1 - - [24/May/2024 11:28:18] \"GET /test-model HTTP/1.1\" 500 -\n" + ] + } + ], + "source": [ + "if __name__ == '__main__':\n", + "\t# Entry point for running on the local machine\n", + "\t# host is localhost; port is 8080; this file is index (.py)\n", + "\tapp.run(host='127.0.0.1', port=8080, debug=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c24f86ed-2e75-4cb2-a66a-0db6a18aeb05", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "05340c7e-9ba7-464e-9853-e6011581cee7", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}