diff --git a/DE.ipynb b/DE.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..25d2989b67f38f47b3294a7370764e9be8acd634 --- /dev/null +++ b/DE.ipynb @@ -0,0 +1,328 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Base DE code: https://pablormier.github.io/2017/09/05/a-tutorial-on-differential-evolution-with-python/\n", + "\n", + "Adaptive mutation strategies: https://link.springer.com/article/10.1007/s00500-014-1349-y#Sec3\n", + "\n", + "ResNet9: https://github.com/Moddy2024/ResNet-9/blob/main/README.md" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Setup\n", + "\n", + "import torch\n", + "from torch import nn\n", + "import torch.nn.functional as F\n", + "from torch.utils.data import DataLoader, random_split\n", + "import torchvision\n", + "from torchvision import datasets\n", + "from torchvision import transforms\n", + "from torchvision.models.feature_extraction import create_feature_extractor\n", + "from sklearn.metrics import accuracy_score\n", + "import os\n", + "import copy\n", + "import json\n", + "import time\n", + "import matplotlib.pyplot as plt\n", + "from sklearn.metrics import confusion_matrix\n", + "import pandas as pd\n", + "import seaborn as sn\n", + "import numpy as np\n", + "import random\n", + "from typing import List , Dict , Tuple\n", + "\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "print(f\"Using device: {device}\")\n", + "\n", + "# Import models\n", + "from ResNet9 import ResNet9, ResNet_last_layer\n", + "\n", + "# Load pretrained model\n", + "ResNet = ResNet9(3,10).to(device)\n", + "ResNet.load_state_dict(torch.load('trained_weights.pkl', weights_only=True, map_location=torch.device(device)))\n", + "\n", + "# Get dimensions from final layer\n", + "fc_layer = ResNet.dim_reduce[11]\n", + "num_weights = fc_layer.weight.numel()\n", + "num_biases = fc_layer.bias.numel()\n", + "\n", + "# Data transforms\n", + "mean, std = [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261]\n", + "transform = transforms.Compose([\n", + " transforms.RandomHorizontalFlip(p=0.5),\n", + " transforms.RandomRotation(20),\n", + " transforms.ColorJitter(brightness = 0.1,contrast = 0.1,saturation = 0.1),\n", + " transforms.RandomAdjustSharpness(sharpness_factor = 2,p = 0.2),\n", + " transforms.ToTensor() ,\n", + " transforms.Normalize(mean, std),\n", + "])\n", + "\n", + "# Load and split data\n", + "train_data = datasets.CIFAR10(root=\"data\", train=True, download=True, transform=transform)\n", + "test_data = datasets.CIFAR10(root=\"data\", train=False, download=True, transform=transform)\n", + "\n", + "\n", + "val_size = 5000\n", + "train_size = len(train_data) - val_size\n", + "train_dataset, val_dataset = random_split(train_data, [train_size, val_size])\n", + "\n", + "# Setup last layer as model\n", + "last_layer = ResNet_last_layer(10).to(device)\n", + "last_layer.eval()\n", + "FE = create_feature_extractor(ResNet, return_nodes=[\"dim_reduce.10\"])\n", + "\n", + "# Extract features once\n", + "def extract_features(dataset, batch_size=1000):\n", + " dataloader = DataLoader(dataset, batch_size=batch_size)\n", + " features, labels = [], []\n", + " with torch.inference_mode():\n", + " for images, batch_labels in dataloader:\n", + " images = images.to(device)\n", + " batch_features = FE(images)[\"dim_reduce.10\"].cpu()\n", + " features.append(batch_features)\n", + " labels.extend(batch_labels)\n", + " return torch.cat(features), torch.tensor(labels)\n", + "\n", + "\n", + "print(\"Extracting features...\")\n", + "train_features, train_labels = extract_features(train_dataset)\n", + "val_features, val_labels = extract_features(val_dataset)\n", + "test_features, test_labels = extract_features(test_data)\n", + "print(\"Feature extraction complete.\")\n", + "\n", + "def evaluate_model(ind, features, labels, n_eval=50):\n", + " \"\"\"Evaluate model on any dataset with n_eval samples\"\"\"\n", + " weights = torch.tensor(np.array(ind[:num_weights]).reshape(10, -1), \n", + " dtype=torch.float32).to(device)\n", + " biases = torch.tensor(np.array(ind[num_weights:]), \n", + " dtype=torch.float32).to(device)\n", + " \n", + " \n", + " with torch.no_grad():\n", + " last_layer.classifier.weight.copy_(weights)\n", + " last_layer.classifier.bias.copy_(biases)\n", + " \n", + " idx = random.sample(range(len(labels)), n_eval)\n", + " batch_features = features[idx].to(device)\n", + " batch_labels = labels[idx].to(device)\n", + " \n", + " pred = last_layer(batch_features).argmax(dim=1)\n", + " accuracy = float(accuracy_score(batch_labels.cpu(), pred.cpu()))\n", + " \n", + " return -accuracy\n", + "\n", + "def evaluate_individual(ind, n_eval=50):\n", + " \"\"\"Training evaluation\"\"\"\n", + " return evaluate_model(ind, train_features, train_labels, n_eval)\n", + "\n", + "def evaluate_solution(individual, dataset='val', n_eval=50):\n", + " \"\"\"Evaluate on validation or test set\"\"\"\n", + " features = val_features if dataset == 'val' else test_features\n", + " labels = val_labels if dataset == 'val' else test_labels\n", + " return evaluate_model(individual, features, labels, n_eval)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Stats report\n", + "\n", + "def save_plot_stats(iteration, time, best_accuracy, population, stats_file, pop_file, iteration_snapshots=10):\n", + " if iteration == 0:\n", + " if not os.path.exists(stats_file):\n", + " with open(stats_file, \"w\") as f:\n", + " pass\n", + " \n", + " stats = {\n", + " \"iteration\": iteration,\n", + " \"time\": time,\n", + " \"best_accuracy\": float(best_accuracy),\n", + " }\n", + " \n", + " # Append stats to JSON\n", + " with open(stats_file, \"a\") as f:\n", + " f.write(json.dumps(stats) + \"\\n\")\n", + " \n", + " # Save population snapshot (optional)\n", + " if iteration==0:\n", + " pop_data = {str(iteration): population}\n", + " np.savez(pop_file, **pop_data)\n", + " elif iteration%iteration_snapshots:\n", + " pop_data = dict(np.load(pop_file))\n", + " pop_data[str(iteration)] = population\n", + " np.savez(pop_file, **pop_data) \n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### DE" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def DE(evaluate_individual, evaluate_solution, bounds, n_eval, save_stats=True, mut=0.8, crossp=0.7, popsize=20, its=100):\n", + " start = time.time()\n", + " stats_file = f\"Test_stats/stats_DE_pop{popsize}_it{its}_F{mut}_CR{crossp}.json\"\n", + " pop_file = f\"Test_stats/popSnap_DE_pop{popsize}_it{its}_F{mut}_CR{crossp}.npz\"\n", + "\n", + " best_val_accuracy = 0\n", + " best_test_accuracy = 0\n", + " \n", + " dimensions = len(bounds)\n", + " pop = np.random.rand(popsize, dimensions)\n", + " min_b, max_b = np.asarray(bounds).T\n", + " diff = np.fabs(min_b - max_b)\n", + " pop_denorm = min_b + pop * diff\n", + "\n", + " fitnesses = []\n", + " for ind in pop_denorm:\n", + " f = evaluate_individual(ind,n_eval)\n", + " fitnesses.append(f)\n", + " \n", + " fitnesses = np.asarray(fitnesses)\n", + "\n", + " best_idx = np.argmin(fitnesses)\n", + " best = pop_denorm[best_idx]\n", + "\n", + " for i in range(its):\n", + " for j in range(popsize):\n", + "\n", + " #rand/1 mutation\n", + " idxs = [idx for idx in range(popsize) if idx != j]\n", + " a, b, c = pop[np.random.choice(idxs, 3, replace = False)]\n", + " mutant = np.clip(a + mut * (b - c), 0, 1) # mutate with probability 'mut', and clip in range\n", + "\n", + " # BINOMIAL crossover\n", + " cross_points = np.random.rand(dimensions) < crossp # for each dimension, crossover probability 'crossp'\n", + " if not np.any(cross_points):\n", + " cross_points[np.random.randint(0, dimensions)] = True # add at least one crossover to encourage exploration\n", + " trial = np.where(cross_points, mutant, pop[j])\n", + " trial_denorm = min_b + trial * diff\n", + "\n", + " f = evaluate_individual(trial_denorm,n_eval)\n", + "\n", + " if f < fitnesses[j]:\n", + " fitnesses[j] = f\n", + " pop[j] = trial\n", + " if f < fitnesses[best_idx]:\n", + " best_idx = j\n", + " best = trial_denorm\n", + "\n", + " \n", + " # Check validation and test performance\n", + " for ind in pop:\n", + " val_accuracy = evaluate_solution(ind, 'val')\n", + " if val_accuracy < best_val_accuracy:\n", + " test_accuracy = evaluate_solution(ind, 'test')\n", + " best_val_accuracy = val_accuracy\n", + " best_test_accuracy = test_accuracy\n", + " best_ind = ind\n", + "\n", + "\n", + " elapsed_time = time.time()-start\n", + "\n", + " if i % 5 == 0:\n", + " print(f\"Iteration [{i}/{its}] || Best Test accuracy = [{-best_test_accuracy:.3f}] || Best Validation Accuracy = {-best_val_accuracy:.3f} || Time: {elapsed_time:.3f}\")\n", + "\n", + " if save_stats:\n", + " pop_denorm = min_b + pop * diff\n", + " save_plot_stats(i, elapsed_time, -best_test_accuracy, pop_denorm, stats_file, pop_file, iteration_snapshots=10)\n", + "\n", + " return pop, best_ind, -best_val_accuracy, -best_test_accuracy" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Note: play around with these based on results\n", + "\n", + "n_eval = 50 # number of images to load per evaluation\n", + "bounds = [(-0.1, 0.1)] * (num_weights+num_biases)\n", + "F = 0.5\n", + "CR = 0.6\n", + "popsize = 500\n", + "iterations = 500" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "result = list(DE(evaluate_individual, evaluate_solution, bounds=bounds, n_eval=n_eval, save_stats=True, mut=F, crossp=CR, popsize=popsize, its=iterations))\n", + "result" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.20" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}