iui-group-l-name-zensiert/2-second-project/tdt/DataViz.ipynb

824 lines
58 KiB
Plaintext
Raw Normal View History

{
"cells": [
{
"cell_type": "markdown",
"id": "ae397d48",
"metadata": {},
"source": [
"# Constants"
]
},
{
"cell_type": "code",
2021-07-14 10:15:52 +02:00
"execution_count": 1,
"id": "3827a09b",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
2021-07-14 10:15:52 +02:00
"\n",
"os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' # this is required\n",
"os.environ['CUDA_VISIBLE_DEVICES'] = '1' # set to '0' for GPU0, '1' for GPU1 or '2' for GPU2. Check \"gpustat\" in a terminal."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "654f2682",
2021-07-14 10:15:52 +02:00
"metadata": {},
"outputs": [],
"source": [
"glob_path = '/opt/iui-datarelease3-sose2021/*.csv'\n",
"\n",
"pickle_file = '../data.pickle'\n",
"\n",
"cenario = 'SYY'\n",
"\n",
"win_sz = 50\n",
"stride_sz = 25 "
2021-07-14 10:15:52 +02:00
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "6cc88c90",
"metadata": {},
"outputs": [],
"source": [
"from matplotlib import pyplot as plt\n",
"\n",
"def pplot(dd):\n",
" x = dd.shape[0]\n",
" fix = int(x/3)+1\n",
" fiy = 3\n",
" fig, axs = plt.subplots(fix, fiy, figsize=(3*fiy, 9*fix))\n",
" \n",
" for i in range(x):\n",
" axs[int(i/3)][i%3].plot(dd[i])"
]
},
{
"cell_type": "markdown",
"id": "3c47f127",
"metadata": {},
"source": [
"# Loading Data"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "9dc8d47e",
2021-07-14 10:15:52 +02:00
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from glob import glob\n",
"import pandas as pd\n",
2021-07-14 10:15:52 +02:00
"from tqdm import tqdm\n",
"\n",
2021-07-14 10:15:52 +02:00
"def dl_from_blob(filename, user_filter=None):\n",
" \n",
" dic_data = []\n",
" \n",
2021-07-14 10:15:52 +02:00
" for p in tqdm(glob(glob_path)):\n",
" path = p\n",
2021-07-14 10:15:52 +02:00
" filename = path.split('/')[-1].split('.')[0]\n",
" splitname = filename.split('_')\n",
" user = int(splitname[0][1:])\n",
" if (user_filter):\n",
" if (user != user_filter):\n",
" continue\n",
2021-07-14 10:15:52 +02:00
" scenario = splitname[1][len('Scenario'):]\n",
" heightnorm = splitname[2][len('HeightNormalization'):] == 'True'\n",
" armnorm = splitname[3][len('ArmNormalization'):] == 'True'\n",
" rep = int(splitname[4][len('Repetition'):])\n",
" session = int(splitname[5][len('Session'):])\n",
" data = pd.read_csv(path)\n",
" dic_data.append(\n",
" {\n",
" 'filename': path,\n",
" 'user': user,\n",
" 'scenario': scenario,\n",
" 'heightnorm': heightnorm,\n",
" 'armnorm': armnorm,\n",
" 'rep': rep,\n",
2021-07-14 10:15:52 +02:00
" 'session': session,\n",
" 'data': data \n",
" }\n",
" )\n",
2021-07-14 10:15:52 +02:00
" return dic_data"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "1294685f",
2021-07-14 10:15:52 +02:00
"metadata": {},
"outputs": [],
"source": [
"import pickle\n",
"\n",
2021-07-14 10:15:52 +02:00
"def save_pickle(f, structure):\n",
" _p = open(f, 'wb')\n",
" pickle.dump(structure, _p)\n",
" _p.close()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "5e418dc4",
2021-07-14 10:15:52 +02:00
"metadata": {},
"outputs": [],
"source": [
"def load_pickles(f) -> list:\n",
" _p = open(pickle_file, 'rb')\n",
" _d = pickle.load(_p)\n",
" _p.close()\n",
" \n",
" return _d"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "7938c466",
"metadata": {},
2021-07-14 10:15:52 +02:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading data...\n",
"../data.pickle found...\n",
"768\n",
"CPU times: user 615 ms, sys: 2.24 s, total: 2.85 s\n",
"Wall time: 2.85 s\n"
2021-07-14 10:15:52 +02:00
]
}
],
"source": [
"%%time\n",
"\n",
"def load_data() -> list:\n",
" if os.path.isfile(pickle_file):\n",
" print(f'{pickle_file} found...')\n",
" return load_pickles(pickle_file)\n",
" print(f'Didn\\'t find {pickle_file}...')\n",
" all_data = dl_from_blob(glob_path)\n",
" print(f'Creating {pickle_file}...')\n",
" save_pickle(pickle_file, all_data)\n",
" return all_data\n",
"\n",
"print(\"Loading data...\")\n",
"dic_data = load_data()\n",
"print(len(dic_data))"
2021-07-14 10:15:52 +02:00
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "e3f38b64",
2021-07-14 10:15:52 +02:00
"metadata": {
"tags": []
},
"outputs": [],
2021-07-14 10:15:52 +02:00
"source": [
"# Categorized Data\n",
"cdata = dict() \n",
"# Sorting, HeightNorm, ArmNorm\n",
"cdata['SYY'] = list() \n",
"cdata['SYN'] = list() \n",
"cdata['SNY'] = list() \n",
"cdata['SNN'] = list() \n",
"\n",
"# Jenga, HeightNorm, ArmNorm\n",
"cdata['JYY'] = list() \n",
"cdata['JYN'] = list() \n",
"cdata['JNY'] = list() \n",
"cdata['JNN'] = list() \n",
2021-07-14 10:15:52 +02:00
"for d in dic_data:\n",
" if d['scenario'] == 'Sorting':\n",
" if d['heightnorm']:\n",
" if d['armnorm']:\n",
" cdata['SYY'].append(d)\n",
" else:\n",
" cdata['SYN'].append(d)\n",
" else:\n",
" if d['armnorm']:\n",
" cdata['SNY'].append(d)\n",
" else:\n",
" cdata['SNN'].append(d)\n",
" elif d['scenario'] == 'Jenga':\n",
" if d['heightnorm']:\n",
" if d['armnorm']:\n",
" cdata['JYY'].append(d)\n",
" else:\n",
" cdata['JYN'].append(d)\n",
" else:\n",
" if d['armnorm']:\n",
" cdata['JNY'].append(d)\n",
" else:\n",
" cdata['JNN'].append(d)\n",
2021-07-14 10:15:52 +02:00
"\n",
"# for k,v in cdata.items():\n",
"# print(k,': ',len(v))\n",
"# test_entry = pickle.loads(pickle.dumps(cdata['SYY'][17]))\n",
"# test_entry['data']"
]
},
{
"cell_type": "markdown",
"id": "83953c92",
"metadata": {},
"source": [
"# Preprocessing"
2021-07-14 10:15:52 +02:00
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "583e8c34",
2021-07-14 10:15:52 +02:00
"metadata": {
"tags": []
},
"outputs": [],
2021-07-14 10:15:52 +02:00
"source": [
"def drop(entry) -> pd.DataFrame:\n",
" droptable = ['participantID', 'FrameID', 'Scenario', 'HeightNormalization', 'ArmNormalization', 'Repetition', 'Session', 'Unnamed: 0']\n",
" centry = pickle.loads(pickle.dumps(entry))\n",
" return centry['data'].drop(droptable, axis=1)"
2021-07-14 10:15:52 +02:00
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "b8a05286",
2021-07-14 10:15:52 +02:00
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def floatize(entry) -> pd.DataFrame:\n",
" centry = pickle.loads(pickle.dumps(entry))\n",
2021-07-14 10:15:52 +02:00
" centry['data']['LeftHandTrackingAccuracy'] = (entry['data']['LeftHandTrackingAccuracy'] == 'High') * 1.0\n",
" centry['data']['RightHandTrackingAccuracy'] = (entry['data']['RightHandTrackingAccuracy'] == 'High') * 1.0\n",
" return centry['data']"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "fbe90e8d",
2021-07-14 10:15:52 +02:00
"metadata": {},
"outputs": [],
2021-07-14 10:15:52 +02:00
"source": [
"import numpy as np\n",
"right_Hand_ident='right_Hand'\n",
"left_Hand_ident='left_hand'\n",
2021-07-14 10:15:52 +02:00
"\n",
"def rem_low_acc(entry) -> pd.DataFrame:\n",
" centry = pickle.loads(pickle.dumps(entry))\n",
" right_Hand_cols = [c for c in centry['data'] if right_Hand_ident in c]\n",
" left_Hand_cols = [c for c in centry['data'] if left_Hand_ident in c]\n",
" \n",
" centry['data'].loc[centry['data']['RightHandTrackingAccuracy'] == 0.0, right_Hand_cols] = np.nan\n",
" centry['data'].loc[centry['data']['LeftHandTrackingAccuracy'] == 0.0, left_Hand_cols] = np.nan\n",
" return centry['data']"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "26059dd4",
"metadata": {},
"outputs": [],
"source": [
"from tensorflow.keras.preprocessing.sequence import pad_sequences\n",
"\n",
"stride = 150\n",
"def pad(entry) -> pd.DataFrame:\n",
" centry = pickle.loads(pickle.dumps(entry))\n",
" cols = centry['data'].columns\n",
" pentry = pad_sequences(centry['data'].T.to_numpy(),\n",
" maxlen=(int(centry['data'].shape[0]/stride)+1)*stride,\n",
" dtype='float64',\n",
" padding='pre', \n",
" truncating='post',\n",
" value=np.nan\n",
" ) \n",
" pdentry = pd.DataFrame(pentry.T, columns=cols)\n",
" pdentry.loc[0] = [0 for _ in cols]\n",
" return pdentry"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "2f2181f0",
"metadata": {},
"outputs": [],
"source": [
"def interpol(entry) -> pd.DataFrame:\n",
" centry = pickle.loads(pickle.dumps(entry))\n",
" return centry['data'].interpolate(method='linear', axis=0)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "276ecf82",
"metadata": {},
"outputs": [],
"source": [
"from tensorflow.keras.preprocessing import timeseries_dataset_from_array\n",
2021-07-14 10:15:52 +02:00
"\n",
"def slicing(entry):\n",
" centry = pickle.loads(pickle.dumps(entry))\n",
" return timeseries_dataset_from_array(\n",
" data=centry['data'], \n",
" targets=[centry['user'] for _ in range(centry['data'].shape[0])], \n",
" sequence_length=win_sz,\n",
" sequence_stride=stride_sz, \n",
" batch_size=8, \n",
" seed=177013\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "dab70ad9",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 96/96 [00:15<00:00, 6.14it/s]\n"
]
}
],
"source": [
"classes = 16 # dynamic\n",
"\n",
"def preproc(data):\n",
" res_list = list()\n",
" \n",
" for e in tqdm(data):\n",
" res_list.append(preproc_entry(e))\n",
" \n",
" return res_list\n",
" \n",
"def preproc_entry(entry):\n",
" entry2 = pickle.loads(pickle.dumps(entry))\n",
" entry2['data'] = drop(entry2)\n",
" \n",
" entry3 = pickle.loads(pickle.dumps(entry2))\n",
" entry3['data'] = floatize(entry3)\n",
" \n",
" entry4 = pickle.loads(pickle.dumps(entry3))\n",
" entry4['data'] = rem_low_acc(entry4)\n",
" \n",
" entry5 = pickle.loads(pickle.dumps(entry4))\n",
" entry5['data'] = pad(entry5)\n",
" \n",
" entry6 = pickle.loads(pickle.dumps(entry5))\n",
" entry6['data'] = interpol(entry6)\n",
" \n",
" entry7 = pickle.loads(pickle.dumps(entry6))\n",
" entry7['data'] = slicing(entry7)\n",
" \n",
" return entry7\n",
2021-07-14 10:15:52 +02:00
"\n",
"pdata = preproc(cdata[cenario])"
]
},
{
"cell_type": "markdown",
"id": "ddba89b9",
"metadata": {},
"source": [
"# Building Model"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "61c34fed",
"metadata": {},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"from tensorflow.keras.models import Sequential\n",
"from tensorflow.keras.layers import Dense, Flatten, BatchNormalization, Dropout, Conv2D, MaxPooling2D\n",
"\n",
"def build_model(train):\n",
" s = train[0].shape\n",
"\n",
" model = Sequential()\n",
" ncount = s[0]*s[1]\n",
" \n",
" model.add(Flatten(input_shape=s))\n",
" \n",
" model.add(BatchNormalization())\n",
" \n",
" model.add(Dropout(0.1))\n",
" \n",
" for i in range(1,6):\n",
" model.add(Dense(int(ncount/pow(3,i)), activation='relu'))\n",
" model.add(Dropout(0.1))\n",
" \n",
" model.add(Dense(classes, activation='softmax'))\n",
"\n",
" model.compile(\n",
" optimizer=tf.keras.optimizers.Adam(0.001),\n",
" loss=\"categorical_crossentropy\", \n",
" metrics=[\"acc\"],\n",
" )\n",
"\n",
" return model"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "47058299",
"metadata": {},
"outputs": [],
"source": [
"checkpoint_file = './goat.weights'\n",
"\n",
"def train_model(X_train, y_train):\n",
" model = build_model(X_train)\n",
" \n",
" model.summary()\n",
"\n",
" history = model.fit(X_train, \n",
" y_train,\n",
" epochs=30,\n",
" batch_size=128,\n",
" shuffle=True,\n",
" verbose=0,\n",
" )\n",
" return model, history"
2021-07-14 10:15:52 +02:00
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "6c99e0bc",
2021-07-14 10:15:52 +02:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(48, 48)"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.model_selection import train_test_split\n",
"from sklearn.preprocessing import LabelEncoder, LabelBinarizer\n",
"\n",
"train = np.array([x['data'] for x in pdata if x['session'] == 1])\n",
"test = np.array([x['data'] for x in pdata if x['session'] == 2])\n",
"\n",
"len(train), len(test)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "727b89e0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 8.86 s, sys: 3.63 s, total: 12.5 s\n",
"Wall time: 4.7 s\n"
]
}
],
"source": [
"%%time\n",
"X_train = list()\n",
"y_train = list()\n",
2021-07-14 10:15:52 +02:00
"\n",
"train = list()\n",
"test = list()\n",
"\n",
"for x in pdata:\n",
" if x['session'] == 1:\n",
" train.append(\n",
" {\n",
" 'label': x['user'],\n",
" 'data': list()\n",
" })\n",
" for y in x['data'].unbatch().as_numpy_iterator():\n",
" X_train.append(y[0])\n",
" y_train.append(y[1])\n",
" \n",
" train[-1]['data'].append(y[0])\n",
" if x['session'] == 2:\n",
" test.append(\n",
" {\n",
" 'label': x['user'],\n",
" 'data': list()\n",
" })\n",
" for y in x['data'].unbatch().as_numpy_iterator():\n",
" test[-1]['data'].append(y[0])\n",
"\n",
"X_train = np.array(X_train)\n",
"y_train = np.array(y_train)"
2021-07-14 10:15:52 +02:00
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "ba64dca4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(5832, 50, 338)\n",
"(5832, 16)\n"
]
}
],
"source": [
"lb = LabelBinarizer()\n",
"yy_train = lb.fit_transform(y_train)\n",
"\n",
"for e in test:\n",
" e['label'] = lb.transform([e['label']])\n",
" e['data'] = np.array(e['data'])\n",
" \n",
"for e in train:\n",
" e['label'] = lb.transform([e['label']])\n",
" e['data'] = np.array(e['data'])\n",
"\n",
"print(X_train.shape)\n",
"print(yy_train.shape)\n"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "399176de",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" 4: 53 (53, 50, 338)\n",
"14: 35 (35, 50, 338)\n",
"12: 65 (65, 50, 338)\n",
" 8: 149 (149, 50, 338)\n",
" 1: 53 (53, 50, 338)\n",
" 3: 107 (107, 50, 338)\n",
"11: 53 (53, 50, 338)\n",
" 3: 125 (125, 50, 338)\n",
" 1: 41 (41, 50, 338)\n",
"13: 71 (71, 50, 338)\n",
"15: 59 (59, 50, 338)\n",
" 3: 77 (77, 50, 338)\n",
"10: 119 (119, 50, 338)\n",
" 6: 47 (47, 50, 338)\n",
"14: 41 (41, 50, 338)\n",
" 5: 167 (167, 50, 338)\n",
" 8: 89 (89, 50, 338)\n",
"14: 41 (41, 50, 338)\n",
" 9: 71 (71, 50, 338)\n",
"10: 77 (77, 50, 338)\n",
" 8: 77 (77, 50, 338)\n",
"16: 77 (77, 50, 338)\n",
"16: 77 (77, 50, 338)\n",
" 2: 59 (59, 50, 338)\n",
" 9: 77 (77, 50, 338)\n",
"15: 77 (77, 50, 338)\n",
" 5: 101 (101, 50, 338)\n",
"16: 71 (71, 50, 338)\n",
"15: 71 (71, 50, 338)\n",
"12: 95 (95, 50, 338)\n",
" 6: 71 (71, 50, 338)\n",
" 2: 53 (53, 50, 338)\n",
"12: 845 (845, 50, 338)\n",
" 7: 65 (65, 50, 338)\n",
" 2: 65 (65, 50, 338)\n",
"13: 95 (95, 50, 338)\n",
" 5: 125 (125, 50, 338)\n",
"11: 65 (65, 50, 338)\n",
" 7: 59 (59, 50, 338)\n",
"10: 77 (77, 50, 338)\n",
" 6: 59 (59, 50, 338)\n",
" 7: 53 (53, 50, 338)\n",
" 1: 101 (101, 50, 338)\n",
"13: 71 (71, 50, 338)\n",
"11: 59 (59, 50, 338)\n",
" 4: 77 (77, 50, 338)\n",
" 9: 29 (29, 50, 338)\n",
" 4: 107 (107, 50, 338)\n"
]
}
],
"source": [
"for e in test:\n",
" print(f\"{lb.inverse_transform(e['label'])[0]:2d}: {len(e['data']):3d} {e['data'].shape}\")"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "75af2444",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"sequential\"\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"flatten (Flatten) (None, 16900) 0 \n",
"_________________________________________________________________\n",
"batch_normalization (BatchNo (None, 16900) 67600 \n",
"_________________________________________________________________\n",
"dropout (Dropout) (None, 16900) 0 \n",
"_________________________________________________________________\n",
"dense (Dense) (None, 5633) 95203333 \n",
"_________________________________________________________________\n",
"dropout_1 (Dropout) (None, 5633) 0 \n",
"_________________________________________________________________\n",
"dense_1 (Dense) (None, 1877) 10575018 \n",
"_________________________________________________________________\n",
"dropout_2 (Dropout) (None, 1877) 0 \n",
"_________________________________________________________________\n",
"dense_2 (Dense) (None, 625) 1173750 \n",
"_________________________________________________________________\n",
"dropout_3 (Dropout) (None, 625) 0 \n",
"_________________________________________________________________\n",
"dense_3 (Dense) (None, 208) 130208 \n",
"_________________________________________________________________\n",
"dropout_4 (Dropout) (None, 208) 0 \n",
"_________________________________________________________________\n",
"dense_4 (Dense) (None, 69) 14421 \n",
"_________________________________________________________________\n",
"dropout_5 (Dropout) (None, 69) 0 \n",
"_________________________________________________________________\n",
"dense_5 (Dense) (None, 16) 1120 \n",
"=================================================================\n",
"Total params: 107,165,450\n",
"Trainable params: 107,131,650\n",
"Non-trainable params: 33,800\n",
"_________________________________________________________________\n",
"CPU times: user 32.2 s, sys: 9.61 s, total: 41.8 s\n",
"Wall time: 18 s\n"
]
}
],
"source": [
"%%time\n",
"model, history = train_model(np.array(X_train), np.array(yy_train))"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "1a63ecda",
"metadata": {},
"outputs": [],
"source": [
"def predict(model, entry):\n",
" p_dict = dict()\n",
" predictions = model.predict_classes(entry['data'])\n",
" \n",
" for p in predictions:\n",
" if p in p_dict:\n",
" p_dict[p] += 1\n",
" else:\n",
" p_dict[p] = 1\n",
" prediction = max(p_dict, key=p_dict.get)\n",
" return prediction\n"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "aae03bc6",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/jupyterhub/lib/python3.8/site-packages/tensorflow/python/keras/engine/sequential.py:455: UserWarning: `model.predict_classes()` is deprecated and will be removed after 2021-01-01. Please use instead:* `np.argmax(model.predict(x), axis=-1)`, if your model does multi-class classification (e.g. if it uses a `softmax` last-layer activation).* `(model.predict(x) > 0.5).astype(\"int32\")`, if your model does binary classification (e.g. if it uses a `sigmoid` last-layer activation).\n",
" warnings.warn('`model.predict_classes()` is deprecated and '\n"
]
}
],
"source": [
"ltest = [lb.inverse_transform(e['label'])[0] for e in test]\n",
"ptest = [predict(model, e) for e in test]\n",
"\n",
"# for e in test:\n",
"# print(f\"Label: {lb.inverse_transform(e['label'])[0]:2d}\")\n",
"# print(f\"Prediction: {predict(model, e):2d}\\n_______________\")"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "888494f1",
"metadata": {},
"outputs": [],
"source": [
"ltrain = [lb.inverse_transform(e['label'])[0] for e in train]\n",
"ptrain = [predict(model, e) for e in train]\n",
"# for e in train:\n",
"# print(f\"Label: {lb.inverse_transform(e['label'])[0]:2d}\")\n",
"# print(f\"Prediction: {predict(model, e):2d}\\n_______________\")"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "03dfed1a",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjEAAAGtCAYAAADnIyVRAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAABmO0lEQVR4nO3df3xU933n+9dHo/DDBIeR8QxFoViLYddsWVsgR2FtvKnAxik4cQ3cOgFq1wGt2cTtrhcolyLxowu3KdTdcuu6ISW08db42i32GuStU4Rj7BUlyAhNItJAUDGEiLEEo9YxWf2Y+d4/NEwkQCCJmaM5o/fz8TgP5vyY8z7fOWc0X77f88Occ4iIiIj4Td5gb4CIiIjIQKgSIyIiIr6kSoyIiIj4kioxIiIi4kuqxIiIiIgvqRIjIiIivqRKjIiIiGScmX3LzD40sx/0Mt/MbLuZ/djMImY2/UbrVCVGREREvPCXwMPXmf95YHJyKAdeuNEKVYkRERGRjHPOHQQuXmeRLwLfdl3+ARhjZr90vXXmp3MD00y3EhYRkaHGPA0zS+dv7X+kqwXlsh3OuR39eH8hcLbb+E+S05p6e0M2V2JERETEJ5IVlv5UWm6aKjEiIiJDlJmnDT83cg6Y0G3808lpvdI5MSIiIkOUmaVtSIM3gN9MXqX0WeCfnXO9diWBWmJERETEA2a2G/gcMNbMfgKsBz4B4Jz7c+BN4NeAHwOXgN+64Tqdy9rzZ7N2w0RERDLE0/6d/Pz8tP3WdnZ2et43pZYYERGRISovz99nlfh767s5ePAgc+fO5cEHH2THjsydHJ1rOV5mKSf7s5ST/VnKyf4sL8s05DnnsnXos87OTjd79mx35swZ19bW5h555BF38uTJ/qxiSOZ4maWc7M9STvZnKSf7s9KQ4+lv7bBhw1y6Bq+33TmXGy0xkUiEiRMnMmHCBIYNG8a8efOorq5WThZlKSf7s5ST/VnKyf4sL8uUDll2dVK/ZawSY2b/xsx+N/kwp+3J13dlIisajTJu3LjUeDgcJhqNKieLspST/VnKyf4s5WR/lpdlkgxVYszsd4GX6TrL+nvJwYDdZrbmOu8rN7NaM6tVP6KIiEhm+b0lJlNXJ30F+LfOuY7uE83sOaAB+INrvcn1vGVxny/7CofDnD9/PjUejUYJh8P93eYhl+NllnKyP0s52Z+lnOzP8rJM6ZBld+ztt0x1JyWA8deY/kvJeWk1bdo0Tp8+zdmzZ2lvb6eqqoqysrJ0x+RcjpdZysn+LOVkf5Zysj/LyzJJ5lpi/jNQbWYn+cUTKX8ZuBP4WrrD8vPzqaysZNmyZcTjcRYsWMDkyZPTHZNzOV5mKSf7s5ST/VnKyf4sL8uUDn5vicnYHXvNLA/4DF2P0Yauhzgdcc7F+7gK3bFXRESGGk9rFaNHj07bb+1HH32UO3fsdc4lgH/I1PpFRERkaNNjB0RERIYov3cnqRIjIiIyRPm9EpMTd+wVERGRoUctMSIiIkOU31tiVIkREREZovxeiVF3koiIiPiSWmJkwA4dOuRJzsyZMz3JkZuj4yH7aR/JlfzeEqNKjIiIyBCVl+fvDhl/b72IiIgMWWqJERERGaLUnSQiIiK+5PdKjLqTRERExJfUEiMiIjJE+b0lJmcqMQcPHmTz5s0kEgkWLVpEeXm5cjzOikQivPTSSyQSCR544AHmz5/fY/6BAwc4cOAAZsaIESN48sknKSwspLGxkV27dqWWe/TRR5kxY8aglydbcrzM0rEwcNpHg1uebMryskw3y++VGJxz2Tr0WWdnp5s9e7Y7c+aMa2trc4888og7efJkf1YxJHNuNqumpiY1vPfee+6+++5zr7/+ujt48KCbPXu2e/XVV3ssU11dnXr9/PPPu8cee8zV1NS4t99+27377ruupqbGvfnmm27GjBmp8ZqaGs/Kk405XmbdbI4Xx0J/jgfto6tpH2X/PnIe/9aOGzfOpWvwetudc7lxTkwkEmHixIlMmDCBYcOGMW/ePKqrq5XjYVZjYyPhcJhQKER+fj6lpaXU1dX1WGbkyJGp121tban/AQwfPpxAIABAR0fHTf3PQPto8HOG2rHgZZb20cD5bR9J33jenWRmv+Wc23XjJfsuGo0ybty41Hg4HCYSiaQzIidz0pkVi8UoKChIjQeDQRobG69abv/+/bz11lvE43FWr16dmn7q1Cl27tzJhQsXKC8vT/2R7C/to8HPGWrHgpdZ2kcD57d95BXd7K7/NvY2w8zKzazWzGp37Njh5TaJR+bMmcPWrVtZtGgRe/fuTU2fNGkSW7ZsYf369ezbt4/29vZB3Erxgo6F7Kd9lPvMLG3DYMhIJcbMIr0M3wfCvb3PObfDOVfinCvpz4lQ4XCY8+fPp8aj0SjhcK8xA5ZrOenMCgaDXLx4MTUei8UIBoO9Ll9aWsrRo0evmj5+/HhGjBjBuXPn+r0NoH2UDTlD7VjwMkv7aOD8to+kbzLVEhMGfhN45BrDhXSHTZs2jdOnT3P27Fna29upqqqirKws3TE5l5POrKKiIqLRKM3NzXR2dnL48GGKi4t7LNP9i11fX5/6Yjc3NxOPxwFoaWmhqamJsWPHDmp5siXHyywdCwOnfaR95HVOuvi9JSZT58TsAz7pnDt25Qwz+266w/Lz86msrGTZsmXE43EWLFjA5MmT0x2TcznpzAoEAixZsoRt27aRSCSYNWsWhYWF7Nmzh6KiIoqLi6murqahoYFAIMCoUaNYvnw5ACdOnKCqqopAIEBeXh5Lly5l9OjRg1qebMnxMkvHwsBpH2kfeZ2TLn6/xNqcc4O9Db3J2g2TLocOHfIkZ+bMmZ7kyM3R8ZD9tI98wdNaxS//8i+n7bf2zJkznteIcuZmdyIiItI/fm+JUSVGRERkiPJ7JcbfF4iLiIjIkKWWGBERkSHK7ze7UyVGRERkiFJ3koiIiMggUEuMSA7z6pJa0GW1fuDVPtKl3P7h95YYVWJERESGKL9XYtSdJCIiIr6klhgREZEhyu8tMarEiIiIDFF+v8Ta31svIiIiQ5ZaYkRERIYodSdliYMHD7J582YSiQSLFi2ivLxcOR5nRSIRXnrpJRKJBA888ADz58/vMf/AgQMcOHAAM2PEiBE8+eSTFBYW0tjYyK5du1LLPfroo8yYMWPQy5MtOenM0j7K/n2Uizk67jJ33N0sv1dicM5l69BnnZ2dbvbs2e7MmTOura3NPfLII+7kyZP9WcWQzLnZrJqamtTw3nvvufvuu8+9/vrr7uDBg2727Nnu1Vdf7bFMdXV16vXzzz/vHnvsMVdTU+Pefvtt9+6777qamhr35ptvuhkzZqTGa2pqPCtPNubcbJZX+6g/+0n7KPdz9LfhpnI8/a296667XLoGr7fdOZcb58REIhEmTpzIhAkTGDZsGPPmzaO6ulo5HmY1NjYSDocJhULk5+dTWlpKXV1dj2VGjhyZet3W1pb6H8Dw4cMJBAIAdHR03NT/DLSPeqd9lP37KBdzdNxl7rhLh7y8vLQNgyFj3Ulm9m+AQuCwc+5n3aY/7Jz7u3RmRaNRxo0blxoPh8NEIpF0RuRkTjqzYrEYBQUFqfFgMEhjY+NVy+3fv5+33nqLeDzO6tWrU9NPnTrFzp07uXDhAuXl5ak/XP2lfdQ77aPs30e5mKPjLnPHXTr4vTspI1UnM/tt4H8CzwA/MLMvdpu95TrvKzezWjOr3bFjRyY2TQbZnDlz2Lp1K4sWLWLv3r2p6ZMmTWLLli2sX7+effv20d7ePohbObRpH8lg0HEnA5Gp9p/lwAzn3KPA54AKM/ud5Lxeq33OuR3OuRLnXEl/ToQKh8OcP38+NR6NRgmHwwPZ7iGVk86sYDDIxYsXU+OxWIxgMNjr8qWlpRw9evSq6ePHj2fEiBGcO3eu39sA2kfXo32U/fsoF3N03GXuuEsHv3cnZSo173IXknPuNF0Vmc+b2XNcpxIzUNOmTeP06dOcPXuW9vZ2qqqqKCsrS3dMzuWkM6uoqIhoNEpzczOdnZ0cPnyY4uLiHst0/2LX19envtjNzc3E43EAWlp
"text/plain": [
"<Figure size 720x504 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"from sklearn.metrics import confusion_matrix\n",
"import seaborn as sn\n",
"\n",
"set_digits = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}\n",
"\n",
"train_cm = confusion_matrix(ltrain, ptrain, normalize='true')\n",
"test_cm = confusion_matrix(ltest, ptest, normalize='true')\n",
"\n",
"df_cm = pd.DataFrame(test_cm, index=set_digits, columns=set_digits)\n",
"plt.figure(figsize = (10,7))\n",
"sn_plot = sn.heatmap(df_cm, annot=True, cmap=\"Greys\")\n",
"plt.ylabel(\"True Label\")\n",
"plt.xlabel(\"Predicted Label\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9ad253a7",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2021-07-14 10:15:52 +02:00
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}