iui-group-l-name-zensiert/2-second-project/tdt/DataViz.ipynb

1221 lines
54 KiB
Plaintext
Raw Normal View History

{
"cells": [
{
"cell_type": "markdown",
"id": "2bc4ab88",
"metadata": {},
"source": [
"## Constants"
]
},
{
"cell_type": "code",
2021-07-14 10:15:52 +02:00
"execution_count": 1,
"id": "c767cb34",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
2021-07-14 10:15:52 +02:00
"\n",
"os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' # this is required\n",
2021-07-21 03:01:19 +02:00
"os.environ['CUDA_VISIBLE_DEVICES'] = '0' # set to '0' for GPU0, '1' for GPU1 or '2' for GPU2. Check \"gpustat\" in a terminal."
2021-07-14 10:15:52 +02:00
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "f783fc7f",
2021-07-14 10:15:52 +02:00
"metadata": {},
"outputs": [],
"source": [
"glob_path = '/opt/iui-datarelease3-sose2021/*.csv'\n",
"\n",
"pickle_file = '../data.pickle'\n",
"\n",
"checkpoint_path = \"training_1/cp.ckpt\"\n",
"checkpoint_dir = os.path.dirname(checkpoint_path)"
]
},
{
"cell_type": "markdown",
"id": "bb1c9c9b",
"metadata": {},
"source": [
"# Config"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "3d812543",
"metadata": {},
"outputs": [],
"source": [
2021-07-19 02:20:02 +02:00
"# Possibilities: 'SYY', 'SYN', 'SNY', 'SNN', \n",
"# 'JYY', 'JYN', 'JNY', 'JNN'\n",
2021-07-21 03:01:19 +02:00
"cenario = 'SYN'\n",
"\n",
"win_sz = 30\n",
"stride_sz = 2\n",
"\n",
"# divisor for neuron count step downs (hard to describe), e.g. dense_step = 3: layer1=900, layer2 = 300, layer3 = 100, layer4 = 33...\n",
"dense_steps = 3\n",
"# amount of dense/dropout layers\n",
2021-07-21 03:01:19 +02:00
"layer_count = 5\n",
"# how much to drop\n",
2021-07-19 02:20:02 +02:00
"drop_count = 0.2"
]
},
{
"cell_type": "markdown",
"id": "8cef4021",
"metadata": {},
"source": [
"# Helper Functions"
2021-07-14 10:15:52 +02:00
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "cde65835",
"metadata": {},
"outputs": [],
"source": [
"from matplotlib import pyplot as plt\n",
"\n",
"def pplot(dd):\n",
" x = dd.shape[0]\n",
" fix = int(x/3)+1\n",
" fiy = 3\n",
" fig, axs = plt.subplots(fix, fiy, figsize=(3*fiy, 9*fix))\n",
" \n",
" for i in range(x):\n",
" axs[int(i/3)][i%3].plot(dd[i])"
]
},
{
"cell_type": "markdown",
"id": "476851ec",
"metadata": {},
"source": [
"# Loading Data"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "199e4435",
2021-07-14 10:15:52 +02:00
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from glob import glob\n",
"import pandas as pd\n",
2021-07-14 10:15:52 +02:00
"from tqdm import tqdm\n",
"\n",
2021-07-14 10:15:52 +02:00
"def dl_from_blob(filename, user_filter=None):\n",
" \n",
" dic_data = []\n",
" \n",
2021-07-14 10:15:52 +02:00
" for p in tqdm(glob(glob_path)):\n",
" path = p\n",
2021-07-14 10:15:52 +02:00
" filename = path.split('/')[-1].split('.')[0]\n",
" splitname = filename.split('_')\n",
" user = int(splitname[0][1:])\n",
" if (user_filter):\n",
" if (user != user_filter):\n",
" continue\n",
2021-07-14 10:15:52 +02:00
" scenario = splitname[1][len('Scenario'):]\n",
" heightnorm = splitname[2][len('HeightNormalization'):] == 'True'\n",
" armnorm = splitname[3][len('ArmNormalization'):] == 'True'\n",
" rep = int(splitname[4][len('Repetition'):])\n",
" session = int(splitname[5][len('Session'):])\n",
" data = pd.read_csv(path)\n",
" dic_data.append(\n",
" {\n",
" 'filename': path,\n",
" 'user': user,\n",
" 'scenario': scenario,\n",
" 'heightnorm': heightnorm,\n",
" 'armnorm': armnorm,\n",
" 'rep': rep,\n",
2021-07-14 10:15:52 +02:00
" 'session': session,\n",
" 'data': data \n",
" }\n",
" )\n",
2021-07-14 10:15:52 +02:00
" return dic_data"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "9e2817c1",
2021-07-14 10:15:52 +02:00
"metadata": {},
"outputs": [],
"source": [
"import pickle\n",
"\n",
2021-07-14 10:15:52 +02:00
"def save_pickle(f, structure):\n",
" _p = open(f, 'wb')\n",
" pickle.dump(structure, _p)\n",
" _p.close()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "12c5098e",
2021-07-14 10:15:52 +02:00
"metadata": {},
"outputs": [],
"source": [
"def load_pickles(f) -> list:\n",
" _p = open(pickle_file, 'rb')\n",
" _d = pickle.load(_p)\n",
" _p.close()\n",
" \n",
" return _d"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "00ee7490",
"metadata": {},
2021-07-14 10:15:52 +02:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading data...\n",
"../data.pickle found...\n",
"768\n",
"CPU times: user 548 ms, sys: 2.56 s, total: 3.11 s\n",
"Wall time: 3.11 s\n"
2021-07-14 10:15:52 +02:00
]
}
],
"source": [
"%%time\n",
"\n",
"def load_data() -> list:\n",
" if os.path.isfile(pickle_file):\n",
" print(f'{pickle_file} found...')\n",
" return load_pickles(pickle_file)\n",
" print(f'Didn\\'t find {pickle_file}...')\n",
" all_data = dl_from_blob(glob_path)\n",
" print(f'Creating {pickle_file}...')\n",
" save_pickle(pickle_file, all_data)\n",
" return all_data\n",
"\n",
"print(\"Loading data...\")\n",
"dic_data = load_data()\n",
"print(len(dic_data))"
2021-07-14 10:15:52 +02:00
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "d1db1537",
2021-07-14 10:15:52 +02:00
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 95 µs, sys: 297 µs, total: 392 µs\n",
"Wall time: 396 µs\n"
]
}
],
2021-07-14 10:15:52 +02:00
"source": [
"%%time\n",
"\n",
"# Categorized Data\n",
"cdata = dict() \n",
"# Sorting, HeightNorm, ArmNorm\n",
"cdata['SYY'] = list() \n",
"cdata['SYN'] = list() \n",
"cdata['SNY'] = list() \n",
"cdata['SNN'] = list() \n",
"\n",
"# Jenga, HeightNorm, ArmNorm\n",
"cdata['JYY'] = list() \n",
"cdata['JYN'] = list() \n",
"cdata['JNY'] = list() \n",
"cdata['JNN'] = list() \n",
"\n",
2021-07-14 10:15:52 +02:00
"for d in dic_data:\n",
" if d['scenario'] == 'Sorting':\n",
" if d['heightnorm']:\n",
" if d['armnorm']:\n",
" cdata['SYY'].append(d)\n",
" else:\n",
" cdata['SYN'].append(d)\n",
" else:\n",
" if d['armnorm']:\n",
" cdata['SNY'].append(d)\n",
" else:\n",
" cdata['SNN'].append(d)\n",
" elif d['scenario'] == 'Jenga':\n",
" if d['heightnorm']:\n",
" if d['armnorm']:\n",
" cdata['JYY'].append(d)\n",
" else:\n",
" cdata['JYN'].append(d)\n",
" else:\n",
" if d['armnorm']:\n",
" cdata['JNY'].append(d)\n",
" else:\n",
" cdata['JNN'].append(d)"
]
},
{
"cell_type": "markdown",
"id": "46382aad",
"metadata": {},
"source": [
"# Preprocessing"
2021-07-14 10:15:52 +02:00
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "f7842338",
2021-07-14 10:15:52 +02:00
"metadata": {
"tags": []
},
"outputs": [],
2021-07-14 10:15:52 +02:00
"source": [
"def drop(entry, data=True) -> pd.DataFrame:\n",
" droptable = ['participantID', 'FrameID', 'Scenario', 'HeightNormalization', 'ArmNormalization', 'Repetition', 'Session', 'Unnamed: 0']\n",
" if data:\n",
" centry = pickle.loads(pickle.dumps(entry['data']))\n",
" else:\n",
" centry = pickle.loads(pickle.dumps(entry))\n",
"\n",
" return centry.drop(droptable, axis=1)\n",
" \n"
2021-07-14 10:15:52 +02:00
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "b73d9485",
2021-07-14 10:15:52 +02:00
"metadata": {},
"outputs": [],
2021-07-14 10:15:52 +02:00
"source": [
"import numpy as np\n",
"right_Hand_ident='right_Hand'\n",
"left_Hand_ident='left_hand'\n",
2021-07-14 10:15:52 +02:00
"\n",
"def rem_low_acc(entry, data=True) -> pd.DataFrame:\n",
" if data:\n",
" centry = pickle.loads(pickle.dumps(entry['data']))\n",
" else:\n",
" centry = pickle.loads(pickle.dumps(entry))\n",
" \n",
" centry['LeftHandTrackingAccuracy'] = (centry['LeftHandTrackingAccuracy'] == 'High') * 1.0\n",
" centry['RightHandTrackingAccuracy'] = (centry['RightHandTrackingAccuracy'] == 'High') * 1.0\n",
" \n",
" right_Hand_cols = [c for c in centry if right_Hand_ident in c]\n",
" left_Hand_cols = [c for c in centry if left_Hand_ident in c]\n",
" \n",
" centry.loc[centry['RightHandTrackingAccuracy'] == 0.0, right_Hand_cols] = np.nan\n",
" centry.loc[centry['LeftHandTrackingAccuracy'] == 0.0, left_Hand_cols] = np.nan\n",
" \n",
" return centry\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "1a298d6d",
"metadata": {},
"outputs": [],
"source": [
"from tensorflow.keras.preprocessing.sequence import pad_sequences\n",
"\n",
"def pad(entry, data=True) -> pd.DataFrame:\n",
" if data:\n",
" centry = pickle.loads(pickle.dumps(entry['data']))\n",
" else:\n",
" centry = pickle.loads(pickle.dumps(entry))\n",
" \n",
" cols = centry.columns\n",
" pentry = pad_sequences(centry.T.to_numpy(),\n",
" maxlen=(int(centry.shape[0]/stride_sz)+1)*stride_sz,\n",
" dtype='float64',\n",
" padding='pre', \n",
" truncating='post',\n",
" value=np.nan\n",
" ) \n",
" pdentry = pd.DataFrame(pentry.T, columns=cols)\n",
" pdentry.loc[0] = [0 for _ in cols]\n",
" return pdentry"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "3be1bd3f",
"metadata": {},
"outputs": [],
"source": [
"def interpol(entry, data=True) -> pd.DataFrame:\n",
" if data:\n",
" centry = pickle.loads(pickle.dumps(entry['data']))\n",
" else:\n",
" centry = pickle.loads(pickle.dumps(entry))\n",
" \n",
" return centry.interpolate(method='linear', axis=0)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "2a7f4e26",
"metadata": {},
"outputs": [],
"source": [
"from tensorflow.keras.preprocessing import timeseries_dataset_from_array\n",
2021-07-14 10:15:52 +02:00
"\n",
"def slicing(entry, label, data=True):\n",
" if data:\n",
" centry = pickle.loads(pickle.dumps(entry['data']))\n",
" else:\n",
" centry = pickle.loads(pickle.dumps(entry))\n",
" \n",
" return timeseries_dataset_from_array(\n",
" data=centry, \n",
" targets=[label for _ in range(centry.shape[0])], \n",
" sequence_length=win_sz,\n",
" sequence_stride=stride_sz, \n",
" batch_size=8, \n",
" seed=177013\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "b012b0f7",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# %%time \n",
"\n",
"# acc_data = pd.DataFrame()\n",
"\n",
"# for d in tqdm(cdata['SYY']):\n",
"# acc_data = acc_data.append(d['data'])\n",
"\n",
"\n",
"# dacc_data = drop(acc_data, False)\n",
"# ddacc_data = rem_low_acc(dacc_data, False)\n",
"\n",
"# for c in ddacc_data:\n",
"# print(f\"{c}: {dacc_data[c].min()}, {dacc_data[c].max()}\")"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "a2440d77",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 96/96 [00:09<00:00, 10.55it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 8.07 s, sys: 1.19 s, total: 9.27 s\n",
"Wall time: 9.1 s\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"%%time\n",
"\n",
"classes = 16 # dynamic\n",
"\n",
"def preproc(data):\n",
" res_list = list()\n",
" \n",
" for e in tqdm(data):\n",
" res_list.append(preproc_entry(e))\n",
" \n",
" return res_list\n",
" \n",
"def preproc_entry(entry, data = True):\n",
" entry2 = pickle.loads(pickle.dumps(entry))\n",
" entry2['data'] = drop(entry2, data)\n",
" \n",
" entry4 = pickle.loads(pickle.dumps(entry2))\n",
" entry4['data'] = rem_low_acc(entry4, data)\n",
" \n",
" entry5 = pickle.loads(pickle.dumps(entry4))\n",
" entry5['data'] = pad(entry5, data)\n",
" \n",
"# entry6 = pickle.loads(pickle.dumps(entry5))\n",
"# entry6['data'] = interpol(entry6, data)\n",
" \n",
" entry7 = pickle.loads(pickle.dumps(entry5))\n",
" entry7['data'] = slicing(entry7, entry7['user'], data)\n",
" \n",
" return entry7\n",
2021-07-14 10:15:52 +02:00
"\n",
"pdata = preproc(cdata[cenario])"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "11e96fef",
2021-07-14 10:15:52 +02:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 96 µs, sys: 107 µs, total: 203 µs\n",
"Wall time: 214 µs\n"
]
},
{
"data": {
"text/plain": [
"(48, 48)"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"train = np.array([x['data'] for x in pdata if x['session'] == 1])\n",
"test = np.array([x['data'] for x in pdata if x['session'] == 2])\n",
"\n",
"len(train), len(test)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "1807a2f7",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 1min, sys: 13.2 s, total: 1min 13s\n",
"Wall time: 21.1 s\n"
]
}
],
"source": [
"%%time\n",
"\n",
"X_train = list()\n",
"y_train = list()\n",
2021-07-14 10:15:52 +02:00
"\n",
"X_test = list()\n",
"y_test = list()\n",
"\n",
"# train = list()\n",
"test = list()\n",
"\n",
"for x in pdata:\n",
" if x['session'] == 1:\n",
"# train.append(\n",
"# {\n",
"# 'label': x['user'],\n",
"# 'data': list()\n",
"# })\n",
" for y in x['data'].unbatch().as_numpy_iterator():\n",
" X_train.append(y[0])\n",
" y_train.append(y[1])\n",
" \n",
"# train[-1]['data'].append(y[0])\n",
" if x['session'] == 2:\n",
" test.append(\n",
" {\n",
" 'label': x['user'],\n",
" 'data': list()\n",
" })\n",
" for y in x['data'].unbatch().as_numpy_iterator():\n",
" X_test.append(y[0])\n",
" y_test.append(y[1])\n",
" \n",
"# test[-1]['data'].append(y[0])\n",
"\n",
"X_train = np.array(X_train)\n",
"y_train = np.array(y_train)\n",
"X_test = np.array(X_test)\n",
"y_test = np.array(y_test)"
2021-07-14 10:15:52 +02:00
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "44330b34",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(37902, 30, 338) (73692, 30, 338) (37902,) (73692,)\n",
"(32745, 30, 338) (48773, 30, 338) (32745,) (48773,)\n"
]
}
],
"source": [
"XX_train = list()\n",
"yy_train = list()\n",
"XX_test = list()\n",
"yy_test = list()\n",
"\n",
"for X,y in zip(X_train, y_train):\n",
" if not np.isnan(X).any():\n",
" XX_train.append(X)\n",
" yy_train.append(y)\n",
"\n",
"for X,y in zip(X_test, y_test):\n",
" if not np.isnan(X).any():\n",
" XX_test.append(X)\n",
" yy_test.append(y)\n",
" \n",
"XX_train = np.array(XX_train)\n",
"yy_train = np.array(yy_train)\n",
"XX_test = np.array(XX_test)\n",
"yy_test = np.array(yy_test)\n",
"\n",
"print(np.array(XX_train).shape, X_train.shape, np.array(yy_train).shape, np.array(y_train).shape)\n",
"print(np.array(XX_test).shape, X_test.shape, np.array(yy_test).shape, np.array(y_test).shape)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "bd805e81",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>0</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>11</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>11</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>11</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>11</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>11</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>37897</th>\n",
" <td>9</td>\n",
" </tr>\n",
" <tr>\n",
" <th>37898</th>\n",
" <td>9</td>\n",
" </tr>\n",
" <tr>\n",
" <th>37899</th>\n",
" <td>9</td>\n",
" </tr>\n",
" <tr>\n",
" <th>37900</th>\n",
" <td>9</td>\n",
" </tr>\n",
" <tr>\n",
" <th>37901</th>\n",
" <td>9</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>37902 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" 0\n",
"0 11\n",
"1 11\n",
"2 11\n",
"3 11\n",
"4 11\n",
"... ..\n",
"37897 9\n",
"37898 9\n",
"37899 9\n",
"37900 9\n",
"37901 9\n",
"\n",
"[37902 rows x 1 columns]"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a = pd.DataFrame(yy_train)\n",
"a"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "f6416fc1",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"16 0.113055\n",
"11 0.110680\n",
"10 0.103662\n",
"3 0.102739\n",
"5 0.095404\n",
"13 0.069469\n",
"15 0.061738\n",
"7 0.048625\n",
"9 0.044826\n",
"6 0.043955\n",
"14 0.043375\n",
"1 0.038890\n",
"4 0.036489\n",
"8 0.035433\n",
"12 0.032030\n",
"2 0.019630\n",
"dtype: float64"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"b = a.value_counts(normalize=True)\n",
"b"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "1885329b",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>0</th>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>16</th>\n",
" <td>0.113055</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11</th>\n",
" <td>0.110680</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10</th>\n",
" <td>0.103662</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>0.102739</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>0.095404</td>\n",
" </tr>\n",
" <tr>\n",
" <th>13</th>\n",
" <td>0.069469</td>\n",
" </tr>\n",
" <tr>\n",
" <th>15</th>\n",
" <td>0.061738</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>0.048625</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>0.044826</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>0.043955</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14</th>\n",
" <td>0.043375</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0.038890</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0.036489</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>0.035433</td>\n",
" </tr>\n",
" <tr>\n",
" <th>12</th>\n",
" <td>0.032030</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>0.019630</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" 0\n",
"0 \n",
"16 0.113055\n",
"11 0.110680\n",
"10 0.103662\n",
"3 0.102739\n",
"5 0.095404\n",
"13 0.069469\n",
"15 0.061738\n",
"7 0.048625\n",
"9 0.044826\n",
"6 0.043955\n",
"14 0.043375\n",
"1 0.038890\n",
"4 0.036489\n",
"8 0.035433\n",
"12 0.032030\n",
"2 0.019630"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"c = pd.DataFrame(b)\n",
"c"
]
},
{
"cell_type": "code",
2021-07-21 03:01:19 +02:00
"execution_count": 23,
"id": "7dfe2339",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<AxesSubplot:ylabel='0'>"
]
},
2021-07-21 03:01:19 +02:00
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAADnCAYAAAA+T+sCAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAABEvklEQVR4nO2deXxU5dXHf+fOkp0kZA8TiKwJENl3QgDXt2jV9lWrqLi1UnFDbEv1bTtvfa1pXVtLtbvaRa27gjtI2EFkCxBkE8gKCSEJWWfm3vP+cW8wJDOZe2fubMn9fj7zIbnz3OceYM69z5znnN8hZoaBgUH/Qgi1AQYGBsHHcHwDg36I4fgGBv0Qw/ENDPohhuMbGPRDDMc3MOiHGI5vYNAPMRzfwKAfYji+gUE/xHB8A4N+iOH4Bgb9EMPxDQz6IYbjGxj0QwzHNzDohxiOb2DQDzEc38CgH2I4fgggohgiKiEiExF9REQNRLSy25gLiGgrER0moteIyOpmniuI6JfBs9ygr2A4fmi4HcBbzCwCeALAzW7G/BrAM8w8HMAZAHe4GbMKwJVEFBswSw36JIbjh4aFAN4FAGZeDeBs1zeJiADMB/CGcuglAFd3n4Rl3bS1AK4InKkGfRHD8YOMsmQfyszHehmWAqCBmV3K7xUABnkYux1AoX4WGvQHDMcPPqkAGnSc7xSAbB3nM+gHGI4ffNoARHsZcxpAEhGZld9tACo9jI1W5jQwUI3h+EGGmc8AMBGRR+dXvrt/DuC/lUOLoMQEiOgaInq8y/CRAPYGyFyDPorh+KHhEwCzAYCI1gN4HcBFRFRBRJcpY34C4EEiOgz5O/9flePDADR1mWse5Oi+gYFqyGioEXyIaCKApczsbhvP27n/VM6tJaIMAP9m5ot0N9KgT2M4foggotsBvKTs5fs6xxQATmbepZthBv0Cw/ENDPohZu9DDCKNgpcKzACGQo4HZABIg7yN2PlKA5AM+f+/887PXX6WIGcL1kLeLuz650kARwEcL11UKgXhr6MaIooB8BHk5KdVAKYD2MDMV3QZcw+AByD/26Qxc52beQoALGPmW4NgdkgwnvgRTsFLBSMBjAdQoLzyITt9oG/q7QAOv1lRvWmk03kCwB4Au2FvPBHg63qEiJYAMDPzb4noIgCxAO7q5vgTIN/U1gKY7M7xlXGfAbidmUP29wkkhuNHEAUvFRCAMQCKlNccyE/00MDs3HGsHBbA0uVoJeStyDUAVgfzRkBEmwDc2JkVSURzATzU1fG7jD2G3h3/fgBRzPybQNkbSoylfphT8FJBBoDvALgYsqOnhtaib7AyTljkJXNXBgG4SXkB9sSjkG8CawB8CnujW0fzF5Wp0FrYDmA5AMPxDYJDwUsF6QC+C+A6yM4elvkWGaLrFHo6fneGKq87AbhgT/wMwL8BvA17Y7OO5hip0BowHD9MKHipIBnA9fjG2U2htcg7eQ6nQ+MpZgCXK69W2BPfh3wT+BD2Rqef5qhJhdZCn06FNhw/xBS8VDAWwH2QS3Ujqq5+YntHD3EQDcRCvtFdD6Ae9sRXADwLe+NhXyZj5jOKsEk0M7drPZ+IpgK4h5lvUQ716VRow/FDQMFLBQKAqwDcCznlNiKZ2N6eptNUAwEsAfBD2BPfBfAE7I2bfZinMxX6MyUVOg9APBFVALiDmT8movsA/BhAJoA9RPQBM98JYDDOf8L36VRoI6ofRApeKogFsBiyw+eG1ho/YXbtOFbO3SL6erIZwJMA3oG9UVW+gJ+p0E8A+Acz7yGiKAAlAGZ30UToUxiOHwQKXiqIBvBDyFHi9BCbowtWiY98ebzcW2BPDw4DKAbwIuyNXtObdUqFHgFgEDOv9XWOcMdw/ABSlpdvAnD7O9Ppyn/PM10Zanv0JMfp3PxBRfWMIF5yP4DlsDe+H8Rr9lnCcpuoL1CWl38F5Gy2P317K0+PdrCeW1chx4eIvr+MBvAe7IlrYU8cH+Rr9zkMx9eZsrx8W1le/rsA3of8YYXASPvBh9L20FqmL35G9P2hCMCXsCe+AHti2CQzRRqG4+tEWV6+UJaXvwTykvTb3d+ftZ8nJrRyffAtCwwT2jv0iuj7ggDgLgAHYU+8xdtgg54Yjq8DZXn5YwBsAPB7AAnuxhAw4L53pdKgGhYomF0jHI7BoTYDcoXhS7AnvgV7Yp8ImgYLw/H9oCwv31yWl28HsAOA10DXhcd4WmojVwfcsABjZRy3AqFa6rvjGgB7YU+8JtSGRAqG4/tIWV7+IMhVaL+ASicgIHrZW6JPmWnhRIboqg21DW5IA/AW7Ikvw56YGGpjwp2IdPwuveeGENEOItpFRPuIaLGH8a8qe7O6UJaXfymAnVAEM7UwtAYzbbX8tV62hIJRDqfmlNggcjPkp7/m/5v+REQ6PpTecwCqAcxg5vEApgFYTkTuKqqeh5ym6RdKAO9RAB9CfsJohgDTQ2+KNf7aEkomtXdEhdoGL9gArIY98c5QGxKuRKrjLwTwLjM7mLlDORYFz3+f9QAu7tKgQjNlefkZAD4D8D+9XEcVWWcwfVQ5l/kzRygJcURfLVYAf4Y98TnYE42alG5EnON3F1wgohwi2gOgHMCvmbmq+znMLEFO/RznyzXL8vLzAGyFTgU1BNCDb4steswVdMInoq+WewB8BHviwFAbEk5EnOOjm+ACM5cz84UAhgNYpGjNu8MnYYWyvPyZADYCGKLdVM8kt2DypEPSLj3nDAZW4ESYRfTVcBGAbbAnjgm1IeFCJDq+W8EF5Um/F547x2oWVijLy78K8vI+IE+Le96Xwl5sozvpLtepUNvgI8MAbIY9sSjUhoQDEef4XXvPEZFNkVQGESVDjrJ/pfz+siKu0IkmYYWyvPzFAN4EEKOb8d2I60DBvN3StkDNHwjywjui740EAB/AntjvOw9FnOMrdAou5APYSkS7IddPP8nMndlxFwKoAgBl+d/GzKqi6UpSzvMIgvzV7Z9IySTHICKCieEf0fdGLICVsCde5nVkHyZSHX8FgEXM/CkzX8jM45Q//wQARDQAwCFmrlDG3wjgj2omVpz+F4Ew2h1RLoy4ajP7ojYTEia2d/SFwphoAO/CnthDdru/EJGOz8w7AHxORG6fyMzcxMzXdjnUAOAlb/OW5eU/giA6fSfXrZcGm0T2V2wy8MgRfV2DnCEkCnKmX79M841IxwcAZv6bWpUVZv67Nwmlsrz8BwD8nx62acUsIWfh51LYP/UjNKLfGxYA/4E98TuhNiTYRKzj60lZXv4tAJ4OpQ3f2s55UQ4O6739CI7o94YZwL9hT/S0G9Qn6feOryjl/BUAhdIOgZH+/Y+kL0JpgzfyHM4O76MikijI3/nzQ21IsOjXjq/U0b+CMJEZL9zHE+LbuCHUdnhiYntHoBR1w4HkvVLub3KXrwpdL8Ig0m8dvywvPxnAuwDiQ21LJwQk3vuetDvUdnhiYmTk6PvEanFCyRWOxxYAeCd3+apI37L0Sr90fEX99jV47/sWdMYf5akpTeryDYJK34ron4MZrudcV6+/w/mjIoAIwHQAfwm1XYGmXzo+5A6ol4TaCHcQELP0bfFgqO3oTh+M6IMZjUuc9+95ynVd98DeTbnLVy0LiVFBot85flle/s0AHgy1Hb0xogozB9Xx8VDb0ZW+FtF3sVBxheOx2g+kaRM9DPlV7vJVnt6LePqV45fl5Y+Cygy+UEKA+aG3xMpQ29GVvhTRb+bo/TM7novaxxcM72WYFcC/c5eviqhGpmrpN46vfK9/GQEsutGT7NOYMbySvwq1HZ30lYj+CSlty5SO53NPIVlNoHIUgN8G2qZQ0G8cH8DDAKZ6HRUmEEDL3hKbQm1HJ30hor9eHFtS5HhmahuitDzF78xdvqrPZfb1C8cvy8ufCOBnobZDKynNmDL+iLQn1HZEekSfGeIfXQvW3ex8uIgh+PKZ/3Pu8lU23Q0LIX3e8cvy8qMgL/Ejcql633tSSDMKgciO6DPj7FLn3Tsfdy2c48c0AyF/hvoMfd7xATwKIGIll+LbUVBUGtpU3kiN6LtYqL7a8cvqd6T
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"c.plot.pie(y=0, legend=False)"
]
},
{
"cell_type": "code",
2021-07-21 03:01:19 +02:00
"execution_count": 24,
"id": "fd3f3f1e",
"metadata": {},
"outputs": [
{
"ename": "SyntaxError",
"evalue": "unmatched ']' (<ipython-input-24-481b8b7ed782>, line 1)",
"output_type": "error",
"traceback": [
"\u001b[0;36m File \u001b[0;32m\"<ipython-input-24-481b8b7ed782>\"\u001b[0;36m, line \u001b[0;32m1\u001b[0m\n\u001b[0;31m ]]]\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m unmatched ']'\n"
]
}
],
"source": [
"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "615cccc9",
"metadata": {},
"outputs": [],
"source": [
"%%time\n",
"\n",
"from sklearn.preprocessing import LabelBinarizer\n",
"\n",
"\n",
"lb = LabelBinarizer()\n",
"yyy_train = lb.fit_transform(yy_train)\n",
"yyy_test = lb.fit_transform(yy_test)\n",
"\n",
"for e in test:\n",
" e['label'] = lb.transform([e['label']])\n",
" e['data'] = np.array(e['data'])\n",
" \n",
"# for e in train:\n",
"# e['label'] = lb.transform([e['label']])\n",
"# e['data'] = np.array(e['data'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f56e5055",
"metadata": {},
"outputs": [],
"source": [
"print(XX_train.shape)\n",
"print(yyy_train.shape)\n",
"print(XX_test.shape)\n",
"print(yyy_test.shape)"
]
},
{
"cell_type": "markdown",
"id": "f046e211",
"metadata": {},
"source": [
"# Building Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "53832797",
"metadata": {},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"from tensorflow.keras.regularizers import l2\n",
"from tensorflow.keras.models import Sequential\n",
"from tensorflow.keras.layers import Dense, Flatten, BatchNormalization, Dropout\n",
"from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau\n",
"from tensorflow.keras.optimizers import Adam\n",
"\n",
"def build_model(shape, classes):\n",
" model = Sequential()\n",
" \n",
" ncount = shape[0]*shape[1]\n",
" \n",
" model.add(Flatten(input_shape=shape))\n",
" \n",
" model.add(Dropout(drop_count))\n",
" model.add(BatchNormalization())\n",
" \n",
" for i in range(1,layer_count):\n",
" neurons = int(ncount/pow(dense_steps,i))\n",
" if neurons <= classes*dense_steps:\n",
" break\n",
" model.add(Dropout(drop_count*i))\n",
" model.add(Dense(neurons, activation='relu', \n",
" kernel_regularizer=l2(0.001))\n",
" )\n",
" \n",
" model.add(Dense(classes, activation='softmax'))\n",
" \n",
" model.compile(\n",
" optimizer=Adam(),\n",
" loss=\"categorical_crossentropy\", \n",
" metrics=[\"acc\"],\n",
" )\n",
" \n",
" return model\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "64f2aaa7",
"metadata": {},
"outputs": [],
"source": [
"checkpoint_file = './goat.weights'\n",
"\n",
"def train_model(X_train, y_train, X_test, y_test):\n",
" model = build_model(X_train[0].shape, 16)\n",
" \n",
" model.summary()\n",
" \n",
" # Create a callback that saves the model's weights\n",
" model_checkpoint = ModelCheckpoint(filepath=checkpoint_path, monitor='loss', \n",
"\t\t\tsave_best_only=True)\n",
" \n",
" reduce_lr = ReduceLROnPlateau(monitor='loss', factor=0.5, patience=5, min_lr=0.0001)\n",
"\n",
" callbacks = [model_checkpoint, reduce_lr]\n",
" \n",
" history = model.fit(X_train, \n",
" y_train,\n",
" epochs=50,\n",
" batch_size=32,\n",
" verbose=2,\n",
" validation_data=(X_test, y_test),\n",
" callbacks=callbacks\n",
" )\n",
" \n",
" model.load_weights(checkpoint_path)\n",
" return model, history\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "858670a8",
"metadata": {},
"outputs": [],
"source": [
"%%time\n",
"\n",
"model, history = train_model(np.array(XX_train), np.array(yyy_train), np.array(XX_test), np.array(yyy_test))"
]
},
{
"cell_type": "markdown",
"id": "6e905067",
"metadata": {},
"source": [
"# Eval"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "196dae1d",
"metadata": {},
"outputs": [],
"source": [
"def predict(model, entry):\n",
" p_dict = dict()\n",
" predictions = np.argmax(model.predict(entry['data']), axis=-1)\n",
" for p in predictions:\n",
" if p in p_dict:\n",
" p_dict[p] += 1\n",
" else:\n",
" p_dict[p] = 1\n",
" prediction = max(p_dict, key=p_dict.get)\n",
" return prediction+1"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "49aebfaa",
"metadata": {},
"outputs": [],
"source": [
"%%time\n",
"\n",
"ltest = [lb.inverse_transform(e['label'])[0] for e in test]\n",
"ptest = [predict(model, e) for e in test]\n",
"# for e in test:\n",
"# print(f\"Label: {lb.inverse_transform(e['label'])[0]:2d}\")\n",
"# print(f\"Prediction: {predict(model, e):2d}\\n_______________\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6577cf5f",
"metadata": {},
"outputs": [],
"source": [
"%%time\n",
"\n",
"ltrain = [lb.inverse_transform(e['label'])[0] for e in train]\n",
"ptrain = [predict(model, e) for e in train]\n",
"\n",
"# for e in train:\n",
"# print(f\"Label: {lb.inverse_transform(e['label'])[0]:2d}\")\n",
"# print(f\"Prediction: {predict(model, e):2d}\\n_______________\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d62e2063",
"metadata": {},
"outputs": [],
"source": [
"%%time\n",
"\n",
"from sklearn.metrics import confusion_matrix\n",
"import seaborn as sn\n",
"\n",
"from sklearn.metrics import classification_report\n",
"\n",
"set_digits = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 }\n",
"\n",
"train_cm = confusion_matrix(ltrain, ptrain, normalize='true')\n",
"test_cm = confusion_matrix(ltest, ptest, normalize='true')\n",
"\n",
"df_cm = pd.DataFrame(test_cm, index=set_digits, columns=set_digits)\n",
"plt.figure(figsize = (10,7))\n",
"sn_plot = sn.heatmap(df_cm, annot=True, cmap=\"Greys\")\n",
"plt.ylabel(\"True Label\")\n",
"plt.xlabel(\"Predicted Label\")\n",
"plt.show()\n",
"\n",
"print(classification_report(ltest, ptest, zero_division=0))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5041115f",
"metadata": {},
"outputs": [],
"source": [
"print(f'cenario: {cenario}')\n",
"print(f'win_sz: {win_sz}')\n",
"print(f'stride_sz: {stride_sz}')\n",
"print(f'dense_steps: {dense_steps}')\n",
"print(f'layer_count: {layer_count}')\n",
"print(f'drop_count: {drop_count}')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d25d662c",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2021-07-14 10:15:52 +02:00
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}