iui-group-l-name-zensiert/0-pilot-project/MNIST-kNN-base.ipynb

760 lines
71 KiB
Plaintext
Raw Normal View History

{
"cells": [
{
"cell_type": "markdown",
"id": "2507dc1b",
"metadata": {},
"source": [
"### Load MNIST dataset"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "d214bb2f",
"metadata": {},
"outputs": [],
"source": [
"# Python ≥3.5 is required\n",
"import sys\n",
"assert sys.version_info >= (3, 5)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "056cad96",
"metadata": {},
"outputs": [],
"source": [
"# scikit-learn ≥0.20 is required\n",
"import sklearn\n",
"assert sklearn.__version__ >= \"0.20\""
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "80c92d8a",
"metadata": {},
"outputs": [],
"source": [
"# common imports\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "e07cdb1a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"sklearn.utils.Bunch"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# import function to scikit-learn datasets\n",
"from sklearn.datasets import fetch_openml\n",
"\n",
"# load specified dataset (MNIST)\n",
"mnist = fetch_openml('mnist_784', version=1, as_frame=False)\n",
"\n",
"# print type of dataset\n",
"type(mnist)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "8ccbd6b7",
"metadata": {},
"outputs": [],
"source": [
"X, y = mnist[\"data\"], mnist[\"target\"]"
]
},
{
"cell_type": "markdown",
"id": "78c78c04",
"metadata": {},
"source": [
"### Plot data"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "2963a0bf",
"metadata": {},
"outputs": [],
"source": [
"# import plotting libraries\n",
"import matplotlib as mpl\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "a7f1c08b",
"metadata": {},
"outputs": [],
"source": [
"# convert string labels to int\n",
"y = y.astype(np.uint8)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "b33ff35b",
"metadata": {},
"outputs": [],
"source": [
"# function to quickly plot an image\n",
"def plot_digit(data):\n",
" image = data.reshape(28, 28)\n",
" plt.imshow(image, cmap = mpl.cm.binary, interpolation=\"nearest\")\n",
" plt.axis(\"off\")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "89a76fba",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAOcAAADnCAYAAADl9EEgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAGPklEQVR4nO3dT4hN/x/H8d9o/FspmUHZkY2ylI2NoRhTLGQpG0lSygZplpTZScrGRrMQUrOZlNlNipWYpmxZ2fgzhRjMb/crNed9/e4Y8zrj8Vh6de495fv8nvLp3tszNzf3HyDPiqW+AWB+4oRQ4oRQ4oRQ4oRQvR12/5QLi69nvj/05IRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQvUt9A2305s2bcr969Wq5v3z5snGbnJzs6p5+1+HDh8v9/fv3jduOHTvKa3ft2lXuJ06cKHd+5ckJocQJocQJocQJocQJocQJocQJoXrm5uaqvRzb6vnz5+V+7dq1cn/y5Em5v379+v+9pf/ZsGFDuW/fvr3cO93bYurr6yv3t2/f/qU7aZ2e+f7QkxNCiRNCiRNCiRNCiRNCiRNCtfYjY3fu3Cn306dPN26zs7PltZ32gYGBch8bGyv3bdu2NW4rVtT/v+ztrf/Kvn37Vu4HDhwo98X+yBq/z5MTQokTQokTQokTQokTQokTQokTQrX2nPPjx4/l/vnz565fe+PGjeU+MjJS7jt37uz6vReq0zlop3PUhRgaGlq01/4XeXJCKHFCKHFCKHFCKHFCKHFCKHFCqNZ+NeaPHz/Kvfopu05WrlxZ7uvWrev6tRfb1NRUuXc6i6y+1nPNmjXltffv3y/3wcHBcv+H+WpMaBNxQihxQihxQihxQihxQihxQqjWnnMyv9WrV5d7p+/krc4yL1y4UF47PDxc7jRyzgltIk4IJU4IJU4IJU4IJU4IJU4I1drvrW2zmZmZxu3u3bvltVeuXCn3TueYq1atKvdLly41bpcvXy6v5c/y5IRQ4oRQ4oRQ4oRQ4oRQ4oRQjlK68OnTp3I/efJkuY+PjzdunX7acKH27NlT7sePH1/U9+f3eXJCKHFCKHFCKHFCKHFCKHFCKHFCKF+N2YUPHz6U+6ZNm8r958+fjdv379+7uaU/pr+/v3Fbv359ee2pU6fK/ezZs+W+YsU/+6zw1ZjQJuKEUOKEUOKEUOKEUOKEUOKEUM45l8DU1FTj9uzZswW99vXr18v9xYsXC3r9hRgYGCj30dHRxq06f10GnHNCm4gTQokTQokTQokTQokTQokTQjnnXGa+fPlS7tPT0+X++PHjxu3ixYtd3dPvGhsba9yGhoYW9b2XmHNOaBNxQihxQihxQihxQihxQihxQijnnPyi+u9hcHCwvPbRo0cLeu/z5883biMjIwt67XDOOaFNxAmhxAmhxAmhxAmhxAmhepf6BsjS0zPvv+p33P6ErVu3Lurrt40nJ4QSJ4QSJ4QSJ4QSJ4QSJ4QSJ4Ryzskv7t2717hNTEws6nvv27dvUV+/bTw5IZQ4IZQ4IZQ4IZQ4IZQ4IZQ4IZRzzn/M5ORkuQ8PDzdus7OzC3rvI0eOlPvmzZsX9PrLjScnhBInhBInhBInhBInhBInhBInhPITgMvM7du3y/3MmTPl/vXr167fe8uWLeX+6tWrcl+7dm3X791yfgIQ2kScEEqcEEqcEEqcEEqcEMpHxsJMT0+X+40bN8r91q1b5d7h6KzU19dX7g8ePCj3f/iopCuenBBKnBBKnBBKnBBKnBBKnBBKnBBq2Z5zVueF4+Pj5bUHDx4s93fv3pX706dPy31qaqpxe/jwYXntzMxMuXfS21v/lR86dKhxu3nzZnmtr7b8szw5IZQ4IZQ4IZQ4IZQ4IZQ4IZQ4IdSy/WrM/fv3N24TExN/8U7+rt27d5f7uXPnyv3YsWN/8G74Tb4aE9pEnBBKnBBKnBBKnBBKnBBKnBBq2X6e8+jRo41b8jlnf39/uY+Ojpb73r17y72nZ94jNQJ5ckIocUIocUIocUIocUIocUIocUKoZft5TmgRn+eENhEnhBInhBInhBInhBInhBInhBInhBInhBInhBInhBInhBInhBInhBInhBInhBInhBInhBInhBInhBInhBInhOr0E4B+Lw6WiCcnhBInhBInhBInhBInhBInhPovMLcDdQGgUUMAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plot_digit(X[10000])"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "0f59fcd5",
"metadata": {},
"outputs": [],
"source": [
"# function to quickly plot several digits\n",
"def plot_digits(instances, **options):\n",
" size = 28\n",
" images = [instance.reshape(size,size) for instance in instances]\n",
" image = np.concatenate(images, axis=1)\n",
" plt.imshow(image, cmap = mpl.cm.binary, **options)\n",
" plt.axis(\"off\")"
]
},
{
"cell_type": "markdown",
"id": "ff272b79",
"metadata": {},
"source": [
"### Prepare data for machine learning"
]
},
{
"cell_type": "markdown",
"id": "a330fd86",
"metadata": {},
"source": [
"### Identify Train Set and Test Set"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "36f7f273",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"X_train: 56000, (56000, 784)\n",
"X_test: 14000, (14000, 784)\n",
"y_train: 56000, (56000,)\n",
"y_test: 14000, (14000,)\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAOcAAADnCAYAAADl9EEgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAFiElEQVR4nO3dsUtVfRzH8XMfosGGFglcUhCiEFxcw/oDrJbwf4jIxVZx0EUHy6GhtcYIxwhaiv4GG1zChlYdbBDiPsuzRJ7fKY/nuZ+jr9f4fLn3dyje/OD5cjuD4XBYAXn+GfUDACcTJ4QSJ4QSJ4QSJ4S61DD3v3Khe4OT/qObE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0JdGvUD1Pn69WtxvrCwUJzv7u7Wzra3t4ufnZycLM7v379fnPfZq1evameHh4etvvvJkyetPn/RuDkhlDghlDghlDghlDghlDghlDgh1GA4HJbmxWGXfvz4UZw/ffq0OH/58uWpz75y5UpxPjExcervrqqqKv2ZDwaDVt/d5uyqqqpv377Vzo6Pj1udPT09XZx/+PChdnb9+vVWZ4c78S/dzQmhxAmhxAmhxAmhxAmhxAmhxAmhYvecTZp+7zk3N1c7Ozg4ONuH+UvJe84uz286u7S73tzcPOvHSWLPCX0iTgglTgglTgglTgglTgglTggV++/WNpmamirOd3Z2amcvXrwofrZp1/fmzZvivEuPHj0qzm/evFmcLy0tneXjnKlbt26N+hGiuDkhlDghlDghlDghlDghlDghlDghVG9/z0k33r17Vztreidqk6bfc+7t7dXOmv7N257ze07oE3FCKHFCKHFCKHFCKHFCqN7+ZIxurK2tjfoR+I+bE0KJE0KJE0KJE0KJE0KJE0KJE0LZc14wTa8//P79e2dnP3jwoDifmJjo7Ow+cnNCKHFCKHFCKHFCKHFCKHFCKHFCKHvOC+b169fF+f7+fmdnT05OFudjY2Odnd1Hbk4IJU4IJU4IJU4IJU4IJU4IJU4IZc95wSwtLRXng8GJb6M7E7dv3+7su88jNyeEEieEEieEEieEEieEEieEEieEsuc8Z3Z3d4vzpj1mmz3nyspKcf7w4cNTf/dF5OaEUOKEUOKEUOKEUOKEUOKEUFYpPXN8fFycb25u/k9P8rvx8fGRnX0euTkhlDghlDghlDghlDghlDghlDghlD1nz2xtbRXnTa/4a+PatWvF+fz8fGdnX0RuTgglTgglTgglTgglTgglTgglTghlz9kznz59Ks6Hw2Grecm9e/eK89nZ2VN/N79zc0IocUIocUIocUIocUIocUIocUIoe84wHz9+LM4/f/5cnLd5hV9VVdXdu3drZ8+fP2/13fwdNyeEEieEEieEEieEEieEEieEskoZgYODg9rZxsZG8bNHR0dn/DS/mpqaqp2NjY11eja/cnNCKHFCKHFCKHFCKHFCKHFCKHFCKHvOESi9pu/9+/ednn3nzp3i/NmzZ52ez59zc0IocUIocUIocUIocUIocUIocUIoe84RKL3Gr80r+v7E48ePi/OrV692ej5/zs0JocQJocQJocQJocQJocQJocQJoQYNe7Vul27n1NraWnG+urpaO2v7Cr8mP3/+7PT7OZUT/9LdnBBKnBBKnBBKnBBKnBBKnBBKnBDKnrMDX758Kc5nZmZqZ017zsuXLxfny8vLxfn6+npxzkjYc0KfiBNCiRNCiRNCiRNCiRNCWaWMwOLiYu3s7du3xc/euHGjOG9a4xDJKgX6RJwQSpwQSpwQSpwQSpwQSpwQyp4TRs+eE/pEnBBKnBBKnBBKnBBKnBBKnBDqUsO82/fRAbXcnBBKnBBKnBBKnBBKnBBKnBDqX1Mfv8Wjc6DfAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"from sklearn.model_selection import train_test_split\n",
"\n",
"k = 3\n",
"\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1337)\n",
"\n",
"print(f\"X_train: {len(X_train)}, {X_train.shape}\")\n",
"print(f\"X_test: {len(X_test)}, {X_test.shape}\")\n",
"print(f\"y_train: {len(y_train)}, {y_train.shape}\")\n",
"print(f\"y_test: {len(y_test)}, {y_test.shape}\")\n",
"\n",
"plot_digit(X_test[10000])"
]
},
{
"cell_type": "markdown",
"id": "b3267043",
"metadata": {},
"source": [
"## Train kNN classifier"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "afb277c8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 107 ms, sys: 3.58 ms, total: 111 ms\n",
"Wall time: 110 ms\n"
]
},
{
"data": {
"text/plain": [
"KNeighborsClassifier(n_neighbors=3)"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"from sklearn.neighbors import KNeighborsClassifier\n",
"\n",
"classifier = KNeighborsClassifier(n_neighbors=k)\n",
"classifier.fit(X_train, y_train)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "e30c9192",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"3\n"
]
}
],
"source": [
"# take a test digit\n",
"td = 4000\n",
"test_digit = X_test[td]\n",
"print(y_test[td])"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "2e4676ab",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAOcAAADnCAYAAADl9EEgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAFf0lEQVR4nO3dsUtVfRzH8eODYA2ZoBYtLUVjBIHgYotbQwi51VjQFEFD6BBE/4CCtImTs6NEUNQi4iY0VVMQNES5RBD4bA/E4/leu3rzc/P1GvtwbhfizYF+nHMHdnd3GyDPP0f9BYC9iRNCiRNCiRNCiRNCDXbY/Vcu9N7AXn/ozgmhxAmhxAmhxAmhxAmhxAmhxAmhxAmhxAmhxAmhxAmhxAmhxAmhxAmhxAmhxAmhxAmhxAmhxAmhxAmhxAmhxAmhxAmhxAmhxAmhxAmhxAmhxAmhxAmhxAmhxAmhxAmhxAmhxAmhxAmhxAmhxAmhxAmhxAmhxAmhxAmhxAmhxAmhxAmhxAmhxAmhBo/6C/CrM2fOlPv09PSBPn92drbcZ2ZmDvT5HB53TgglTgglTgglTgglTgglTgjlKCXMjRs3yv358+flfvXq1XK/fft2uT99+rR1u3fvXnnt0NBQufN73DkhlDghlDghlDghlDghlDghlDgh1MDu7m61lyN/3s7OTrmvrKyU+/3798v9xIkTrdvHjx/La0dHR8udVgN7/aE7J4QSJ4QSJ4QSJ4QSJ4QSJ4QSJ4TyPGefGR4eLvfz588f6PMfP37cup06depAn83vceeEUOKEUOKEUOKEUOKEUOKEUOKEUJ7n/Mtsb2+X+9u3b8t9amqqdRsfHy+vHRx0bN4lz3NCPxEnhBInhBInhBInhBInhBInhHIwdQS+ffvWuj158qS89sOHD+W+trZW7iMjI+W+ubnZup07d668lsPlzgmhxAmhxAmhxAmhxAmhxAmhHKX0QHVU0jRNMzs727q9ePHiQH/3wMCeTx/95+vXr+W+urrauj169Ki8dmhoqNz5Pe6cEEqcEEqcEEqcEEqcEEqcEEqcEMqrMXtga2ur3CcmJv7QN/m/Dv/ezeXLl1u3TmewnV6dSSuvxoR+Ik4IJU4IJU4IJU4IJU4IJU4I5XnOHlhfX+/62unp6XK/e/du15/dNE3z+vXrcn/58mXrtrCwUF57/fr1cp+cnCx3fuXOCaHECaHECaHECaHECaHECaHECaE8z9kDP378KPf379+3bpcuXSqvHRzs7dH0zs5O61a907ZpmmZsbKzcb9682dV3OgY8zwn9RJwQSpwQSpwQSpwQSpwQSpwQyjkn+/b58+dy7/T7nc+ePSv3Y/z7ns45oZ+IE0KJE0KJE0KJE0KJE0J5NSaHZmVlpdwvXrxY7nNzc4f4bfqfOyeEEieEEieEEieEEieEEieEEieEcs7JvnX6+cDTp0+Xe6dHzviVOyeEEieEEieEEieEEieEEieEEieE8mpM9u3s2bPlvrS0VO5Xrlwp907Pe/7FvBoT+ok4IZQ4IZQ4IZQ4IZQ4IZQ4IZTnOdm3kydPlvuFCxfK/RifY3bFnRNCiRNCiRNCiRNCiRNCiRNCOUo5Zt68eVPuU1NTrVunV1uOj4939Z3YmzsnhBInhBInhBInhBInhBInhBInhHLO+Zd59+5duW9sbJT7q1evWreRkZEuvhHdcueEUOKEUOKEUOKEUOKEUOKEUOKEUH4CsAd+/vxZ7l++fGndlpeXy2tHR0fL/fv37+V+586dcu/0+kt6wk8AQj8RJ4QSJ4QSJ4QSJ4QSJ4QSJ4TyPGcPLC4ulvvDhw+7/uwHDx6U+61bt8rdOWb/cOeEUOKEUOKEUOKEUOKEUOKEUOKEUJ7n7IFr166V+6dPn1q3+fn58tqZmZlyHx4eLncieZ4T+ok4IZQ4IZQ4IZQ4IZQ4IZSjFDh6jlKgn4gTQokTQokTQokTQokTQokTQokTQokTQokTQokTQokTQokTQokTQokTQnX6CcA9nzMDes+dE0KJE0KJE0KJE0KJE0KJE0L9C3ketmqhMN1TAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plot_digit(test_digit)"
]
},
{
"cell_type": "markdown",
"id": "2d53bd60",
"metadata": {},
"source": [
"### Evaluation"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "cae65dc4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy Train 98.57321428571429\n",
"CPU times: user 7min 43s, sys: 14min 20s, total: 22min 4s\n",
"Wall time: 51.7 s\n"
]
}
],
"source": [
"%%time\n",
"\n",
"# trainings accuracy\n",
"wrong_images = X_train[(classifier.predict(X_train)-y_train) != 0]\n",
"percentage = ((1-len(wrong_images)/len(X_train)) * 100)\n",
"print(\"Accuracy Train \" + str(percentage))"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "b551da35",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy Test 97.14285714285714\n",
"CPU times: user 1min 54s, sys: 3min 39s, total: 5min 34s\n",
"Wall time: 13.2 s\n"
]
}
],
"source": [
"%%time\n",
"\n",
"# test accuracy\n",
"wrong_images = X_test[(classifier.predict(X_test)-y_test) != 0]\n",
"percentage = ((1-len(wrong_images)/len(X_test)) * 100)\n",
"print(\"Accuracy Test \" + str(percentage))"
]
},
{
"cell_type": "markdown",
"id": "22c0058c",
"metadata": {},
"source": [
"Accuracy is strongly influenced by the distribution of the classes in the test data."
]
},
{
"cell_type": "markdown",
"id": "351768b0",
"metadata": {},
"source": [
"#### Cross Validation\n",
"[Find more information on cross validation here.](https://scikit-learn.org/stable/modules/cross_validation.html#cross-validation)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "7d943683",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.96866127 0.9686077 0.9694632 ]\n",
"CPU times: user 4min 38s, sys: 10min, total: 14min 38s\n",
"Wall time: 33.7 s\n"
]
}
],
"source": [
"%%time\n",
"\n",
"# cross validation score\n",
"from sklearn.model_selection import cross_val_score\n",
"\n",
"print(cross_val_score(classifier, X_train, y_train, cv=3, scoring=\"accuracy\"))"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "d49c670c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[5 8 0 ... 9 3 7]\n",
"CPU times: user 5min 52s, sys: 11min 7s, total: 17min\n",
"Wall time: 41.3 s\n"
]
}
],
"source": [
"%%time\n",
"\n",
"# prediction of classifier\n",
"from sklearn.model_selection import cross_val_predict\n",
"\n",
"y_train_pred = cross_val_predict(classifier, X_train, y_train, cv=5)\n",
"print(y_train_pred)"
]
},
{
"cell_type": "markdown",
"id": "182a1c52",
"metadata": {},
"source": [
"#### Accuracy"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "1d66c93e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.970375"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.metrics import accuracy_score\n",
"\n",
"accuracy_score(y_train, y_train_pred)"
]
},
{
"cell_type": "markdown",
"id": "2c1c857b",
"metadata": {},
"source": [
"#### Precision"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "327ca201",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.9705500243024229"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.metrics import precision_score\n",
"\n",
"precision_score(y_train, y_train_pred, average='weighted')"
]
},
{
"cell_type": "markdown",
"id": "1a051995",
"metadata": {},
"source": [
"#### Recall"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "0852a4e4",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.970375"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.metrics import recall_score\n",
"\n",
"recall_score(y_train, y_train_pred, average='weighted')"
]
},
{
"cell_type": "markdown",
"id": "4848e2f8",
"metadata": {},
"source": [
"#### F1 Score"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "d5d2fe0e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.9703221334372104"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.metrics import f1_score\n",
"\n",
"f1_score(y_train, y_train_pred, average='weighted')"
]
},
{
"cell_type": "markdown",
"id": "3acab292",
"metadata": {},
"source": [
"#### Confusion Matrix"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "f8f613b8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[5464 4 3 0 0 7 16 2 0 3]\n",
" [ 1 6252 9 2 3 1 4 10 2 3]\n",
" [ 35 49 5396 14 5 4 11 67 11 3]\n",
" [ 5 15 36 5475 2 67 3 30 30 16]\n",
" [ 4 45 4 1 5260 0 13 9 2 112]\n",
" [ 17 9 4 60 10 4887 46 5 9 21]\n",
" [ 23 17 1 0 5 14 5480 0 2 0]\n",
" [ 3 67 16 0 12 1 0 5691 1 55]\n",
" [ 20 69 28 85 27 69 19 15 5134 38]\n",
" [ 10 16 10 44 61 15 3 63 6 5302]]\n"
]
}
],
"source": [
"# confusing matrix\n",
"from sklearn.metrics import confusion_matrix\n",
"\n",
"print(confusion_matrix(y_train, y_train_pred))"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "57dc8b25",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[9.93635206e-01 7.27404983e-04 5.45553737e-04 0.00000000e+00\n",
" 0.00000000e+00 1.27295872e-03 2.90961993e-03 3.63702491e-04\n",
" 0.00000000e+00 5.45553737e-04]\n",
" [1.59058374e-04 9.94432957e-01 1.43152537e-03 3.18116749e-04\n",
" 4.77175123e-04 1.59058374e-04 6.36233498e-04 1.59058374e-03\n",
" 3.18116749e-04 4.77175123e-04]\n",
" [6.25558534e-03 8.75781948e-03 9.64432529e-01 2.50223414e-03\n",
" 8.93655049e-04 7.14924039e-04 1.96604111e-03 1.19749777e-02\n",
" 1.96604111e-03 5.36193029e-04]\n",
" [8.80436697e-04 2.64131009e-03 6.33914422e-03 9.64078183e-01\n",
" 3.52174679e-04 1.17978517e-02 5.28262018e-04 5.28262018e-03\n",
" 5.28262018e-03 2.81739743e-03]\n",
" [7.33944954e-04 8.25688073e-03 7.33944954e-04 1.83486239e-04\n",
" 9.65137615e-01 0.00000000e+00 2.38532110e-03 1.65137615e-03\n",
" 3.66972477e-04 2.05504587e-02]\n",
" [3.35438043e-03 1.77584846e-03 7.89265983e-04 1.18389897e-02\n",
" 1.97316496e-03 9.64285714e-01 9.07655880e-03 9.86582478e-04\n",
" 1.77584846e-03 4.14364641e-03]\n",
" [4.15012631e-03 3.06748466e-03 1.80440274e-04 0.00000000e+00\n",
" 9.02201371e-04 2.52616384e-03 9.88812703e-01 0.00000000e+00\n",
" 3.60880549e-04 0.00000000e+00]\n",
" [5.13171399e-04 1.14608279e-02 2.73691413e-03 0.00000000e+00\n",
" 2.05268560e-03 1.71057133e-04 0.00000000e+00 9.73486144e-01\n",
" 1.71057133e-04 9.40814232e-03]\n",
" [3.63372093e-03 1.25363372e-02 5.08720930e-03 1.54433140e-02\n",
" 4.90552326e-03 1.25363372e-02 3.45203488e-03 2.72529070e-03\n",
" 9.32776163e-01 6.90406977e-03]\n",
" [1.80831826e-03 2.89330922e-03 1.80831826e-03 7.95660036e-03\n",
" 1.10307414e-02 2.71247740e-03 5.42495479e-04 1.13924051e-02\n",
" 1.08499096e-03 9.58770344e-01]]\n"
]
}
],
"source": [
"cm = confusion_matrix(y_train, y_train_pred, normalize='true')\n",
"print(cm)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "a0891a45",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import seaborn as sn"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "c584cc3d",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjEAAAGpCAYAAAB8smdHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAACNGUlEQVR4nOzde1xUdf4/8NdHEO8SUoAJYrjk3cy01RQvgzNchgFF8FJR39KoXxdUvIKrpZvYtpta7dZq5WZeai3vgjdAkBJT0hU1TLxwFXFVVERgmJn37w/kLCO3AebCxPv5eJxHzpxzPud1Pn3O4cPnc2YQRATGGGOMMWvTxtIBGGOMMcaagjsxjDHGGLNK3IlhjDHGmFXiTgxjjDHGrBJ3YhhjjDFmlWwtHaAuQgir+9gUf9KLMcZYMwmzHsyIP2uJyKzZAR6JYYwxxpiVarEjMYwxxhgzLSHMPnhiVNyJYYwxxlopa+/E8HQSY4wxxqwSj8QwxhhjrZS1j8RwJ4Yxxhhrpdq0se4JGetOzxhjjLFWi0diGGOMsVaKp5MYY4wxZpWsvRPD00mMMcYYs0o8EsMYY4y1UtY+EsOdGMYYY6yVsvZODE8nMcYYY8wq8UgMY4wx1krxSEwL4uPjg/PnzyMzMxMLFy6ssb5nz56Ij4/H6dOncfjwYfTo0UNa98EHH+DMmTM4c+YMpkyZ0uwsR44cgY+PD+RyOdatW1djvVqtxuzZsyGXyxEaGoq8vDxp3dq1ayGXy+Hj44OUlJQGy3z++ecRFBSEoKAgjB49Gm+++SYAID4+HiqVCkFBQQgODkZaWlqLybxo0SLIZDIpd0ZGBgDg559/xjPPPCO9//e//73ezMbQ0Hm3RJbKbIo2EhUVhZEjRyIgIECvrDVr1kjt99VXX0VhYWGLyFxQUICwsDD4+/tDqVRiw4YN0vbnz5/H1KlToVKp8MYbb+DevXsWyVhfmdHR0QgMDIRKpUJERARKSkqkdXFxcdJ5zZ0716DszWHOdtxa728NadOmjdEWiyCiFrkAoMYsbdq0oYsXL9ITTzxBbdu2pf/85z/Ur18/vW22bt1KL730EgGg8ePH0zfffEMAyN/fnw4ePEg2NjbUsWNHOn78OHXp0qVRx6+sykoajYa8vb0pJyeHysvLSaVSUWZmJlW3adMmWrJkCRER7d27l2bNmkVERJmZmaRSqai8vJxycnLI29ubNBqNQWUSEb399tu0Y8cOIiK6d+8e6XQ6IiLKyMggHx+fGttbKvPChQtp3759NXIcO3aMwsPD68xpbIbWa0tiqcymaCNERMePH6ezZ8+SUqnUK6u4uFj694YNG6RyLZ25sLCQzp49K2VUKBRSmcHBwfTzzz8TEdH3339Pq1evtkjG+sqsXq8xMTG0du1aIiK6cuUKBQUF0e3bt4mI6MaNGwbVcVOZsx1b2f3NrD9rO3fuTMZazJ2diEw3EiOE6CuEWCiE+OTBslAI0c9Ux3v22Wdx8eJFXLlyBRUVFfjuu+8QFBSkt03//v2RmJgIADh8+LC0vn///jhy5Ai0Wi3u37+P9PR0+Pr6NjlLeno63N3d4ebmBjs7OyiVSiQkJOhtk5iYiEmTJgGoHEFKTU0FESEhIQFKpRJ2dnZwc3ODu7s70tPTDSrz3r17OHbsGCZMmAAA6NSpkzRUWFpaWu+woaUyW5o1ZHyYpTKboo0AwPDhw2Fvb1/jeJ07d5b+3VD7NWdmJycnDBgwQMro4eEhjRJlZWVh+PDhAIBRo0bh4MGDFslYX5lV9UpEKCsrk46xdetWvPDCC9L/C0dHx8ZVdiOZsx231vubIYQQRlsswSSdGCHEQgDfARAAjj9YBIBvhRCLTHHMHj16IDc3V3qdl5enN10EAKdPn0ZwcDAAYNKkSejatSu6deuG06dPw9fXFx06dICjoyPGjx8PNze3JmcpLCyEi4uL9NrZ2bnGUHhhYSG6d+8OALC1tUWXLl1QVFRU576GlBkfH4+RI0fq3fwPHToEX19fvP7664iJiWlRmVevXg2VSoWYmBio1Wrp/f/85z8IDAzEzJkzkZmZWWdmYzDkvFsaS2U2RRtpyOrVqzF27Fjs2bMHs2bNanGZ8/LykJGRgaeeegoA4OnpKf0g279/PwoKCiySsaEyo6KiMGrUKFy+fBlhYWEAKjtgV65cwbRp0zBlyhQcOXKkwezNYc523Frvb4bgTkztZgAYTkQfENGmB8sHAJ59sK5WQohwIUSaEKL+hzeaaN68eRg7dixOnjyJsWPHIi8vD1qtFocOHUJcXByOHj2Kb7/9FqmpqdBqtaaIYFJ79+6FUqnUe08ul2P//v34xz/+gY8//thCyWqKjIzE/v37sW3bNty5c0eaTx4wYAASExOxe/duhIWF4a233rJwUmZJc+bMQXJyMlQqFTZt2mTpOHpKSkoQERGB6Oho6ReHFStWYMuWLQgODkZJSQns7OwsnLJ2K1euREpKCnr37o24uDgAgFarRXZ2NjZu3IiPPvoIS5Yswd27dy2c1Drx/c18TNWJ0QF4vJb3uz9YVysiWkdEw4hoWGMPmJ+frzd64urqivz8fL1tCgoKMHnyZAwdOhSLFy8GANy5cwcAEBMTg6effhoKhQJCCFy4cKGxESTOzs64du2a9LqwsBDOzs41tqn6LU2j0aC4uBgODg517ttQmbdu3cKZM2cwbty4WjMNHz4cubm5uHXrVovI7OTkBCEE7OzsEBwcjDNnzgCoHOru1KkTAGDs2LHQaDR1ZjYGQ867pbFUZlO0EUOpVCqDpmbMlbmiogIRERFQqVRQKBTSNr1798b69euxfft2KJVKg0Z0LXG/AAAbGxsolUqpXp2dnSGTydC2bVu4ubmhV69eyMrKajB/U5mzHbfW+5sheCSmdrMBJAgh9gkh1j1Y9gNIAND4MWEDnDhxAp6enujVqxfatm2LadOmYffu3XrbODo6ShUdFRWF9evXA6h8Ortbt24AgEGDBmHw4MFNumFWGTRoELKyspCbmwu1Wo3Y2FjIZDK9bWQyGXbs2AEAOHDgAEaMGAEhBGQyGWJjY6FWq5Gbm4usrCwMHjy4wTIPHDiAcePGoV27dtJ72dnZVQ9J49y5c1Cr1XBwcGgRma9fvw6gcl4+Pj4enp6eAID//ve/Uub09HTodLo6MxuDIefd0lgqsynaSH2q/wBNSEiAh4dHi8hMRFi8eDE8PDzwyiuv6JV18+ZNAIBOp8Pnn3+OadOmWSRjXWUSEbKzswFUXnuJiYlSvU6YMAHHjx8HUPlLUVZWVrOm1Y1x3uY81u/x/mYIa+/EmPLTRW0AjAAw+cEyAoBNI/Zv9KeD/Pz86LfffqOLFy9SdHQ0AaBly5aRSqUiADR58mS6cOEC/fbbb/TFF1+QnZ0dAaB27drRuXPn6Ny5c5SamkpPPfVUo4+Nap9OIiJKSkoihUJB3t7e9NlnnxER0Zo1ayg+Pp6IiMrKyuidd96hCRMm0OTJkyknJ0fa97PPPiNvb29SKBSUlJRUb5lVXnzxRUpOTtZ7b+3ateTv70+BgYE0ZcoUOnHiBNXHnJnDwsIoICCAlEolzZ07l+7du0dERBs3biR/f39SqVQUGhpKv/zyS72ZjaG+em2pLJXZFG1kzpw5NGrUKOrfvz95eXnR1q1biajyk3ZKpZICAgLo9ddfp2vXrrWIzCdOnKAnn3ySAgICKDAwkAIDA6V1X3/9NSkUClIoFPTXv/5V+nSgJeq1tjK1Wi1NnTpVuvYiIyOlTyvpdDqKiYkhPz8/CggIoL179za+shvJnO3Yiu5vZv10zyOPPELGWsydnYgg6EGvsKURQrTMYPVoqXXJGGPMaph1SKNbt25G+8F169Ytsw/H8Df2MsYYY62Uxb6kzkisOz1jjDHGWi0eiWGMMcZaKWv/20nciWGMMcZaKWvvxPB0EmOMMcasEo/EMMYYY62UtY/EcCeGMcYYa6W4E8MYY4wxq2TtnRh+JoYxxhhjVolHYhhjjLFWytq/7I47MYwxxlgrxdNJjDHGGGMWwCMxjDHGWCtl7SMx3IlhjDHGWinuxJgIkdH
"text/plain": [
"<Figure size 720x504 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"set_digits = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 }\n",
"\n",
"df_cm = pd.DataFrame(cm, index=set_digits, columns=set_digits)\n",
"plt.figure(figsize = (10,7))\n",
"sn_plot = sn.heatmap(df_cm, annot=True, cmap=\"Greys\")\n",
"plt.ylabel(\"True Label\")\n",
"plt.xlabel(\"Predicted Label\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "dc855642",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Pipeline: knn (97.0375%)\n",
" precision recall f1-score support\n",
"\n",
" 0 0.98 0.99 0.99 5499\n",
" 1 0.96 0.99 0.97 6287\n",
" 2 0.98 0.96 0.97 5595\n",
" 3 0.96 0.96 0.96 5679\n",
" 4 0.98 0.97 0.97 5450\n",
" 5 0.96 0.96 0.96 5068\n",
" 6 0.98 0.99 0.98 5542\n",
" 7 0.97 0.97 0.97 5846\n",
" 8 0.99 0.93 0.96 5504\n",
" 9 0.95 0.96 0.96 5530\n",
"\n",
" accuracy 0.97 56000\n",
" macro avg 0.97 0.97 0.97 56000\n",
"weighted avg 0.97 0.97 0.97 56000\n",
"\n"
]
}
],
"source": [
"from sklearn.metrics import classification_report\n",
"\n",
"accuracy = accuracy_score(y_train, y_train_pred, normalize=True)*100\n",
"print(f\"Pipeline: knn ({accuracy:.4f}%)\")\n",
"print(classification_report(y_train, y_train_pred))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b27fdf05",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}