diff --git a/notebooks/regression/Perceptron.ipynb b/notebooks/regression/Perceptron.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..9871c62c4c00089980e2625b94c49ee18e257533 --- /dev/null +++ b/notebooks/regression/Perceptron.ipynb @@ -0,0 +1,1673 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 42, + "id": "4e148684-d4ca-468c-b41c-197d6ca3ceac", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>date</th>\n", + " <th>targeted_productivity</th>\n", + " <th>smv</th>\n", + " <th>over_time</th>\n", + " <th>incentive</th>\n", + " <th>no_of_style_change</th>\n", + " <th>no_of_workers</th>\n", + " <th>actual_productivity</th>\n", + " <th>overtime_bin</th>\n", + " <th>wip_log</th>\n", + " <th>idle_men_ratio</th>\n", + " <th>idle_ratio</th>\n", + " <th>day_num</th>\n", + " <th>day_sin</th>\n", + " <th>day_cos</th>\n", + " <th>department_encoded</th>\n", + " <th>team_encoded</th>\n", + " <th>quarter_encoded</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>0</th>\n", + " <td>2015-01-01</td>\n", + " <td>0.80</td>\n", + " <td>26.16</td>\n", + " <td>7080</td>\n", + " <td>98</td>\n", + " <td>0</td>\n", + " <td>59.0</td>\n", + " <td>0.940725</td>\n", + " <td>5001-10000</td>\n", + " <td>7.011214</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>3</td>\n", + " <td>0.433884</td>\n", + " <td>-0.900969</td>\n", + " <td>0.722013</td>\n", + " <td>0.674148</td>\n", + " <td>0.759686</td>\n", + " </tr>\n", + " <tr>\n", + " <th>1</th>\n", + " <td>2015-01-01</td>\n", + " <td>0.75</td>\n", + " <td>3.94</td>\n", + " <td>960</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>8.0</td>\n", + " <td>0.886500</td>\n", + " <td>501-1000</td>\n", + " <td>6.946976</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>3</td>\n", + " <td>0.433884</td>\n", + " <td>-0.900969</td>\n", + " <td>0.752951</td>\n", + " <td>0.821054</td>\n", + " <td>0.759686</td>\n", + " </tr>\n", + " <tr>\n", + " <th>2</th>\n", + " <td>2015-01-01</td>\n", + " <td>0.80</td>\n", + " <td>11.41</td>\n", + " <td>3660</td>\n", + " <td>50</td>\n", + " <td>0</td>\n", + " <td>30.5</td>\n", + " <td>0.800570</td>\n", + " <td>2001-5000</td>\n", + " <td>6.876265</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>3</td>\n", + " <td>0.433884</td>\n", + " <td>-0.900969</td>\n", + " <td>0.722013</td>\n", + " <td>0.681985</td>\n", + " <td>0.759686</td>\n", + " </tr>\n", + " <tr>\n", + " <th>3</th>\n", + " <td>2015-01-01</td>\n", + " <td>0.80</td>\n", + " <td>11.41</td>\n", + " <td>3660</td>\n", + " <td>50</td>\n", + " <td>0</td>\n", + " <td>30.5</td>\n", + " <td>0.800570</td>\n", + " <td>2001-5000</td>\n", + " <td>6.876265</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>3</td>\n", + " <td>0.433884</td>\n", + " <td>-0.900969</td>\n", + " <td>0.722013</td>\n", + " <td>0.779055</td>\n", + " <td>0.759686</td>\n", + " </tr>\n", + " <tr>\n", + " <th>4</th>\n", + " <td>2015-01-01</td>\n", + " <td>0.80</td>\n", + " <td>25.90</td>\n", + " <td>1920</td>\n", + " <td>50</td>\n", + " <td>0</td>\n", + " <td>56.0</td>\n", + " <td>0.800382</td>\n", + " <td>1001-2000</td>\n", + " <td>7.065613</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>3</td>\n", + " <td>0.433884</td>\n", + " <td>-0.900969</td>\n", + " <td>0.722013</td>\n", + " <td>0.685385</td>\n", + " <td>0.759686</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "</div>" + ], + "text/plain": [ + " date targeted_productivity smv over_time incentive \\\n", + "0 2015-01-01 0.80 26.16 7080 98 \n", + "1 2015-01-01 0.75 3.94 960 0 \n", + "2 2015-01-01 0.80 11.41 3660 50 \n", + "3 2015-01-01 0.80 11.41 3660 50 \n", + "4 2015-01-01 0.80 25.90 1920 50 \n", + "\n", + " no_of_style_change no_of_workers actual_productivity overtime_bin \\\n", + "0 0 59.0 0.940725 5001-10000 \n", + "1 0 8.0 0.886500 501-1000 \n", + "2 0 30.5 0.800570 2001-5000 \n", + "3 0 30.5 0.800570 2001-5000 \n", + "4 0 56.0 0.800382 1001-2000 \n", + "\n", + " wip_log idle_men_ratio idle_ratio day_num day_sin day_cos \\\n", + "0 7.011214 0.0 0.0 3 0.433884 -0.900969 \n", + "1 6.946976 0.0 0.0 3 0.433884 -0.900969 \n", + "2 6.876265 0.0 0.0 3 0.433884 -0.900969 \n", + "3 6.876265 0.0 0.0 3 0.433884 -0.900969 \n", + "4 7.065613 0.0 0.0 3 0.433884 -0.900969 \n", + "\n", + " department_encoded team_encoded quarter_encoded \n", + "0 0.722013 0.674148 0.759686 \n", + "1 0.752951 0.821054 0.759686 \n", + "2 0.722013 0.681985 0.759686 \n", + "3 0.722013 0.779055 0.759686 \n", + "4 0.722013 0.685385 0.759686 " + ] + }, + "execution_count": 42, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import pandas as pd\n", + "\n", + "# Replace 'your_file.csv' with the actual filename\n", + "df = pd.read_csv('svm_neuralnet_ready.csv')\n", + "\n", + "# Show the first few rows to confirm it's loaded\n", + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "id": "28346333-5262-4509-bf97-ccee64d91ede", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "<Figure size 640x480 with 1 Axes>" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "\n", + "sns.histplot(df['actual_productivity'], kde=True)\n", + "plt.title(\"Target Distribution: actual_productivity\")\n", + "plt.xlabel(\"actual_productivity\")\n", + "plt.ylabel(\"Count\")\n", + "plt.show()\n" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "id": "6b71d343-dac2-4fd2-b307-e6af69a86df4", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "date 0\n", + "targeted_productivity 0\n", + "smv 0\n", + "over_time 0\n", + "incentive 0\n", + "no_of_style_change 0\n", + "no_of_workers 0\n", + "actual_productivity 0\n", + "overtime_bin 0\n", + "wip_log 0\n", + "idle_men_ratio 0\n", + "idle_ratio 0\n", + "day_num 0\n", + "day_sin 0\n", + "day_cos 0\n", + "department_encoded 0\n", + "team_encoded 0\n", + "quarter_encoded 0\n", + "dtype: int64\n", + "date object\n", + "targeted_productivity float64\n", + "smv float64\n", + "over_time int64\n", + "incentive int64\n", + "no_of_style_change int64\n", + "no_of_workers float64\n", + "actual_productivity float64\n", + "overtime_bin object\n", + "wip_log float64\n", + "idle_men_ratio float64\n", + "idle_ratio float64\n", + "day_num int64\n", + "day_sin float64\n", + "day_cos float64\n", + "department_encoded float64\n", + "team_encoded float64\n", + "quarter_encoded float64\n", + "dtype: object\n" + ] + }, + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>date</th>\n", + " <th>targeted_productivity</th>\n", + " <th>smv</th>\n", + " <th>over_time</th>\n", + " <th>incentive</th>\n", + " <th>no_of_style_change</th>\n", + " <th>no_of_workers</th>\n", + " <th>actual_productivity</th>\n", + " <th>overtime_bin</th>\n", + " <th>wip_log</th>\n", + " <th>idle_men_ratio</th>\n", + " <th>idle_ratio</th>\n", + " <th>day_num</th>\n", + " <th>day_sin</th>\n", + " <th>day_cos</th>\n", + " <th>department_encoded</th>\n", + " <th>team_encoded</th>\n", + " <th>quarter_encoded</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>967</th>\n", + " <td>2015-02-28</td>\n", + " <td>0.8</td>\n", + " <td>15.26</td>\n", + " <td>1700</td>\n", + " <td>62</td>\n", + " <td>0</td>\n", + " <td>34.0</td>\n", + " <td>0.800261</td>\n", + " <td>1001-2000</td>\n", + " <td>7.085901</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>5</td>\n", + " <td>-0.974928</td>\n", + " <td>-0.222521</td>\n", + " <td>0.722013</td>\n", + " <td>0.779055</td>\n", + " <td>0.709067</td>\n", + " </tr>\n", + " <tr>\n", + " <th>646</th>\n", + " <td>2015-02-07</td>\n", + " <td>0.8</td>\n", + " <td>3.94</td>\n", + " <td>960</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>8.0</td>\n", + " <td>0.771583</td>\n", + " <td>501-1000</td>\n", + " <td>6.946976</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>5</td>\n", + " <td>-0.974928</td>\n", + " <td>-0.222521</td>\n", + " <td>0.752951</td>\n", + " <td>0.734462</td>\n", + " <td>0.759686</td>\n", + " </tr>\n", + " <tr>\n", + " <th>38</th>\n", + " <td>2015-01-03</td>\n", + " <td>0.8</td>\n", + " <td>2.90</td>\n", + " <td>960</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>8.0</td>\n", + " <td>0.628333</td>\n", + " <td>501-1000</td>\n", + " <td>6.946976</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>5</td>\n", + " <td>-0.974928</td>\n", + " <td>-0.222521</td>\n", + " <td>0.752951</td>\n", + " <td>0.674148</td>\n", + " <td>0.759686</td>\n", + " </tr>\n", + " <tr>\n", + " <th>563</th>\n", + " <td>2015-02-02</td>\n", + " <td>0.8</td>\n", + " <td>22.52</td>\n", + " <td>7020</td>\n", + " <td>88</td>\n", + " <td>0</td>\n", + " <td>58.5</td>\n", + " <td>0.900158</td>\n", + " <td>5001-10000</td>\n", + " <td>9.970492</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>0</td>\n", + " <td>0.000000</td>\n", + " <td>1.000000</td>\n", + " <td>0.722013</td>\n", + " <td>0.770855</td>\n", + " <td>0.759686</td>\n", + " </tr>\n", + " <tr>\n", + " <th>878</th>\n", + " <td>2015-02-22</td>\n", + " <td>0.7</td>\n", + " <td>30.10</td>\n", + " <td>8160</td>\n", + " <td>33</td>\n", + " <td>2</td>\n", + " <td>58.0</td>\n", + " <td>0.626578</td>\n", + " <td>5001-10000</td>\n", + " <td>6.948897</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>6</td>\n", + " <td>-0.781831</td>\n", + " <td>0.623490</td>\n", + " <td>0.722013</td>\n", + " <td>0.803880</td>\n", + " <td>0.709067</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "</div>" + ], + "text/plain": [ + " date targeted_productivity smv over_time incentive \\\n", + "967 2015-02-28 0.8 15.26 1700 62 \n", + "646 2015-02-07 0.8 3.94 960 0 \n", + "38 2015-01-03 0.8 2.90 960 0 \n", + "563 2015-02-02 0.8 22.52 7020 88 \n", + "878 2015-02-22 0.7 30.10 8160 33 \n", + "\n", + " no_of_style_change no_of_workers actual_productivity overtime_bin \\\n", + "967 0 34.0 0.800261 1001-2000 \n", + "646 0 8.0 0.771583 501-1000 \n", + "38 0 8.0 0.628333 501-1000 \n", + "563 0 58.5 0.900158 5001-10000 \n", + "878 2 58.0 0.626578 5001-10000 \n", + "\n", + " wip_log idle_men_ratio idle_ratio day_num day_sin day_cos \\\n", + "967 7.085901 0.0 0.0 5 -0.974928 -0.222521 \n", + "646 6.946976 0.0 0.0 5 -0.974928 -0.222521 \n", + "38 6.946976 0.0 0.0 5 -0.974928 -0.222521 \n", + "563 9.970492 0.0 0.0 0 0.000000 1.000000 \n", + "878 6.948897 0.0 0.0 6 -0.781831 0.623490 \n", + "\n", + " department_encoded team_encoded quarter_encoded \n", + "967 0.722013 0.779055 0.709067 \n", + "646 0.752951 0.734462 0.759686 \n", + "38 0.752951 0.674148 0.759686 \n", + "563 0.722013 0.770855 0.759686 \n", + "878 0.722013 0.803880 0.709067 " + ] + }, + "execution_count": 44, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Check null values\n", + "print(df.isnull().sum())\n", + "\n", + "# Check datatypes\n", + "print(df.dtypes)\n", + "\n", + "# View a sample\n", + "df.sample(5)" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "id": "a4541b66-9311-4c6b-9e7b-e240ca27510e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>targeted_productivity</th>\n", + " <th>smv</th>\n", + " <th>over_time</th>\n", + " <th>incentive</th>\n", + " <th>no_of_style_change</th>\n", + " <th>no_of_workers</th>\n", + " <th>actual_productivity</th>\n", + " <th>wip_log</th>\n", + " <th>idle_men_ratio</th>\n", + " <th>idle_ratio</th>\n", + " <th>...</th>\n", + " <th>day_cos</th>\n", + " <th>department_encoded</th>\n", + " <th>team_encoded</th>\n", + " <th>quarter_encoded</th>\n", + " <th>overtime_bin_10001-20000</th>\n", + " <th>overtime_bin_1001-2000</th>\n", + " <th>overtime_bin_20001+</th>\n", + " <th>overtime_bin_2001-5000</th>\n", + " <th>overtime_bin_5001-10000</th>\n", + " <th>overtime_bin_501-1000</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>0</th>\n", + " <td>0.80</td>\n", + " <td>26.16</td>\n", + " <td>7080</td>\n", + " <td>98</td>\n", + " <td>0</td>\n", + " <td>59.0</td>\n", + " <td>0.940725</td>\n", + " <td>7.011214</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>...</td>\n", + " <td>-0.900969</td>\n", + " <td>0.722013</td>\n", + " <td>0.674148</td>\n", + " <td>0.759686</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " <td>0</td>\n", + " </tr>\n", + " <tr>\n", + " <th>1</th>\n", + " <td>0.75</td>\n", + " <td>3.94</td>\n", + " <td>960</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>8.0</td>\n", + " <td>0.886500</td>\n", + " <td>6.946976</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>...</td>\n", + " <td>-0.900969</td>\n", + " <td>0.752951</td>\n", + " <td>0.821054</td>\n", + " <td>0.759686</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " </tr>\n", + " <tr>\n", + " <th>2</th>\n", + " <td>0.80</td>\n", + " <td>11.41</td>\n", + " <td>3660</td>\n", + " <td>50</td>\n", + " <td>0</td>\n", + " <td>30.5</td>\n", + " <td>0.800570</td>\n", + " <td>6.876265</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>...</td>\n", + " <td>-0.900969</td>\n", + " <td>0.722013</td>\n", + " <td>0.681985</td>\n", + " <td>0.759686</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " </tr>\n", + " <tr>\n", + " <th>3</th>\n", + " <td>0.80</td>\n", + " <td>11.41</td>\n", + " <td>3660</td>\n", + " <td>50</td>\n", + " <td>0</td>\n", + " <td>30.5</td>\n", + " <td>0.800570</td>\n", + " <td>6.876265</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>...</td>\n", + " <td>-0.900969</td>\n", + " <td>0.722013</td>\n", + " <td>0.779055</td>\n", + " <td>0.759686</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " </tr>\n", + " <tr>\n", + " <th>4</th>\n", + " <td>0.80</td>\n", + " <td>25.90</td>\n", + " <td>1920</td>\n", + " <td>50</td>\n", + " <td>0</td>\n", + " <td>56.0</td>\n", + " <td>0.800382</td>\n", + " <td>7.065613</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>...</td>\n", + " <td>-0.900969</td>\n", + " <td>0.722013</td>\n", + " <td>0.685385</td>\n", + " <td>0.759686</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "<p>5 rows × 22 columns</p>\n", + "</div>" + ], + "text/plain": [ + " targeted_productivity smv over_time incentive no_of_style_change \\\n", + "0 0.80 26.16 7080 98 0 \n", + "1 0.75 3.94 960 0 0 \n", + "2 0.80 11.41 3660 50 0 \n", + "3 0.80 11.41 3660 50 0 \n", + "4 0.80 25.90 1920 50 0 \n", + "\n", + " no_of_workers actual_productivity wip_log idle_men_ratio idle_ratio \\\n", + "0 59.0 0.940725 7.011214 0.0 0.0 \n", + "1 8.0 0.886500 6.946976 0.0 0.0 \n", + "2 30.5 0.800570 6.876265 0.0 0.0 \n", + "3 30.5 0.800570 6.876265 0.0 0.0 \n", + "4 56.0 0.800382 7.065613 0.0 0.0 \n", + "\n", + " ... day_cos department_encoded team_encoded quarter_encoded \\\n", + "0 ... -0.900969 0.722013 0.674148 0.759686 \n", + "1 ... -0.900969 0.752951 0.821054 0.759686 \n", + "2 ... -0.900969 0.722013 0.681985 0.759686 \n", + "3 ... -0.900969 0.722013 0.779055 0.759686 \n", + "4 ... -0.900969 0.722013 0.685385 0.759686 \n", + "\n", + " overtime_bin_10001-20000 overtime_bin_1001-2000 overtime_bin_20001+ \\\n", + "0 0 0 0 \n", + "1 0 0 0 \n", + "2 0 0 0 \n", + "3 0 0 0 \n", + "4 0 1 0 \n", + "\n", + " overtime_bin_2001-5000 overtime_bin_5001-10000 overtime_bin_501-1000 \n", + "0 0 1 0 \n", + "1 0 0 1 \n", + "2 1 0 0 \n", + "3 1 0 0 \n", + "4 0 0 0 \n", + "\n", + "[5 rows x 22 columns]" + ] + }, + "execution_count": 45, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Drop date column\n", + "df = df.drop(columns=['date'])\n", + "\n", + "# One-hot encode overtime_bin\n", + "df = pd.get_dummies(df, columns=['overtime_bin'], drop_first=True, dtype=int)\n", + "\n", + "# Confirm\n", + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "id": "1a6e2e30-9f01-4671-afa0-9a5958bb50d9", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Feature correlations with target:\n", + "targeted_productivity 0.421594\n", + "team_encoded 0.294280\n", + "wip_log 0.172146\n", + "overtime_bin_1001-2000 0.158944\n", + "quarter_encoded 0.131581\n", + "department_encoded 0.087624\n", + "incentive 0.076538\n", + "overtime_bin_10001-20000 0.025611\n", + "day_cos 0.014380\n", + "overtime_bin_2001-5000 0.001174\n", + "day_num 0.000030\n", + "day_sin -0.018568\n", + "over_time -0.054206\n", + "no_of_workers -0.057991\n", + "overtime_bin_20001+ -0.063851\n", + "overtime_bin_5001-10000 -0.081499\n", + "idle_ratio -0.082272\n", + "smv -0.122089\n", + "overtime_bin_501-1000 -0.135879\n", + "idle_men_ratio -0.184774\n", + "no_of_style_change -0.207366\n", + "Name: actual_productivity, dtype: float64\n" + ] + } + ], + "source": [ + "correlation = df.corr(numeric_only=True)\n", + "target_corr = correlation['actual_productivity'].drop('actual_productivity')\n", + "target_corr = target_corr.sort_values(ascending=False)\n", + "\n", + "print(\"Feature correlations with target:\")\n", + "print(target_corr)" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "id": "fdb8ad09-b984-4fb6-8ecf-597dd11e6617", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>Feature</th>\n", + " <th>F-score</th>\n", + " <th>p-value</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>0</th>\n", + " <td>targeted_productivity</td>\n", + " <td>258.314084</td>\n", + " <td>8.997899e-53</td>\n", + " </tr>\n", + " <tr>\n", + " <th>13</th>\n", + " <td>team_encoded</td>\n", + " <td>113.299745</td>\n", + " <td>2.428823e-25</td>\n", + " </tr>\n", + " <tr>\n", + " <th>4</th>\n", + " <td>no_of_style_change</td>\n", + " <td>53.694481</td>\n", + " <td>4.299784e-13</td>\n", + " </tr>\n", + " <tr>\n", + " <th>7</th>\n", + " <td>idle_men_ratio</td>\n", + " <td>42.241233</td>\n", + " <td>1.182042e-10</td>\n", + " </tr>\n", + " <tr>\n", + " <th>6</th>\n", + " <td>wip_log</td>\n", + " <td>36.494402</td>\n", + " <td>2.041955e-09</td>\n", + " </tr>\n", + " <tr>\n", + " <th>16</th>\n", + " <td>overtime_bin_1001-2000</td>\n", + " <td>30.972045</td>\n", + " <td>3.229131e-08</td>\n", + " </tr>\n", + " <tr>\n", + " <th>20</th>\n", + " <td>overtime_bin_501-1000</td>\n", + " <td>22.478314</td>\n", + " <td>2.380888e-06</td>\n", + " </tr>\n", + " <tr>\n", + " <th>14</th>\n", + " <td>quarter_encoded</td>\n", + " <td>21.054217</td>\n", + " <td>4.935447e-06</td>\n", + " </tr>\n", + " <tr>\n", + " <th>1</th>\n", + " <td>smv</td>\n", + " <td>18.081815</td>\n", + " <td>2.281130e-05</td>\n", + " </tr>\n", + " <tr>\n", + " <th>12</th>\n", + " <td>department_encoded</td>\n", + " <td>9.246175</td>\n", + " <td>2.411260e-03</td>\n", + " </tr>\n", + " <tr>\n", + " <th>8</th>\n", + " <td>idle_ratio</td>\n", + " <td>8.143783</td>\n", + " <td>4.395462e-03</td>\n", + " </tr>\n", + " <tr>\n", + " <th>19</th>\n", + " <td>overtime_bin_5001-10000</td>\n", + " <td>7.990303</td>\n", + " <td>4.781241e-03</td>\n", + " </tr>\n", + " <tr>\n", + " <th>3</th>\n", + " <td>incentive</td>\n", + " <td>7.041570</td>\n", + " <td>8.069572e-03</td>\n", + " </tr>\n", + " <tr>\n", + " <th>17</th>\n", + " <td>overtime_bin_20001+</td>\n", + " <td>4.891926</td>\n", + " <td>2.717121e-02</td>\n", + " </tr>\n", + " <tr>\n", + " <th>5</th>\n", + " <td>no_of_workers</td>\n", + " <td>4.032236</td>\n", + " <td>4.486346e-02</td>\n", + " </tr>\n", + " <tr>\n", + " <th>2</th>\n", + " <td>over_time</td>\n", + " <td>3.521583</td>\n", + " <td>6.081808e-02</td>\n", + " </tr>\n", + " <tr>\n", + " <th>15</th>\n", + " <td>overtime_bin_10001-20000</td>\n", + " <td>0.784346</td>\n", + " <td>3.759934e-01</td>\n", + " </tr>\n", + " <tr>\n", + " <th>10</th>\n", + " <td>day_sin</td>\n", + " <td>0.412164</td>\n", + " <td>5.209965e-01</td>\n", + " </tr>\n", + " <tr>\n", + " <th>11</th>\n", + " <td>day_cos</td>\n", + " <td>0.247143</td>\n", + " <td>6.191856e-01</td>\n", + " </tr>\n", + " <tr>\n", + " <th>18</th>\n", + " <td>overtime_bin_2001-5000</td>\n", + " <td>0.001647</td>\n", + " <td>9.676358e-01</td>\n", + " </tr>\n", + " <tr>\n", + " <th>9</th>\n", + " <td>day_num</td>\n", + " <td>0.000001</td>\n", + " <td>9.991721e-01</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "</div>" + ], + "text/plain": [ + " Feature F-score p-value\n", + "0 targeted_productivity 258.314084 8.997899e-53\n", + "13 team_encoded 113.299745 2.428823e-25\n", + "4 no_of_style_change 53.694481 4.299784e-13\n", + "7 idle_men_ratio 42.241233 1.182042e-10\n", + "6 wip_log 36.494402 2.041955e-09\n", + "16 overtime_bin_1001-2000 30.972045 3.229131e-08\n", + "20 overtime_bin_501-1000 22.478314 2.380888e-06\n", + "14 quarter_encoded 21.054217 4.935447e-06\n", + "1 smv 18.081815 2.281130e-05\n", + "12 department_encoded 9.246175 2.411260e-03\n", + "8 idle_ratio 8.143783 4.395462e-03\n", + "19 overtime_bin_5001-10000 7.990303 4.781241e-03\n", + "3 incentive 7.041570 8.069572e-03\n", + "17 overtime_bin_20001+ 4.891926 2.717121e-02\n", + "5 no_of_workers 4.032236 4.486346e-02\n", + "2 over_time 3.521583 6.081808e-02\n", + "15 overtime_bin_10001-20000 0.784346 3.759934e-01\n", + "10 day_sin 0.412164 5.209965e-01\n", + "11 day_cos 0.247143 6.191856e-01\n", + "18 overtime_bin_2001-5000 0.001647 9.676358e-01\n", + "9 day_num 0.000001 9.991721e-01" + ] + }, + "execution_count": 47, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from sklearn.feature_selection import SelectKBest, f_regression\n", + "\n", + "# Features & target\n", + "X = df.drop(columns=['actual_productivity'])\n", + "y = df['actual_productivity']\n", + "\n", + "# Apply SelectKBest\n", + "selector = SelectKBest(score_func=f_regression, k='all') # You can set k=10 later\n", + "selector.fit(X, y)\n", + "\n", + "# Results\n", + "scores = selector.scores_\n", + "pvalues = selector.pvalues_\n", + "\n", + "# Combine into DataFrame\n", + "kbest_result = pd.DataFrame({\n", + " 'Feature': X.columns,\n", + " 'F-score': scores,\n", + " 'p-value': pvalues\n", + "}).sort_values('F-score', ascending=False)\n", + "\n", + "kbest_result" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "id": "a7a3c710-4172-4876-8186-f42841a96973", + "metadata": {}, + "outputs": [], + "source": [ + "selected_features = [\n", + " 'targeted_productivity',\n", + " 'team_encoded',\n", + " 'no_of_style_change',\n", + " 'idle_men_ratio',\n", + " 'wip_log',\n", + " 'overtime_bin_1001-2000',\n", + " 'overtime_bin_501-1000',\n", + " 'quarter_encoded',\n", + " 'smv',\n", + " 'idle_ratio' # Optional but decent\n", + "]\n" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "id": "0191eeb1-6857-494b-b316-2b56abe8770c", + "metadata": {}, + "outputs": [], + "source": [ + "# Final dataset\n", + "X = df[selected_features]\n", + "y = df['actual_productivity']" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "id": "42785415-f444-43ad-bd5a-853e012fd5d9", + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.model_selection import train_test_split\n", + "\n", + "# Split the dataset (80% train, 20% test)\n", + "X_train, X_test, y_train, y_test = train_test_split(\n", + " X, y, test_size=0.2, random_state=42\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "id": "93f61889-3599-40b0-9cab-02b71508980e", + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.preprocessing import StandardScaler\n", + "\n", + "# Initialize and fit scaler\n", + "scaler = StandardScaler()\n", + "X_train_scaled = scaler.fit_transform(X_train)\n", + "X_test_scaled = scaler.transform(X_test)" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "id": "245e1345-c170-42c6-86ca-a74644991b87", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "<style>#sk-container-id-2 {\n", + " /* Definition of color scheme common for light and dark mode */\n", + " --sklearn-color-text: #000;\n", + " --sklearn-color-text-muted: #666;\n", + " --sklearn-color-line: gray;\n", + " /* Definition of color scheme for unfitted estimators */\n", + " --sklearn-color-unfitted-level-0: #fff5e6;\n", + " --sklearn-color-unfitted-level-1: #f6e4d2;\n", + " --sklearn-color-unfitted-level-2: #ffe0b3;\n", + " --sklearn-color-unfitted-level-3: chocolate;\n", + " /* Definition of color scheme for fitted estimators */\n", + " --sklearn-color-fitted-level-0: #f0f8ff;\n", + " --sklearn-color-fitted-level-1: #d4ebff;\n", + " --sklearn-color-fitted-level-2: #b3dbfd;\n", + " --sklearn-color-fitted-level-3: cornflowerblue;\n", + "\n", + " /* Specific color for light theme */\n", + " --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n", + " --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));\n", + " --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n", + " --sklearn-color-icon: #696969;\n", + "\n", + " @media (prefers-color-scheme: dark) {\n", + " /* Redefinition of color scheme for dark theme */\n", + " --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n", + " --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));\n", + " --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n", + " --sklearn-color-icon: #878787;\n", + " }\n", + "}\n", + "\n", + "#sk-container-id-2 {\n", + " color: var(--sklearn-color-text);\n", + "}\n", + "\n", + "#sk-container-id-2 pre {\n", + " padding: 0;\n", + "}\n", + "\n", + "#sk-container-id-2 input.sk-hidden--visually {\n", + " border: 0;\n", + " clip: rect(1px 1px 1px 1px);\n", + " clip: rect(1px, 1px, 1px, 1px);\n", + " height: 1px;\n", + " margin: -1px;\n", + " overflow: hidden;\n", + " padding: 0;\n", + " position: absolute;\n", + " width: 1px;\n", + "}\n", + "\n", + "#sk-container-id-2 div.sk-dashed-wrapped {\n", + " border: 1px dashed var(--sklearn-color-line);\n", + " margin: 0 0.4em 0.5em 0.4em;\n", + " box-sizing: border-box;\n", + " padding-bottom: 0.4em;\n", + " background-color: var(--sklearn-color-background);\n", + "}\n", + "\n", + "#sk-container-id-2 div.sk-container {\n", + " /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n", + " but bootstrap.min.css set `[hidden] { display: none !important; }`\n", + " so we also need the `!important` here to be able to override the\n", + " default hidden behavior on the sphinx rendered scikit-learn.org.\n", + " See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n", + " display: inline-block !important;\n", + " position: relative;\n", + "}\n", + "\n", + "#sk-container-id-2 div.sk-text-repr-fallback {\n", + " display: none;\n", + "}\n", + "\n", + "div.sk-parallel-item,\n", + "div.sk-serial,\n", + "div.sk-item {\n", + " /* draw centered vertical line to link estimators */\n", + " background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n", + " background-size: 2px 100%;\n", + " background-repeat: no-repeat;\n", + " background-position: center center;\n", + "}\n", + "\n", + "/* Parallel-specific style estimator block */\n", + "\n", + "#sk-container-id-2 div.sk-parallel-item::after {\n", + " content: \"\";\n", + " width: 100%;\n", + " border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n", + " flex-grow: 1;\n", + "}\n", + "\n", + "#sk-container-id-2 div.sk-parallel {\n", + " display: flex;\n", + " align-items: stretch;\n", + " justify-content: center;\n", + " background-color: var(--sklearn-color-background);\n", + " position: relative;\n", + "}\n", + "\n", + "#sk-container-id-2 div.sk-parallel-item {\n", + " display: flex;\n", + " flex-direction: column;\n", + "}\n", + "\n", + "#sk-container-id-2 div.sk-parallel-item:first-child::after {\n", + " align-self: flex-end;\n", + " width: 50%;\n", + "}\n", + "\n", + "#sk-container-id-2 div.sk-parallel-item:last-child::after {\n", + " align-self: flex-start;\n", + " width: 50%;\n", + "}\n", + "\n", + "#sk-container-id-2 div.sk-parallel-item:only-child::after {\n", + " width: 0;\n", + "}\n", + "\n", + "/* Serial-specific style estimator block */\n", + "\n", + "#sk-container-id-2 div.sk-serial {\n", + " display: flex;\n", + " flex-direction: column;\n", + " align-items: center;\n", + " background-color: var(--sklearn-color-background);\n", + " padding-right: 1em;\n", + " padding-left: 1em;\n", + "}\n", + "\n", + "\n", + "/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\n", + "clickable and can be expanded/collapsed.\n", + "- Pipeline and ColumnTransformer use this feature and define the default style\n", + "- Estimators will overwrite some part of the style using the `sk-estimator` class\n", + "*/\n", + "\n", + "/* Pipeline and ColumnTransformer style (default) */\n", + "\n", + "#sk-container-id-2 div.sk-toggleable {\n", + " /* Default theme specific background. It is overwritten whether we have a\n", + " specific estimator or a Pipeline/ColumnTransformer */\n", + " background-color: var(--sklearn-color-background);\n", + "}\n", + "\n", + "/* Toggleable label */\n", + "#sk-container-id-2 label.sk-toggleable__label {\n", + " cursor: pointer;\n", + " display: flex;\n", + " width: 100%;\n", + " margin-bottom: 0;\n", + " padding: 0.5em;\n", + " box-sizing: border-box;\n", + " text-align: center;\n", + " align-items: start;\n", + " justify-content: space-between;\n", + " gap: 0.5em;\n", + "}\n", + "\n", + "#sk-container-id-2 label.sk-toggleable__label .caption {\n", + " font-size: 0.6rem;\n", + " font-weight: lighter;\n", + " color: var(--sklearn-color-text-muted);\n", + "}\n", + "\n", + "#sk-container-id-2 label.sk-toggleable__label-arrow:before {\n", + " /* Arrow on the left of the label */\n", + " content: \"▸\";\n", + " float: left;\n", + " margin-right: 0.25em;\n", + " color: var(--sklearn-color-icon);\n", + "}\n", + "\n", + "#sk-container-id-2 label.sk-toggleable__label-arrow:hover:before {\n", + " color: var(--sklearn-color-text);\n", + "}\n", + "\n", + "/* Toggleable content - dropdown */\n", + "\n", + "#sk-container-id-2 div.sk-toggleable__content {\n", + " max-height: 0;\n", + " max-width: 0;\n", + " overflow: hidden;\n", + " text-align: left;\n", + " /* unfitted */\n", + " background-color: var(--sklearn-color-unfitted-level-0);\n", + "}\n", + "\n", + "#sk-container-id-2 div.sk-toggleable__content.fitted {\n", + " /* fitted */\n", + " background-color: var(--sklearn-color-fitted-level-0);\n", + "}\n", + "\n", + "#sk-container-id-2 div.sk-toggleable__content pre {\n", + " margin: 0.2em;\n", + " border-radius: 0.25em;\n", + " color: var(--sklearn-color-text);\n", + " /* unfitted */\n", + " background-color: var(--sklearn-color-unfitted-level-0);\n", + "}\n", + "\n", + "#sk-container-id-2 div.sk-toggleable__content.fitted pre {\n", + " /* unfitted */\n", + " background-color: var(--sklearn-color-fitted-level-0);\n", + "}\n", + "\n", + "#sk-container-id-2 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n", + " /* Expand drop-down */\n", + " max-height: 200px;\n", + " max-width: 100%;\n", + " overflow: auto;\n", + "}\n", + "\n", + "#sk-container-id-2 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n", + " content: \"▾\";\n", + "}\n", + "\n", + "/* Pipeline/ColumnTransformer-specific style */\n", + "\n", + "#sk-container-id-2 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n", + " color: var(--sklearn-color-text);\n", + " background-color: var(--sklearn-color-unfitted-level-2);\n", + "}\n", + "\n", + "#sk-container-id-2 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n", + " background-color: var(--sklearn-color-fitted-level-2);\n", + "}\n", + "\n", + "/* Estimator-specific style */\n", + "\n", + "/* Colorize estimator box */\n", + "#sk-container-id-2 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n", + " /* unfitted */\n", + " background-color: var(--sklearn-color-unfitted-level-2);\n", + "}\n", + "\n", + "#sk-container-id-2 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n", + " /* fitted */\n", + " background-color: var(--sklearn-color-fitted-level-2);\n", + "}\n", + "\n", + "#sk-container-id-2 div.sk-label label.sk-toggleable__label,\n", + "#sk-container-id-2 div.sk-label label {\n", + " /* The background is the default theme color */\n", + " color: var(--sklearn-color-text-on-default-background);\n", + "}\n", + "\n", + "/* On hover, darken the color of the background */\n", + "#sk-container-id-2 div.sk-label:hover label.sk-toggleable__label {\n", + " color: var(--sklearn-color-text);\n", + " background-color: var(--sklearn-color-unfitted-level-2);\n", + "}\n", + "\n", + "/* Label box, darken color on hover, fitted */\n", + "#sk-container-id-2 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n", + " color: var(--sklearn-color-text);\n", + " background-color: var(--sklearn-color-fitted-level-2);\n", + "}\n", + "\n", + "/* Estimator label */\n", + "\n", + "#sk-container-id-2 div.sk-label label {\n", + " font-family: monospace;\n", + " font-weight: bold;\n", + " display: inline-block;\n", + " line-height: 1.2em;\n", + "}\n", + "\n", + "#sk-container-id-2 div.sk-label-container {\n", + " text-align: center;\n", + "}\n", + "\n", + "/* Estimator-specific */\n", + "#sk-container-id-2 div.sk-estimator {\n", + " font-family: monospace;\n", + " border: 1px dotted var(--sklearn-color-border-box);\n", + " border-radius: 0.25em;\n", + " box-sizing: border-box;\n", + " margin-bottom: 0.5em;\n", + " /* unfitted */\n", + " background-color: var(--sklearn-color-unfitted-level-0);\n", + "}\n", + "\n", + "#sk-container-id-2 div.sk-estimator.fitted {\n", + " /* fitted */\n", + " background-color: var(--sklearn-color-fitted-level-0);\n", + "}\n", + "\n", + "/* on hover */\n", + "#sk-container-id-2 div.sk-estimator:hover {\n", + " /* unfitted */\n", + " background-color: var(--sklearn-color-unfitted-level-2);\n", + "}\n", + "\n", + "#sk-container-id-2 div.sk-estimator.fitted:hover {\n", + " /* fitted */\n", + " background-color: var(--sklearn-color-fitted-level-2);\n", + "}\n", + "\n", + "/* Specification for estimator info (e.g. \"i\" and \"?\") */\n", + "\n", + "/* Common style for \"i\" and \"?\" */\n", + "\n", + ".sk-estimator-doc-link,\n", + "a:link.sk-estimator-doc-link,\n", + "a:visited.sk-estimator-doc-link {\n", + " float: right;\n", + " font-size: smaller;\n", + " line-height: 1em;\n", + " font-family: monospace;\n", + " background-color: var(--sklearn-color-background);\n", + " border-radius: 1em;\n", + " height: 1em;\n", + " width: 1em;\n", + " text-decoration: none !important;\n", + " margin-left: 0.5em;\n", + " text-align: center;\n", + " /* unfitted */\n", + " border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n", + " color: var(--sklearn-color-unfitted-level-1);\n", + "}\n", + "\n", + ".sk-estimator-doc-link.fitted,\n", + "a:link.sk-estimator-doc-link.fitted,\n", + "a:visited.sk-estimator-doc-link.fitted {\n", + " /* fitted */\n", + " border: var(--sklearn-color-fitted-level-1) 1pt solid;\n", + " color: var(--sklearn-color-fitted-level-1);\n", + "}\n", + "\n", + "/* On hover */\n", + "div.sk-estimator:hover .sk-estimator-doc-link:hover,\n", + ".sk-estimator-doc-link:hover,\n", + "div.sk-label-container:hover .sk-estimator-doc-link:hover,\n", + ".sk-estimator-doc-link:hover {\n", + " /* unfitted */\n", + " background-color: var(--sklearn-color-unfitted-level-3);\n", + " color: var(--sklearn-color-background);\n", + " text-decoration: none;\n", + "}\n", + "\n", + "div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n", + ".sk-estimator-doc-link.fitted:hover,\n", + "div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n", + ".sk-estimator-doc-link.fitted:hover {\n", + " /* fitted */\n", + " background-color: var(--sklearn-color-fitted-level-3);\n", + " color: var(--sklearn-color-background);\n", + " text-decoration: none;\n", + "}\n", + "\n", + "/* Span, style for the box shown on hovering the info icon */\n", + ".sk-estimator-doc-link span {\n", + " display: none;\n", + " z-index: 9999;\n", + " position: relative;\n", + " font-weight: normal;\n", + " right: .2ex;\n", + " padding: .5ex;\n", + " margin: .5ex;\n", + " width: min-content;\n", + " min-width: 20ex;\n", + " max-width: 50ex;\n", + " color: var(--sklearn-color-text);\n", + " box-shadow: 2pt 2pt 4pt #999;\n", + " /* unfitted */\n", + " background: var(--sklearn-color-unfitted-level-0);\n", + " border: .5pt solid var(--sklearn-color-unfitted-level-3);\n", + "}\n", + "\n", + ".sk-estimator-doc-link.fitted span {\n", + " /* fitted */\n", + " background: var(--sklearn-color-fitted-level-0);\n", + " border: var(--sklearn-color-fitted-level-3);\n", + "}\n", + "\n", + ".sk-estimator-doc-link:hover span {\n", + " display: block;\n", + "}\n", + "\n", + "/* \"?\"-specific style due to the `<a>` HTML tag */\n", + "\n", + "#sk-container-id-2 a.estimator_doc_link {\n", + " float: right;\n", + " font-size: 1rem;\n", + " line-height: 1em;\n", + " font-family: monospace;\n", + " background-color: var(--sklearn-color-background);\n", + " border-radius: 1rem;\n", + " height: 1rem;\n", + " width: 1rem;\n", + " text-decoration: none;\n", + " /* unfitted */\n", + " color: var(--sklearn-color-unfitted-level-1);\n", + " border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n", + "}\n", + "\n", + "#sk-container-id-2 a.estimator_doc_link.fitted {\n", + " /* fitted */\n", + " border: var(--sklearn-color-fitted-level-1) 1pt solid;\n", + " color: var(--sklearn-color-fitted-level-1);\n", + "}\n", + "\n", + "/* On hover */\n", + "#sk-container-id-2 a.estimator_doc_link:hover {\n", + " /* unfitted */\n", + " background-color: var(--sklearn-color-unfitted-level-3);\n", + " color: var(--sklearn-color-background);\n", + " text-decoration: none;\n", + "}\n", + "\n", + "#sk-container-id-2 a.estimator_doc_link.fitted:hover {\n", + " /* fitted */\n", + " background-color: var(--sklearn-color-fitted-level-3);\n", + "}\n", + "</style><div id=\"sk-container-id-2\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>SGDRegressor(alpha=1e-05, eta0=0.001, learning_rate='adaptive',\n", + " penalty='elasticnet', random_state=42)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-2\" type=\"checkbox\" checked><label for=\"sk-estimator-id-2\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow\"><div><div>SGDRegressor</div></div><div><a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.6/modules/generated/sklearn.linear_model.SGDRegressor.html\">?<span>Documentation for SGDRegressor</span></a><span class=\"sk-estimator-doc-link fitted\">i<span>Fitted</span></span></div></label><div class=\"sk-toggleable__content fitted\"><pre>SGDRegressor(alpha=1e-05, eta0=0.001, learning_rate='adaptive',\n", + " penalty='elasticnet', random_state=42)</pre></div> </div></div></div></div>" + ], + "text/plain": [ + "SGDRegressor(alpha=1e-05, eta0=0.001, learning_rate='adaptive',\n", + " penalty='elasticnet', random_state=42)" + ] + }, + "execution_count": 52, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from sklearn.linear_model import SGDRegressor\n", + "\n", + "# Initialize the model\n", + "model = SGDRegressor(\n", + " loss='squared_error', # standard linear regression loss\n", + " penalty='elasticnet', # regularization (optional: 'l2', 'l1', 'elasticnet')\n", + " alpha=1e-5, # regularization strength\n", + " learning_rate='adaptive',\n", + " eta0=0.001,\n", + " max_iter=1000,\n", + " random_state=42\n", + ")\n", + "\n", + "# Train the model\n", + "model.fit(X_train_scaled, y_train)" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "id": "c869d2d8-7d2a-42cb-85cc-a33c31dd4e1a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Mean Squared Error: 0.0203\n", + "R² Score : 0.2349\n", + "Mean Absolute Error: 0.1038\n" + ] + } + ], + "source": [ + "from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error\n", + "\n", + "# Make predictions\n", + "y_pred = model.predict(X_test_scaled)\n", + "\n", + "# Evaluate\n", + "mse = mean_squared_error(y_test, y_pred)\n", + "r2 = r2_score(y_test, y_pred)\n", + "mae = mean_absolute_error(y_test, y_pred)\n", + "\n", + "print(f\"Mean Squared Error: {mse:.4f}\")\n", + "print(f\"R² Score : {r2:.4f}\")\n", + "print(f\"Mean Absolute Error: {mae:.4f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "id": "5f6f268e-e8a5-4f9e-9bc1-02474897aa02", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Fitting 5 folds for each of 20 candidates, totalling 100 fits\n", + "Best Parameters: {'sgd__penalty': 'elasticnet', 'sgd__max_iter': 1500, 'sgd__learning_rate': 'adaptive', 'sgd__eta0': 0.005, 'sgd__alpha': 0.0001}\n", + "Tuned MSE: 0.020302359461082282\n", + "Tuned R² : 0.23538706119742847\n", + "Tuned MAE: 0.10378416921028938\n" + ] + } + ], + "source": [ + "from sklearn.model_selection import RandomizedSearchCV\n", + "from sklearn.linear_model import SGDRegressor\n", + "from sklearn.pipeline import Pipeline\n", + "from sklearn.preprocessing import StandardScaler\n", + "\n", + "# Pipeline with scaler and model\n", + "pipeline = Pipeline([\n", + " ('scaler', StandardScaler()),\n", + " ('sgd', SGDRegressor(loss='squared_error', random_state=42))\n", + "])\n", + "\n", + "# Parameter grid\n", + "param_distributions = {\n", + " 'sgd__penalty': ['l2', 'elasticnet'], # remove 'l1' for stability\n", + " 'sgd__alpha': [1e-5, 1e-4, 1e-3], # reduce strength\n", + " 'sgd__learning_rate': ['constant', 'adaptive'],\n", + " 'sgd__eta0': [0.0005, 0.001, 0.005], # smaller learning rate\n", + " 'sgd__max_iter': [1000, 1500]\n", + "}\n", + "\n", + "# Randomized search\n", + "random_search = RandomizedSearchCV(\n", + " estimator=pipeline,\n", + " param_distributions=param_distributions,\n", + " n_iter=20, # You can increase for better search\n", + " scoring='neg_mean_squared_error',\n", + " cv=5,\n", + " verbose=2,\n", + " random_state=42,\n", + " n_jobs=-1\n", + ")\n", + "\n", + "# Fit to data\n", + "random_search.fit(X_train, y_train)\n", + "\n", + "# Best model and evaluation\n", + "best_model = random_search.best_estimator_\n", + "y_pred = best_model.predict(X_test)\n", + "\n", + "from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error\n", + "print(\"Best Parameters:\", random_search.best_params_)\n", + "print(\"Tuned MSE:\", mean_squared_error(y_test, y_pred))\n", + "print(\"Tuned R² :\", r2_score(y_test, y_pred))\n", + "print(\"Tuned MAE:\", mean_absolute_error(y_test, y_pred))" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "id": "d30ba3b4-d9e1-4185-852c-ee26b13f0fed", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Poly MSE : 0.0242\n", + "Poly R² : 0.0878\n", + "Poly MAE : 0.1235\n" + ] + } + ], + "source": [ + "from sklearn.preprocessing import PolynomialFeatures\n", + "from sklearn.pipeline import Pipeline\n", + "from sklearn.linear_model import SGDRegressor\n", + "from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error\n", + "\n", + "# Create pipeline: polynomial transformation + scaling + SGD\n", + "poly_pipeline = Pipeline([\n", + " ('poly', PolynomialFeatures(degree=2, include_bias=False)), # You can try degree=2 or 3\n", + " ('scaler', StandardScaler()),\n", + " ('sgd', SGDRegressor(\n", + " loss='squared_error',\n", + " penalty='elasticnet',\n", + " alpha=0.001,\n", + " learning_rate='adaptive',\n", + " eta0=0.0001,\n", + " max_iter=1500,\n", + " random_state=42\n", + " ))\n", + "])\n", + "\n", + "# Train on original train data (X_train, y_train — not scaled or poly-transformed yet)\n", + "poly_pipeline.fit(X_train, y_train)\n", + "\n", + "# Predict on X_test\n", + "y_poly_pred = poly_pipeline.predict(X_test)\n", + "\n", + "# Evaluate\n", + "mse = mean_squared_error(y_test, y_poly_pred)\n", + "r2 = r2_score(y_test, y_poly_pred)\n", + "mae = mean_absolute_error(y_test, y_poly_pred)\n", + "\n", + "print(f\"Poly MSE : {mse:.4f}\")\n", + "print(f\"Poly R² : {r2:.4f}\")\n", + "print(f\"Poly MAE : {mae:.4f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ada8b857-bc21-4019-b8ab-02e06e071094", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.1" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}