Fixed crossvalidatoin/classification in pipeline and prepared for others to work on

master
Tuan-Dat Tran 2021-05-17 16:31:29 +00:00
parent 87d51ce0f4
commit bc73108d74
5 changed files with 2006 additions and 1247 deletions

File diff suppressed because one or more lines are too long

View File

@ -1,522 +0,0 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "e4d89124",
"metadata": {},
"source": [
"### Load MNIST dataset"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "5789ec72",
"metadata": {},
"outputs": [],
"source": [
"# Python ≥3.5 is required\n",
"import sys\n",
"assert sys.version_info >= (3, 5)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "f491d383",
"metadata": {},
"outputs": [],
"source": [
"# scikit-learn ≥0.20 is required\n",
"import sklearn\n",
"assert sklearn.__version__ >= \"0.20\""
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "575a6a42",
"metadata": {},
"outputs": [],
"source": [
"# common imports\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "921dc114",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"sklearn.utils.Bunch"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# import function to scikit-learn datasets\n",
"from sklearn.datasets import fetch_openml\n",
"\n",
"# load specified dataset (MNIST)\n",
"mnist = fetch_openml('mnist_784', version=1, as_frame=False)\n",
"\n",
"# print type of dataset\n",
"type(mnist)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "6271045c",
"metadata": {},
"outputs": [],
"source": [
"X, y = mnist[\"data\"], mnist[\"target\"]"
]
},
{
"cell_type": "markdown",
"id": "37777133",
"metadata": {},
"source": [
"### Fix labels"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "30a441d3",
"metadata": {},
"outputs": [],
"source": [
"# import plotting libraries\n",
"import matplotlib as mpl\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "2d9693b1",
"metadata": {},
"outputs": [],
"source": [
"# convert string labels to int\n",
"y = y.astype(np.uint8)"
]
},
{
"cell_type": "markdown",
"id": "182f4b1b",
"metadata": {},
"source": [
"### Prepare data for machine learning"
]
},
{
"cell_type": "markdown",
"id": "77ff6bd1",
"metadata": {},
"source": [
"### Identify Train Set and Test Set"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "f8247d13",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"X_train: 56000, (56000, 784)\n",
"X_test: 14000, (14000, 784)\n",
"y_train: 56000, (56000,)\n",
"y_test: 14000, (14000,)\n"
]
}
],
"source": [
"from sklearn.model_selection import train_test_split\n",
"\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1337)\n",
"\n",
"print(f\"X_train: {len(X_train)}, {X_train.shape}\")\n",
"print(f\"X_test: {len(X_test)}, {X_test.shape}\")\n",
"print(f\"y_train: {len(y_train)}, {y_train.shape}\")\n",
"print(f\"y_test: {len(y_test)}, {y_test.shape}\")"
]
},
{
"cell_type": "markdown",
"id": "c4062436",
"metadata": {},
"source": [
"## Pipeline Declaration"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "67645a47",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.pipeline import Pipeline\n",
"from sklearn.decomposition import PCA\n",
"from sklearn.preprocessing import StandardScaler, MinMaxScaler\n",
"from sklearn.neighbors import KNeighborsClassifier\n",
"from sklearn.model_selection import cross_val_predict\n",
"from sklearn.metrics import classification_report, precision_score\n",
"\n",
"n_neighbors = 3\n",
"n95_components = 0.95\n",
"n99_components = 0.99"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "82b1e834",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"6"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"names = ['knn', \n",
" 'scalar+knn', \n",
" 'standard+pca95+knn', \n",
" 'minmax+pca95+knn', \n",
" 'standard+pca99+knn', \n",
" 'minmax+pca99+knn'\n",
" ]\n",
"\n",
"classifiers = [\n",
" Pipeline([('knn', KNeighborsClassifier(n_neighbors=n_neighbors))]),\n",
" Pipeline([\n",
" ('scaler', StandardScaler()),\n",
" ('knn', KNeighborsClassifier(n_neighbors=n_neighbors))\n",
" ]),\n",
" Pipeline([\n",
" ('standard', StandardScaler()),\n",
" ('pca', PCA(n_components=n95_components)),\n",
" ('knn', KNeighborsClassifier(n_neighbors=n_neighbors))\n",
" ]),\n",
" Pipeline([\n",
" ('minmax', MinMaxScaler()),\n",
" ('pca', PCA(n_components=n95_components)),\n",
" ('knn', KNeighborsClassifier(n_neighbors=n_neighbors))\n",
" ]),\n",
" Pipeline([\n",
" ('standard', StandardScaler()),\n",
" ('pca', PCA(n_components=n99_components)),\n",
" ('knn', KNeighborsClassifier(n_neighbors=n_neighbors))\n",
" ]),\n",
" Pipeline([\n",
" ('minmax', MinMaxScaler()),\n",
" ('pca', PCA(n_components=n99_components)),\n",
" ('knn', KNeighborsClassifier(n_neighbors=n_neighbors))\n",
" ])\n",
"]\n",
"\n",
"len(names)"
]
},
{
"cell_type": "markdown",
"id": "156dbf2c",
"metadata": {},
"source": [
"# Crossvalidation"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "8e5168e4",
"metadata": {},
"outputs": [],
"source": [
"# for name, clf in zip(names, classifiers):\n",
"# y_train_pred = cross_val_predict(clf, X_train, y_train, cv=3)\n",
"# print(f\"Pipeline: {name} ({precision_score(y_train, y_train_pred, average='weighted'):.4f})\")\n",
"# print(classification_report(y_train, y_train_pred))\n",
"precs = []"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "db14a027",
"metadata": {},
"outputs": [],
"source": [
"def cv(num):\n",
" name = names[num]\n",
" clf = classifiers[num]\n",
" y_train_pred = cross_val_predict(clf, X_train, y_train, cv=3)\n",
" precision = precision_score(y_train, y_train_pred, average='weighted')\n",
" print(f\"Pipeline: {name} ({precision:.4f})\")\n",
" print(classification_report(y_train, y_train_pred))\n",
" return precision"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "b8983bf8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Pipeline: knn (0.9691)\n",
" precision recall f1-score support\n",
"\n",
" 0 0.98 0.99 0.99 5499\n",
" 1 0.95 0.99 0.97 6287\n",
" 2 0.98 0.96 0.97 5595\n",
" 3 0.96 0.97 0.96 5679\n",
" 4 0.98 0.96 0.97 5450\n",
" 5 0.96 0.96 0.96 5068\n",
" 6 0.98 0.99 0.98 5542\n",
" 7 0.96 0.97 0.97 5846\n",
" 8 0.99 0.93 0.96 5504\n",
" 9 0.95 0.96 0.96 5530\n",
"\n",
" accuracy 0.97 56000\n",
" macro avg 0.97 0.97 0.97 56000\n",
"weighted avg 0.97 0.97 0.97 56000\n",
"\n"
]
}
],
"source": [
"cv(0)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "62ff42f4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Pipeline: scalar+knn (0.9420)\n",
" precision recall f1-score support\n",
"\n",
" 0 0.95 0.99 0.97 5499\n",
" 1 0.95 0.99 0.97 6287\n",
" 2 0.95 0.93 0.94 5595\n",
" 3 0.92 0.94 0.93 5679\n",
" 4 0.94 0.93 0.94 5450\n",
" 5 0.93 0.92 0.93 5068\n",
" 6 0.96 0.97 0.97 5542\n",
" 7 0.94 0.93 0.94 5846\n",
" 8 0.97 0.89 0.93 5504\n",
" 9 0.91 0.92 0.91 5530\n",
"\n",
" accuracy 0.94 56000\n",
" macro avg 0.94 0.94 0.94 56000\n",
"weighted avg 0.94 0.94 0.94 56000\n",
"\n"
]
}
],
"source": [
"cv(1)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "9b553141",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Pipeline: standard+pca95+knn (0.9457)\n",
" precision recall f1-score support\n",
"\n",
" 0 0.96 0.99 0.97 5499\n",
" 1 0.95 0.99 0.97 6287\n",
" 2 0.95 0.94 0.95 5595\n",
" 3 0.93 0.94 0.94 5679\n",
" 4 0.95 0.93 0.94 5450\n",
" 5 0.93 0.92 0.93 5068\n",
" 6 0.96 0.97 0.97 5542\n",
" 7 0.95 0.94 0.94 5846\n",
" 8 0.97 0.89 0.93 5504\n",
" 9 0.92 0.92 0.92 5530\n",
"\n",
" accuracy 0.95 56000\n",
" macro avg 0.95 0.94 0.94 56000\n",
"weighted avg 0.95 0.95 0.95 56000\n",
"\n"
]
}
],
"source": [
"cv(2)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "2a1d6c65",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Pipeline: minmax+pca95+knn (0.9706)\n",
" precision recall f1-score support\n",
"\n",
" 0 0.98 0.99 0.99 5499\n",
" 1 0.96 0.99 0.98 6287\n",
" 2 0.98 0.97 0.97 5595\n",
" 3 0.96 0.96 0.96 5679\n",
" 4 0.98 0.97 0.97 5450\n",
" 5 0.96 0.96 0.96 5068\n",
" 6 0.98 0.99 0.98 5542\n",
" 7 0.97 0.98 0.97 5846\n",
" 8 0.99 0.93 0.96 5504\n",
" 9 0.95 0.96 0.96 5530\n",
"\n",
" accuracy 0.97 56000\n",
" macro avg 0.97 0.97 0.97 56000\n",
"weighted avg 0.97 0.97 0.97 56000\n",
"\n"
]
}
],
"source": [
"cv(3)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "de40f817",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Pipeline: standard+pca99+knn (0.9424)\n",
" precision recall f1-score support\n",
"\n",
" 0 0.95 0.99 0.97 5499\n",
" 1 0.95 0.99 0.97 6287\n",
" 2 0.95 0.93 0.94 5595\n",
" 3 0.92 0.94 0.93 5679\n",
" 4 0.94 0.93 0.94 5450\n",
" 5 0.93 0.92 0.93 5068\n",
" 6 0.96 0.97 0.97 5542\n",
" 7 0.94 0.93 0.94 5846\n",
" 8 0.97 0.89 0.93 5504\n",
" 9 0.91 0.92 0.92 5530\n",
"\n",
" accuracy 0.94 56000\n",
" macro avg 0.94 0.94 0.94 56000\n",
"weighted avg 0.94 0.94 0.94 56000\n",
"\n"
]
}
],
"source": [
"cv(4)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "2c1c26b0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Pipeline: minmax+pca99+knn (0.9695)\n",
" precision recall f1-score support\n",
"\n",
" 0 0.98 0.99 0.99 5499\n",
" 1 0.95 0.99 0.97 6287\n",
" 2 0.98 0.96 0.97 5595\n",
" 3 0.96 0.97 0.96 5679\n",
" 4 0.98 0.96 0.97 5450\n",
" 5 0.96 0.96 0.96 5068\n",
" 6 0.98 0.99 0.98 5542\n",
" 7 0.96 0.97 0.97 5846\n",
" 8 0.99 0.93 0.96 5504\n",
" 9 0.95 0.96 0.96 5530\n",
"\n",
" accuracy 0.97 56000\n",
" macro avg 0.97 0.97 0.97 56000\n",
"weighted avg 0.97 0.97 0.97 56000\n",
"\n"
]
}
],
"source": [
"cv(5)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cb6b6d69",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -0,0 +1,837 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "03d83636",
"metadata": {},
"source": [
"### Load MNIST dataset"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "242d14f0",
"metadata": {},
"outputs": [],
"source": [
"# Python ≥3.5 is required\n",
"import sys\n",
"assert sys.version_info >= (3, 5)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "cfd3a54a",
"metadata": {},
"outputs": [],
"source": [
"# scikit-learn ≥0.20 is required\n",
"import sklearn\n",
"assert sklearn.__version__ >= \"0.20\""
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "7e1587b3",
"metadata": {},
"outputs": [],
"source": [
"# common imports\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "bbccfc32",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"sklearn.utils.Bunch"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# import function to scikit-learn datasets\n",
"from sklearn.datasets import fetch_openml\n",
"\n",
"# load specified dataset (MNIST)\n",
"mnist = fetch_openml('mnist_784', version=1, as_frame=False)\n",
"\n",
"# print type of dataset\n",
"type(mnist)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "6db7c96a",
"metadata": {},
"outputs": [],
"source": [
"X, y = mnist[\"data\"], mnist[\"target\"]"
]
},
{
"cell_type": "markdown",
"id": "459780d0",
"metadata": {},
"source": [
"### Fix labels"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "48c4e861",
"metadata": {},
"outputs": [],
"source": [
"# import plotting libraries\n",
"import matplotlib as mpl\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "72876ab1",
"metadata": {},
"outputs": [],
"source": [
"# convert string labels to int\n",
"y = y.astype(np.uint8)"
]
},
{
"cell_type": "markdown",
"id": "c9dacae4",
"metadata": {},
"source": [
"### Prepare data for machine learning"
]
},
{
"cell_type": "markdown",
"id": "b44b3f87",
"metadata": {},
"source": [
"### Identify Train Set and Test Set"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "51c5da44",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"X_train: 56000, (56000, 784)\n",
"X_test: 14000, (14000, 784)\n",
"y_train: 56000, (56000,)\n",
"y_test: 14000, (14000,)\n"
]
}
],
"source": [
"from sklearn.model_selection import train_test_split\n",
"\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1337)\n",
"\n",
"print(f\"X_train: {len(X_train)}, {X_train.shape}\")\n",
"print(f\"X_test: {len(X_test)}, {X_test.shape}\")\n",
"print(f\"y_train: {len(y_train)}, {y_train.shape}\")\n",
"print(f\"y_test: {len(y_test)}, {y_test.shape}\")"
]
},
{
"cell_type": "markdown",
"id": "673e237d",
"metadata": {},
"source": [
"## Pipeline Declaration"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "8ca34ce2",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(11, 11)"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.pipeline import Pipeline\n",
"from sklearn.decomposition import PCA\n",
"from sklearn.preprocessing import (StandardScaler, \n",
" MinMaxScaler, \n",
" MaxAbsScaler, \n",
" PowerTransformer)\n",
"from sklearn.neighbors import KNeighborsClassifier\n",
"from sklearn.model_selection import cross_validate\n",
"from sklearn.metrics import classification_report, accuracy_score\n",
"\n",
"n_neighbors = 3\n",
"n95_components = 0.95\n",
"n99_components = 0.99\n",
"\n",
"names = ['knn (baseline)', \n",
" 'scalar+knn', \n",
" 'minmax+knn', \n",
" 'standard+pca95+knn', \n",
" 'minmax+pca95+knn', # Best so far w/ 97.0429%\n",
" 'standard+pca99+knn', \n",
" 'minmax+pca99+knn',\n",
" 'maxabs+pca95+knn', \n",
" 'maxabs+pca99+knn', # Best so far w/ 97.0429%\n",
" 'power+pca95+knn',\n",
" 'power+pca99+knn',\n",
" ]\n",
"\n",
"classifiers = [\n",
" Pipeline([('knn', KNeighborsClassifier(n_neighbors=n_neighbors, n_jobs=-1))]),\n",
" Pipeline([\n",
" ('standard', StandardScaler()),\n",
" ('knn', KNeighborsClassifier(n_neighbors=n_neighbors, n_jobs=-1))\n",
" ]),\n",
" Pipeline([\n",
" ('minmax', MinMaxScaler()),\n",
" ('knn', KNeighborsClassifier(n_neighbors=n_neighbors, n_jobs=-1))\n",
" ]),\n",
" Pipeline([\n",
" ('standard', StandardScaler()),\n",
" ('pca', PCA(n_components=n95_components)),\n",
" ('knn', KNeighborsClassifier(n_neighbors=n_neighbors, n_jobs=-1))\n",
" ]),\n",
" Pipeline([\n",
" ('minmax', MinMaxScaler()),\n",
" ('pca', PCA(n_components=n95_components)),\n",
" ('knn', KNeighborsClassifier(n_neighbors=n_neighbors, n_jobs=-1))\n",
" ]),\n",
" Pipeline([\n",
" ('standard', StandardScaler()),\n",
" ('pca', PCA(n_components=n99_components)),\n",
" ('knn', KNeighborsClassifier(n_neighbors=n_neighbors, n_jobs=-1))\n",
" ]),\n",
" Pipeline([\n",
" ('minmax', MinMaxScaler()),\n",
" ('pca', PCA(n_components=n99_components)),\n",
" ('knn', KNeighborsClassifier(n_neighbors=n_neighbors, n_jobs=-1))\n",
" ]),\n",
" Pipeline([\n",
" ('maxabs', MaxAbsScaler()),\n",
" ('pca', PCA(n_components=n99_components)),\n",
" ('knn', KNeighborsClassifier(n_neighbors=n_neighbors, n_jobs=-1))\n",
" ]),\n",
" Pipeline([\n",
" ('maxabs', MaxAbsScaler()),\n",
" ('pca', PCA(n_components=n95_components)),\n",
" ('knn', KNeighborsClassifier(n_neighbors=n_neighbors, n_jobs=-1))\n",
" ]),\n",
" Pipeline([\n",
" ('power', PowerTransformer()),\n",
" ('pca', PCA(n_components=n99_components)),\n",
" ('knn', KNeighborsClassifier(n_neighbors=n_neighbors, n_jobs=-1))\n",
" ]),\n",
" Pipeline([\n",
" ('power', PowerTransformer()),\n",
" ('pca', PCA(n_components=n95_components)),\n",
" ('knn', KNeighborsClassifier(n_neighbors=n_neighbors, n_jobs=-1))\n",
" ]),\n",
"]\n",
"\n",
"len(names), len(classifiers)"
]
},
{
"cell_type": "markdown",
"id": "f38b2bb2",
"metadata": {},
"source": [
"# Crossvalidation"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "3465f546",
"metadata": {},
"outputs": [],
"source": [
"accuracies = []"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "1f087f92",
"metadata": {},
"outputs": [],
"source": [
"def cv_train(num,cv):\n",
" name = names[num]\n",
" clf = classifiers[num]\n",
" y_train_pred = cross_val_predict(clf, X_train, y_train, cv=cv, n_jobs=-1)\n",
" accuracy = accuracy_score(y_train, y_train_pred, normalize=True)*100\n",
" print(f\"Pipeline: {name} ({accuracy:.4f}%)\")\n",
" print(classification_report(y_train, y_train_pred))\n",
" return accuracy"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "9cb9b3e7",
"metadata": {},
"outputs": [],
"source": [
"def cv_test(num):\n",
" name = names[num]\n",
" clf = classifiers[num]\n",
" y_test_pred = cross_val_predict(clf, X_test, y_test, cv=5, n_jobs=-1)\n",
" accuracy = accuracy_score(y_test, y_test_pred, normalize=True)*100\n",
" print(f\"Pipeline: {name} ({accuracy:.4f}%)\")\n",
" print(classification_report(y_test, y_test_pred))\n",
" return accuracy"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "462e340f",
"metadata": {},
"outputs": [],
"source": [
"def cv(num,cv_arg=10):\n",
" name = names[num]\n",
" clf = classifiers[num]\n",
" clf = clf.fit(X_train, y_train)\n",
" cv = cross_validate(clf, X_train, y_train, cv=cv_arg, n_jobs=-1, return_estimator=True) \n",
" cv_clf = cv['estimator'][np.argmax(cv['test_score'])]\n",
" y_test_pred = cv_clf.predict(X_test)\n",
" accuracy = accuracy_score(y_test, y_test_pred, normalize=True)*100\n",
" print(f\"Pipeline: {name} ({accuracy:.4f}%)\")\n",
" print(classification_report(y_test, y_test_pred))\n",
" return accuracy"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "cbca3b1f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Pipeline: knn (baseline) (96.9857%)\n",
" precision recall f1-score support\n",
"\n",
" 0 0.96 0.99 0.98 1404\n",
" 1 0.97 0.99 0.98 1590\n",
" 2 0.98 0.97 0.97 1395\n",
" 3 0.98 0.96 0.97 1462\n",
" 4 0.98 0.97 0.97 1374\n",
" 5 0.96 0.96 0.96 1245\n",
" 6 0.98 0.98 0.98 1334\n",
" 7 0.97 0.97 0.97 1447\n",
" 8 0.99 0.92 0.95 1321\n",
" 9 0.95 0.96 0.96 1428\n",
"\n",
" accuracy 0.97 14000\n",
" macro avg 0.97 0.97 0.97 14000\n",
"weighted avg 0.97 0.97 0.97 14000\n",
"\n"
]
}
],
"source": [
"from sklearn.model_selection import LeaveOneOut\n",
"# accuracies.append(cv(0,5))\n",
"accuracies.append(cv(0,10))"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "ad92d1f0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Pipeline: scalar+knn (94.2143%)\n",
" precision recall f1-score support\n",
"\n",
" 0 0.95 0.98 0.96 1404\n",
" 1 0.96 0.99 0.97 1590\n",
" 2 0.93 0.95 0.94 1395\n",
" 3 0.94 0.94 0.94 1462\n",
" 4 0.95 0.93 0.94 1374\n",
" 5 0.93 0.92 0.93 1245\n",
" 6 0.96 0.97 0.96 1334\n",
" 7 0.93 0.94 0.94 1447\n",
" 8 0.97 0.88 0.92 1321\n",
" 9 0.91 0.91 0.91 1428\n",
"\n",
" accuracy 0.94 14000\n",
" macro avg 0.94 0.94 0.94 14000\n",
"weighted avg 0.94 0.94 0.94 14000\n",
"\n"
]
}
],
"source": [
"# accuracies.append(cv(1,5))\n",
"accuracies.append(cv(1,10))"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "881d8a07",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Pipeline: minmax+knn (96.9857%)\n",
" precision recall f1-score support\n",
"\n",
" 0 0.96 0.99 0.98 1404\n",
" 1 0.97 0.99 0.98 1590\n",
" 2 0.98 0.97 0.97 1395\n",
" 3 0.98 0.96 0.97 1462\n",
" 4 0.98 0.97 0.97 1374\n",
" 5 0.96 0.96 0.96 1245\n",
" 6 0.98 0.98 0.98 1334\n",
" 7 0.97 0.97 0.97 1447\n",
" 8 0.99 0.92 0.95 1321\n",
" 9 0.95 0.96 0.96 1428\n",
"\n",
" accuracy 0.97 14000\n",
" macro avg 0.97 0.97 0.97 14000\n",
"weighted avg 0.97 0.97 0.97 14000\n",
"\n"
]
}
],
"source": [
"# accuracies.append(cv(2,5))\n",
"accuracies.append(cv(2,10))"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "1402e10b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Pipeline: standard+pca95+knn (94.7000%)\n",
" precision recall f1-score support\n",
"\n",
" 0 0.95 0.98 0.97 1404\n",
" 1 0.96 0.99 0.98 1590\n",
" 2 0.94 0.95 0.94 1395\n",
" 3 0.94 0.94 0.94 1462\n",
" 4 0.96 0.94 0.95 1374\n",
" 5 0.93 0.93 0.93 1245\n",
" 6 0.96 0.97 0.97 1334\n",
" 7 0.94 0.95 0.94 1447\n",
" 8 0.97 0.89 0.93 1321\n",
" 9 0.92 0.92 0.92 1428\n",
"\n",
" accuracy 0.95 14000\n",
" macro avg 0.95 0.95 0.95 14000\n",
"weighted avg 0.95 0.95 0.95 14000\n",
"\n"
]
}
],
"source": [
"# accuracies.append(cv(3,5))\n",
"accuracies.append(cv(3,10))"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "24035514",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Pipeline: minmax+pca95+knn (97.2571%)\n",
" precision recall f1-score support\n",
"\n",
" 0 0.97 0.99 0.98 1404\n",
" 1 0.98 0.99 0.99 1590\n",
" 2 0.98 0.97 0.97 1395\n",
" 3 0.98 0.97 0.97 1462\n",
" 4 0.98 0.97 0.97 1374\n",
" 5 0.97 0.96 0.97 1245\n",
" 6 0.98 0.99 0.98 1334\n",
" 7 0.97 0.98 0.97 1447\n",
" 8 0.99 0.93 0.96 1321\n",
" 9 0.95 0.97 0.96 1428\n",
"\n",
" accuracy 0.97 14000\n",
" macro avg 0.97 0.97 0.97 14000\n",
"weighted avg 0.97 0.97 0.97 14000\n",
"\n"
]
}
],
"source": [
"# accuracies.append(cv(4,5))\n",
"accuracies.append(cv(4,10))"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "1c27528e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Pipeline: standard+pca99+knn (94.2929%)\n",
" precision recall f1-score support\n",
"\n",
" 0 0.95 0.98 0.96 1404\n",
" 1 0.96 0.99 0.97 1590\n",
" 2 0.94 0.95 0.94 1395\n",
" 3 0.94 0.94 0.94 1462\n",
" 4 0.96 0.93 0.94 1374\n",
" 5 0.93 0.92 0.92 1245\n",
" 6 0.96 0.97 0.96 1334\n",
" 7 0.93 0.94 0.94 1447\n",
" 8 0.97 0.88 0.92 1321\n",
" 9 0.91 0.91 0.91 1428\n",
"\n",
" accuracy 0.94 14000\n",
" macro avg 0.94 0.94 0.94 14000\n",
"weighted avg 0.94 0.94 0.94 14000\n",
"\n"
]
}
],
"source": [
"# accuracies.append(cv(5,5))\n",
"accuracies.append(cv(5,10))"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "46bcb35f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Pipeline: minmax+pca99+knn (96.9929%)\n",
" precision recall f1-score support\n",
"\n",
" 0 0.96 0.99 0.98 1404\n",
" 1 0.97 0.99 0.98 1590\n",
" 2 0.97 0.97 0.97 1395\n",
" 3 0.97 0.96 0.97 1462\n",
" 4 0.98 0.97 0.97 1374\n",
" 5 0.97 0.96 0.96 1245\n",
" 6 0.98 0.98 0.98 1334\n",
" 7 0.97 0.97 0.97 1447\n",
" 8 0.99 0.93 0.96 1321\n",
" 9 0.95 0.96 0.96 1428\n",
"\n",
" accuracy 0.97 14000\n",
" macro avg 0.97 0.97 0.97 14000\n",
"weighted avg 0.97 0.97 0.97 14000\n",
"\n"
]
}
],
"source": [
"# accuracies.append(cv(6,5))\n",
"accuracies.append(cv(6,10))"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "45d8092d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Pipeline: maxabs+pca95+knn (96.9929%)\n",
" precision recall f1-score support\n",
"\n",
" 0 0.96 0.99 0.98 1404\n",
" 1 0.97 0.99 0.98 1590\n",
" 2 0.97 0.97 0.97 1395\n",
" 3 0.97 0.96 0.97 1462\n",
" 4 0.98 0.97 0.97 1374\n",
" 5 0.97 0.96 0.96 1245\n",
" 6 0.98 0.98 0.98 1334\n",
" 7 0.97 0.97 0.97 1447\n",
" 8 0.99 0.93 0.96 1321\n",
" 9 0.95 0.96 0.96 1428\n",
"\n",
" accuracy 0.97 14000\n",
" macro avg 0.97 0.97 0.97 14000\n",
"weighted avg 0.97 0.97 0.97 14000\n",
"\n"
]
}
],
"source": [
"# accuracies.append(cv(7,5))\n",
"accuracies.append(cv(7,10))"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "a805b3fd",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Pipeline: maxabs+pca99+knn (97.2571%)\n",
" precision recall f1-score support\n",
"\n",
" 0 0.97 0.99 0.98 1404\n",
" 1 0.98 0.99 0.99 1590\n",
" 2 0.98 0.97 0.97 1395\n",
" 3 0.98 0.97 0.97 1462\n",
" 4 0.98 0.97 0.97 1374\n",
" 5 0.97 0.96 0.97 1245\n",
" 6 0.98 0.99 0.98 1334\n",
" 7 0.97 0.98 0.97 1447\n",
" 8 0.99 0.93 0.96 1321\n",
" 9 0.95 0.97 0.96 1428\n",
"\n",
" accuracy 0.97 14000\n",
" macro avg 0.97 0.97 0.97 14000\n",
"weighted avg 0.97 0.97 0.97 14000\n",
"\n"
]
}
],
"source": [
"# accuracies.append(cv(8,5))\n",
"accuracies.append(cv(8,10))"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "3af8abf8",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/jupyterhub/lib/python3.8/site-packages/sklearn/preprocessing/_data.py:3237: RuntimeWarning: divide by zero encountered in log\n",
" loglike = -n_samples / 2 * np.log(x_trans.var())\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Pipeline: power+pca95+knn (94.1071%)\n",
" precision recall f1-score support\n",
"\n",
" 0 0.95 0.98 0.97 1404\n",
" 1 0.96 0.99 0.97 1590\n",
" 2 0.94 0.95 0.94 1395\n",
" 3 0.94 0.93 0.94 1462\n",
" 4 0.95 0.93 0.94 1374\n",
" 5 0.93 0.91 0.92 1245\n",
" 6 0.95 0.97 0.96 1334\n",
" 7 0.94 0.94 0.94 1447\n",
" 8 0.95 0.89 0.92 1321\n",
" 9 0.90 0.92 0.91 1428\n",
"\n",
" accuracy 0.94 14000\n",
" macro avg 0.94 0.94 0.94 14000\n",
"weighted avg 0.94 0.94 0.94 14000\n",
"\n"
]
}
],
"source": [
"# accuracies.append(cv(9,5))\n",
"accuracies.append(cv(9,10)) # likes to die"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "d971b4df",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/jupyterhub/lib/python3.8/site-packages/sklearn/preprocessing/_data.py:3237: RuntimeWarning: divide by zero encountered in log\n",
" loglike = -n_samples / 2 * np.log(x_trans.var())\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Pipeline: power+pca99+knn (94.4429%)\n",
" precision recall f1-score support\n",
"\n",
" 0 0.95 0.98 0.97 1404\n",
" 1 0.96 0.99 0.98 1590\n",
" 2 0.94 0.95 0.94 1395\n",
" 3 0.94 0.93 0.94 1462\n",
" 4 0.95 0.93 0.94 1374\n",
" 5 0.94 0.91 0.93 1245\n",
" 6 0.95 0.97 0.96 1334\n",
" 7 0.94 0.94 0.94 1447\n",
" 8 0.95 0.90 0.93 1321\n",
" 9 0.91 0.92 0.92 1428\n",
"\n",
" accuracy 0.94 14000\n",
" macro avg 0.94 0.94 0.94 14000\n",
"weighted avg 0.94 0.94 0.94 14000\n",
"\n"
]
}
],
"source": [
"# accuracies.append(cv(10,5))\n",
"accuracies.append(cv(10,10)) # likes to die"
]
},
{
"cell_type": "markdown",
"id": "281e0f59",
"metadata": {},
"source": [
"# Auswertung"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "e3eeabc7",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Maximum accuracy (97.2571%) at: ['minmax+pca95+knn', 'maxabs+pca99+knn']\n"
]
}
],
"source": [
"print(f\"Maximum accuracy ({max(accuracies):.6}%) at: {[names[n] for n in np.where(accuracies==max(accuracies))[0]]}\")"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "7754b1e8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"name: knn (baseline) (96.9857%)\n",
"name: scalar+knn (94.2143%)\n",
"name: minmax+knn (96.9857%)\n",
"name: standard+pca95+knn (94.7000%)\n",
"name: minmax+pca95+knn (97.2571%)\n",
"name: standard+pca99+knn (94.2929%)\n",
"name: minmax+pca99+knn (96.9929%)\n",
"name: maxabs+pca95+knn (96.9929%)\n",
"name: maxabs+pca99+knn (97.2571%)\n",
"name: power+pca95+knn (94.1071%)\n",
"name: power+pca99+knn (94.4429%)\n"
]
}
],
"source": [
"for n, a in zip(names, accuracies):\n",
" print(f\"name: {n:20} ({a:.4f}%)\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "22316563",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -0,0 +1,410 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "804dacb6",
"metadata": {},
"source": [
"### Load MNIST dataset"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "7d09885b",
"metadata": {},
"outputs": [],
"source": [
"# Python ≥3.5 is required\n",
"import sys\n",
"assert sys.version_info >= (3, 5)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "bf4121a0",
"metadata": {},
"outputs": [],
"source": [
"# scikit-learn ≥0.20 is required\n",
"import sklearn\n",
"assert sklearn.__version__ >= \"0.20\""
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "71d91fd8",
"metadata": {},
"outputs": [],
"source": [
"# common imports\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "1dc68441",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"sklearn.utils.Bunch"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# import function to scikit-learn datasets\n",
"from sklearn.datasets import fetch_openml\n",
"\n",
"# load specified dataset (MNIST)\n",
"mnist = fetch_openml('mnist_784', version=1, as_frame=False)\n",
"\n",
"# print type of dataset\n",
"type(mnist)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "2c7a4966",
"metadata": {},
"outputs": [],
"source": [
"X, y = mnist[\"data\"], mnist[\"target\"]"
]
},
{
"cell_type": "markdown",
"id": "e2684670",
"metadata": {},
"source": [
"### Fix labels"
]
},
{
"cell_type": "code",
"execution_count": 113,
"id": "dbdbc64f",
"metadata": {},
"outputs": [],
"source": [
"# import plotting libraries\n",
"import matplotlib as mpl\n",
"import matplotlib.pyplot as plt\n",
"from math import isqrt, sqrt"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "4c94aaf6",
"metadata": {},
"outputs": [],
"source": [
"# convert string labels to int\n",
"y = y.astype(np.uint8)"
]
},
{
"cell_type": "code",
"execution_count": 126,
"id": "f1ba6703",
"metadata": {},
"outputs": [],
"source": [
"# function to quickly plot an image\n",
"def plot_digit(data):\n",
" image = data.reshape(28, 28)\n",
" plt.imshow(image, cmap = mpl.cm.binary, interpolation=\"nearest\")\n",
" plt.axis(\"off\")"
]
},
{
"cell_type": "markdown",
"id": "eec5415d",
"metadata": {},
"source": [
"### Prepare data for machine learning"
]
},
{
"cell_type": "markdown",
"id": "27ed1cdb",
"metadata": {},
"source": [
"### Identify Train Set and Test Set"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "09446324",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"X_train: 56000, (56000, 784)\n",
"X_test: 14000, (14000, 784)\n",
"y_train: 56000, (56000,)\n",
"y_test: 14000, (14000,)\n"
]
}
],
"source": [
"from sklearn.model_selection import train_test_split\n",
"\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1337)\n",
"\n",
"print(f\"X_train: {len(X_train)}, {X_train.shape}\")\n",
"print(f\"X_test: {len(X_test)}, {X_test.shape}\")\n",
"print(f\"y_train: {len(y_train)}, {y_train.shape}\")\n",
"print(f\"y_test: {len(y_test)}, {y_test.shape}\")"
]
},
{
"cell_type": "markdown",
"id": "2c3041ac",
"metadata": {},
"source": [
"## Pipeline Declaration"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "99f24362",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.pipeline import Pipeline\n",
"from sklearn.decomposition import PCA\n",
"from sklearn.preprocessing import StandardScaler, MinMaxScaler, MaxAbsScaler\n",
"from sklearn.neighbors import KNeighborsClassifier\n",
"from sklearn.model_selection import cross_val_predict\n",
"from sklearn.metrics import classification_report, accuracy_score\n",
"\n",
"n_neighbors = 3\n",
"n95_components = 0.95\n",
"n99_components = 0.99"
]
},
{
"cell_type": "code",
"execution_count": 122,
"id": "a6ee7588",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(3, 3)"
]
},
"execution_count": 122,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"names = ['scaler', \n",
" 'minmax', \n",
" 'maxabs', \n",
" ]\n",
"\n",
"classifiers = [\n",
" Pipeline([('scaler', StandardScaler())]),\n",
" Pipeline([('minmax', MinMaxScaler())]),\n",
" Pipeline([('maxabs', MaxAbsScaler())]),\n",
"]\n",
"\n",
"len(names), len(classifiers)"
]
},
{
"cell_type": "markdown",
"id": "650c96b4",
"metadata": {},
"source": [
"# Crossvalidation"
]
},
{
"cell_type": "code",
"execution_count": 123,
"id": "584cb66b",
"metadata": {},
"outputs": [],
"source": [
"def cv(num):\n",
" name = names[num]\n",
" clf = classifiers[num]\n",
" i = 10000\n",
" _X_train = clf.fit_transform(X_train, y_train)\n",
" print(y_train[i])\n",
" plot_digit(_X_train[i])\n",
" return _X_train[i]"
]
},
{
"cell_type": "code",
"execution_count": 128,
"id": "0b815be6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"3\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAOcAAADnCAYAAADl9EEgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAGuElEQVR4nO3dS4jNcR/H8TOPy9ggyXVHKQtKUohsFKGsrWRBiFwWMxasXGpqIpKVHaVcFoqdLGaShYUdi9lYTIlYSFEUnt2z8v/+e87MmM/h9Vr69Dtzanp3yq/zn75fv351gDz/me43APyeOCGUOCGUOCGUOCHUzJbdf+XC1Ov73T/65IRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQbX8CkCkwPj7euI2MjJRnHz16VO737t0r9x07dpT748ePG7dZs2aVZ5lcPjkhlDghlDghlDghlDghlDghlDghVN+vX7+qvRz/Vrdv3y73u3fvTuj1nz592rh9+/ZtQq/d8vvs9PX1lfvu3bsbt4cPH5ZnZ850bd6l3/5SfHJCKHFCKHFCKHFCKHFCKHFCKHFCqH/ynvPWrVvlfuLEiXL//PnzhH7+zp07G7dVq1aVZ/fu3Vvu169fL/e274NWhoaGyn1wcLDr1/7HueeEXiJOCCVOCCVOCCVOCCVOCNWzVyk/fvwo9wsXLjRuV69eLc+2XZX09/eX+8DAQLmfO3eucZs9e3Z5ts3Pnz/Lfdu2beX+/Pnzxm3Dhg3l2RcvXpQ7jVylQC8RJ4QSJ4QSJ4QSJ4QSJ4QSJ4Tq2WcZtt2pnT9/vuvXXrFiRbm3PRqz7T5wOk3k8ZWLFy+exHdCG5+cEEqcEEqcEEqcEEqcEEqcEEqcEKpn7znb7usWLFjQuB04cKA8e+TIkXJve3zldHr//n25j46Odv3amzZt6vos/z+fnBBKnBBKnBBKnBBKnBBKnBBKnBCqZ59b2+bDhw+N26JFi/7gO/mzDh8+XO43b94s9yVLljRub968Kc/OmTOn3GnkubXQS8QJocQJocQJocQJocQJocQJoXr2+5xtevUu8+PHj+V+9OjRcn/w4EG59/X99krtf6p7UveYf5ZPTgglTgglTgglTgglTgglTgj1135lLNnly5cbt0uXLpVnP336VO4tv8/Wq5T+/v7GbdeuXeXZtkeO7t27t9z/Yb4yBr1EnBBKnBBKnBBKnBBKnBBKnBDKPecUGBsbK/f169c3bl+/fp3Qz57oPedEtL328ePHy/3atWuT+XZ6iXtO6CXihFDihFDihFDihFDihFDihFDuOafBsmXLGre2R2OeOnWq3IeHh8v9+/fv5V49WvPMmTPl2bdv35b7z58/y31oaKhxGxwcLM9O5f3tH+CeE3qJOCGUOCGUOCGUOCGUOCGUOCGUe85p8PTp08Zt/vz55dkNGzZM9tuZNNXzeDudTmdgYKDr175z506579u3r+vXDuCeE3qJOCGUOCGUOCGUOCGUOCGUOCGUe04mzbt378q97Y62+j7o9u3by7NPnjwp93DuOaGXiBNCiRNCiRNCiRNCiRNCzZzuN8DfY+nSpeV+7Nixcj979uxkvp2e55MTQokTQokTQokTQokTQokTQokTQvnKWBfGxsbK/eDBg+W+ZcuWxu3ixYvl2RkzZpR7spcvX5Z79ZWyefPmlWfHx8fLfe7cueU+zXxlDHqJOCGUOCGUOCGUOCGUOCGUOCGU73N24fbt2+X+7NmzrveVK1eWZw8dOlTuyV6/ft312UWLFpX7rFmzun7tVD45IZQ4IZQ4IZQ4IZQ4IZQ4IZQ4IZR7zi5s3Lix3Pv7+8v927dvjdvp06fLs5s3by73NWvWlPtU+vr1a7m33Q9X1q9fX+5z5szp+rVT+eSEUOKEUOKEUOKEUOKEUOKEUB6NOQXu379f7vv372/cqmuWTqf9ymDt2rXlfvLkyXKv/ozfly9fyrNXrlwp95GRkXKfP39+4/bq1avy7PLly8s9nEdjQi8RJ4QSJ4QSJ4QSJ4QSJ4QSJ4RyzzkNRkdHG7c9e/aUZ9vuGlt+n52+vt9eqf0R1T1mp9PpPHr0qHHbunXrZL+dJO45oZeIE0KJE0KJE0KJE0KJE0KJE0K55wzT9n3OGzdulHvbnx98+PBhuVffF123bl15dvXq1eU+PDxc7gsXLiz3v5h7Tugl4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ7jlh+rnnhF4iTgglTgglTgglTgglTgglTgglTgglTgglTgglTgglTgglTgglTgglTgglTgglTgglTgglTgglTgglTgglTgglTgglTgglTgglTgglTgglTgglTgglTgg1s2X/7Z8mA6aeT04IJU4IJU4IJU4IJU4IJU4I9V9xQCui+SkYGAAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"i = 10000\n",
"print(y_train[i])\n",
"plot_digit(X_train[i])"
]
},
{
"cell_type": "code",
"execution_count": 132,
"id": "8640f2ad",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"3\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAOcAAADnCAYAAADl9EEgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAJdElEQVR4nO3d32uOfxzH8WumyczISMx23zFR8ytRiHY4O3PqTHHIP+IPcUJxIpRi2YEQRZO1rbBi+dXNNJsN3zNHu17vtY+P6zXf5+Pw++ozt22v71XevT9X069fvwoAflZU/QEALIxyAqYoJ2CKcgKmKCdgaqUKh4aG+KdcILPjx483LfTfeXICpignYIpyAqYoJ2CKcgKmKCdginICpuSc81+1YsW/+/+kpqYFR2a/LdctpJ8/f1b9Ef66f/e3FFjmKCdginICpignYIpyAqYoJ2CKcgKm/tk5Z8osM5oVpp5XecrZqqXOUNX56OcZ/dmpn62KOStPTsAU5QRMUU7AFOUETFFOwBTlBEzZjlJS17pyjitSc/V3Szm7mPMpUscVUa7GFdEoI3XUEX029X3PNWbhyQmYopyAKcoJmKKcgCnKCZiinIApygmYqmzOmXOOGX391Flic3NzUv7hw4fSbHx8XJ4dHR2V+cjIiMx7e3tlfubMmdIs+r5F8765ubkln//x44c8G+Wpc9KUdbalzkF5cgKmKCdginICpignYIpyAqYoJ2CKcgKmss45c15PGX1tlUdzyMuXL8v8+fPnMt+6davMG41Gadbe3i7PtrW1yXzv3r0yj/7uQ0NDpdnAwIA8G80aozmnyr9//y7PRr8PqXPSKM+BJydginICpignYIpyAqYoJ2CKcgKmKCdgqrJ9zpxzzKLQ87yrV6/Ks9evX5d5Z2enzGdmZmR+9OjRJX/tXbt2yTza55yYmJD558+fS7OxsTF5NtoVnZ2dlXnKnbvRHDSScuduNANd6ryfJydginICpignYIpyAqYoJ2CKcgKmkkYpOVfCUq9hvHjxYmn29OlTebZWq8l827ZtMu/v75d5X19fabZ69Wp5tqWlRea7d++WeTRGmpycLM1evXolzx48eFDmOV8BmHo1ZsrvY3Q2+nuX4ckJmKKcgCnKCZiinIApygmYopyAKcoJmMq6MpYyG4pmqIODgzK/cuVKaXb48GF5tqenR+Znz56Veb1el7maVUZXV6a+yi7l+spoBptzDTB1hTBlHa0qPDkBU5QTMEU5AVOUEzBFOQFTlBMwRTkBU5VdjZmqtbVV5l1dXaXZqVOn5NmTJ0/KPHrFX8qe6/z8vMyjOebXr19lPjo6KnM1Z+3o6JBno8+WIuXqytz5Uvc1Izw5AVOUEzBFOQFTlBMwRTkBU5QTMEU5AVPLds555MgRmV+6dKk027x5szy7atUqmafuDqqdymjfMsrv3Lkj82/fvsl83bp1pdm+ffvk2dRd05RZYq5Z49/6+gvhyQmYopyAKcoJmKKcgCnKCZiinIApygmYsp1zps6V1q9fX5qlvhs09X2Mat43NTUlz964cUPmY2NjMo/unt2/f39ptnJl2q9Lyp26KTPS3HnKz1vhyQmYopyAKcoJmKKcgCnKCZiinICpykYpqaOSlFe6pb6qLnWU8vjx49Ls3r178uzs7KzM29raZB6tww0PD5dm09PT8uyOHTtkvmXLFplXcf3kYrEyBuA3ygmYopyAKcoJmKKcgCnKCZiinICprHNONRtKmVMuhppVpq6ERRqNhszv379fmkXrRWvWrJF5tBIWrX2p79vExIQ8G71esF6vy1xdd5p7zpjyO8ErAIH/GcoJmKKcgCnKCZiinIApygmYopyAqaQ5ZzSTUzOz3PucKk+dY0afXb1GryjinUolevVhX1+fzKM5p5pVRrumMzMzMld7rEWhr8Y8dOiQPBvJPVfPgScnYIpyAqYoJ2CKcgKmKCdginICpignYMr2FYCRnDuZuV8n19/fX5q1trbKs11dXTJvbm6WeXQnr7p7tru7W55Ve6pFURQ3b96U+eDgYGnW0tIiz/b29so8VRVzUp6cgCnKCZiinIApygmYopyAKcoJmKKcgKnK5pypc8qc94xGudo7XIzOzs7SLNq3jP7s6PsSnVd5dHbnzp0yj/ZBv3z5Upo9ePBAns0951RyzUB5cgKmKCdginICpignYIpyAqYoJ2DKdmUs92v6lOjKz5xXa6ZeGRqdjz67Oh99X6LXDx44cEDmb9++XdLn+hN5TtGaXum5P/w5APwhlBMwRTkBU5QTMEU5AVOUEzBFOQFTy3ZlLKdoJvbx40eZ3759W+bqessTJ07Is9HVlymvZcytVqvJXK2kjY+Py7PT09Myj1bxUuag0dnoZ1KGJydginICpignYIpyAqYoJ2CKcgKmKCdgKuucU80qc885c/7ZL168kPm7d+9k/unTp9Js06ZN8my0ExnNMaM5aco+Z+T9+/cyn52dLc1Sf2bO+55leHICpignYIpyAqYoJ2CKcgKmKCdginICpmzvra1SNCvs7u6W+cjIiMzVTO7Ro0fybE9Pj8yjOWnKKwKjWaCaUxZFUTx58kTmjUajNNuwYYM8m/rqxEgVc1CenIApygmYopyAKcoJmKKcgCnKCZiyHaVU+cq3aNywfft2mQ8MDMh8cHCwNIvGONeuXZP5xo0bZb5nzx6Zq5HE1NSUPHv37l2ZP3z4UObqytELFy7Is9E6W2rOKAXAb5QTMEU5AVOUEzBFOQFTlBMwRTkBU5XNOVNfmxatAKk85WxRxNdLRnPQ9vb20uzWrVvy7Pz8vMzfvHkj85cvX8pcvUpPrXQVRVFMTk7KPPps586dK806Ojrk2ehn5jjHjPDkBExRTsAU5QRMUU7AFOUETFFOwBTlBExlnXOmvE4u5QrHoiiKubk5mecUzczUNY+nT5+WZ589eybz169fy3x4eFjmamczmu9Gu6Tnz5+X+dq1a0uz1Dlm6pwz56sRy/DkBExRTsAU5QRMUU7AFOUETFFOwBTlBExlnXOq+U90P2uVc6nUGWr0Ojo1L4zmu/V6Xea1Wk3mx44dk7n6u0e7pKl7siqvco65mK+fA09OwBTlBExRTsAU5QRMUU7AFOUETFFOwJTt+zlT9z2V1JlWNO+LZrhRnlPKPDB1lpgyq6zyfa1V4ckJmKKcgCnKCZiinIApygmYopyAqcpGKdE/q0fjhpwrPqlfOxrzpIyBUqWMHHKPM1LGHamjkipWwiI8OQFTlBMwRTkBU5QTMEU5AVOUEzBFOQFTy3ZlLKLmpNEVjalyzjFzz0irnDUqjnPI3HhyAqYoJ2CKcgKmKCdginICpignYIpyAqaaluOVgcD/AU9OwBTlBExRTsAU5QRMUU7AFOUETP0Hj4dT7WBQ3HQAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"a = cv(0)"
]
},
{
"cell_type": "code",
"execution_count": 133,
"id": "3ef8cf89",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"3\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAOcAAADnCAYAAADl9EEgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAGuElEQVR4nO3dS4jNcR/H8TOPy9ggyXVHKQtKUohsFKGsrWRBiFwWMxasXGpqIpKVHaVcFoqdLGaShYUdi9lYTIlYSFEUnt2z8v/+e87MmM/h9Vr69Dtzanp3yq/zn75fv351gDz/me43APyeOCGUOCGUOCGUOCHUzJbdf+XC1Ov73T/65IRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQbX8CkCkwPj7euI2MjJRnHz16VO737t0r9x07dpT748ePG7dZs2aVZ5lcPjkhlDghlDghlDghlDghlDghlDghVN+vX7+qvRz/Vrdv3y73u3fvTuj1nz592rh9+/ZtQq/d8vvs9PX1lfvu3bsbt4cPH5ZnZ850bd6l3/5SfHJCKHFCKHFCKHFCKHFCKHFCKHFCqH/ynvPWrVvlfuLEiXL//PnzhH7+zp07G7dVq1aVZ/fu3Vvu169fL/e274NWhoaGyn1wcLDr1/7HueeEXiJOCCVOCCVOCCVOCCVOCNWzVyk/fvwo9wsXLjRuV69eLc+2XZX09/eX+8DAQLmfO3eucZs9e3Z5ts3Pnz/Lfdu2beX+/Pnzxm3Dhg3l2RcvXpQ7jVylQC8RJ4QSJ4QSJ4QSJ4QSJ4QSJ4Tq2WcZtt2pnT9/vuvXXrFiRbm3PRqz7T5wOk3k8ZWLFy+exHdCG5+cEEqcEEqcEEqcEEqcEEqcEEqcEKpn7znb7usWLFjQuB04cKA8e+TIkXJve3zldHr//n25j46Odv3amzZt6vos/z+fnBBKnBBKnBBKnBBKnBBKnBBKnBCqZ59b2+bDhw+N26JFi/7gO/mzDh8+XO43b94s9yVLljRub968Kc/OmTOn3GnkubXQS8QJocQJocQJocQJocQJocQJoXr2+5xtevUu8+PHj+V+9OjRcn/w4EG59/X99krtf6p7UveYf5ZPTgglTgglTgglTgglTgglTgj1135lLNnly5cbt0uXLpVnP336VO4tv8/Wq5T+/v7GbdeuXeXZtkeO7t27t9z/Yb4yBr1EnBBKnBBKnBBKnBBKnBBKnBDKPecUGBsbK/f169c3bl+/fp3Qz57oPedEtL328ePHy/3atWuT+XZ6iXtO6CXihFDihFDihFDihFDihFDihFDuOafBsmXLGre2R2OeOnWq3IeHh8v9+/fv5V49WvPMmTPl2bdv35b7z58/y31oaKhxGxwcLM9O5f3tH+CeE3qJOCGUOCGUOCGUOCGUOCGUOCGUe85p8PTp08Zt/vz55dkNGzZM9tuZNNXzeDudTmdgYKDr175z506579u3r+vXDuCeE3qJOCGUOCGUOCGUOCGUOCGUOCGUe04mzbt378q97Y62+j7o9u3by7NPnjwp93DuOaGXiBNCiRNCiRNCiRNCiRNCzZzuN8DfY+nSpeV+7Nixcj979uxkvp2e55MTQokTQokTQokTQokTQokTQokTQvnKWBfGxsbK/eDBg+W+ZcuWxu3ixYvl2RkzZpR7spcvX5Z79ZWyefPmlWfHx8fLfe7cueU+zXxlDHqJOCGUOCGUOCGUOCGUOCGUOCGU73N24fbt2+X+7NmzrveVK1eWZw8dOlTuyV6/ft312UWLFpX7rFmzun7tVD45IZQ4IZQ4IZQ4IZQ4IZQ4IZQ4IZR7zi5s3Lix3Pv7+8v927dvjdvp06fLs5s3by73NWvWlPtU+vr1a7m33Q9X1q9fX+5z5szp+rVT+eSEUOKEUOKEUOKEUOKEUOKEUB6NOQXu379f7vv372/cqmuWTqf9ymDt2rXlfvLkyXKv/ozfly9fyrNXrlwp95GRkXKfP39+4/bq1avy7PLly8s9nEdjQi8RJ4QSJ4QSJ4QSJ4QSJ4QSJ4RyzzkNRkdHG7c9e/aUZ9vuGlt+n52+vt9eqf0R1T1mp9PpPHr0qHHbunXrZL+dJO45oZeIE0KJE0KJE0KJE0KJE0KJE0K55wzT9n3OGzdulHvbnx98+PBhuVffF123bl15dvXq1eU+PDxc7gsXLiz3v5h7Tugl4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ7jlh+rnnhF4iTgglTgglTgglTgglTgglTgglTgglTgglTgglTgglTgglTgglTgglTgglTgglTgglTgglTgglTgglTgglTgglTgglTgglTgglTgglTgglTgglTgglTgg1s2X/7Z8mA6aeT04IJU4IJU4IJU4IJU4IJU4I9V9xQCui+SkYGAAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"a = cv(1)"
]
},
{
"cell_type": "code",
"execution_count": 134,
"id": "fe0246a2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"3\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAOcAAADnCAYAAADl9EEgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAGuElEQVR4nO3dS4jNcR/H8TOPy9ggyXVHKQtKUohsFKGsrWRBiFwWMxasXGpqIpKVHaVcFoqdLGaShYUdi9lYTIlYSFEUnt2z8v/+e87MmM/h9Vr69Dtzanp3yq/zn75fv351gDz/me43APyeOCGUOCGUOCGUOCHUzJbdf+XC1Ov73T/65IRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQbX8CkCkwPj7euI2MjJRnHz16VO737t0r9x07dpT748ePG7dZs2aVZ5lcPjkhlDghlDghlDghlDghlDghlDghVN+vX7+qvRz/Vrdv3y73u3fvTuj1nz592rh9+/ZtQq/d8vvs9PX1lfvu3bsbt4cPH5ZnZ850bd6l3/5SfHJCKHFCKHFCKHFCKHFCKHFCKHFCqH/ynvPWrVvlfuLEiXL//PnzhH7+zp07G7dVq1aVZ/fu3Vvu169fL/e274NWhoaGyn1wcLDr1/7HueeEXiJOCCVOCCVOCCVOCCVOCNWzVyk/fvwo9wsXLjRuV69eLc+2XZX09/eX+8DAQLmfO3eucZs9e3Z5ts3Pnz/Lfdu2beX+/Pnzxm3Dhg3l2RcvXpQ7jVylQC8RJ4QSJ4QSJ4QSJ4QSJ4QSJ4Tq2WcZtt2pnT9/vuvXXrFiRbm3PRqz7T5wOk3k8ZWLFy+exHdCG5+cEEqcEEqcEEqcEEqcEEqcEEqcEKpn7znb7usWLFjQuB04cKA8e+TIkXJve3zldHr//n25j46Odv3amzZt6vos/z+fnBBKnBBKnBBKnBBKnBBKnBBKnBCqZ59b2+bDhw+N26JFi/7gO/mzDh8+XO43b94s9yVLljRub968Kc/OmTOn3GnkubXQS8QJocQJocQJocQJocQJocQJoXr2+5xtevUu8+PHj+V+9OjRcn/w4EG59/X99krtf6p7UveYf5ZPTgglTgglTgglTgglTgglTgj1135lLNnly5cbt0uXLpVnP336VO4tv8/Wq5T+/v7GbdeuXeXZtkeO7t27t9z/Yb4yBr1EnBBKnBBKnBBKnBBKnBBKnBDKPecUGBsbK/f169c3bl+/fp3Qz57oPedEtL328ePHy/3atWuT+XZ6iXtO6CXihFDihFDihFDihFDihFDihFDuOafBsmXLGre2R2OeOnWq3IeHh8v9+/fv5V49WvPMmTPl2bdv35b7z58/y31oaKhxGxwcLM9O5f3tH+CeE3qJOCGUOCGUOCGUOCGUOCGUOCGUe85p8PTp08Zt/vz55dkNGzZM9tuZNNXzeDudTmdgYKDr175z506579u3r+vXDuCeE3qJOCGUOCGUOCGUOCGUOCGUOCGUe04mzbt378q97Y62+j7o9u3by7NPnjwp93DuOaGXiBNCiRNCiRNCiRNCiRNCzZzuN8DfY+nSpeV+7Nixcj979uxkvp2e55MTQokTQokTQokTQokTQokTQokTQvnKWBfGxsbK/eDBg+W+ZcuWxu3ixYvl2RkzZpR7spcvX5Z79ZWyefPmlWfHx8fLfe7cueU+zXxlDHqJOCGUOCGUOCGUOCGUOCGUOCGU73N24fbt2+X+7NmzrveVK1eWZw8dOlTuyV6/ft312UWLFpX7rFmzun7tVD45IZQ4IZQ4IZQ4IZQ4IZQ4IZQ4IZR7zi5s3Lix3Pv7+8v927dvjdvp06fLs5s3by73NWvWlPtU+vr1a7m33Q9X1q9fX+5z5szp+rVT+eSEUOKEUOKEUOKEUOKEUOKEUB6NOQXu379f7vv372/cqmuWTqf9ymDt2rXlfvLkyXKv/ozfly9fyrNXrlwp95GRkXKfP39+4/bq1avy7PLly8s9nEdjQi8RJ4QSJ4QSJ4QSJ4QSJ4QSJ4RyzzkNRkdHG7c9e/aUZ9vuGlt+n52+vt9eqf0R1T1mp9PpPHr0qHHbunXrZL+dJO45oZeIE0KJE0KJE0KJE0KJE0KJE0K55wzT9n3OGzdulHvbnx98+PBhuVffF123bl15dvXq1eU+PDxc7gsXLiz3v5h7Tugl4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ7jlh+rnnhF4iTgglTgglTgglTgglTgglTgglTgglTgglTgglTgglTgglTgglTgglTgglTgglTgglTgglTgglTgglTgglTgglTgglTgglTgglTgglTgglTgglTgglTgg1s2X/7Z8mA6aeT04IJU4IJU4IJU4IJU4IJU4I9V9xQCui+SkYGAAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"a = cv(2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "87a073e1",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

File diff suppressed because one or more lines are too long