{ "cells": [ { "cell_type": "markdown", "id": "9c890798", "metadata": {}, "source": [ "# Change Scenario here.\n", "\n", "| | GameType | HeightNorm | ArmNorm |\n", "|:---:|:--------:|:----------:|:-------:|\n", "| SYY | Sorting | ✅ | ✅ |\n", "| SYN | Sorting | ✅ | ❌ |\n", "| SNY | Sorting | ❌ | ✅ |\n", "| SNN | Sorting | ❌ | ❌ |\n", "| JYY | Jenga | ✅ | ✅ |\n", "| JYN | Jenga | ✅ | ❌ |\n", "| JNY | Jenga | ❌ | ✅ |\n", "| JNN | Jenga | ❌ | ❌ |\n", "\n", "Weights for the corresponding scenario are loaded automatically." ] }, { "cell_type": "code", "execution_count": 1, "id": "1c9e114c", "metadata": {}, "outputs": [], "source": [ "# Possibilities: 'SYY', 'SYN', 'SNY', 'SNN', \n", "# 'JYY', 'JYN', 'JNY', 'JNN'\n", "cenario = 'SYN'" ] }, { "cell_type": "markdown", "id": "a3c8b624", "metadata": {}, "source": [ "## Constants" ] }, { "cell_type": "code", "execution_count": 2, "id": "5f120a31", "metadata": {}, "outputs": [], "source": [ "import os\n", "\n", "os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' # this is required\n", "os.environ['CUDA_VISIBLE_DEVICES'] = '2' # set to '0' for GPU0, '1' for GPU1 or '2' for GPU2. Check \"gpustat\" in a terminal." ] }, { "cell_type": "code", "execution_count": 3, "id": "3be386b5", "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "\n", "glob_path = '/opt/iui-datarelease3-sose2021/*.csv'\n", "\n", "pickle_file = '../data.pickle'\n", "\n", "pd.set_option('display.float_format', lambda x: '%.2f' % x)" ] }, { "cell_type": "markdown", "id": "375756bc", "metadata": {}, "source": [ "# Config" ] }, { "cell_type": "code", "execution_count": 4, "id": "fe73e572", "metadata": {}, "outputs": [], "source": [ "create_new = False\n", "checkpoint_path = f\"training_{cenario}/cp.ckpt\"\n", "checkpoint_dir = os.path.dirname(checkpoint_path)\n", "\n", "win_sz = 5\n", "stride_sz = 1\n", "\n", "epoch = 50\n", "\n", "# divisor for neuron count step downs (hard to describe), e.g. dense_step = 3: layer1=900, layer2 = 300, layer3 = 100, layer4 = 33...\n", "dense_steps = 3\n", "# amount of dense/dropout layers\n", "layer_count = 3\n", "# how much to drop\n", "drop_count = 0.1" ] }, { "cell_type": "markdown", "id": "0173497c", "metadata": {}, "source": [ "# Helper Functions" ] }, { "cell_type": "code", "execution_count": 5, "id": "ef82a419", "metadata": {}, "outputs": [], "source": [ "from matplotlib import pyplot as plt\n", "\n", "def pplot(dd):\n", " x = dd.shape[0]\n", " fix = int(x/3)+1\n", " fiy = 3\n", " fig, axs = plt.subplots(fix, fiy, figsize=(3*fiy, 9*fix))\n", " \n", " for i in range(x):\n", " axs[int(i/3)][i%3].plot(dd[i])" ] }, { "cell_type": "markdown", "id": "556c7dde", "metadata": {}, "source": [ "# Loading Data" ] }, { "cell_type": "code", "execution_count": 6, "id": "51195751", "metadata": { "tags": [] }, "outputs": [], "source": [ "from glob import glob\n", "from tqdm import tqdm\n", "\n", "def dl_from_blob(filename, user_filter=None):\n", " \n", " dic_data = []\n", " \n", " for p in tqdm(glob(glob_path)):\n", " path = p\n", " filename = path.split('/')[-1].split('.')[0]\n", " splitname = filename.split('_')\n", " user = int(splitname[0][1:])\n", " if (user_filter):\n", " if (user != user_filter):\n", " continue\n", " scenario = splitname[1][len('Scenario'):]\n", " heightnorm = splitname[2][len('HeightNormalization'):] == 'True'\n", " armnorm = splitname[3][len('ArmNormalization'):] == 'True'\n", " rep = int(splitname[4][len('Repetition'):])\n", " session = int(splitname[5][len('Session'):])\n", " data = pd.read_csv(path)\n", " dic_data.append(\n", " {\n", " 'filename': path,\n", " 'user': user,\n", " 'scenario': scenario,\n", " 'heightnorm': heightnorm,\n", " 'armnorm': armnorm,\n", " 'rep': rep,\n", " 'session': session,\n", " 'data': data \n", " }\n", " )\n", " return dic_data" ] }, { "cell_type": "code", "execution_count": 7, "id": "457bc16f", "metadata": {}, "outputs": [], "source": [ "import pickle\n", "\n", "def save_pickle(f, structure):\n", " _p = open(f, 'wb')\n", " pickle.dump(structure, _p)\n", " _p.close()" ] }, { "cell_type": "code", "execution_count": 8, "id": "9482bc78", "metadata": {}, "outputs": [], "source": [ "def load_pickles(f) -> list:\n", " _p = open(pickle_file, 'rb')\n", " _d = pickle.load(_p)\n", " _p.close()\n", " \n", " return _d" ] }, { "cell_type": "code", "execution_count": 9, "id": "230fb3b8", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loading data...\n", "../data.pickle found...\n", "768\n", "CPU times: user 572 ms, sys: 2.57 s, total: 3.14 s\n", "Wall time: 3.14 s\n" ] } ], "source": [ "%%time\n", "\n", "def load_data() -> list:\n", " if os.path.isfile(pickle_file):\n", " print(f'{pickle_file} found...')\n", " return load_pickles(pickle_file)\n", " print(f'Didn\\'t find {pickle_file}...')\n", " all_data = dl_from_blob(glob_path)\n", " print(f'Creating {pickle_file}...')\n", " save_pickle(pickle_file, all_data)\n", " return all_data\n", "\n", "print(\"Loading data...\")\n", "dic_data = load_data()\n", "print(len(dic_data))" ] }, { "cell_type": "code", "execution_count": 10, "id": "effa570d", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 393 µs, sys: 0 ns, total: 393 µs\n", "Wall time: 397 µs\n" ] } ], "source": [ "%%time\n", "\n", "# Categorized Data\n", "cdata = dict() \n", "# Sorting, HeightNorm, ArmNorm\n", "cdata['SYY'] = list() \n", "cdata['SYN'] = list() \n", "cdata['SNY'] = list() \n", "cdata['SNN'] = list() \n", "\n", "# Jenga, HeightNorm, ArmNorm\n", "cdata['JYY'] = list() \n", "cdata['JYN'] = list() \n", "cdata['JNY'] = list() \n", "cdata['JNN'] = list() \n", "\n", "for d in dic_data:\n", " if d['scenario'] == 'Sorting':\n", " if d['heightnorm']:\n", " if d['armnorm']:\n", " cdata['SYY'].append(d)\n", " else:\n", " cdata['SYN'].append(d)\n", " else:\n", " if d['armnorm']:\n", " cdata['SNY'].append(d)\n", " else:\n", " cdata['SNN'].append(d)\n", " elif d['scenario'] == 'Jenga':\n", " if d['heightnorm']:\n", " if d['armnorm']:\n", " cdata['JYY'].append(d)\n", " else:\n", " cdata['JYN'].append(d)\n", " else:\n", " if d['armnorm']:\n", " cdata['JNY'].append(d)\n", " else:\n", " cdata['JNN'].append(d)" ] }, { "cell_type": "markdown", "id": "2ad62c63", "metadata": {}, "source": [ "# Preprocessing" ] }, { "cell_type": "code", "execution_count": 11, "id": "55619c6e", "metadata": { "tags": [] }, "outputs": [], "source": [ "def drop(entry, data=True) -> pd.DataFrame:\n", " droptable = ['participantID', 'FrameID', 'Scenario', 'HeightNormalization', 'ArmNormalization', 'Repetition', 'Session', 'Unnamed: 0']\n", " if data:\n", " centry = pickle.loads(pickle.dumps(entry['data']))\n", " else:\n", " centry = pickle.loads(pickle.dumps(entry))\n", "\n", " return centry.drop(droptable, axis=1)\n", " \n" ] }, { "cell_type": "code", "execution_count": 12, "id": "d7be5822", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "left_Hand_ident='left'\n", "right_Hand_ident='right'\n", "\n", "def rem_low_acc(entry, data=True) -> pd.DataFrame:\n", " if data:\n", " centry = pickle.loads(pickle.dumps(entry['data']))\n", " else:\n", " centry = pickle.loads(pickle.dumps(entry))\n", " \n", " centry['LeftHandTrackingAccuracy'] = (centry['LeftHandTrackingAccuracy'] == 'High') * 1.0\n", " centry['RightHandTrackingAccuracy'] = (centry['RightHandTrackingAccuracy'] == 'High') * 1.0\n", " \n", " left_Hand_cols = [c for c in centry if left_Hand_ident in c.lower() and c != 'LeftHandTrackingAccuracy']\n", " right_Hand_cols = [c for c in centry if right_Hand_ident in c.lower() and c != 'RightHandTrackingAccuracy']\n", " \n", " centry.loc[centry['LeftHandTrackingAccuracy'] == 0.0, left_Hand_cols] = np.nan\n", " centry.loc[centry['RightHandTrackingAccuracy'] == 0.0, right_Hand_cols] = np.nan\n", "\n", " return centry" ] }, { "cell_type": "code", "execution_count": 13, "id": "da77d0a9", "metadata": {}, "outputs": [], "source": [ "from tensorflow.keras.preprocessing.sequence import pad_sequences\n", "\n", "def pad(entry, data=True) -> pd.DataFrame:\n", " if data:\n", " centry = pickle.loads(pickle.dumps(entry['data']))\n", " else:\n", " centry = pickle.loads(pickle.dumps(entry))\n", " \n", " cols = centry.columns\n", " pentry = pad_sequences(centry.T.to_numpy(),\n", " maxlen=(int(centry.shape[0]/stride_sz)+1)*stride_sz,\n", " dtype='float64',\n", " padding='pre', \n", " truncating='post',\n", " value=np.nan\n", " ) \n", " pdentry = pd.DataFrame(pentry.T, columns=cols)\n", " pdentry.loc[0] = [0 for _ in cols]\n", " return pdentry" ] }, { "cell_type": "code", "execution_count": 14, "id": "ac13ea7d", "metadata": {}, "outputs": [], "source": [ "def interpol(entry, data=True) -> pd.DataFrame:\n", " if data:\n", " centry = pickle.loads(pickle.dumps(entry['data']))\n", " else:\n", " centry = pickle.loads(pickle.dumps(entry))\n", " \n", " return centry.interpolate(limit_direction='both')" ] }, { "cell_type": "code", "execution_count": 15, "id": "2f6b0535", "metadata": {}, "outputs": [], "source": [ "from tensorflow.keras.preprocessing import timeseries_dataset_from_array\n", "\n", "def slicing(entry, label, data=True):\n", " if data:\n", " centry = pickle.loads(pickle.dumps(entry['data']))\n", " else:\n", " centry = pickle.loads(pickle.dumps(entry))\n", " \n", " return timeseries_dataset_from_array(\n", " data=centry, \n", " targets=[label for _ in range(centry.shape[0])], \n", " sequence_length=win_sz,\n", " sequence_stride=stride_sz, \n", " batch_size=8, \n", " seed=177013\n", " )" ] }, { "cell_type": "code", "execution_count": 16, "id": "be9a3bee", "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 96/96 [00:05<00:00, 16.33it/s]\n" ] } ], "source": [ "acc_data = pd.DataFrame()\n", "\n", "for e in tqdm(cdata[cenario]):\n", " acc_data = acc_data.append(e['data'], ignore_index=True)\n", "\n", "ddacc_data = rem_low_acc(drop(acc_data, False),False)\n", "\n", "eula = ddacc_data[[c for c in ddacc_data if 'euler' in c.lower()]]\n", "posi = ddacc_data[[c for c in ddacc_data if 'pos' in c.lower()]]\n", "eulamin = eula.min()\n", "eulamax = eula.max()\n", "eulamean = eula.mean()\n", "eulastd = eula.std()\n", "posimin = posi.min()\n", "posimax = posi.max()\n", "posimean = posi.mean()\n", "posistd = posi.std()" ] }, { "cell_type": "code", "execution_count": 17, "id": "bf571416", "metadata": {}, "outputs": [], "source": [ "def minmaxscaler(entry, minimum, maximum):\n", " return (entry-minimum)/(maximum-minimum)" ] }, { "cell_type": "code", "execution_count": 18, "id": "dc70c74b", "metadata": {}, "outputs": [], "source": [ "euler_ident = 'euler'\n", "pos_ident = 'pos'\n", "\n", "def norm(entry, data=True) -> pd.DataFrame:\n", " if data:\n", " centry = pickle.loads(pickle.dumps(entry['data']))\n", " else:\n", " centry = pickle.loads(pickle.dumps(entry))\n", " \n", " euler_cols = [c for c in centry if euler_ident in c.lower()]\n", " pos_cols = [c for c in centry if pos_ident in c.lower()]\n", " \n", " centry[euler_cols] = minmaxscaler(centry[euler_cols], eulamin, eulamax)\n", " centry[pos_cols] = minmaxscaler(centry[pos_cols], posimin, posimax)\n", " return centry" ] }, { "cell_type": "code", "execution_count": 19, "id": "45877405", "metadata": {}, "outputs": [], "source": [ "def drop_acc(entry, data=True) -> pd.DataFrame:\n", " droptable = ['LeftHandTrackingAccuracy', 'RightHandTrackingAccuracy']\n", " if data:\n", " centry = pickle.loads(pickle.dumps(entry['data']))\n", " else:\n", " centry = pickle.loads(pickle.dumps(entry))\n", "\n", " return centry.drop(droptable, axis=1)\n", " \n" ] }, { "cell_type": "code", "execution_count": 20, "id": "d7a30d7b", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 96/96 [00:14<00:00, 6.67it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 13 s, sys: 1.93 s, total: 14.9 s\n", "Wall time: 14.4 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "%%time\n", "\n", "classes = 16 # dynamic\n", "\n", "def preproc(data):\n", " res_list = list()\n", " \n", " for e in tqdm(data):\n", " res_list.append(preproc_entry(e))\n", " \n", " return res_list\n", " \n", "def preproc_entry(entry, data = True):\n", " entry2 = pickle.loads(pickle.dumps(entry))\n", " entry2['data'] = drop(entry2, data)\n", " \n", " entry3 = pickle.loads(pickle.dumps(entry2))\n", " entry3['data'] = rem_low_acc(entry3, data)\n", " \n", " entry1 = pickle.loads(pickle.dumps(entry3))\n", " entry1['data'] = norm(entry1, data)\n", " \n", " entry8 = pickle.loads(pickle.dumps(entry1))\n", " entry8['data'] = drop_acc(entry8, data)\n", " \n", "# entry5 = pickle.loads(pickle.dumps(entry4))\n", "# entry5['data'] = pad(entry5, data)\n", " \n", "# entry6 = pickle.loads(pickle.dumps(entry8))\n", "# entry6['data'] = interpol(entry6, data)\n", " \n", " entry7 = pickle.loads(pickle.dumps(entry8))\n", " entry7['data'] = slicing(entry7, entry7['user'], data)\n", " \n", " return entry7\n", "\n", "pdata = preproc(cdata[cenario])" ] }, { "cell_type": "code", "execution_count": 21, "id": "c88f53a4", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "a = drop(cdata[cenario][0]['data'], False)\n", "a['left_OVRHandPrefab_pos_X'].plot()\n", "plt.plot((a['LeftHandTrackingAccuracy'] == 'High')*1.0)" ] }, { "cell_type": "code", "execution_count": 22, "id": "b9518087", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "b = rem_low_acc(a, False)\n", "b['left_OVRHandPrefab_pos_X'].plot()\n", "plt.plot(b['LeftHandTrackingAccuracy'])" ] }, { "cell_type": "code", "execution_count": 23, "id": "09687aab", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAaDUlEQVR4nO3de5RcZZnv8e/T3bmSzrU7kEvnQghoAxqgBUZAcVRMokNwCUq8IYMywyLncI7OuFAchoNrnXXUc5zRIeNMGB3UGY0MHjSjYTLKMBOYQzDNLVdCQhJIh4R0rgRy7e7n/FE7odKpTu3q1K7a767fZ62sVO3aXfX0u9/69a5377e2uTsiIhK+umoXICIi5aFAFxHJCAW6iEhGKNBFRDJCgS4ikhEN1XrhpqYmnzJlSrVeXkQkSE8//fROd28u9FjVAn3KlCm0t7dX6+VFRIJkZi/39ZiGXEREMkKBLiKSEQp0EZGMUKCLiGSEAl1EJCOKBrqZ/cDMdpjZqj4eNzP7rpltMLMVZnZx+csUEZFi4uyhPwDMPMXjs4Dp0b9bge+dflkiIlKqouehu/tSM5tyilXmAD/y3PfwLjOzkWY2zt23lavIYB1+A363AI4ejP8zk98N096XXE1Z8eIS6ChhHoPVwTs/AaPPTq6muHp64Km/gYN7ql1JaVqvhbMurHYVcgrlmFg0AdiSd78jWnZSoJvZreT24pk0aVIZXjrlNi2FR/9HdMdi/IDn3jDTnkiyqmxY/Cew9xXitSuAw9EDcM3Xk6wqnl3rYclXojtx6682hz2b4WP3V7sQOYWKzhR19wXAAoC2trbsX1nDu3P///ET8fZsFn4Kdm9Ktqas6OmBGZ+G6+bHW/9/TgDvSbamuHqifnHDD+H866paSmx/dclb/VlSqxxnuWwFWvLuT4yWiYhIBZUj0BcBn43Odrkc2KfxcxGRyis65GJmPwWuBprMrAP4c2AAgLv/DbAYmA1sAA4ANydVrIiI9C3OWS5zizzuwO1lqyhL+nUB7uwfWiiPfrRTai6InpY6SpSa9pO+aKaoiEhGKNArIuapaRbKKWwpUVJzpbBtg9reIdVauxToIiIZoUAXEckIBXqiQj5wl3IhH3AOdhuHWnftUKCLiGSEAr0SYh/80oGn0pTQXqk8AJnGmvqQyvaT3hToIiIZoUAXEckIBXqSgj34JckKtF+oP6eeAj119KaJR2cQifSmQK8IzRRNREntlcK2DWp7h1Rr7VKgi4hkhAJdRCQjFOiJ0pitFBDsWH6oddcOBXraBPtmr7CQp/6LJESBXgmaKSoFBbS9gzqAW7sU6BKwUqb+J1eFSFoo0EVEMkKBLiKSEQr0JOnAXYJCnimaljpKlJr2k74o0CtCM0WlgKC2d0i11i4FuoQr9Kn/ImWmQBcRyQgFuohIRijQE6WDSFJAsAcXQ627dijQ0ybYN3uF6QwikZMo0CtBU/8TootEV0wq2096U6CLiGSEAl1EJCMU6EnSeLgUFGi/UH9OvViBbmYzzWydmW0wszsLPD7JzB4zs2fNbIWZzS5/qbVCb5p4Qp76L5KMooFuZvXAfGAW0ArMNbPWXqt9DXjQ3S8CbgT+utyFhk1T/6WAoLZ3SLXWrjh76JcCG9x9o7sfARYCc3qt48Dw6PYI4NXylSjSB039FzlBnECfAGzJu98RLct3D/BpM+sAFgP/pdATmdmtZtZuZu2dnZ39KFdERPpSroOic4EH3H0iMBv4sZmd9NzuvsDd29y9rbm5uUwvLRIYjeVLQuIE+lagJe/+xGhZvluABwHc/UlgMNBUjgJrjt7s8WimqMhJ4gT6cmC6mU01s4HkDnou6rXOK8D7Aczs7eQCXWMqx2imqBQU0PYO6gBu7Soa6O7eBcwDlgBryZ3NstrM7jWza6PVvgR8wcyeB34KfM5du5qStNCn/ouUV0Ocldx9MbmDnfnL7s67vQa4oryliYhIKTRTNEn6kCIFBdov1J9TT4EuIpIRCvSKKGWmqPaC4snA1P+gxvVDqrV2KdAlXJopKnICBbqISEYo0EVEMkKBnqiUjdlKOgTbLYItvGYo0CuhlJmiaTtwl1aZmPof0Lh+QKXWMgW6iEhGKNAlYJr6L5JPgS4ikhEK9CRpPFwKCrRfqD+nngI9TTRTtASaKVpZIdVauxToIiIZoUCXcGnqv8gJFOgiIhmhQE9UysZsJR3SNpYfW6h11w4FuohIRijQK0FT/8tPU/8rK6gzcmqXAl0CppmiIvkU6CIiGaFAT5KGT6SgQPuF+nPqKdBFRDJCgV4Rukh0+WVh6n+1CyhFUMXWLAW6iEhGKNAlXJr6L3ICBXqiUvYRX9IhbUM/sYVad+1QoIuIZIQCvRJKmimaaCXZ0a92SlvjBjQMpIlZQVCgi4hkRKxAN7OZZrbOzDaY2Z19rPNxM1tjZqvN7CflLVOkEE39F8nXUGwFM6sH5gMfBDqA5Wa2yN3X5K0zHfgKcIW77zGzsUkVHJRgD35JsgLtF+rPqRdnD/1SYIO7b3T3I8BCYE6vdb4AzHf3PQDuvqO8ZYqISDFxAn0CsCXvfke0LN+5wLlm9p9mtszMZhZ6IjO71czazay9s7OzfxUHqZSZolIzgtreIdVau8p1ULQBmA5cDcwF7jezkb1XcvcF7t7m7m3Nzc1leums0cfaeDIw9V+kzOIE+lagJe/+xGhZvg5gkbsfdfdNwIvkAl4kOZopKnKCOIG+HJhuZlPNbCBwI7Co1zq/ILd3jpk1kRuC2Vi+MkVEpJiige7uXcA8YAmwFnjQ3Veb2b1mdm202hJgl5mtAR4D/tTddyVVdDj0EV8KCHboJ9S6a0fR0xYB3H0xsLjXsrvzbjvwxeif9FbKTFGpIQFt76AO4NYuzRRNm2D33iosExeJFikvBbqISEYo0CVgmvovkk+BniQNn0hBgfYL9efUU6CLiGSEAr0i4k79h2D33iquPzNFy1/FaQlqGCikWmuXAl1EJCMU6BIuTf0XOYECPVFp+4wvqRDswcVQ664dCnQRkYxQoFeCpv5LQQFt76AO4NYuBXraBPtxvMI09V/kJAp0qQ3aw5QaoEBPkva2paBA+4X6c+op0EVEMkKBXhG6SLQUENT2DqnW2qVATx19rI1HF4kW6U2BLiKSEQr0RGmPMFGhTv0P9pNCqHXXDgW6iEhGKNArQTNFpaCAtndQB3BrV0O1C5Begv04XmEBzxR943AXw4CHn+3ggqH7OWfsMKwMgbno+Vf5/hObuGLaGMaNHMLlU0cfz+Fzxjae9vNL+inQRSps38EjDAN+1t7Bf//dUqaPHcZN757C2U1nMG7kEKY2nXHKnz/S1cOfL1rF5p0HuOMD07l0ymi63fnGIy+wde9BVnbspafX365PtLXwldlvY+TQgcn9YlJ1CnQJWPy92kPdPdR39zAgwWrimjByKAD3zb2IJQem85OnXuFrv1h1/PGrpjfRuf8wv/+2sdz+vnM4Y9Bbb9MjXT186u+WsXzzHgCeXLCLQQ11uMOR7h6+f1Mb757WxM43DrN0fSeDG+p58bX9/N0Tm/jt2te475MX83vTxlT2F5aKUaAnScMnqXDgSBe79x/mtY59XFLtYoBjQz9NwwbxqQsn88lLJ/F8xz7ePNzFv6zazn+82MmZwwfxvf94if/csJP7PnkxLaOHcuhoN19+aAXLN+/hf9/wTmZdcBa/XfsaKzv2se61/QwfMoD3nTeWujqjZfRQPnXZ5OOvOGfGBP7rwme5+YHf8c/zrmT6mf0YglF/Tj0FekVopmg1rezYx3iH8SMHV7uUE0Xb28yY0TISgCvOaTr+8L+u3s6XHnye2d95nJuvnMqja19jzbbX+dMPncf1l0wEckE9Z8aEoi/VOn44P/n8Zcz6zuPc9o/P8Pefexcto4eWUmwJ60q16CwXybyVW/cBMCqw8eNrzj+LxXdcxblnNfLdR9ez843D3P+ZNm5/3zn9er6xwwfzlzfO4JVdB5j9ncd5Yv3OMlcs1aY99NTRx9p44rfTyq37mF1nDG4Ib/+lZfRQfn7bu9n1xmFGDR1IXd3p7SlfNb2ZR7/0Xm5+YDm3/+QZnvrq+xk8oL5M1Uq1hdfDRY6JOUS1smMfAxrCDq0xwwaddpgf0zJ6KPfOOZ99B4+yZPX2sjynpIMCPVHa2662/YeOsnHnmwxM0955Cg4uXj51DONHDOaXz71awk9Vv245tRT1cpHyW7X1dQAG1uugXr66OuMPZoxn6Yud7H7zSLXLkTKJFehmNtPM1pnZBjO78xTrfczM3MzayldiBmjqf9Ws2RYFepr20I+r7vae884JdPU4v14RYy9dZ2AFoWgvN7N6YD4wC2gF5ppZa4H1GoE7gKfKXWRNScHH8SDEbKeXd71J46AG6szQkMGJ3j6ukXPPHMbPn9la7VKkTOLstlwKbHD3je5+BFgIzCmw3teBbwCHylifyGl5edcBJo0Zqs8+BZgZn3jXJJ7bspdV0amdErY4gT4B2JJ3vyNadpyZXQy0uPuvT/VEZnarmbWbWXtnZ2fJxQZHe9sJKx7Tr+w+wOQxQ1M2ZJCefnH9JRMZMqCeH/6/zcVXVn9OvdMeWDSzOuDbwJeKrevuC9y9zd3bmpubT/elRU6pu8fp2HOASaNP/WVXtWzEkAHc0DaRh5/dyqadb1a7HDlNcQJ9K9CSd39itOyYRuAC4N/NbDNwObBIB0bzaep/Nby69yBHuz23h55GKdne837/HAY21PHNf3mhz3Uc6NYeeurFCfTlwHQzm2pmA4EbgUXHHnT3fe7e5O5T3H0KsAy41t3bE6k48/Smiad4O23eldvjnHzsO0sUSAWNbRzM5686m0dWbWfDjjdOenzL7gOs2vo6m3cdqEJ1Uoqige7uXcA8YAmwFnjQ3Veb2b1mdm3SBYr018+Wb6HO4PwJI6pdSup95vLJNA5u4KsPr6Qn78vUO/YcYO79yzjc1cOIIWn48mE5lVjf5eLui4HFvZbd3ce6V59+Wdng7jq7IkGO8U/Lt3DZ2aMZ2ziY7z+xkRe27+ejF02guXEQv1qxjUunjI6CKEVbIoWfFJobB/FnH2nlyw+t4KsPr+RD55/Fyq37+IdlL3PwaDfnnjmM4YF9uVkt0pdzJei5LXu5CPjRk5v5zAfHluUyY/KWTTvf5MuPruBtZzVy6Gj38SGBX63YdnydP/vISVMmpA83XDKRF7fnLoaxcHnuxLbzxw/n2x+fwfDF2jsPgQI9QRdMGAEr4bv/9hJDRo/nhraWIj+hwI/Lgec79gLwwvb9NNQZ93+2jXPPHMbj63fytV+somnYIC6cmObhlnRtbzPjax9p5cPvGMfrh7qY0TLyrWEW7YwEQYGeoAHR94e0jm/kL3+7nlkXjmPYoCJNnsKP42lzuKubuu5utu07xFXTm3h8/U5uaGvhg61nAjB5zBl8vK2FI909vX5SbRvHRZNGVbsE6ac0fsFF5tzx/ulsf/0Qd/58xSnX27LnIIe7enCF+ikNaqinzoyB9XXcc+35PHDzu7jrw28/YZ2BDXXF/3iKZIwCPUlRMF8yeTS3vXcav1qxrc8p1rvfPMKTG3dx8Gg33b0v2S4nqTfjlqvOZlrzMK4+b2zx8E7VkEGg21c7GqmnQK+QL7znbEYNHcC3lqwr+Pgvn9vK4aPdDBvUQEO9NkscaYpokTRQclSEMWLIAD59+WQeX9/J3gMnfv+0u/PQ0x2MGDqQhjJdlUYCkKpPDcWEVGvtUqBX0NXnjaXHYWmvi/MuXb+T1a++Tuu44VWqTESyQIFeQTNaRtLcOIh/fv7ECwr8w7KXaRo2iKlNZxDs+GrF9aOdNAYsGadAr6D6OuO6GeN57IUdxy/7tWX3Af593Q4+etF46jXcUhuC/cMSat21Q4FeYddf0kJXj/NA9P3T/+uRF6ivM/7wyqnVLSxEJY1B64+lZJ9O1K2EvOA576xGPvyOcSxY+hIHj3Tx65Xb+OIHz2XciCEodGpNQNvbLOBPFrVDe+hVcNfst9M4eAD3P76JK84Zw21XT6t2SSKSAdpDr4LxI4fw4B/9Hss27uK6GRMYkH/eufaC4ulXO6ltJdsU6FUytemM6KwWEZHy0JBLkrS3nbASxqBTNYkn0H6h/px6CvSK0DVFpQBtbykzBbqISEYo0FNHH2vj0UxRkd4U6CIiGaFAT5T2CBMV6kzRYLtFsIXXDAW6iEhGKNArIfaeZIr2IqUCAtreOiMnCAp0EZGMUKCnjc7EiEdT/0VOokBPksJZCgq0X6g/p54CXQIW6tR/kWQo0NNEoVNbgtreIdVauxToIiIZoUBPHY1TxqOp/yK9xQp0M5tpZuvMbIOZ3Vng8S+a2RozW2Fmj5rZ5PKXGiIFiBQQ7B+WUOuuHUUD3czqgfnALKAVmGtmrb1WexZoc/d3AA8B3yx3oSInCXXqv0hC4uyhXwpscPeN7n4EWAjMyV/B3R9z9wPR3WXAxPKWGTjNFJWCAtreQR3ArV1xAn0CsCXvfke0rC+3AI8UesDMbjWzdjNr7+zsjF+liIgUVdaDomb2aaAN+Fahx919gbu3uXtbc3NzOV86OzRMmSA1rmRbnItEbwVa8u5PjJadwMw+ANwFvNfdD5envMAFe/BLkhVov1B/Tr04e+jLgelmNtXMBgI3AovyVzCzi4C/Ba519x3lL1OkEM0UFclXNNDdvQuYBywB1gIPuvtqM7vXzK6NVvsWMAz4JzN7zswW9fF0NUoXiZYCgtrcQRVbs+IMueDui4HFvZbdnXf7A2WuS0RESqSZoiIiGaFATx0deCqqvwfndFBPMk6BnigFSKJCnSka7B+WUOuuHQr0StBMUSkooO2tA/ZBUKCLiGSEAl1EJCMU6GkT7PhqBfW7jdS2km0K9CQpnKWgQPuF+nPqKdAlYJr6L5JPgV4RmvovBQS1vUOqtXYp0EVEMkKBnjoapyxOM0VFClGgJ0oBIgUE+4cl1LprhwJdwhXq1H+RhCjQKyGog19SOQH1C/XhICjQRUQyQoEuIpIRCvQk9efgV7AHzCoo+DYKtP7g2z37FOgSMM0UFcmnQK8IzRSVAoLa3iHVWrsU6CIiGaFAFxHJCAV6ovpzEEkHnooLfOp/WuooWah11w4FuohIRijQK0EXiU5GSc2VxrZNY019COoAbu1SoIuIZIQCXUQkIxToSdJM0WQEf5HotNRRIvXN1FOgi4hkhAK9IjRTNBmlTP1Prop+C2p7h1Rr7VKgi4hkRKxAN7OZZrbOzDaY2Z0FHh9kZj+LHn/KzKaUvVIRETmlooFuZvXAfGAW0ArMNbPWXqvdAuxx93OAvwC+Ue5CRUTk1BpirHMpsMHdNwKY2UJgDrAmb505wD3R7YeA+8zM3BM4LP7Mj+HJ+8r+tIk4sKv0n+k6CPMvK38tWeI9/fu5TUvT0baH91e7gv7ZsTYd7ZcF7/0yXPCxsj9tnECfAGzJu98B9N6qx9dx9y4z2weMAXbmr2RmtwK3AkyaNKl/FQ8dDc3n9e9nq2HUVBgwJN66rdfB3lf6H1i15KwL4bzZ8de/7DZYvyS5eko1+APQdG61q4jvks/BgMHVriI7Bo9M5Gmt2E60mV0PzHT3z0f3PwNc5u7z8tZZFa3TEd1/KVpnZ6HnBGhra/P29vYy/AoiIrXDzJ5297ZCj8U5KLoVaMm7PzFaVnAdM2sARgD9GG8QEZH+ihPoy4HpZjbVzAYCNwKLeq2zCLgpun098G+JjJ+LiEifio6hR2Pi84AlQD3wA3dfbWb3Au3uvgj4PvBjM9sA7CYX+iIiUkFxDori7ouBxb2W3Z13+xBwQ3lLExGRUmimqIhIRijQRUQyQoEuIpIRCnQRkYwoOrEosRc26wRe7uePN9FrFqoUpHaKR+1UnNoonkq002R3by70QNUC/XSYWXtfM6XkLWqneNROxamN4ql2O2nIRUQkIxToIiIZEWqgL6h2AYFQO8WjdipObRRPVdspyDF0ERE5Wah76CIi0osCXUQkI4IL9GIXrK4lZrbZzFaa2XNm1h4tG21mvzGz9dH/o6LlZmbfjdpthZldXN3qk2NmPzCzHdGFV44tK7ldzOymaP31ZnZTodcKWR/tdI+ZbY361HNmNjvvsa9E7bTOzD6Utzyz70kzazGzx8xsjZmtNrM7ouXp7E/uHsw/cl/f+xJwNjAQeB5orXZdVWyPzUBTr2XfBO6Mbt8JfCO6PRt4BDDgcuCpatefYLu8B7gYWNXfdgFGAxuj/0dFt0dV+3erQDvdA/xJgXVbo/fbIGBq9D6sz/p7EhgHXBzdbgRejNoilf0ptD304xesdvcjwLELVstb5gA/jG7/ELgub/mPPGcZMNLMxlWhvsS5+1Jy38ufr9R2+RDwG3ff7e57gN8AMxMvvoL6aKe+zAEWuvthd98EbCD3fsz0e9Ldt7n7M9Ht/cBactdQTmV/Ci3QC12wekKVakkDB/7VzJ6OLsANcKa7b4tubwfOjG7XetuV2i613F7zouGCHxwbSkDthJlNAS4CniKl/Sm0QJcTXenuFwOzgNvN7D35D3rus57OS+1F7XJK3wOmATOAbcD/qWo1KWFmw4CfA//N3V/PfyxN/Sm0QI9zweqa4e5bo/93AA+T+/j72rGhlOj/HdHqtd52pbZLTbaXu7/m7t3u3gPcT65PQQ23k5kNIBfm/+ju/zdanMr+FFqgx7lgdU0wszPMrPHYbeAaYBUnXrD7JuCX0e1FwGejo/CXA/vyPjLWglLbZQlwjZmNioYdromWZVqv4yofJdenINdON5rZIDObCkwHfkfG35NmZuSumbzW3b+d91A6+1O1jyL346jzbHJHml8C7qp2PVVsh7PJnVHwPLD6WFsAY4BHgfXAb4HR0XID5kftthJoq/bvkGDb/JTccMFRcmOVt/SnXYA/JHfwbwNwc7V/rwq104+jdlhBLpzG5a1/V9RO64BZecsz+54EriQ3nLICeC76Nzut/UlT/0VEMiK0IRcREemDAl1EJCMU6CIiGaFAFxHJCAW6iEhGKNBFRDJCgS4ikhH/HyJEX7uYuMpYAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "c = norm(b, False)\n", "c['left_OVRHandPrefab_pos_X'].plot()\n", "plt.plot(c['LeftHandTrackingAccuracy'])" ] }, { "cell_type": "code", "execution_count": 24, "id": "9ae9b71e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "d = interpol(c, False)\n", "d['left_OVRHandPrefab_pos_X'].plot()\n", "plt.plot(d['LeftHandTrackingAccuracy'])" ] }, { "cell_type": "code", "execution_count": 25, "id": "29e9063e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 234 µs, sys: 0 ns, total: 234 µs\n", "Wall time: 252 µs\n" ] }, { "data": { "text/plain": [ "(48, 48)" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%time\n", "train = np.array([x['data'] for x in pdata if x['session'] == 1])\n", "test = np.array([x['data'] for x in pdata if x['session'] == 2])\n", "\n", "len(train), len(test)" ] }, { "cell_type": "code", "execution_count": 26, "id": "a52352aa", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 96/96 [00:36<00:00, 2.62it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(57800, 5, 336) (57800,) (37106, 5, 336) (37106,)\n", "CPU times: user 1min 48s, sys: 14.9 s, total: 2min 3s\n", "Wall time: 37 s\n" ] } ], "source": [ "%%time\n", "\n", "X_train = list()\n", "y_train = list()\n", "\n", "X_test = list()\n", "y_test = list()\n", "\n", "train = list()\n", "test = list()\n", "\n", "for x in tqdm(pdata):\n", " if x['session'] == 1:\n", " train.append(\n", " {\n", " 'label': x['user'],\n", " 'data': list()\n", " })\n", " for y in x['data'].unbatch().as_numpy_iterator():\n", " if not np.isnan(y[0]).any():\n", " X_train.append(y[0])\n", " y_train.append(y[1])\n", " \n", " train[-1]['data'].append(y[0])\n", " if len(train[-1]['data']) == 0:\n", " del train[-1]\n", " if x['session'] == 2:\n", " test.append(\n", " {\n", " 'label': x['user'],\n", " 'data': list()\n", " })\n", " for y in x['data'].unbatch().as_numpy_iterator():\n", " if not np.isnan(y[0]).any():\n", " X_test.append(y[0])\n", " y_test.append(y[1])\n", " \n", " test[-1]['data'].append(y[0])\n", " \n", " if len(test[-1]['data']) == 0:\n", " del test[-1]\n", " \n", "X_train = np.array(X_train)\n", "y_train = np.array(y_train)\n", "X_test = np.array(X_test)\n", "y_test = np.array(y_test)\n", "\n", "print(X_train.shape, y_train.shape, X_test.shape, y_test.shape)" ] }, { "cell_type": "code", "execution_count": 27, "id": "8c85c181", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Key: 1: 1347\n", "Key: 2: 1583\n", "Key: 3: 8568\n", "Key: 4: 3034\n", "Key: 5: 1960\n", "Key: 6: 3311\n", "Key: 7: 3971\n", "Key: 8: 1407\n", "Key: 9: 1135\n", "Key: 10: 7466\n", "Key: 11: 6494\n", "Key: 12: 1813\n", "Key: 13: 3596\n", "Key: 14: 3260\n", "Key: 15: 2825\n", "Key: 16: 6030\n" ] }, { "data": { "text/plain": [ "array([], dtype=object)" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "Xy_train = list(zip(X_train, y_train))\n", "Xy_test = list(zip(X_test, y_test))\n", "train_dict = {\"1\":[], \"2\":[],\"3\":[], \"4\":[], \"5\":[],\"6\":[], \"7\":[], \"8\":[],\"9\":[], \"10\":[], \"11\":[],\"12\":[], \"13\":[], \"14\":[], \"15\": [], \"16\": []}\n", "\n", "[train_dict[str(e[1])].append(e[0]) for e in Xy_train]\n", "[print(f'Key: {k}: {len(v)}') for k, v in train_dict.items()]\n", "pd.DataFrame.from_dict({k: len(v) for k, v in train_dict.items()}, orient='index').plot.pie(subplots=True, legend=False)" ] }, { "cell_type": "code", "execution_count": 28, "id": "92991de2", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Key: 1: 790\n", "Key: 2: 59\n", "Key: 3: 4330\n", "Key: 4: 0\n", "Key: 5: 545\n", "Key: 6: 348\n", "Key: 7: 5245\n", "Key: 8: 3558\n", "Key: 9: 2565\n", "Key: 10: 4163\n", "Key: 11: 3654\n", "Key: 12: 2868\n", "Key: 13: 2130\n", "Key: 14: 2360\n", "Key: 15: 2390\n", "Key: 16: 2101\n" ] }, { "data": { "text/plain": [ "array([], dtype=object)" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "Xy_test = list(zip(X_test, y_test))\n", "test_dict = {\"1\":[], \"2\":[],\"3\":[], \"4\":[], \"5\":[],\"6\":[], \"7\":[], \"8\":[],\"9\":[], \"10\":[], \"11\":[],\"12\":[], \"13\":[], \"14\":[], \"15\": [], \"16\": []}\n", "\n", "[test_dict[str(e[1])].append(e[0]) for e in Xy_test]\n", "[print(f'Key: {k}: {len(v)}') for k, v in test_dict.items()]\n", "pd.DataFrame.from_dict({k: len(v) for k, v in test_dict.items()}, orient='index').plot.pie(subplots=True, legend=False)" ] }, { "cell_type": "code", "execution_count": 29, "id": "419d603a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 355 ms, sys: 13 ms, total: 368 ms\n", "Wall time: 367 ms\n" ] } ], "source": [ "%%time\n", "\n", "from sklearn.preprocessing import LabelBinarizer\n", "\n", "\n", "lb = LabelBinarizer()\n", "yy_train = lb.fit_transform(y_train)\n", "yy_test = lb.transform(y_test)" ] }, { "cell_type": "code", "execution_count": 30, "id": "da224750", "metadata": {}, "outputs": [], "source": [ "for e in test:\n", " e['label'] = lb.transform([e['label']])\n", " e['data'] = np.array(e['data'])\n", "\n", " \n", "for e in train:\n", " e['label'] = lb.transform([e['label']])\n", " e['data'] = np.array(e['data'])" ] }, { "cell_type": "code", "execution_count": 31, "id": "073c2c51", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(57800, 5, 336)\n", "(57800, 16)\n", "(37106, 5, 336)\n", "(37106, 16)\n" ] } ], "source": [ "print(X_train.shape)\n", "print(yy_train.shape)\n", "print(X_test.shape)\n", "print(yy_test.shape)" ] }, { "cell_type": "markdown", "id": "cee9b1c3", "metadata": {}, "source": [ "# Building Model" ] }, { "cell_type": "code", "execution_count": 32, "id": "75c9ba6d", "metadata": {}, "outputs": [], "source": [ "import tensorflow as tf\n", "from tensorflow.keras.regularizers import l2\n", "from tensorflow.keras.models import Sequential\n", "from tensorflow.keras.layers import Dense, Flatten, BatchNormalization, Dropout\n", "from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau\n", "from tensorflow.keras.optimizers import Adam\n", "\n", "def build_model(shape, classes):\n", " model = Sequential()\n", " \n", " ncount = shape[0]*shape[1]\n", " \n", " model.add(Flatten(input_shape=shape, name='flatten'))\n", " \n", " model.add(Dropout(drop_count, name=f'dropout_{drop_count*100}'))\n", " model.add(BatchNormalization(name='batchNorm'))\n", " \n", " for i in range(2,layer_count+2):\n", " neurons = int(ncount/pow(dense_steps,i))\n", " if neurons <= classes:\n", " break\n", " model.add(Dropout(drop_count*i, name=f'HiddenDropout_{drop_count*i*100:.0f}'))\n", " model.add(Dense(neurons, activation='relu', \n", " kernel_regularizer=l2(0.001), name=f'Hidden_{i}')\n", " )\n", " \n", " model.add(Dense(classes, activation='softmax', name='Output'))\n", " \n", " model.compile(\n", " optimizer=Adam(),\n", " loss=\"categorical_crossentropy\", \n", " metrics=[\"acc\"],\n", " )\n", " \n", " model.summary()\n", " return model" ] }, { "cell_type": "code", "execution_count": 33, "id": "8f71c4bf", "metadata": {}, "outputs": [], "source": [ "checkpoint_file = './goat.weights'\n", "\n", "def train_model(X_train, y_train, X_test, y_test):\n", " model = build_model(X_train[0].shape, 16)\n", " \n", " # Create a callback that saves the model's weights\n", " model_checkpoint = ModelCheckpoint(filepath=checkpoint_path, monitor='loss', \n", "\t\t\tsave_best_only=True)\n", " \n", " reduce_lr = ReduceLROnPlateau(monitor='loss', factor=0.5, patience=5, min_lr=0.0001)\n", "\n", " callbacks = [model_checkpoint, reduce_lr]\n", " \n", " history = model.fit(X_train, \n", " y_train,\n", " epochs=epoch,\n", " batch_size=32,\n", " verbose=2,\n", " validation_data=(X_test, y_test),\n", " callbacks=callbacks\n", " )\n", " \n", " model.load_weights(checkpoint_path)\n", " return model, history" ] }, { "cell_type": "code", "execution_count": 34, "id": "77e0fc90", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loaded weights...\n", "Model: \"sequential\"\n", "_________________________________________________________________\n", "Layer (type) Output Shape Param # \n", "=================================================================\n", "flatten (Flatten) (None, 1680) 0 \n", "_________________________________________________________________\n", "dropout_10.0 (Dropout) (None, 1680) 0 \n", "_________________________________________________________________\n", "batchNorm (BatchNormalizatio (None, 1680) 6720 \n", "_________________________________________________________________\n", "HiddenDropout_20 (Dropout) (None, 1680) 0 \n", "_________________________________________________________________\n", "Hidden_2 (Dense) (None, 186) 312666 \n", "_________________________________________________________________\n", "HiddenDropout_30 (Dropout) (None, 186) 0 \n", "_________________________________________________________________\n", "Hidden_3 (Dense) (None, 62) 11594 \n", "_________________________________________________________________\n", "HiddenDropout_40 (Dropout) (None, 62) 0 \n", "_________________________________________________________________\n", "Hidden_4 (Dense) (None, 20) 1260 \n", "_________________________________________________________________\n", "Output (Dense) (None, 16) 336 \n", "=================================================================\n", "Total params: 332,576\n", "Trainable params: 329,216\n", "Non-trainable params: 3,360\n", "_________________________________________________________________\n", "CPU times: user 80.5 ms, sys: 3.3 ms, total: 83.8 ms\n", "Wall time: 79.5 ms\n" ] } ], "source": [ "%%time\n", "\n", "if not os.path.isdir(checkpoint_dir) or create_new:\n", " tf.keras.backend.clear_session()\n", " model, history = train_model(np.array(X_train), np.array(yy_train), np.array(X_test), np.array(yy_test))\n", "else:\n", " print(\"Loaded weights...\")\n", " model = build_model(X_train[0].shape, 16)\n", " model.load_weights(checkpoint_path)" ] }, { "cell_type": "markdown", "id": "f2e6f8ad", "metadata": {}, "source": [ "# Eval" ] }, { "cell_type": "code", "execution_count": 35, "id": "b7ede2b1", "metadata": {}, "outputs": [], "source": [ "def predict(model, entry):\n", " p_dict = dict()\n", " predictions = np.argmax(model.predict(entry), axis=-1)\n", " for p in predictions:\n", " if p in p_dict:\n", " p_dict[p] += 1\n", " else:\n", " p_dict[p] = 1\n", " prediction = max(p_dict, key=p_dict.get)\n", " return prediction+1" ] }, { "cell_type": "code", "execution_count": 36, "id": "a71bb247", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 3.36 s, sys: 529 ms, total: 3.89 s\n", "Wall time: 2.95 s\n" ] }, { "data": { "text/plain": [ "(43, 43)" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%time\n", "\n", "ltest = [lb.inverse_transform(e['label'])[0] for e in test]\n", "ptest = [predict(model, e['data']) for e in test]\n", "\n", "len(ltest), len(ptest)" ] }, { "cell_type": "code", "execution_count": 37, "id": "ab3ecfc9", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 3.85 s, sys: 448 ms, total: 4.3 s\n", "Wall time: 2.99 s\n" ] }, { "data": { "text/plain": [ "(47, 47)" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%time\n", "\n", "ltrain = [lb.inverse_transform(e['label'])[0] for e in train]\n", "ptrain = [predict(model, e['data']) for e in train]\n", "\n", "\n", "len(ltrain), len(ptrain)" ] }, { "cell_type": "code", "execution_count": 38, "id": "ac226caa", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},\n", " {1, 2, 3, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16})" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "set(ltrain), set(ltest)" ] }, { "cell_type": "code", "execution_count": 39, "id": "3c3bac5d", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ " precision recall f1-score support\n", "\n", " 1 1.00 0.67 0.80 3\n", " 2 0.00 0.00 0.00 1\n", " 3 0.43 1.00 0.60 3\n", " 4 0.00 0.00 0.00 0\n", " 5 0.00 0.00 0.00 3\n", " 6 0.50 0.33 0.40 3\n", " 7 1.00 1.00 1.00 3\n", " 8 1.00 0.33 0.50 3\n", " 9 0.00 0.00 0.00 3\n", " 10 0.40 0.67 0.50 3\n", " 11 0.00 0.00 0.00 3\n", " 12 0.60 1.00 0.75 3\n", " 13 0.75 1.00 0.86 3\n", " 14 0.00 0.00 0.00 3\n", " 15 0.00 0.00 0.00 3\n", " 16 0.50 1.00 0.67 3\n", "\n", " accuracy 0.49 43\n", " macro avg 0.39 0.44 0.38 43\n", "weighted avg 0.43 0.49 0.42 43\n", "\n", "CPU times: user 646 ms, sys: 195 ms, total: 840 ms\n", "Wall time: 610 ms\n" ] } ], "source": [ "%%time\n", "\n", "from sklearn.metrics import confusion_matrix\n", "import seaborn as sn\n", "\n", "from sklearn.metrics import classification_report\n", "\n", "set_digits = set(ltrain)\n", "\n", "train_cm = confusion_matrix(ltrain, ptrain, labels=list(set_digits), normalize='true')\n", "test_cm = confusion_matrix(ltest, ptest, labels=list(set_digits), normalize='true')\n", "\n", "df_cm = pd.DataFrame(test_cm, index=set_digits, columns=set_digits)\n", "plt.figure(figsize = (10,7))\n", "sn_plot = sn.heatmap(df_cm, annot=True, cmap=\"Greys\")\n", "plt.ylabel(\"True Label\")\n", "plt.xlabel(\"Predicted Label\")\n", "plt.show()\n", "\n", "print(classification_report(ltest, ptest, zero_division=0))" ] }, { "cell_type": "code", "execution_count": 40, "id": "43acba77", "metadata": {}, "outputs": [], "source": [ "def plot_keras_history(history, name='', acc='acc'):\n", " \"\"\"Plots keras history.\"\"\"\n", " import matplotlib.pyplot as plt\n", "\n", " training_acc = history.history[acc]\n", " validation_acc = history.history['val_' + acc]\n", " loss = history.history['loss']\n", " val_loss = history.history['val_loss']\n", "\n", " epochs = range(len(training_acc))\n", "\n", " plt.ylim(0, 1)\n", " plt.plot(epochs, training_acc, 'tab:blue', label='Training acc')\n", " plt.plot(epochs, validation_acc, 'tab:orange', label='Validation acc')\n", " plt.title('Training and validation accuracy ' + name)\n", " plt.legend()\n", "\n", " plt.figure()\n", "\n", " plt.plot(epochs, loss, 'tab:green', label='Training loss')\n", " plt.plot(epochs, val_loss, 'tab:red', label='Validation loss')\n", " plt.title('Training and validation loss ' + name)\n", " plt.legend()\n", " plt.show()\n", " plt.close()\n", "if 'history' in locals():\n", " plot_keras_history(history)" ] }, { "cell_type": "code", "execution_count": 41, "id": "af999e08", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Scenario: SYN\n", "Window Size: 5\n", "Strides: 1\n", "Epochs: 50\n", "HiddenL Count: 3\n", "Neuron Factor: 3\n", "Drop Factor: 0.1\n" ] } ], "source": [ "print(f'Scenario: {cenario}')\n", "print(f'Window Size: {win_sz}')\n", "print(f'Strides: {stride_sz}')\n", "print(f'Epochs: {epoch}')\n", "print(f'HiddenL Count: {layer_count}')\n", "print(f'Neuron Factor: {dense_steps}')\n", "print(f'Drop Factor: {drop_count}')" ] }, { "cell_type": "code", "execution_count": 42, "id": "b16af0c6", "metadata": {}, "outputs": [], "source": [ "exit()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" }, "toc-showtags": false }, "nbformat": 4, "nbformat_minor": 5 }