523 lines
14 KiB
Plaintext
523 lines
14 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "e4d89124",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Load MNIST dataset"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "5789ec72",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Python ≥3.5 is required\n",
|
|
"import sys\n",
|
|
"assert sys.version_info >= (3, 5)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "f491d383",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# scikit-learn ≥0.20 is required\n",
|
|
"import sklearn\n",
|
|
"assert sklearn.__version__ >= \"0.20\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "575a6a42",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# common imports\n",
|
|
"import numpy as np"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "921dc114",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"sklearn.utils.Bunch"
|
|
]
|
|
},
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# import function to scikit-learn datasets\n",
|
|
"from sklearn.datasets import fetch_openml\n",
|
|
"\n",
|
|
"# load specified dataset (MNIST)\n",
|
|
"mnist = fetch_openml('mnist_784', version=1, as_frame=False)\n",
|
|
"\n",
|
|
"# print type of dataset\n",
|
|
"type(mnist)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "6271045c",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"X, y = mnist[\"data\"], mnist[\"target\"]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "37777133",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Fix labels"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "30a441d3",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# import plotting libraries\n",
|
|
"import matplotlib as mpl\n",
|
|
"import matplotlib.pyplot as plt"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "2d9693b1",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# convert string labels to int\n",
|
|
"y = y.astype(np.uint8)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "182f4b1b",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Prepare data for machine learning"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "77ff6bd1",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Identify Train Set and Test Set"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "f8247d13",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"X_train: 56000, (56000, 784)\n",
|
|
"X_test: 14000, (14000, 784)\n",
|
|
"y_train: 56000, (56000,)\n",
|
|
"y_test: 14000, (14000,)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from sklearn.model_selection import train_test_split\n",
|
|
"\n",
|
|
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1337)\n",
|
|
"\n",
|
|
"print(f\"X_train: {len(X_train)}, {X_train.shape}\")\n",
|
|
"print(f\"X_test: {len(X_test)}, {X_test.shape}\")\n",
|
|
"print(f\"y_train: {len(y_train)}, {y_train.shape}\")\n",
|
|
"print(f\"y_test: {len(y_test)}, {y_test.shape}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "c4062436",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Pipeline Declaration"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"id": "67645a47",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from sklearn.pipeline import Pipeline\n",
|
|
"from sklearn.decomposition import PCA\n",
|
|
"from sklearn.preprocessing import StandardScaler, MinMaxScaler\n",
|
|
"from sklearn.neighbors import KNeighborsClassifier\n",
|
|
"from sklearn.model_selection import cross_val_predict\n",
|
|
"from sklearn.metrics import classification_report, precision_score\n",
|
|
"\n",
|
|
"n_neighbors = 3\n",
|
|
"n95_components = 0.95\n",
|
|
"n99_components = 0.99"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"id": "82b1e834",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"6"
|
|
]
|
|
},
|
|
"execution_count": 14,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"names = ['knn', \n",
|
|
" 'scalar+knn', \n",
|
|
" 'standard+pca95+knn', \n",
|
|
" 'minmax+pca95+knn', \n",
|
|
" 'standard+pca99+knn', \n",
|
|
" 'minmax+pca99+knn'\n",
|
|
" ]\n",
|
|
"\n",
|
|
"classifiers = [\n",
|
|
" Pipeline([('knn', KNeighborsClassifier(n_neighbors=n_neighbors))]),\n",
|
|
" Pipeline([\n",
|
|
" ('scaler', StandardScaler()),\n",
|
|
" ('knn', KNeighborsClassifier(n_neighbors=n_neighbors))\n",
|
|
" ]),\n",
|
|
" Pipeline([\n",
|
|
" ('standard', StandardScaler()),\n",
|
|
" ('pca', PCA(n_components=n95_components)),\n",
|
|
" ('knn', KNeighborsClassifier(n_neighbors=n_neighbors))\n",
|
|
" ]),\n",
|
|
" Pipeline([\n",
|
|
" ('minmax', MinMaxScaler()),\n",
|
|
" ('pca', PCA(n_components=n95_components)),\n",
|
|
" ('knn', KNeighborsClassifier(n_neighbors=n_neighbors))\n",
|
|
" ]),\n",
|
|
" Pipeline([\n",
|
|
" ('standard', StandardScaler()),\n",
|
|
" ('pca', PCA(n_components=n99_components)),\n",
|
|
" ('knn', KNeighborsClassifier(n_neighbors=n_neighbors))\n",
|
|
" ]),\n",
|
|
" Pipeline([\n",
|
|
" ('minmax', MinMaxScaler()),\n",
|
|
" ('pca', PCA(n_components=n99_components)),\n",
|
|
" ('knn', KNeighborsClassifier(n_neighbors=n_neighbors))\n",
|
|
" ])\n",
|
|
"]\n",
|
|
"\n",
|
|
"len(names)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "156dbf2c",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Crossvalidation"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"id": "8e5168e4",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# for name, clf in zip(names, classifiers):\n",
|
|
"# y_train_pred = cross_val_predict(clf, X_train, y_train, cv=3)\n",
|
|
"# print(f\"Pipeline: {name} ({precision_score(y_train, y_train_pred, average='weighted'):.4f})\")\n",
|
|
"# print(classification_report(y_train, y_train_pred))\n",
|
|
"precs = []"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 27,
|
|
"id": "db14a027",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def cv(num):\n",
|
|
" name = names[num]\n",
|
|
" clf = classifiers[num]\n",
|
|
" y_train_pred = cross_val_predict(clf, X_train, y_train, cv=3)\n",
|
|
" precision = precision_score(y_train, y_train_pred, average='weighted')\n",
|
|
" print(f\"Pipeline: {name} ({precision:.4f})\")\n",
|
|
" print(classification_report(y_train, y_train_pred))\n",
|
|
" return precision"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 19,
|
|
"id": "b8983bf8",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Pipeline: knn (0.9691)\n",
|
|
" precision recall f1-score support\n",
|
|
"\n",
|
|
" 0 0.98 0.99 0.99 5499\n",
|
|
" 1 0.95 0.99 0.97 6287\n",
|
|
" 2 0.98 0.96 0.97 5595\n",
|
|
" 3 0.96 0.97 0.96 5679\n",
|
|
" 4 0.98 0.96 0.97 5450\n",
|
|
" 5 0.96 0.96 0.96 5068\n",
|
|
" 6 0.98 0.99 0.98 5542\n",
|
|
" 7 0.96 0.97 0.97 5846\n",
|
|
" 8 0.99 0.93 0.96 5504\n",
|
|
" 9 0.95 0.96 0.96 5530\n",
|
|
"\n",
|
|
" accuracy 0.97 56000\n",
|
|
" macro avg 0.97 0.97 0.97 56000\n",
|
|
"weighted avg 0.97 0.97 0.97 56000\n",
|
|
"\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"cv(0)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 21,
|
|
"id": "62ff42f4",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Pipeline: scalar+knn (0.9420)\n",
|
|
" precision recall f1-score support\n",
|
|
"\n",
|
|
" 0 0.95 0.99 0.97 5499\n",
|
|
" 1 0.95 0.99 0.97 6287\n",
|
|
" 2 0.95 0.93 0.94 5595\n",
|
|
" 3 0.92 0.94 0.93 5679\n",
|
|
" 4 0.94 0.93 0.94 5450\n",
|
|
" 5 0.93 0.92 0.93 5068\n",
|
|
" 6 0.96 0.97 0.97 5542\n",
|
|
" 7 0.94 0.93 0.94 5846\n",
|
|
" 8 0.97 0.89 0.93 5504\n",
|
|
" 9 0.91 0.92 0.91 5530\n",
|
|
"\n",
|
|
" accuracy 0.94 56000\n",
|
|
" macro avg 0.94 0.94 0.94 56000\n",
|
|
"weighted avg 0.94 0.94 0.94 56000\n",
|
|
"\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"cv(1)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 22,
|
|
"id": "9b553141",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Pipeline: standard+pca95+knn (0.9457)\n",
|
|
" precision recall f1-score support\n",
|
|
"\n",
|
|
" 0 0.96 0.99 0.97 5499\n",
|
|
" 1 0.95 0.99 0.97 6287\n",
|
|
" 2 0.95 0.94 0.95 5595\n",
|
|
" 3 0.93 0.94 0.94 5679\n",
|
|
" 4 0.95 0.93 0.94 5450\n",
|
|
" 5 0.93 0.92 0.93 5068\n",
|
|
" 6 0.96 0.97 0.97 5542\n",
|
|
" 7 0.95 0.94 0.94 5846\n",
|
|
" 8 0.97 0.89 0.93 5504\n",
|
|
" 9 0.92 0.92 0.92 5530\n",
|
|
"\n",
|
|
" accuracy 0.95 56000\n",
|
|
" macro avg 0.95 0.94 0.94 56000\n",
|
|
"weighted avg 0.95 0.95 0.95 56000\n",
|
|
"\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"cv(2)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 23,
|
|
"id": "2a1d6c65",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Pipeline: minmax+pca95+knn (0.9706)\n",
|
|
" precision recall f1-score support\n",
|
|
"\n",
|
|
" 0 0.98 0.99 0.99 5499\n",
|
|
" 1 0.96 0.99 0.98 6287\n",
|
|
" 2 0.98 0.97 0.97 5595\n",
|
|
" 3 0.96 0.96 0.96 5679\n",
|
|
" 4 0.98 0.97 0.97 5450\n",
|
|
" 5 0.96 0.96 0.96 5068\n",
|
|
" 6 0.98 0.99 0.98 5542\n",
|
|
" 7 0.97 0.98 0.97 5846\n",
|
|
" 8 0.99 0.93 0.96 5504\n",
|
|
" 9 0.95 0.96 0.96 5530\n",
|
|
"\n",
|
|
" accuracy 0.97 56000\n",
|
|
" macro avg 0.97 0.97 0.97 56000\n",
|
|
"weighted avg 0.97 0.97 0.97 56000\n",
|
|
"\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"cv(3)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 25,
|
|
"id": "de40f817",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Pipeline: standard+pca99+knn (0.9424)\n",
|
|
" precision recall f1-score support\n",
|
|
"\n",
|
|
" 0 0.95 0.99 0.97 5499\n",
|
|
" 1 0.95 0.99 0.97 6287\n",
|
|
" 2 0.95 0.93 0.94 5595\n",
|
|
" 3 0.92 0.94 0.93 5679\n",
|
|
" 4 0.94 0.93 0.94 5450\n",
|
|
" 5 0.93 0.92 0.93 5068\n",
|
|
" 6 0.96 0.97 0.97 5542\n",
|
|
" 7 0.94 0.93 0.94 5846\n",
|
|
" 8 0.97 0.89 0.93 5504\n",
|
|
" 9 0.91 0.92 0.92 5530\n",
|
|
"\n",
|
|
" accuracy 0.94 56000\n",
|
|
" macro avg 0.94 0.94 0.94 56000\n",
|
|
"weighted avg 0.94 0.94 0.94 56000\n",
|
|
"\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"cv(4)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 26,
|
|
"id": "2c1c26b0",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Pipeline: minmax+pca99+knn (0.9695)\n",
|
|
" precision recall f1-score support\n",
|
|
"\n",
|
|
" 0 0.98 0.99 0.99 5499\n",
|
|
" 1 0.95 0.99 0.97 6287\n",
|
|
" 2 0.98 0.96 0.97 5595\n",
|
|
" 3 0.96 0.97 0.96 5679\n",
|
|
" 4 0.98 0.96 0.97 5450\n",
|
|
" 5 0.96 0.96 0.96 5068\n",
|
|
" 6 0.98 0.99 0.98 5542\n",
|
|
" 7 0.96 0.97 0.97 5846\n",
|
|
" 8 0.99 0.93 0.96 5504\n",
|
|
" 9 0.95 0.96 0.96 5530\n",
|
|
"\n",
|
|
" accuracy 0.97 56000\n",
|
|
" macro avg 0.97 0.97 0.97 56000\n",
|
|
"weighted avg 0.97 0.97 0.97 56000\n",
|
|
"\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"cv(5)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "cb6b6d69",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.8.5"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|