iui-group-l-name-zensiert/0-pilot-project/MNIST-kNN-best-pipeline.ipynb

523 lines
14 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "e4d89124",
"metadata": {},
"source": [
"### Load MNIST dataset"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "5789ec72",
"metadata": {},
"outputs": [],
"source": [
"# Python ≥3.5 is required\n",
"import sys\n",
"assert sys.version_info >= (3, 5)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "f491d383",
"metadata": {},
"outputs": [],
"source": [
"# scikit-learn ≥0.20 is required\n",
"import sklearn\n",
"assert sklearn.__version__ >= \"0.20\""
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "575a6a42",
"metadata": {},
"outputs": [],
"source": [
"# common imports\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "921dc114",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"sklearn.utils.Bunch"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# import function to scikit-learn datasets\n",
"from sklearn.datasets import fetch_openml\n",
"\n",
"# load specified dataset (MNIST)\n",
"mnist = fetch_openml('mnist_784', version=1, as_frame=False)\n",
"\n",
"# print type of dataset\n",
"type(mnist)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "6271045c",
"metadata": {},
"outputs": [],
"source": [
"X, y = mnist[\"data\"], mnist[\"target\"]"
]
},
{
"cell_type": "markdown",
"id": "37777133",
"metadata": {},
"source": [
"### Fix labels"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "30a441d3",
"metadata": {},
"outputs": [],
"source": [
"# import plotting libraries\n",
"import matplotlib as mpl\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "2d9693b1",
"metadata": {},
"outputs": [],
"source": [
"# convert string labels to int\n",
"y = y.astype(np.uint8)"
]
},
{
"cell_type": "markdown",
"id": "182f4b1b",
"metadata": {},
"source": [
"### Prepare data for machine learning"
]
},
{
"cell_type": "markdown",
"id": "77ff6bd1",
"metadata": {},
"source": [
"### Identify Train Set and Test Set"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "f8247d13",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"X_train: 56000, (56000, 784)\n",
"X_test: 14000, (14000, 784)\n",
"y_train: 56000, (56000,)\n",
"y_test: 14000, (14000,)\n"
]
}
],
"source": [
"from sklearn.model_selection import train_test_split\n",
"\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1337)\n",
"\n",
"print(f\"X_train: {len(X_train)}, {X_train.shape}\")\n",
"print(f\"X_test: {len(X_test)}, {X_test.shape}\")\n",
"print(f\"y_train: {len(y_train)}, {y_train.shape}\")\n",
"print(f\"y_test: {len(y_test)}, {y_test.shape}\")"
]
},
{
"cell_type": "markdown",
"id": "c4062436",
"metadata": {},
"source": [
"## Pipeline Declaration"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "67645a47",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.pipeline import Pipeline\n",
"from sklearn.decomposition import PCA\n",
"from sklearn.preprocessing import StandardScaler, MinMaxScaler\n",
"from sklearn.neighbors import KNeighborsClassifier\n",
"from sklearn.model_selection import cross_val_predict\n",
"from sklearn.metrics import classification_report, precision_score\n",
"\n",
"n_neighbors = 3\n",
"n95_components = 0.95\n",
"n99_components = 0.99"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "82b1e834",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"6"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"names = ['knn', \n",
" 'scalar+knn', \n",
" 'standard+pca95+knn', \n",
" 'minmax+pca95+knn', \n",
" 'standard+pca99+knn', \n",
" 'minmax+pca99+knn'\n",
" ]\n",
"\n",
"classifiers = [\n",
" Pipeline([('knn', KNeighborsClassifier(n_neighbors=n_neighbors))]),\n",
" Pipeline([\n",
" ('scaler', StandardScaler()),\n",
" ('knn', KNeighborsClassifier(n_neighbors=n_neighbors))\n",
" ]),\n",
" Pipeline([\n",
" ('standard', StandardScaler()),\n",
" ('pca', PCA(n_components=n95_components)),\n",
" ('knn', KNeighborsClassifier(n_neighbors=n_neighbors))\n",
" ]),\n",
" Pipeline([\n",
" ('minmax', MinMaxScaler()),\n",
" ('pca', PCA(n_components=n95_components)),\n",
" ('knn', KNeighborsClassifier(n_neighbors=n_neighbors))\n",
" ]),\n",
" Pipeline([\n",
" ('standard', StandardScaler()),\n",
" ('pca', PCA(n_components=n99_components)),\n",
" ('knn', KNeighborsClassifier(n_neighbors=n_neighbors))\n",
" ]),\n",
" Pipeline([\n",
" ('minmax', MinMaxScaler()),\n",
" ('pca', PCA(n_components=n99_components)),\n",
" ('knn', KNeighborsClassifier(n_neighbors=n_neighbors))\n",
" ])\n",
"]\n",
"\n",
"len(names)"
]
},
{
"cell_type": "markdown",
"id": "156dbf2c",
"metadata": {},
"source": [
"# Crossvalidation"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "8e5168e4",
"metadata": {},
"outputs": [],
"source": [
"# for name, clf in zip(names, classifiers):\n",
"# y_train_pred = cross_val_predict(clf, X_train, y_train, cv=3)\n",
"# print(f\"Pipeline: {name} ({precision_score(y_train, y_train_pred, average='weighted'):.4f})\")\n",
"# print(classification_report(y_train, y_train_pred))\n",
"precs = []"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "db14a027",
"metadata": {},
"outputs": [],
"source": [
"def cv(num):\n",
" name = names[num]\n",
" clf = classifiers[num]\n",
" y_train_pred = cross_val_predict(clf, X_train, y_train, cv=3)\n",
" precision = precision_score(y_train, y_train_pred, average='weighted')\n",
" print(f\"Pipeline: {name} ({precision:.4f})\")\n",
" print(classification_report(y_train, y_train_pred))\n",
" return precision"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "b8983bf8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Pipeline: knn (0.9691)\n",
" precision recall f1-score support\n",
"\n",
" 0 0.98 0.99 0.99 5499\n",
" 1 0.95 0.99 0.97 6287\n",
" 2 0.98 0.96 0.97 5595\n",
" 3 0.96 0.97 0.96 5679\n",
" 4 0.98 0.96 0.97 5450\n",
" 5 0.96 0.96 0.96 5068\n",
" 6 0.98 0.99 0.98 5542\n",
" 7 0.96 0.97 0.97 5846\n",
" 8 0.99 0.93 0.96 5504\n",
" 9 0.95 0.96 0.96 5530\n",
"\n",
" accuracy 0.97 56000\n",
" macro avg 0.97 0.97 0.97 56000\n",
"weighted avg 0.97 0.97 0.97 56000\n",
"\n"
]
}
],
"source": [
"cv(0)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "62ff42f4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Pipeline: scalar+knn (0.9420)\n",
" precision recall f1-score support\n",
"\n",
" 0 0.95 0.99 0.97 5499\n",
" 1 0.95 0.99 0.97 6287\n",
" 2 0.95 0.93 0.94 5595\n",
" 3 0.92 0.94 0.93 5679\n",
" 4 0.94 0.93 0.94 5450\n",
" 5 0.93 0.92 0.93 5068\n",
" 6 0.96 0.97 0.97 5542\n",
" 7 0.94 0.93 0.94 5846\n",
" 8 0.97 0.89 0.93 5504\n",
" 9 0.91 0.92 0.91 5530\n",
"\n",
" accuracy 0.94 56000\n",
" macro avg 0.94 0.94 0.94 56000\n",
"weighted avg 0.94 0.94 0.94 56000\n",
"\n"
]
}
],
"source": [
"cv(1)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "9b553141",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Pipeline: standard+pca95+knn (0.9457)\n",
" precision recall f1-score support\n",
"\n",
" 0 0.96 0.99 0.97 5499\n",
" 1 0.95 0.99 0.97 6287\n",
" 2 0.95 0.94 0.95 5595\n",
" 3 0.93 0.94 0.94 5679\n",
" 4 0.95 0.93 0.94 5450\n",
" 5 0.93 0.92 0.93 5068\n",
" 6 0.96 0.97 0.97 5542\n",
" 7 0.95 0.94 0.94 5846\n",
" 8 0.97 0.89 0.93 5504\n",
" 9 0.92 0.92 0.92 5530\n",
"\n",
" accuracy 0.95 56000\n",
" macro avg 0.95 0.94 0.94 56000\n",
"weighted avg 0.95 0.95 0.95 56000\n",
"\n"
]
}
],
"source": [
"cv(2)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "2a1d6c65",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Pipeline: minmax+pca95+knn (0.9706)\n",
" precision recall f1-score support\n",
"\n",
" 0 0.98 0.99 0.99 5499\n",
" 1 0.96 0.99 0.98 6287\n",
" 2 0.98 0.97 0.97 5595\n",
" 3 0.96 0.96 0.96 5679\n",
" 4 0.98 0.97 0.97 5450\n",
" 5 0.96 0.96 0.96 5068\n",
" 6 0.98 0.99 0.98 5542\n",
" 7 0.97 0.98 0.97 5846\n",
" 8 0.99 0.93 0.96 5504\n",
" 9 0.95 0.96 0.96 5530\n",
"\n",
" accuracy 0.97 56000\n",
" macro avg 0.97 0.97 0.97 56000\n",
"weighted avg 0.97 0.97 0.97 56000\n",
"\n"
]
}
],
"source": [
"cv(3)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "de40f817",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Pipeline: standard+pca99+knn (0.9424)\n",
" precision recall f1-score support\n",
"\n",
" 0 0.95 0.99 0.97 5499\n",
" 1 0.95 0.99 0.97 6287\n",
" 2 0.95 0.93 0.94 5595\n",
" 3 0.92 0.94 0.93 5679\n",
" 4 0.94 0.93 0.94 5450\n",
" 5 0.93 0.92 0.93 5068\n",
" 6 0.96 0.97 0.97 5542\n",
" 7 0.94 0.93 0.94 5846\n",
" 8 0.97 0.89 0.93 5504\n",
" 9 0.91 0.92 0.92 5530\n",
"\n",
" accuracy 0.94 56000\n",
" macro avg 0.94 0.94 0.94 56000\n",
"weighted avg 0.94 0.94 0.94 56000\n",
"\n"
]
}
],
"source": [
"cv(4)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "2c1c26b0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Pipeline: minmax+pca99+knn (0.9695)\n",
" precision recall f1-score support\n",
"\n",
" 0 0.98 0.99 0.99 5499\n",
" 1 0.95 0.99 0.97 6287\n",
" 2 0.98 0.96 0.97 5595\n",
" 3 0.96 0.97 0.96 5679\n",
" 4 0.98 0.96 0.97 5450\n",
" 5 0.96 0.96 0.96 5068\n",
" 6 0.98 0.99 0.98 5542\n",
" 7 0.96 0.97 0.97 5846\n",
" 8 0.99 0.93 0.96 5504\n",
" 9 0.95 0.96 0.96 5530\n",
"\n",
" accuracy 0.97 56000\n",
" macro avg 0.97 0.97 0.97 56000\n",
"weighted avg 0.97 0.97 0.97 56000\n",
"\n"
]
}
],
"source": [
"cv(5)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cb6b6d69",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}