2021-07-05 15:01:40 +02:00
|
|
|
{
|
|
|
|
"cells": [
|
2021-07-17 17:40:05 +02:00
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
2021-08-06 23:49:37 +02:00
|
|
|
"id": "9c890798",
|
2021-07-17 17:40:05 +02:00
|
|
|
"metadata": {},
|
|
|
|
"source": [
|
2021-08-06 22:16:00 +02:00
|
|
|
"# Change Scenario here.\n",
|
|
|
|
"\n",
|
|
|
|
"| | GameType | HeightNorm | ArmNorm |\n",
|
|
|
|
"|:---:|:--------:|:----------:|:-------:|\n",
|
|
|
|
"| SYY | Sorting | ✅ | ✅ |\n",
|
|
|
|
"| SYN | Sorting | ✅ | ❌ |\n",
|
|
|
|
"| SNY | Sorting | ❌ | ✅ |\n",
|
|
|
|
"| SNN | Sorting | ❌ | ❌ |\n",
|
|
|
|
"| JYY | Jenga | ✅ | ✅ |\n",
|
|
|
|
"| JYN | Jenga | ✅ | ❌ |\n",
|
|
|
|
"| JNY | Jenga | ❌ | ✅ |\n",
|
|
|
|
"| JNN | Jenga | ❌ | ❌ |\n",
|
|
|
|
"\n",
|
|
|
|
"Weights for the corresponding scenario are loaded automatically."
|
2021-07-17 17:40:05 +02:00
|
|
|
]
|
|
|
|
},
|
2021-07-05 15:01:40 +02:00
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-07-14 10:15:52 +02:00
|
|
|
"execution_count": 1,
|
2021-08-06 23:49:37 +02:00
|
|
|
"id": "1c9e114c",
|
2021-08-06 22:16:00 +02:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"# Possibilities: 'SYY', 'SYN', 'SNY', 'SNN', \n",
|
|
|
|
"# 'JYY', 'JYN', 'JNY', 'JNN'\n",
|
2021-08-06 23:49:37 +02:00
|
|
|
"cenario = 'SYN'"
|
2021-08-06 22:16:00 +02:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
2021-08-06 23:49:37 +02:00
|
|
|
"id": "a3c8b624",
|
2021-08-06 22:16:00 +02:00
|
|
|
"metadata": {},
|
|
|
|
"source": [
|
|
|
|
"## Constants"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 2,
|
2021-08-06 23:49:37 +02:00
|
|
|
"id": "5f120a31",
|
2021-07-05 15:01:40 +02:00
|
|
|
"metadata": {},
|
2021-07-17 03:32:34 +02:00
|
|
|
"outputs": [],
|
2021-07-05 15:01:40 +02:00
|
|
|
"source": [
|
|
|
|
"import os\n",
|
2021-07-14 10:15:52 +02:00
|
|
|
"\n",
|
|
|
|
"os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' # this is required\n",
|
2021-08-06 23:18:21 +02:00
|
|
|
"os.environ['CUDA_VISIBLE_DEVICES'] = '2' # set to '0' for GPU0, '1' for GPU1 or '2' for GPU2. Check \"gpustat\" in a terminal."
|
2021-07-14 10:15:52 +02:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 22:16:00 +02:00
|
|
|
"execution_count": 3,
|
2021-08-06 23:49:37 +02:00
|
|
|
"id": "3be386b5",
|
2021-07-14 10:15:52 +02:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
2021-08-06 20:20:52 +02:00
|
|
|
"import pandas as pd\n",
|
|
|
|
"\n",
|
2021-07-14 10:15:52 +02:00
|
|
|
"glob_path = '/opt/iui-datarelease3-sose2021/*.csv'\n",
|
|
|
|
"\n",
|
2021-07-27 16:00:03 +02:00
|
|
|
"pickle_file = '../data.pickle'\n",
|
|
|
|
"\n",
|
2021-08-06 20:20:52 +02:00
|
|
|
"pd.set_option('display.float_format', lambda x: '%.2f' % x)"
|
2021-07-19 01:21:53 +02:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
2021-08-06 23:49:37 +02:00
|
|
|
"id": "375756bc",
|
2021-07-19 01:21:53 +02:00
|
|
|
"metadata": {},
|
|
|
|
"source": [
|
|
|
|
"# Config"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 22:16:00 +02:00
|
|
|
"execution_count": 4,
|
2021-08-06 23:49:37 +02:00
|
|
|
"id": "fe73e572",
|
2021-07-19 01:21:53 +02:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
2021-08-06 23:19:00 +02:00
|
|
|
"create_new = False\n",
|
2021-08-06 22:16:00 +02:00
|
|
|
"checkpoint_path = f\"training_{cenario}/cp.ckpt\"\n",
|
|
|
|
"checkpoint_dir = os.path.dirname(checkpoint_path)\n",
|
2021-07-17 17:40:05 +02:00
|
|
|
"\n",
|
2021-08-06 20:20:52 +02:00
|
|
|
"win_sz = 5\n",
|
|
|
|
"stride_sz = 1\n",
|
|
|
|
"\n",
|
|
|
|
"epoch = 50\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
"\n",
|
|
|
|
"# divisor for neuron count step downs (hard to describe), e.g. dense_step = 3: layer1=900, layer2 = 300, layer3 = 100, layer4 = 33...\n",
|
|
|
|
"dense_steps = 3\n",
|
|
|
|
"# amount of dense/dropout layers\n",
|
2021-08-06 20:20:52 +02:00
|
|
|
"layer_count = 3\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
"# how much to drop\n",
|
2021-08-06 20:20:52 +02:00
|
|
|
"drop_count = 0.1"
|
2021-07-19 01:21:53 +02:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
2021-08-06 23:49:37 +02:00
|
|
|
"id": "0173497c",
|
2021-07-19 01:21:53 +02:00
|
|
|
"metadata": {},
|
|
|
|
"source": [
|
|
|
|
"# Helper Functions"
|
2021-07-14 10:15:52 +02:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 22:16:00 +02:00
|
|
|
"execution_count": 5,
|
2021-08-06 23:49:37 +02:00
|
|
|
"id": "ef82a419",
|
2021-07-17 03:32:34 +02:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"from matplotlib import pyplot as plt\n",
|
|
|
|
"\n",
|
|
|
|
"def pplot(dd):\n",
|
|
|
|
" x = dd.shape[0]\n",
|
|
|
|
" fix = int(x/3)+1\n",
|
|
|
|
" fiy = 3\n",
|
|
|
|
" fig, axs = plt.subplots(fix, fiy, figsize=(3*fiy, 9*fix))\n",
|
|
|
|
" \n",
|
|
|
|
" for i in range(x):\n",
|
|
|
|
" axs[int(i/3)][i%3].plot(dd[i])"
|
|
|
|
]
|
|
|
|
},
|
2021-07-17 17:40:05 +02:00
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
2021-08-06 23:49:37 +02:00
|
|
|
"id": "556c7dde",
|
2021-07-17 17:40:05 +02:00
|
|
|
"metadata": {},
|
|
|
|
"source": [
|
|
|
|
"# Loading Data"
|
|
|
|
]
|
|
|
|
},
|
2021-07-17 03:32:34 +02:00
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 22:16:00 +02:00
|
|
|
"execution_count": 6,
|
2021-08-06 23:49:37 +02:00
|
|
|
"id": "51195751",
|
2021-07-14 10:15:52 +02:00
|
|
|
"metadata": {
|
|
|
|
"tags": []
|
|
|
|
},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
2021-07-05 15:01:40 +02:00
|
|
|
"from glob import glob\n",
|
2021-07-14 10:15:52 +02:00
|
|
|
"from tqdm import tqdm\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
"\n",
|
2021-07-14 10:15:52 +02:00
|
|
|
"def dl_from_blob(filename, user_filter=None):\n",
|
|
|
|
" \n",
|
2021-07-05 15:01:40 +02:00
|
|
|
" dic_data = []\n",
|
|
|
|
" \n",
|
2021-07-14 10:15:52 +02:00
|
|
|
" for p in tqdm(glob(glob_path)):\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
" path = p\n",
|
2021-07-14 10:15:52 +02:00
|
|
|
" filename = path.split('/')[-1].split('.')[0]\n",
|
|
|
|
" splitname = filename.split('_')\n",
|
|
|
|
" user = int(splitname[0][1:])\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
" if (user_filter):\n",
|
|
|
|
" if (user != user_filter):\n",
|
|
|
|
" continue\n",
|
2021-07-14 10:15:52 +02:00
|
|
|
" scenario = splitname[1][len('Scenario'):]\n",
|
|
|
|
" heightnorm = splitname[2][len('HeightNormalization'):] == 'True'\n",
|
|
|
|
" armnorm = splitname[3][len('ArmNormalization'):] == 'True'\n",
|
|
|
|
" rep = int(splitname[4][len('Repetition'):])\n",
|
|
|
|
" session = int(splitname[5][len('Session'):])\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
" data = pd.read_csv(path)\n",
|
|
|
|
" dic_data.append(\n",
|
|
|
|
" {\n",
|
|
|
|
" 'filename': path,\n",
|
|
|
|
" 'user': user,\n",
|
|
|
|
" 'scenario': scenario,\n",
|
|
|
|
" 'heightnorm': heightnorm,\n",
|
|
|
|
" 'armnorm': armnorm,\n",
|
|
|
|
" 'rep': rep,\n",
|
2021-07-14 10:15:52 +02:00
|
|
|
" 'session': session,\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
" 'data': data \n",
|
|
|
|
" }\n",
|
|
|
|
" )\n",
|
2021-07-14 10:15:52 +02:00
|
|
|
" return dic_data"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 22:16:00 +02:00
|
|
|
"execution_count": 7,
|
2021-08-06 23:49:37 +02:00
|
|
|
"id": "457bc16f",
|
2021-07-14 10:15:52 +02:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"import pickle\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
"\n",
|
2021-07-14 10:15:52 +02:00
|
|
|
"def save_pickle(f, structure):\n",
|
|
|
|
" _p = open(f, 'wb')\n",
|
|
|
|
" pickle.dump(structure, _p)\n",
|
|
|
|
" _p.close()"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 22:16:00 +02:00
|
|
|
"execution_count": 8,
|
2021-08-06 23:49:37 +02:00
|
|
|
"id": "9482bc78",
|
2021-07-14 10:15:52 +02:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"def load_pickles(f) -> list:\n",
|
|
|
|
" _p = open(pickle_file, 'rb')\n",
|
|
|
|
" _d = pickle.load(_p)\n",
|
|
|
|
" _p.close()\n",
|
|
|
|
" \n",
|
|
|
|
" return _d"
|
2021-07-05 15:01:40 +02:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 22:16:00 +02:00
|
|
|
"execution_count": 9,
|
2021-08-06 23:49:37 +02:00
|
|
|
"id": "230fb3b8",
|
2021-07-05 15:01:40 +02:00
|
|
|
"metadata": {},
|
2021-07-14 10:15:52 +02:00
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"name": "stdout",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
|
|
|
"Loading data...\n",
|
|
|
|
"../data.pickle found...\n",
|
2021-07-17 17:40:05 +02:00
|
|
|
"768\n",
|
2021-08-06 23:49:37 +02:00
|
|
|
"CPU times: user 572 ms, sys: 2.57 s, total: 3.14 s\n",
|
|
|
|
"Wall time: 3.14 s\n"
|
2021-07-14 10:15:52 +02:00
|
|
|
]
|
|
|
|
}
|
|
|
|
],
|
|
|
|
"source": [
|
|
|
|
"%%time\n",
|
|
|
|
"\n",
|
|
|
|
"def load_data() -> list:\n",
|
|
|
|
" if os.path.isfile(pickle_file):\n",
|
|
|
|
" print(f'{pickle_file} found...')\n",
|
|
|
|
" return load_pickles(pickle_file)\n",
|
|
|
|
" print(f'Didn\\'t find {pickle_file}...')\n",
|
|
|
|
" all_data = dl_from_blob(glob_path)\n",
|
|
|
|
" print(f'Creating {pickle_file}...')\n",
|
|
|
|
" save_pickle(pickle_file, all_data)\n",
|
|
|
|
" return all_data\n",
|
|
|
|
"\n",
|
|
|
|
"print(\"Loading data...\")\n",
|
|
|
|
"dic_data = load_data()\n",
|
2021-07-17 17:40:05 +02:00
|
|
|
"print(len(dic_data))"
|
2021-07-14 10:15:52 +02:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 22:16:00 +02:00
|
|
|
"execution_count": 10,
|
2021-08-06 23:49:37 +02:00
|
|
|
"id": "effa570d",
|
2021-07-14 10:15:52 +02:00
|
|
|
"metadata": {
|
|
|
|
"tags": []
|
|
|
|
},
|
2021-07-19 01:21:53 +02:00
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"name": "stdout",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
2021-08-06 23:49:37 +02:00
|
|
|
"CPU times: user 393 µs, sys: 0 ns, total: 393 µs\n",
|
|
|
|
"Wall time: 397 µs\n"
|
2021-07-19 01:21:53 +02:00
|
|
|
]
|
|
|
|
}
|
|
|
|
],
|
2021-07-14 10:15:52 +02:00
|
|
|
"source": [
|
2021-07-19 01:21:53 +02:00
|
|
|
"%%time\n",
|
|
|
|
"\n",
|
2021-07-17 03:32:34 +02:00
|
|
|
"# Categorized Data\n",
|
|
|
|
"cdata = dict() \n",
|
|
|
|
"# Sorting, HeightNorm, ArmNorm\n",
|
|
|
|
"cdata['SYY'] = list() \n",
|
|
|
|
"cdata['SYN'] = list() \n",
|
|
|
|
"cdata['SNY'] = list() \n",
|
|
|
|
"cdata['SNN'] = list() \n",
|
|
|
|
"\n",
|
|
|
|
"# Jenga, HeightNorm, ArmNorm\n",
|
|
|
|
"cdata['JYY'] = list() \n",
|
|
|
|
"cdata['JYN'] = list() \n",
|
|
|
|
"cdata['JNY'] = list() \n",
|
|
|
|
"cdata['JNN'] = list() \n",
|
2021-07-27 16:00:03 +02:00
|
|
|
"\n",
|
2021-07-14 10:15:52 +02:00
|
|
|
"for d in dic_data:\n",
|
|
|
|
" if d['scenario'] == 'Sorting':\n",
|
2021-07-17 03:32:34 +02:00
|
|
|
" if d['heightnorm']:\n",
|
|
|
|
" if d['armnorm']:\n",
|
|
|
|
" cdata['SYY'].append(d)\n",
|
|
|
|
" else:\n",
|
|
|
|
" cdata['SYN'].append(d)\n",
|
|
|
|
" else:\n",
|
|
|
|
" if d['armnorm']:\n",
|
|
|
|
" cdata['SNY'].append(d)\n",
|
|
|
|
" else:\n",
|
|
|
|
" cdata['SNN'].append(d)\n",
|
|
|
|
" elif d['scenario'] == 'Jenga':\n",
|
|
|
|
" if d['heightnorm']:\n",
|
|
|
|
" if d['armnorm']:\n",
|
|
|
|
" cdata['JYY'].append(d)\n",
|
|
|
|
" else:\n",
|
|
|
|
" cdata['JYN'].append(d)\n",
|
|
|
|
" else:\n",
|
|
|
|
" if d['armnorm']:\n",
|
|
|
|
" cdata['JNY'].append(d)\n",
|
|
|
|
" else:\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
" cdata['JNN'].append(d)"
|
2021-07-17 17:40:05 +02:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
2021-08-06 23:49:37 +02:00
|
|
|
"id": "2ad62c63",
|
2021-07-17 17:40:05 +02:00
|
|
|
"metadata": {},
|
|
|
|
"source": [
|
|
|
|
"# Preprocessing"
|
2021-07-14 10:15:52 +02:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 22:16:00 +02:00
|
|
|
"execution_count": 11,
|
2021-08-06 23:49:37 +02:00
|
|
|
"id": "55619c6e",
|
2021-07-14 10:15:52 +02:00
|
|
|
"metadata": {
|
|
|
|
"tags": []
|
|
|
|
},
|
2021-07-17 17:40:05 +02:00
|
|
|
"outputs": [],
|
2021-07-14 10:15:52 +02:00
|
|
|
"source": [
|
2021-07-27 16:00:03 +02:00
|
|
|
"def drop(entry, data=True) -> pd.DataFrame:\n",
|
2021-07-17 17:40:05 +02:00
|
|
|
" droptable = ['participantID', 'FrameID', 'Scenario', 'HeightNormalization', 'ArmNormalization', 'Repetition', 'Session', 'Unnamed: 0']\n",
|
2021-07-27 16:00:03 +02:00
|
|
|
" if data:\n",
|
|
|
|
" centry = pickle.loads(pickle.dumps(entry['data']))\n",
|
|
|
|
" else:\n",
|
|
|
|
" centry = pickle.loads(pickle.dumps(entry))\n",
|
|
|
|
"\n",
|
|
|
|
" return centry.drop(droptable, axis=1)\n",
|
|
|
|
" \n"
|
2021-07-14 10:15:52 +02:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 22:16:00 +02:00
|
|
|
"execution_count": 12,
|
2021-08-06 23:49:37 +02:00
|
|
|
"id": "d7be5822",
|
2021-07-14 10:15:52 +02:00
|
|
|
"metadata": {},
|
2021-07-17 17:40:05 +02:00
|
|
|
"outputs": [],
|
2021-07-14 10:15:52 +02:00
|
|
|
"source": [
|
2021-07-17 03:32:34 +02:00
|
|
|
"import numpy as np\n",
|
2021-08-06 20:20:52 +02:00
|
|
|
"left_Hand_ident='left'\n",
|
|
|
|
"right_Hand_ident='right'\n",
|
2021-07-14 10:15:52 +02:00
|
|
|
"\n",
|
2021-07-27 16:00:03 +02:00
|
|
|
"def rem_low_acc(entry, data=True) -> pd.DataFrame:\n",
|
|
|
|
" if data:\n",
|
|
|
|
" centry = pickle.loads(pickle.dumps(entry['data']))\n",
|
|
|
|
" else:\n",
|
|
|
|
" centry = pickle.loads(pickle.dumps(entry))\n",
|
2021-07-17 03:32:34 +02:00
|
|
|
" \n",
|
2021-07-27 16:00:03 +02:00
|
|
|
" centry['LeftHandTrackingAccuracy'] = (centry['LeftHandTrackingAccuracy'] == 'High') * 1.0\n",
|
|
|
|
" centry['RightHandTrackingAccuracy'] = (centry['RightHandTrackingAccuracy'] == 'High') * 1.0\n",
|
|
|
|
" \n",
|
2021-08-06 20:20:52 +02:00
|
|
|
" left_Hand_cols = [c for c in centry if left_Hand_ident in c.lower() and c != 'LeftHandTrackingAccuracy']\n",
|
|
|
|
" right_Hand_cols = [c for c in centry if right_Hand_ident in c.lower() and c != 'RightHandTrackingAccuracy']\n",
|
2021-07-27 16:00:03 +02:00
|
|
|
" \n",
|
|
|
|
" centry.loc[centry['LeftHandTrackingAccuracy'] == 0.0, left_Hand_cols] = np.nan\n",
|
2021-08-06 20:20:52 +02:00
|
|
|
" centry.loc[centry['RightHandTrackingAccuracy'] == 0.0, right_Hand_cols] = np.nan\n",
|
|
|
|
"\n",
|
|
|
|
" return centry"
|
2021-07-17 17:40:05 +02:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 22:16:00 +02:00
|
|
|
"execution_count": 13,
|
2021-08-06 23:49:37 +02:00
|
|
|
"id": "da77d0a9",
|
2021-07-17 17:40:05 +02:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"from tensorflow.keras.preprocessing.sequence import pad_sequences\n",
|
|
|
|
"\n",
|
2021-07-27 16:00:03 +02:00
|
|
|
"def pad(entry, data=True) -> pd.DataFrame:\n",
|
|
|
|
" if data:\n",
|
|
|
|
" centry = pickle.loads(pickle.dumps(entry['data']))\n",
|
|
|
|
" else:\n",
|
|
|
|
" centry = pickle.loads(pickle.dumps(entry))\n",
|
|
|
|
" \n",
|
|
|
|
" cols = centry.columns\n",
|
|
|
|
" pentry = pad_sequences(centry.T.to_numpy(),\n",
|
|
|
|
" maxlen=(int(centry.shape[0]/stride_sz)+1)*stride_sz,\n",
|
2021-07-17 17:40:05 +02:00
|
|
|
" dtype='float64',\n",
|
|
|
|
" padding='pre', \n",
|
|
|
|
" truncating='post',\n",
|
|
|
|
" value=np.nan\n",
|
|
|
|
" ) \n",
|
|
|
|
" pdentry = pd.DataFrame(pentry.T, columns=cols)\n",
|
|
|
|
" pdentry.loc[0] = [0 for _ in cols]\n",
|
|
|
|
" return pdentry"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 22:16:00 +02:00
|
|
|
"execution_count": 14,
|
2021-08-06 23:49:37 +02:00
|
|
|
"id": "ac13ea7d",
|
2021-07-17 17:40:05 +02:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
2021-07-27 16:00:03 +02:00
|
|
|
"def interpol(entry, data=True) -> pd.DataFrame:\n",
|
|
|
|
" if data:\n",
|
|
|
|
" centry = pickle.loads(pickle.dumps(entry['data']))\n",
|
|
|
|
" else:\n",
|
|
|
|
" centry = pickle.loads(pickle.dumps(entry))\n",
|
|
|
|
" \n",
|
2021-08-06 20:20:52 +02:00
|
|
|
" return centry.interpolate(limit_direction='both')"
|
2021-07-17 17:40:05 +02:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 22:16:00 +02:00
|
|
|
"execution_count": 15,
|
2021-08-06 23:49:37 +02:00
|
|
|
"id": "2f6b0535",
|
2021-07-17 17:40:05 +02:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"from tensorflow.keras.preprocessing import timeseries_dataset_from_array\n",
|
2021-07-14 10:15:52 +02:00
|
|
|
"\n",
|
2021-07-27 16:00:03 +02:00
|
|
|
"def slicing(entry, label, data=True):\n",
|
|
|
|
" if data:\n",
|
|
|
|
" centry = pickle.loads(pickle.dumps(entry['data']))\n",
|
|
|
|
" else:\n",
|
|
|
|
" centry = pickle.loads(pickle.dumps(entry))\n",
|
|
|
|
" \n",
|
2021-07-17 17:40:05 +02:00
|
|
|
" return timeseries_dataset_from_array(\n",
|
2021-07-27 16:00:03 +02:00
|
|
|
" data=centry, \n",
|
|
|
|
" targets=[label for _ in range(centry.shape[0])], \n",
|
2021-07-17 17:40:05 +02:00
|
|
|
" sequence_length=win_sz,\n",
|
|
|
|
" sequence_stride=stride_sz, \n",
|
|
|
|
" batch_size=8, \n",
|
|
|
|
" seed=177013\n",
|
|
|
|
" )"
|
|
|
|
]
|
|
|
|
},
|
2021-07-27 16:00:03 +02:00
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 23:18:21 +02:00
|
|
|
"execution_count": 16,
|
2021-08-06 23:49:37 +02:00
|
|
|
"id": "be9a3bee",
|
2021-07-27 16:00:03 +02:00
|
|
|
"metadata": {
|
|
|
|
"tags": []
|
|
|
|
},
|
2021-08-06 20:20:52 +02:00
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"name": "stderr",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
2021-08-06 23:49:37 +02:00
|
|
|
"100%|██████████| 96/96 [00:05<00:00, 16.33it/s]\n"
|
2021-08-06 20:20:52 +02:00
|
|
|
]
|
|
|
|
}
|
|
|
|
],
|
2021-07-27 16:00:03 +02:00
|
|
|
"source": [
|
2021-08-06 20:20:52 +02:00
|
|
|
"acc_data = pd.DataFrame()\n",
|
2021-07-27 16:00:03 +02:00
|
|
|
"\n",
|
2021-08-06 20:20:52 +02:00
|
|
|
"for e in tqdm(cdata[cenario]):\n",
|
|
|
|
" acc_data = acc_data.append(e['data'], ignore_index=True)\n",
|
2021-07-27 16:00:03 +02:00
|
|
|
"\n",
|
2021-08-06 20:20:52 +02:00
|
|
|
"ddacc_data = rem_low_acc(drop(acc_data, False),False)\n",
|
2021-07-27 16:00:03 +02:00
|
|
|
"\n",
|
2021-08-06 20:20:52 +02:00
|
|
|
"eula = ddacc_data[[c for c in ddacc_data if 'euler' in c.lower()]]\n",
|
|
|
|
"posi = ddacc_data[[c for c in ddacc_data if 'pos' in c.lower()]]\n",
|
|
|
|
"eulamin = eula.min()\n",
|
|
|
|
"eulamax = eula.max()\n",
|
|
|
|
"eulamean = eula.mean()\n",
|
|
|
|
"eulastd = eula.std()\n",
|
|
|
|
"posimin = posi.min()\n",
|
|
|
|
"posimax = posi.max()\n",
|
|
|
|
"posimean = posi.mean()\n",
|
|
|
|
"posistd = posi.std()"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 23:18:21 +02:00
|
|
|
"execution_count": 17,
|
2021-08-06 23:49:37 +02:00
|
|
|
"id": "bf571416",
|
2021-08-06 20:20:52 +02:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"def minmaxscaler(entry, minimum, maximum):\n",
|
|
|
|
" return (entry-minimum)/(maximum-minimum)"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 23:18:21 +02:00
|
|
|
"execution_count": 18,
|
2021-08-06 23:49:37 +02:00
|
|
|
"id": "dc70c74b",
|
2021-08-06 20:20:52 +02:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"euler_ident = 'euler'\n",
|
|
|
|
"pos_ident = 'pos'\n",
|
2021-07-27 16:00:03 +02:00
|
|
|
"\n",
|
2021-08-06 20:20:52 +02:00
|
|
|
"def norm(entry, data=True) -> pd.DataFrame:\n",
|
|
|
|
" if data:\n",
|
|
|
|
" centry = pickle.loads(pickle.dumps(entry['data']))\n",
|
|
|
|
" else:\n",
|
|
|
|
" centry = pickle.loads(pickle.dumps(entry))\n",
|
|
|
|
" \n",
|
|
|
|
" euler_cols = [c for c in centry if euler_ident in c.lower()]\n",
|
|
|
|
" pos_cols = [c for c in centry if pos_ident in c.lower()]\n",
|
|
|
|
" \n",
|
|
|
|
" centry[euler_cols] = minmaxscaler(centry[euler_cols], eulamin, eulamax)\n",
|
|
|
|
" centry[pos_cols] = minmaxscaler(centry[pos_cols], posimin, posimax)\n",
|
|
|
|
" return centry"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 23:18:21 +02:00
|
|
|
"execution_count": 19,
|
2021-08-06 23:49:37 +02:00
|
|
|
"id": "45877405",
|
2021-08-06 20:20:52 +02:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"def drop_acc(entry, data=True) -> pd.DataFrame:\n",
|
|
|
|
" droptable = ['LeftHandTrackingAccuracy', 'RightHandTrackingAccuracy']\n",
|
|
|
|
" if data:\n",
|
|
|
|
" centry = pickle.loads(pickle.dumps(entry['data']))\n",
|
|
|
|
" else:\n",
|
|
|
|
" centry = pickle.loads(pickle.dumps(entry))\n",
|
2021-07-27 16:00:03 +02:00
|
|
|
"\n",
|
2021-08-06 20:20:52 +02:00
|
|
|
" return centry.drop(droptable, axis=1)\n",
|
|
|
|
" \n"
|
2021-07-27 16:00:03 +02:00
|
|
|
]
|
|
|
|
},
|
2021-07-17 17:40:05 +02:00
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 23:18:21 +02:00
|
|
|
"execution_count": 20,
|
2021-08-06 23:49:37 +02:00
|
|
|
"id": "d7a30d7b",
|
2021-07-17 17:40:05 +02:00
|
|
|
"metadata": {},
|
2021-08-06 23:18:21 +02:00
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"name": "stderr",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
2021-08-06 23:49:37 +02:00
|
|
|
"100%|██████████| 96/96 [00:14<00:00, 6.67it/s]"
|
2021-08-06 23:18:21 +02:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"name": "stdout",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
2021-08-06 23:49:37 +02:00
|
|
|
"CPU times: user 13 s, sys: 1.93 s, total: 14.9 s\n",
|
|
|
|
"Wall time: 14.4 s\n"
|
2021-08-06 23:18:21 +02:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"name": "stderr",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
|
|
|
"\n"
|
|
|
|
]
|
|
|
|
}
|
|
|
|
],
|
2021-07-17 17:40:05 +02:00
|
|
|
"source": [
|
2021-07-19 01:21:53 +02:00
|
|
|
"%%time\n",
|
|
|
|
"\n",
|
2021-07-17 17:40:05 +02:00
|
|
|
"classes = 16 # dynamic\n",
|
2021-07-17 03:32:34 +02:00
|
|
|
"\n",
|
2021-07-17 17:40:05 +02:00
|
|
|
"def preproc(data):\n",
|
|
|
|
" res_list = list()\n",
|
|
|
|
" \n",
|
|
|
|
" for e in tqdm(data):\n",
|
|
|
|
" res_list.append(preproc_entry(e))\n",
|
|
|
|
" \n",
|
|
|
|
" return res_list\n",
|
|
|
|
" \n",
|
2021-07-27 16:00:03 +02:00
|
|
|
"def preproc_entry(entry, data = True):\n",
|
2021-07-17 17:40:05 +02:00
|
|
|
" entry2 = pickle.loads(pickle.dumps(entry))\n",
|
2021-07-27 16:00:03 +02:00
|
|
|
" entry2['data'] = drop(entry2, data)\n",
|
2021-07-17 17:40:05 +02:00
|
|
|
" \n",
|
2021-08-06 20:20:52 +02:00
|
|
|
" entry3 = pickle.loads(pickle.dumps(entry2))\n",
|
|
|
|
" entry3['data'] = rem_low_acc(entry3, data)\n",
|
2021-07-17 17:40:05 +02:00
|
|
|
" \n",
|
2021-08-06 20:20:52 +02:00
|
|
|
" entry1 = pickle.loads(pickle.dumps(entry3))\n",
|
|
|
|
" entry1['data'] = norm(entry1, data)\n",
|
2021-07-17 17:40:05 +02:00
|
|
|
" \n",
|
2021-08-06 20:20:52 +02:00
|
|
|
" entry8 = pickle.loads(pickle.dumps(entry1))\n",
|
|
|
|
" entry8['data'] = drop_acc(entry8, data)\n",
|
|
|
|
" \n",
|
|
|
|
"# entry5 = pickle.loads(pickle.dumps(entry4))\n",
|
|
|
|
"# entry5['data'] = pad(entry5, data)\n",
|
|
|
|
" \n",
|
|
|
|
"# entry6 = pickle.loads(pickle.dumps(entry8))\n",
|
2021-07-27 16:00:03 +02:00
|
|
|
"# entry6['data'] = interpol(entry6, data)\n",
|
2021-07-17 17:40:05 +02:00
|
|
|
" \n",
|
2021-08-06 20:20:52 +02:00
|
|
|
" entry7 = pickle.loads(pickle.dumps(entry8))\n",
|
2021-07-27 16:00:03 +02:00
|
|
|
" entry7['data'] = slicing(entry7, entry7['user'], data)\n",
|
2021-07-17 17:40:05 +02:00
|
|
|
" \n",
|
|
|
|
" return entry7\n",
|
2021-07-14 10:15:52 +02:00
|
|
|
"\n",
|
2021-07-17 17:40:05 +02:00
|
|
|
"pdata = preproc(cdata[cenario])"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 23:18:21 +02:00
|
|
|
"execution_count": 21,
|
2021-08-06 23:49:37 +02:00
|
|
|
"id": "c88f53a4",
|
2021-08-06 20:20:52 +02:00
|
|
|
"metadata": {},
|
2021-08-06 23:18:21 +02:00
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"data": {
|
|
|
|
"text/plain": [
|
2021-08-06 23:49:37 +02:00
|
|
|
"[<matplotlib.lines.Line2D at 0x7fb1087bc370>]"
|
2021-08-06 23:18:21 +02:00
|
|
|
]
|
|
|
|
},
|
|
|
|
"execution_count": 21,
|
|
|
|
"metadata": {},
|
|
|
|
"output_type": "execute_result"
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"data": {
|
2021-08-06 23:49:37 +02:00
|
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAD4CAYAAADhNOGaAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAA8rklEQVR4nO2deXxU5dX4vyf7DgmEAGEJ+ypliYgbbqC4VOhbW7VVqdWqtdq+9u3b2rr+1Nrt7b7j0mKtu1VREQVcURGCIKtsYQsQEgjZyJ55fn/cO2Gyzkwyk7kzc76fTz6Ze+9z75w8ufc59znnPOeIMQZFURQleokJtQCKoihKaFFFoCiKEuWoIlAURYlyVBEoiqJEOaoIFEVRopy4UAvQHfr372/y8vJCLYaiKEpYsW7duqPGmOy2+8NSEeTl5VFQUBBqMRRFUcIKEdnX0X41DSmKokQ5qggURVGiHFUEiqIoUY4qAkVRlChHFYGiKEqUExBFICKPi0iJiGzu5LiIyB9EZJeIbBSR6R7HForITvtnYSDkURRFUXwnUDOCfwLzujh+MTDG/rkJ+CuAiGQB9wGnATOB+0QkM0AyKYqiKD4QkHUExpj3RSSviybzgSeMlfN6tYj0FZFBwLnAcmNMGYCILMdSKE8HQq6w5thu2Pgs+JomXASmXAn9RgVXrnDHGFjzCJwo9f2c+GQ47WZISA2eXL5SXQrr/gHNjaGWxHdiEyD/m5DaL9SSKJ3QWwvKcoEDHttF9r7O9rdDRG7Cmk0wbNiw4EjpJNY+Bqv/DIiPJxior4J5PwumVOFPxQF443/tDV/61lbEOZNh7IXBksp3tr0C7/zU3vD13ggldv+lDYAZavl1KmGzstgYswhYBJCfnx/51XRMMyT1gTv3+9b+58PB1RxcmSIBdx8t+BtMvdp7+0PrYdG51v/DCbhc1u8f7oGUrNDK4gtVxfDrcc7pP6VDeitq6CAw1GN7iL2vs/2KoihKL9FbimAJcJ0dPTQLqDDGHAbeBC4UkUzbSXyhvU/pVgnRyJ8o9Zxu9pFjSro6RQ4/cUz/KR0RENOQiDyN5fjtLyJFWJFA8QDGmL8BS4FLgF1ADXC9faxMRB4E1tqXesDtOFbALxuwhIO92EH43F/arz1D+y8cCFTUUJfGVjta6DudHHsceDwQckQW3XiD0rcu73S7jxzSt2H7Pw5XuaMDXVmsKIoS5agicDJ+mXt0Cu4fPvaXU01uTpWrLeEiZ5SjisCpqLPYWTjGJOMUOfzEMf2ndIQqAkVRlChHFYGj0aihoBH2UUNOlast4SJndKOKwLFo1FBQ0KghRWmHKgJFUZQoRxWBk9GooSCiUUO9QrjIGeWoInAqGjUUJDTFREhwTP8pHaGKwNGos1gJd/S+DAdUETgWfYMKKuEaNdTyZu0wubyi97OTUUUQSej02zvhHjWkKEFAFYGTUWexEu6oyTIsUEXgVPTt3hk4biCz7wvHyeUFvZ8djSqCiEIfNu+Ee9SQogQeVQSORqOGFEUJPqoIHIu+gQYVjRrqZfR+djKqCCIJNV94R6OGFKUdAVEEIjJPRLaLyC4RubOD478VkQ32zw4RKfc41uxxbEkg5IkYNGooiGiKiV4hXOSMcnpcs1hEYoE/A3OBImCtiCwxxmx1tzHG3OHR/nZgmsclao0xU3sqR8Shb/dKh4TpfaH3s6MJxIxgJrDLGFNojGkAngHmd9H+auDpAHyv0g592LyjUUOK0pZAKIJc4IDHdpG9rx0iMhwYAbztsTtJRApEZLWILOjsS0TkJrtdQWlpaQDEDgc0aihohKuzuAWnytWWcJEzuultZ/FVwAvGmGaPfcONMfnA14Dficiojk40xiwyxuQbY/Kzs7N7Q9YQo4VpgkK4O4vD9n8crnL3PqVV9Ww+WEFNQ1OvfWePfQTAQWCox/YQe19HXAV8x3OHMeag/btQRN7F8h/sDoBc4Y86i5VwR2eqfnHj4gJWbDvSsv3sTbM4bWS/oH9vIGYEa4ExIjJCRBKwBvt20T8iMh7IBD722JcpIon25/7AmcDWtudGJWH75hcuhGvUkKaYiFReXn+wRQn8z9yxAFy5aDUlVXVB/+4eKwJjTBNwG/AmsA14zhizRUQeEJHLPZpeBTxjTKs7YgJQICKfAe8AP/eMNlL8RR8276izWHEmv3pzOwBrfnIBt18whrsvnQDAHc9uCPp3B8I0hDFmKbC0zb5722zf38F5HwGnBEKGyESdxUq4o/elL2wvruJgeS2zx2YzICMJgBvPHsnDS7fx4a5j1DQ0kZIQkOG6Q3RlsWPRN9CgEq5RQ5piIiK57alPAfjOua1jZR5cMBmAl9Z35nYNDKoIIgl91rzTiyYel8tQXBF8+64S3lTVNbKzpJr+aYnMHJHV6thXZlhxOG9uOdLRqQFDFYGT0aihsKGx2cWKrUe4cfFafvq65eb6/cqdzPrZSoqO14RYuhCiJkuv/OVdK0jyhxeNQ9r0V0JcDEOzknl/RymlVfVBk0EVgVPRt/sgE7ioIWMML64r4sYnClixrYRHPtgDwPKt1ltceU1jl+c3Nbtoanb5Jo9GDUUcj62y7pcLJ+V0ePxL04YA8MHO4C2kVUUQUejD5p3ARg19dqCcGxYXcOd/NnVbomkPLmfmwyu7fb4Svhwsr6WhycVX84fQNyWhwzbfOCMPgE/3Hw+aHKoIHI1GDTmdqx9Zzdufl7TbX1hazdbDlT5do6quibITDYEWzSHofdkVi96zzEILpnaYlQeArNQEslITeL6gKGhyqCJwLPp2H1QCFDVU09Dc4f7zf/1ey2e3icgbzS4f/ucaNRQRLN96hE8Kj7H4430AzPKyejivXwr1TS7qGju+33qKKoJIQu2w3glQriFjDLf8a51PZ/5+5U6e+mS/13Zj7lrKwfLabkmnhBffeqKAKxetBuCKGUOIielasV95qhU99J9PgxNGqorAyWjUkGNpbDYs21Lsc/ufvOTdh+AyRF64qZosvXLtrOFe21w2ZTAAb231/Z7zB1UETkXf7oNMz6KGXH7+f+JjBePDOd7baNRQJJHbN5kpQ/p4bZeaaK0q3nLIN7+Tv6gicDT+OItB7bC+EJioIX8VQWOz8SkO3Cc/QVgRZgqrl/nJJRParR3ojBvPGkFpVX1Q1qWoInAskTYgBJeSyjr+/M4uKrzE7LfQwzfq7gzYB3x4gL1etuV4uA2wej+7OW5HiOX1S+HiyQN9Pu+SKYM4a3R/nxWHP6giUCKCf3y0l1+9uZ3XNx0O+nc9V3CAGxYX+H3ejiPVXtv4Yj5SwpvvPrMegFvPHe3VSezJ9GGZPHnjaeT2TQ64TKoInIxfil+i2g57ot6q5tTobYVuAKKGfvjCRtbsKfP7Cj/+zyZcXl75myPtfxhuvowgU1pVzwc7jwLwlfwhIZbmJKoInEqkDQiOo3cL05w7ziqv+s729ovPPPFucVJncTjzkJ2H6q9fnx4UE093UUUQUejD5p3QFKZ55Lp8kuNjWbXraJftvM0YlPDFGMMrGw4BMM8P30BvoIrA0WiKiUghPjaGtKQ4rytD/Y1Gcj56X7px5wq6/sw8R80GIECKQETmich2EdklInd2cPwbIlIqIhvsnxs9ji0UkZ32z8JAyBMZRNqA4DBCUJgmMS6GusaufRheo5E0xUTY8uRqa3W5u8aAk+hx7TMRiQX+DMwFioC1IrKkg9rDzxpjbmtzbhZwH5CPdaess88NXpq9SCbi3iaDQIBSTHSHpPhYrykk1DIUmRhjeGn9QfqnJTJxcEaoxWlHIGYEM4FdxphCY0wD8Aww38dzLwKWG2PK7MF/OTAvADJFBppiIqy4/AuDO9x/ml11KiMpjjV7yrj0Dx/wzw/3dNg24kxDDjOBhIqPC48BcMkpzvINuAlENeRc4IDHdhFwWgftviwis4EdwB3GmAOdnNt5PtZoItIGhCBSXd9EQ5NlcvE9Dj+wUUO/+PIpXHnqMP5w9TRKqurYVFRBRW0jY3PSGZOTBsCvvvIFlmw4xO9X7mTLoa28tvEwT97Y+lHxrgg0aigccWeg/fpp3vM
|
2021-08-06 23:18:21 +02:00
|
|
|
"text/plain": [
|
|
|
|
"<Figure size 432x288 with 1 Axes>"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
"metadata": {
|
|
|
|
"needs_background": "light"
|
|
|
|
},
|
|
|
|
"output_type": "display_data"
|
|
|
|
}
|
|
|
|
],
|
2021-08-06 20:20:52 +02:00
|
|
|
"source": [
|
|
|
|
"a = drop(cdata[cenario][0]['data'], False)\n",
|
|
|
|
"a['left_OVRHandPrefab_pos_X'].plot()\n",
|
|
|
|
"plt.plot((a['LeftHandTrackingAccuracy'] == 'High')*1.0)"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 23:18:21 +02:00
|
|
|
"execution_count": 22,
|
2021-08-06 23:49:37 +02:00
|
|
|
"id": "b9518087",
|
2021-08-06 20:20:52 +02:00
|
|
|
"metadata": {},
|
2021-08-06 23:18:21 +02:00
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"data": {
|
|
|
|
"text/plain": [
|
2021-08-06 23:49:37 +02:00
|
|
|
"[<matplotlib.lines.Line2D at 0x7fb1086d6820>]"
|
2021-08-06 23:18:21 +02:00
|
|
|
]
|
|
|
|
},
|
|
|
|
"execution_count": 22,
|
|
|
|
"metadata": {},
|
|
|
|
"output_type": "execute_result"
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"data": {
|
2021-08-06 23:49:37 +02:00
|
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAD4CAYAAADvsV2wAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAf80lEQVR4nO3deXQcZ53u8e9Pqy1blhfJ+yI7sUlMyIbIOCSE3MQJjhnicFgmGWYSlmDOHQLMDcM95mbIcDPMAcLOITDXA8wQhokJu2fGkI1skyGJbeI4sR3bshNbkhfJm7xIsrb3/tEtpa3F6laXVG9VPZ9zfNxdXe766e3S47ffeqvKnHOIiEj8FYRdgIiIjA4FvohIQijwRUQSQoEvIpIQCnwRkYQoCruAwVRWVrrq6uqwyxARiZSNGzcecs5VDfSat4FfXV3Nhg0bwi5DRCRSzGzPYK9pSEdEJCEU+CIiCaHAFxFJCAW+iEhCKPBFRBIikMA3sx+aWaOZvTzI62Zm3zazWjPbbGaXBrFdERHJXlA9/H8Blp3l9RuAhek/K4HvBbRdERHJUiDz8J1zT5lZ9VlWWQHc71LXYn7WzCaa2Qzn3P4gth9pW38DBwb8YjSwMRWw5H9CQeHI1RQHx+pg00+guyv7fzPjQjj/XSNXUy5efRpefSrsKnIzeT5c/OdhVyFnMVonXs0C6jKe16eXnRH4ZraS1DcA5s6dO0qlhWztJ6HtGGBZrJy+d8GCt8P0N41gUTGw6SfwxBfJrl0BHJRV+hP4j/4dNGwk+/rDlt43L3gvFJWEW4oMyqszbZ1zq4HVADU1Ncm4M4vrhiV/Bcu+OPS6r6yDNbfk1mtNqu4uwODzx7Jbf91n4KWfjWRFuenugoXvgA88GHYl2Xn6a/DYPfQGv3hptGbpNABzMp7PTi8TEZFRMlqBvxa4NT1bZwnQrPF7EZHRFciQjpk9AFwNVJpZPfB3QDGAc+4fgXXAcqAWaAE+FMR2Y2FY9xTW1+ahDaONvLq/s0+15MCrNpS+gpqlc8sQrzvg40FsK56yPDBnUTmAJ4GI1OcdpVqTS2faSnzlFJgKLIk/Bb6ISEIo8EMX9bFmT0X92EhkP+Oo1p0MCnwfZD30oGGHZInQ5x2p4w3JpcCXGMshhBRYkgAKfBGRhFDghy3qY83eGs6xkeCrGD6visleZI89JIMCP0o07JAskfq8o1RrcinwJb40D1/kDAp8EZGEUOCHLupjzZ6K+rERj0rJTWQLTwQFvg80D18GFKHPO1LHG5JLgS8xpnn4IpkU+CIiCaHAD1vUx5q9FfVrFPlUSw68akPpS4HvBV0PXwYQqc87SrUmlwJf4kvz8EXOoMAXEUkIBX7ooj7W7KmoHxuJ7Gcc1bqTQYHvA83Dl6iL1PGG5FLgS4xpHr5IpkAC38yWmdl2M6s1s1UDvD7XzB43sxfMbLOZLQ9iuyIikr28A9/MCoH7gBuAxcAtZra4z2p/CzzonLsEuBn4br7bjY2ojzV7K+rHRnyqJQdetaH0FUQP/zKg1jm32znXDqwBVvRZxwET0o8rgH0BbDdGsp2HP7JViGciNcwUpVqTqyiA95gF1GU8rwf+pM86nwceNrNPAOOApQFsV+TsIhWYIiNvtA7a3gL8i3NuNrAc+LGZ9du2ma00sw1mtqGpqWmUShMRSYYgAr8BmJPxfHZ6WaaPAA8COOf+AIwBKvu+kXNutXOuxjlXU1VVFUBpURD1sWZPRf3YSGQ/46jWnQxBBP56YKGZzTezElIHZdf2WWcvcC2AmZ1PKvDVhe+hefgyoAh93ho+i4S8A9851wncATwEbCM1G2eLmd1jZjemV/s08FEzexF4APigc5HtwkhkaB6+SKYgDtrinFsHrOuz7O6Mx1uBK4LYloiIDI/OtA1b1MeavRX1YyM+1ZIDr9pQ+lLge0HXw5cBROrzjlKtyaXAl/jS9fBFzqDAD52+AssAIjs0EtW6k0GBLyKSEAp8H+Q6Dz+yvb9RFIuD4REaZorU8YbkUuBLjGkevkgmBX7Y1FuXAUV0v9D+7DUFvohIQijwvZDrPHz1okaEb73TSA0zRanW5FLgS3xpHr7IGRT4ofOsVyl+8O3bhsSCAl9EJCEU+D7QPPzgaR7+6IrU8YbkUuBLjGkevkgmBX7Y1FuXAUV0v9D+7DUFvohIQijwvaB5+MGL+g1QiNgwU5RqTS4FvohIQijwQ+dZrzJOonzilW/fNrIW1bqTQYEvIpIQCnwfaB5+8DQPf3RF6nhDcgUS+Ga2zMy2m1mtma0aZJ33m9lWM9tiZv8WxHZFRCR7Rfm+gZkVAvcB1wH1wHozW+uc25qxzkLgs8AVzrmjZjY13+2KDC3KJ1759m0jS/r26bUgeviXAbXOud3OuXZgDbCizzofBe5zzh0FcM41BrBdERHJQRCBPwuoy3hen16WaRGwyMyeMbNnzWzZQG9kZivNbIOZbWhqagqgtKjQPPzgaR6+SF+jddC2CFgIXA3cAvyTmU3su5JzbrVzrsY5V1NVVTVKpYmIJEMQgd8AzMl4Pju9LFM9sNY51+GcexXYQeo/gGTzrUcZNzl1kD3rTUd234hq3ckQROCvBxaa2XwzKwFuBtb2WefXpHr3mFklqSGe3QFsW0REspR34DvnOoE7gIeAbcCDzrktZnaPmd2YXu0h4LCZbQUeBz7jnDuc77ZjQ/Pwg6d5+KNLxxsiIe9pmQDOuXXAuj7L7s547IA7039ERCQEOtM2TOqpjzDNwx912qe9psAXEUkIBb4XNA8/eJqHP7qiVGtyKfBFRBJCgR8qz3qUcaPr4YcgqnUngwJfRCQhFPg+0Dz84Gke/uiK1PGG5FLgi4gkhAI/TOqpjzDNwx912qe9psAXEUkIBb4XNA8/eDFoI+++dZxNlGpNLgW+iEhCKPBDFYNeqM8iPQ8/7AKGK7KFJ4ICX0QkIRT4Psi6c6l5+Fkbbht51baefes4m0gdb0guBb6ISEIo8MPkVW8yjjQPf9Rpn/aaAl9EJCEU+F7QPPzgxWAM37tvHWcTpVqTS4EvIpIQCvxQedSbjKNIz8OP6r4R1bqTIZDAN7NlZrbdzGrNbNVZ1nuPmTkzqwliuyIikr28A9/MCoH7gBuAxcAtZrZ4gPXKgU8Bz+W7zdjJ9Xr4khAR+rwjdbwhuYLo4V8G1Drndjvn2oE1wIoB1vt74MtAWwDbTDZ9ax7asIdE1LgSX0EE/iygLuN5fXpZLzO7FJjjnPvPs72Rma00sw1mtqGpqSmA0jwX2XHaqNA8/FGnfdprI37Q1swKgK8Dnx5qXefcaudcjXOupqqqaqRLExFJlCACvwGYk/F8dnpZj3LgAuAJM3sNWAKs1YHbTLnOw5dEiNTHHaliEyuIwF8PLDSz+WZWAtwMrO150TnX7JyrdM5VO+eqgWeBG51zGwLYdkLpa/PQYnDilUjA8g5851wncAfwELANeNA5t8XM7jGzG/N9/3hTuIwozcMPQVTrToaiIN7EObcOWNdn2d2DrHt1ENsUEZHc6ExbH2gevgwoQp+3ji9FggI/iiL7dX8UaR6+SD8K/DApuEdYLvPwR66K4YnovqF92msKfBGRhFDgeyHbefgjW4V4JlLj4lGqNbkU+JGkr81Di/Y8/I6ubto6ugN9T+ccD66vo+FYa6DvK9GhwA+VH+EifunqdjS3tvO7LQdoOnE6sPd9ueE4//sXm7niS7/nFxvrRyj4tU/7TIEv8RXRE68KC4yKscW0d3Zz6w+fZ1fTyUDeN/N9Pv2zF3nf9/6b7/x+J1fd+zg7D54IZBviNwW+DzQPX/ooLijgyoVV7D18iptXP0tzS0dO/765pYNvPbqT1vau3mW7m05SYPDkZ67mK++9kH3NbXz14R3sPdLCVx/enl/BkTrekFwK/CjyZJzZazGYhz+zYgz/9tElHDnVzpd+ty2nf/uDZ17lG4/u4DuP7+xdtuvQKeZMLmPelHG8r2YOa1Yu4d73XMgnrzmXh7Yc5KX65qB/BPGMAj9MCm4ZVGrfuGjORD5y5XweeL6OVb/YzGd/uZnG40PfQ+iJ7Y0
|
2021-08-06 23:18:21 +02:00
|
|
|
"text/plain": [
|
|
|
|
"<Figure size 432x288 with 1 Axes>"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
"metadata": {
|
|
|
|
"needs_background": "light"
|
|
|
|
},
|
|
|
|
"output_type": "display_data"
|
|
|
|
}
|
|
|
|
],
|
2021-08-06 20:20:52 +02:00
|
|
|
"source": [
|
|
|
|
"b = rem_low_acc(a, False)\n",
|
|
|
|
"b['left_OVRHandPrefab_pos_X'].plot()\n",
|
|
|
|
"plt.plot(b['LeftHandTrackingAccuracy'])"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 23:18:21 +02:00
|
|
|
"execution_count": 23,
|
2021-08-06 23:49:37 +02:00
|
|
|
"id": "09687aab",
|
2021-08-06 20:20:52 +02:00
|
|
|
"metadata": {},
|
2021-08-06 23:18:21 +02:00
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"data": {
|
|
|
|
"text/plain": [
|
2021-08-06 23:49:37 +02:00
|
|
|
"[<matplotlib.lines.Line2D at 0x7fb108669be0>]"
|
2021-08-06 23:18:21 +02:00
|
|
|
]
|
|
|
|
},
|
|
|
|
"execution_count": 23,
|
|
|
|
"metadata": {},
|
|
|
|
"output_type": "execute_result"
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"data": {
|
2021-08-06 23:49:37 +02:00
|
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAaDUlEQVR4nO3de5RcZZnv8e/T3bmSzrU7kEvnQghoAxqgBUZAcVRMokNwCUq8IYMywyLncI7OuFAchoNrnXXUc5zRIeNMGB3UGY0MHjSjYTLKMBOYQzDNLVdCQhJIh4R0rgRy7e7n/FE7odKpTu3q1K7a767fZ62sVO3aXfX0u9/69a5377e2uTsiIhK+umoXICIi5aFAFxHJCAW6iEhGKNBFRDJCgS4ikhEN1XrhpqYmnzJlSrVeXkQkSE8//fROd28u9FjVAn3KlCm0t7dX6+VFRIJkZi/39ZiGXEREMkKBLiKSEQp0EZGMUKCLiGSEAl1EJCOKBrqZ/cDMdpjZqj4eNzP7rpltMLMVZnZx+csUEZFi4uyhPwDMPMXjs4Dp0b9bge+dflkiIlKqouehu/tSM5tyilXmAD/y3PfwLjOzkWY2zt23lavIYB1+A363AI4ejP8zk98N096XXE1Z8eIS6ChhHoPVwTs/AaPPTq6muHp64Km/gYN7ql1JaVqvhbMurHYVcgrlmFg0AdiSd78jWnZSoJvZreT24pk0aVIZXjrlNi2FR/9HdMdi/IDn3jDTnkiyqmxY/Cew9xXitSuAw9EDcM3Xk6wqnl3rYclXojtx6682hz2b4WP3V7sQOYWKzhR19wXAAoC2trbsX1nDu3P///ET8fZsFn4Kdm9Ktqas6OmBGZ+G6+bHW/9/TgDvSbamuHqifnHDD+H866paSmx/dclb/VlSqxxnuWwFWvLuT4yWiYhIBZUj0BcBn43Odrkc2KfxcxGRyis65GJmPwWuBprMrAP4c2AAgLv/DbAYmA1sAA4ANydVrIiI9C3OWS5zizzuwO1lqyhL+nUB7uwfWiiPfrRTai6InpY6SpSa9pO+aKaoiEhGKNArIuapaRbKKWwpUVJzpbBtg9reIdVauxToIiIZoUAXEckIBXqiQj5wl3IhH3AOdhuHWnftUKCLiGSEAr0SYh/80oGn0pTQXqk8AJnGmvqQyvaT3hToIiIZoUAXEckIBXqSgj34JckKtF+oP6eeAj119KaJR2cQifSmQK8IzRRNREntlcK2DWp7h1Rr7VKgi4hkhAJdRCQjFOiJ0pitFBDsWH6oddcOBXraBPtmr7CQp/6LJESBXgmaKSoFBbS9gzqAW7sU6BKwUqb+J1eFSFoo0EVEMkKBLiKSEQr0JOnAXYJCnimaljpKlJr2k74o0CtCM0WlgKC2d0i11i4FuoQr9Kn/ImWmQBcRyQgFuohIRijQE6WDSFJAsAcXQ627dijQ0ybYN3uF6QwikZMo0CtBU/8TootEV0wq2096U6CLiGSEAl1EJCMU6EnSeLgUFGi/UH9OvViBbmYzzWydmW0wszsLPD7JzB4zs2fNbIWZzS5/qbVCb5p4Qp76L5KMooFuZvXAfGAW0ArMNbPWXqt9DXjQ3S8CbgT+utyFhk1T/6WAoLZ3SLXWrjh76JcCG9x9o7sfARYCc3qt48Dw6PYI4NXylSjSB039FzlBnECfAGzJu98RLct3D/BpM+sAFgP/pdATmdmtZtZuZu2dnZ39KFdERPpSroOic4EH3H0iMBv4sZmd9NzuvsDd29y9rbm5uUwvLRIYjeVLQuIE+lagJe/+xGhZvluABwHc/UlgMNBUjgJrjt7s8WimqMhJ4gT6cmC6mU01s4HkDnou6rXOK8D7Aczs7eQCXWMqx2imqBQU0PYO6gBu7Soa6O7eBcwDlgBryZ3NstrM7jWza6PVvgR8wcyeB34KfM5du5qStNCn/ouUV0Ocldx9MbmDnfnL7s67vQa4oryliYhIKTRTNEn6kCIFBdov1J9TT4EuIpIRCvSKKGWmqPaC4snA1P+gxvVDqrV2KdAlXJopKnICBbqISEYo0EVEMkKBnqiUjdlKOgTbLYItvGYo0CuhlJmiaTtwl1aZmPof0Lh+QKXWMgW6iEhGKNAlYJr6L5JPgS4ikhEK9CRpPFwKCrRfqD+nngI9TTRTtASaKVpZIdVauxToIiIZoUCXcGnqv8gJFOgiIhmhQE9UysZsJR3SNpYfW6h11w4FuohIRijQK0FT/8tPU/8rK6gzcmqXAl0CppmiIvkU6CIiGaFAT5KGT6SgQPuF+nPqKdBFRDJCgV4Rukh0+WVh6n+1CyhFUMXWLAW6iEhGKNAlXJr6L3ICBXqiUvYRX9IhbUM/sYVad+1QoIuIZIQCvRJKmimaaCXZ0a92SlvjBjQMpIlZQVCgi4hkRKxAN7OZZrbOzDaY2Z19rPNxM1tjZqvN7CflLVOkEE39F8nXUGwFM6sH5gMfBDqA5Wa2yN3X5K0zHfgKcIW77zGzsUkVHJRgD35JsgLtF+rPqRdnD/1SYIO7b3T3I8BCYE6vdb4AzHf3PQDuvqO8ZYqISDFxAn0CsCXvfke0LN+5wLlm9p9mtszMZhZ6IjO71czazay9s7OzfxUHqZSZolIzgtreIdVau8p1ULQBmA5cDcwF7jezkb1XcvcF7t7m7m3Nzc1leums0cfaeDIw9V+kzOIE+lagJe/+xGhZvg5gkbsfdfdNwIvkAl4kOZopKnKCOIG+HJhuZlPNbCBwI7Co1zq/ILd3jpk1kRuC2Vi+MkVEpJiige7uXcA8YAmwFnjQ3Veb2b1mdm202hJgl5mtAR4D/tTddyVVdDj0EV8KCHboJ9S6a0fR0xYB3H0xsLjXsrvzbjvwxeif9FbKTFGpIQFt76AO4NYuzRRNm2D33iosExeJFikvBbqISEYo0CVgmvovkk+BniQNn0hBgfYL9efUU6CLiGSEAr0i4k79h2D33iquPzNFy1/FaQlqGCikWmuXAl1EJCMU6BIuTf0XOYECPVFp+4wvqRDswcVQ664dCnQRkYxQoFeCpv5LQQFt76AO4NYuBXraBPtxvMI09V/kJAp0qQ3aw5QaoEBPkva2paBA+4X6c+op0EVEMkKBXhG6SLQUENT2DqnW2qVATx19rI1HF4kW6U2BLiKSEQr0RGmPMFGhTv0P9pNCqHXXDgW6iEhGKNArQTNFpaCAtndQB3BrV0O1C5Begv04XmEBzxR943AXw4CHn+3ggqH7OWfsMKwMgbno+Vf5/hObuGLaGMaNHMLlU0cfz+Fzxjae9vNL+inQRSps38EjDAN+1t7Bf//dUqaPHcZN757C2U1nMG7kEKY2nXHKnz/S1cOfL1rF5p0HuOMD07l0ymi63fnGIy+wde9BVnbspafX365PtLXwldlvY+TQgcn9YlJ1CnQJWPy92kPdPdR39zAgwWrimjByKAD3zb2IJQem85OnXuFrv1h1/PGrpjfRuf8wv/+2sdz+vnM4Y9Bbb9MjXT186u+WsXzzHgCeXLCLQQ11uMOR7h6+f1Mb757WxM43DrN0fSeDG+p58bX9/N0Tm/jt2te475MX83vTxlT2F5aKUaAnScMnqXDgSBe79x/mtY59XFLtYoBjQz9NwwbxqQsn88lLJ/F8xz7ePNzFv6zazn+82MmZwwfxvf94if/csJP7PnkxLaOHcuhoN19+aAXLN+/hf9/wTmZdcBa/XfsaKzv2se61/QwfMoD3nTeWujqjZfRQPnXZ5OOvOGfGBP7rwme5+YHf8c/zrmT6mf0YglF/Tj0FekVopmg1rezYx3iH8SMHV7uUE0Xb28yY0TISgCvOaTr+8L+u3s6XHnye2d95nJuvnMqja19jzbbX+dMPncf1l0wEckE9Z8aEoi/VOn44P/n8Zcz6zuPc9o/P8Pefexcto4eWUmwJ60q16CwXybyVW/cBMCqw8eNrzj+LxXdcxblnNfLdR9ez843D3P+ZNm5/3zn9er6xwwfzlzfO4JVdB5j9ncd5Yv3OMlcs1aY99NTRx9p44rfTyq37mF1nDG4Ib/+lZfRQfn7bu9n1xmFGDR1IXd3p7Sl
|
2021-08-06 23:18:21 +02:00
|
|
|
"text/plain": [
|
|
|
|
"<Figure size 432x288 with 1 Axes>"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
"metadata": {
|
|
|
|
"needs_background": "light"
|
|
|
|
},
|
|
|
|
"output_type": "display_data"
|
|
|
|
}
|
|
|
|
],
|
2021-08-06 20:20:52 +02:00
|
|
|
"source": [
|
|
|
|
"c = norm(b, False)\n",
|
|
|
|
"c['left_OVRHandPrefab_pos_X'].plot()\n",
|
|
|
|
"plt.plot(c['LeftHandTrackingAccuracy'])"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 23:18:21 +02:00
|
|
|
"execution_count": 24,
|
2021-08-06 23:49:37 +02:00
|
|
|
"id": "9ae9b71e",
|
2021-08-06 20:20:52 +02:00
|
|
|
"metadata": {},
|
2021-08-06 23:18:21 +02:00
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"data": {
|
|
|
|
"text/plain": [
|
2021-08-06 23:49:37 +02:00
|
|
|
"[<matplotlib.lines.Line2D at 0x7fb1085d5700>]"
|
2021-08-06 23:18:21 +02:00
|
|
|
]
|
|
|
|
},
|
|
|
|
"execution_count": 24,
|
|
|
|
"metadata": {},
|
|
|
|
"output_type": "execute_result"
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"data": {
|
2021-08-06 23:49:37 +02:00
|
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAb8klEQVR4nO3de5RU5Z3u8e+vuqvpbmnAhubeXETQYC6IHWXiPTEGmBwxE42amBhj4oxL5pgVMy5zM44z65yT5JycxCMnOThxNJkkxNFJhpngkMSYoJlgaG8gIIKA0ojSXEUb6dt7/qhNUzTV1K6mdtV+dz2ftVhdtevtqh/v3vX03u+ut7Y55xAREf+lyl2AiIgUhwJdRCQhFOgiIgmhQBcRSQgFuohIQlSX64VHjRrlpkyZUq6XFxHx0lNPPbXLOdeU67GyBfqUKVNobW0t18uLiHjJzF4e6DENuYiIJIQCXUQkIRToIiIJoUAXEUkIBbqISELkDXQzu8/MdprZ8wM8bmZ2t5ltMrPVZja7+GWKiEg+YfbQ7wfmHufxecD04N+NwPdOvCwRESlU3s+hO+dWmNmU4zRZAPzQZb6Hd6WZjTCzcc65HcUq0luH3oQ/LYaug+F/Z/L7YNrF0dWUFC8uh7YC5jFYCt5zFTSeEl1NYfX2wpPfh4N7y11JYWZeBmPfVe4q5DiKMbFoArAt635bsOyYQDezG8nsxTNp0qQivHTMbVkBj/5tcMdC/ILLvGGmPRFlVcmw7Iuw7xXC9SuAg64OuPTvoqwqnN0bYfmXgjth6y83B3u3wkfvLXchchwlnSnqnFsMLAZoaWlJ/pU1XE/m5189EW7PZsknYM+WaGtKit5emHUtXL4oXPv/NgFcb7Q1hdUbbBdXPgBnXF7WUkL7P2cd2Z4ltorxKZftQHPW/YnBMhERKaFiBPpS4FPBp13mAPs1fi4iUnp5h1zM7KfARcAoM2sDvg6kAZxz3weWAfOBTUAHcH1UxYqIyMDCfMrlmjyPO+DmolWUJIO6AHfyTy0UxyD6KTYXRI9LHQWKTf/JQDRTVEQkIRToJRHyo2nmy0fYYqKg7oph33q1vn2qtXIp0EVEEkKBLiKSEAr0SPl84i7mfD7h7O069rXuyqFAFxFJCAV6KYQ++aUTT4UpoL9ieQIyjjUNIJb9J/0p0EVEEkKBLiKSEAr0KHl78kui5el2oe059hTosaM3TTj6BJFIfwr0ktBM0UgU1F8x7Fuv1rdPtVYuBbqISEIo0EVEEkKBHimN2UoO3o7l+1p35VCgx423b/YS83nqv0hEFOiloJmikpNH69urE7iVS4EuHitk6n90VYjEhQJdRCQhFOgiIgmhQI+STtxFyOeZonGpo0Cx6T8ZiAK9JDRTVHLwan37VGvlUqCLv3yf+i9SZAp0EZGEUKCLiCSEAj1SOokkOXh7ctHXuiuHAj1uvH2zl5g+QSRyDAV6KWjqf0R0keiSiWX/SX8KdBGRhFCgi4gkhAI9ShoPl5w83S60PcdeqEA3s7lmtsHMNpnZ7Tken2Rmj5nZM2a22szmF7/USqE3TTg+T/0XiUbeQDezKmARMA+YCVxjZjP7Nfsq8KBz7kzgauD/FrtQv2nqv+Tg1fr2qdbKFWYP/Wxgk3Nus3OuE1gCLOjXxgHDgtvDgVeLV6LIADT1X+QoYQJ9ArAt635bsCzbncC1ZtYGLAP+OtcTmdmNZtZqZq3t7e2DKFdERAZSrJOi1wD3O+cmAvOBH5nZMc/tnFvsnGtxzrU0NTUV6aVFPKOxfIlImEDfDjRn3Z8YLMt2A/AggHPuj0AtMKoYBVYcvdnD0UxRkWOECfRVwHQzm2pmNWROei7t1+YV4AMAZvYOMoGuMZXDNFNUcvJofXt1Ardy5Q1051w3sBBYDqwn82mWtWZ2l5ldFjS7FficmT0H/BT4tHPa1ZSo+T71X6S4qsM0cs4tI3OyM3vZHVm31wHnFrc0EREphGaKRkkHKZKTp9uFtufYU6CLiCSEAr0kCpkpqr2gcBIw9d+rcX2faq1cCnTxl2aKihxFgS4ikhAKdBGRhFCgRypmY7YSD95uFt4WXjEU6KVQyEzRuJ24i6tETP33aFzfo1IrmQJdRCQhFOjiMU39F8mmQBcRSQgFepQ0Hi45ebpdaHuOPQV6nGimaAE0U7S0fKq1cinQRUQSQoEu/tLUf5GjKNBFRBJCgR6pmI3ZSjzEbSw/NF/rrhwKdBGRhFCgl4Km/hefpv6XllefyKlcCnTxmGaKimRToIuIJIQCPUoaPpGcPN0utD3HngJdRCQhFOgloYtEF18Spv6Xu4BCeFVsxVKgi4gkhAJd/KWp/yJHUaBHKmaH+BIPcRv6Cc3XuiuHAl1EJCEU6KVQ0EzRSCtJjkH1U9w616NhIE3M8oICXUQkIUIFupnNNbMNZrbJzG4foM3HzGydma01s58Ut0yRXDT1XyRbdb4GZlYFLAI+CLQBq8xsqXNuXVab6cCXgHOdc3vNbHRUBXvF25NfEi1Ptwttz7EXZg/9bGCTc26zc64TWAIs6Nfmc8Ai59xeAOfczuKWKSIi+YQJ9AnAtqz7bcGybDOAGWb2BzNbaWZzcz2Rmd1oZq1m1tre3j64ir1UyExRqRherW+faq1cxTopWg1MBy4CrgHuNbMR/Rs55xY751qccy1NTU1Feumk0WFtOAmY+i9SZGECfTvQnHV/YrAsWxuw1DnX5ZzbArxIJuBFoqOZoiJHCRPoq4DpZjbVzGqAq4Gl/dr8gszeOWY2iswQzObilSkiIvnkDXTnXDewEFgOrAcedM6tNbO7zOyyoNlyYLeZrQMeA/7GObc7qqL9oUN8ycHboR9f664ceT+2COCcWwYs67fsjqzbDvhC8E/6K2SmqFQQj9a3VydwK5dmisaNt3tvJZaIi0SLFJcCXUQkIRTo4jFN/RfJpkCPkoZPJCdPtwttz7GnQBcRSQgFekmEnfoP3u69ldxgZooWv4oT4tUwkE+1Vi4FuohIQijQxV+a+i9yFAV6pOJ2jC+x4O3JRV/rrhwKdBGRhFCgl4Km/ktOHq1vr07gVi4Fetx4ezheYpr6L3IMBbpUBu1hSgVQoEdJe9uSk6fbhbbn2FOgi4gkhAK9JHSRaMnBq/XtU62VS4EeOzqsDUcXiRbpT4EuIpIQCvRIaY8wUr5O/ff2SMHXuiuHAl1EJCEU6KWgmaKSk0fr26sTuJWrutwFSD/eHo6XWExnijrnOHCom/0dXew/mPm3r6OLfQc7M/c7uhjZvpkbgYeeauPcxoOMG15XlNde+tyr/OCJLZw7bSTjRtQxZ2pjXw6fOrqhKK8h8aZAF8nhUHdPXwAfCeUgpDsy4bwvWH44uA//6+kd+A9HTXWK9w/ZC8BDT7dx21O/5bzpTfzFmRMY3TCEcSPqmDrqpOPW1tndy9eXPs/WXR3ccsl0zp7SSI9zfOORF9i+7yBr2vbRv4SrWpr50vzTGVFfc8J9I/GlQBePhR8G6HaO3W+8zZ+ee/Wo8N3X0XlMKO/r6OJgV8/Ar2owrDbNiPo0I+rSDKtL09xYz/C6akbU1TCiPrNsRF2a4XVpRtRnlg2vS1ObroItQ+EB+M5Vs/jJ65N4+OntfP5nz/Y9//nTR9F+4BDvP300N198KicNOfI27ezu5RP/sJJVWzN/FP64eDdDqlM4B509vfzguhbeN20Uu948xIqN7dRWV/Hi6wf4hye28Jv1r3PPx2fzZ9NGFt7V4gUFepQ0fBILHZ3d7Nn/Nk/u3cWt65/pW16bTjGirobhdWmG16eZ1FgfBHA6WFaTFcrpvrYNtdWkUicyppzZLsYOq+UL7zmNz18yg2e27eXtrl7+4/nX+P2L7YwZNoTv/f4l/rBpF/d8fDbNjfW83dXDbQ+tZtXWvfzPK9/DvHeO5TfrX2dN2342vH6AYXVpLj5tNKmU0dxYzyfOmdz3igtmTeC/LnmG6+//E/+28DymjxnEEIy259hToJeEZoqW05q2/Yx3cPrYYfzqigv69qpr01XlLSxY36mUcdbkRgDOPXVU38O/Wvsatz74HPO/+zjXnze
|
2021-08-06 23:18:21 +02:00
|
|
|
"text/plain": [
|
|
|
|
"<Figure size 432x288 with 1 Axes>"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
"metadata": {
|
|
|
|
"needs_background": "light"
|
|
|
|
},
|
|
|
|
"output_type": "display_data"
|
|
|
|
}
|
|
|
|
],
|
2021-08-06 20:20:52 +02:00
|
|
|
"source": [
|
|
|
|
"d = interpol(c, False)\n",
|
|
|
|
"d['left_OVRHandPrefab_pos_X'].plot()\n",
|
|
|
|
"plt.plot(d['LeftHandTrackingAccuracy'])"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 23:18:21 +02:00
|
|
|
"execution_count": 25,
|
2021-08-06 23:49:37 +02:00
|
|
|
"id": "29e9063e",
|
2021-07-14 10:15:52 +02:00
|
|
|
"metadata": {},
|
2021-08-06 23:18:21 +02:00
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"name": "stdout",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
2021-08-06 23:49:37 +02:00
|
|
|
"CPU times: user 234 µs, sys: 0 ns, total: 234 µs\n",
|
|
|
|
"Wall time: 252 µs\n"
|
2021-08-06 23:18:21 +02:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"data": {
|
|
|
|
"text/plain": [
|
|
|
|
"(48, 48)"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
"execution_count": 25,
|
|
|
|
"metadata": {},
|
|
|
|
"output_type": "execute_result"
|
|
|
|
}
|
|
|
|
],
|
2021-07-05 15:01:40 +02:00
|
|
|
"source": [
|
2021-07-19 01:21:53 +02:00
|
|
|
"%%time\n",
|
2021-07-17 17:40:05 +02:00
|
|
|
"train = np.array([x['data'] for x in pdata if x['session'] == 1])\n",
|
|
|
|
"test = np.array([x['data'] for x in pdata if x['session'] == 2])\n",
|
2021-07-17 03:32:34 +02:00
|
|
|
"\n",
|
2021-07-17 17:40:05 +02:00
|
|
|
"len(train), len(test)"
|
2021-07-05 15:01:40 +02:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 23:18:21 +02:00
|
|
|
"execution_count": 26,
|
2021-08-06 23:49:37 +02:00
|
|
|
"id": "a52352aa",
|
2021-07-05 15:01:40 +02:00
|
|
|
"metadata": {},
|
2021-08-06 23:18:21 +02:00
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"name": "stderr",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
2021-08-06 23:49:37 +02:00
|
|
|
"100%|██████████| 96/96 [00:36<00:00, 2.62it/s]\n"
|
2021-08-06 23:18:21 +02:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"name": "stdout",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
2021-08-06 23:49:37 +02:00
|
|
|
"(57800, 5, 336) (57800,) (37106, 5, 336) (37106,)\n",
|
|
|
|
"CPU times: user 1min 48s, sys: 14.9 s, total: 2min 3s\n",
|
|
|
|
"Wall time: 37 s\n"
|
2021-08-06 23:18:21 +02:00
|
|
|
]
|
|
|
|
}
|
|
|
|
],
|
2021-07-05 15:01:40 +02:00
|
|
|
"source": [
|
2021-07-17 17:40:05 +02:00
|
|
|
"%%time\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
"\n",
|
2021-07-17 17:40:05 +02:00
|
|
|
"X_train = list()\n",
|
|
|
|
"y_train = list()\n",
|
2021-07-14 10:15:52 +02:00
|
|
|
"\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
"X_test = list()\n",
|
|
|
|
"y_test = list()\n",
|
|
|
|
"\n",
|
2021-08-06 20:20:52 +02:00
|
|
|
"train = list()\n",
|
2021-07-17 17:40:05 +02:00
|
|
|
"test = list()\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
"\n",
|
2021-08-06 20:20:52 +02:00
|
|
|
"for x in tqdm(pdata):\n",
|
2021-07-17 17:40:05 +02:00
|
|
|
" if x['session'] == 1:\n",
|
2021-08-06 20:20:52 +02:00
|
|
|
" train.append(\n",
|
|
|
|
" {\n",
|
|
|
|
" 'label': x['user'],\n",
|
|
|
|
" 'data': list()\n",
|
|
|
|
" })\n",
|
2021-07-17 17:40:05 +02:00
|
|
|
" for y in x['data'].unbatch().as_numpy_iterator():\n",
|
2021-08-06 20:20:52 +02:00
|
|
|
" if not np.isnan(y[0]).any():\n",
|
|
|
|
" X_train.append(y[0])\n",
|
|
|
|
" y_train.append(y[1])\n",
|
|
|
|
" \n",
|
|
|
|
" train[-1]['data'].append(y[0])\n",
|
|
|
|
" if len(train[-1]['data']) == 0:\n",
|
|
|
|
" del train[-1]\n",
|
2021-07-17 17:40:05 +02:00
|
|
|
" if x['session'] == 2:\n",
|
|
|
|
" test.append(\n",
|
|
|
|
" {\n",
|
|
|
|
" 'label': x['user'],\n",
|
|
|
|
" 'data': list()\n",
|
|
|
|
" })\n",
|
|
|
|
" for y in x['data'].unbatch().as_numpy_iterator():\n",
|
2021-08-06 20:20:52 +02:00
|
|
|
" if not np.isnan(y[0]).any():\n",
|
|
|
|
" X_test.append(y[0])\n",
|
|
|
|
" y_test.append(y[1])\n",
|
|
|
|
" \n",
|
|
|
|
" test[-1]['data'].append(y[0])\n",
|
|
|
|
" \n",
|
|
|
|
" if len(test[-1]['data']) == 0:\n",
|
|
|
|
" del test[-1]\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
" \n",
|
2021-07-17 17:40:05 +02:00
|
|
|
"X_train = np.array(X_train)\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
"y_train = np.array(y_train)\n",
|
|
|
|
"X_test = np.array(X_test)\n",
|
2021-08-06 20:20:52 +02:00
|
|
|
"y_test = np.array(y_test)\n",
|
|
|
|
"\n",
|
|
|
|
"print(X_train.shape, y_train.shape, X_test.shape, y_test.shape)"
|
2021-07-14 10:15:52 +02:00
|
|
|
]
|
|
|
|
},
|
2021-07-27 16:00:03 +02:00
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 23:18:21 +02:00
|
|
|
"execution_count": 27,
|
2021-08-06 23:49:37 +02:00
|
|
|
"id": "8c85c181",
|
2021-08-06 20:20:52 +02:00
|
|
|
"metadata": {
|
|
|
|
"tags": []
|
|
|
|
},
|
2021-08-06 23:18:21 +02:00
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"name": "stdout",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
2021-08-06 23:49:37 +02:00
|
|
|
"Key: 1: 1347\n",
|
|
|
|
"Key: 2: 1583\n",
|
|
|
|
"Key: 3: 8568\n",
|
|
|
|
"Key: 4: 3034\n",
|
|
|
|
"Key: 5: 1960\n",
|
|
|
|
"Key: 6: 3311\n",
|
|
|
|
"Key: 7: 3971\n",
|
|
|
|
"Key: 8: 1407\n",
|
|
|
|
"Key: 9: 1135\n",
|
|
|
|
"Key: 10: 7466\n",
|
|
|
|
"Key: 11: 6494\n",
|
|
|
|
"Key: 12: 1813\n",
|
|
|
|
"Key: 13: 3596\n",
|
|
|
|
"Key: 14: 3260\n",
|
|
|
|
"Key: 15: 2825\n",
|
|
|
|
"Key: 16: 6030\n"
|
2021-08-06 23:18:21 +02:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"data": {
|
|
|
|
"text/plain": [
|
|
|
|
"array([<AxesSubplot:ylabel='0'>], dtype=object)"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
"execution_count": 27,
|
|
|
|
"metadata": {},
|
|
|
|
"output_type": "execute_result"
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"data": {
|
2021-08-06 23:49:37 +02:00
|
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAADnCAYAAADGrxD1AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAA7ZElEQVR4nO2deXhU53X/v+feWbSgBSQhiVUgBAhphACx74sdbJE4cWJTbMdrmtgmdZLSpqTpr6FJm6pJ3CZt09Akdi0ntoOT2LGNHduYRYjNrFpAEiCBQBLa932We35/3AGE0GjmztxZdT/PMw9o5r7vezSa79zzvu95zyFmhoaGRugg+NsADQ0NddFEraERYmii1tAIMTRRa2iEGJqoNTRCDE3UGhohhiZqDY0QQxO1hkaIoYlaQyPE0EStoRFiaKLW0AgxNFFraIQYmqg1NEIMTdQaGiGGJmoNjRBD528DNNSBiKoBdAOwAbAyc44fbBABnAZQx8xbfD2+howm6tBiPTO3+HH8bwAoBxDtRxvGPJr7raEKRDQFQC6AX/vblrGOJurQgQF8TERniOirfhj/pwC+DUDyw9gaQ9BEHTqsYuaFAO4DsJ2I1vhqYCLaAqCJmc/4akwNx2iiDhGYuc7+bxOAtwEs8eHwKwF8zr5Y9zsAG4jotz4cX2MIpGUTDX6IKBKAwMzd9v/vA/B9Zv7QD7asA/A32uq3/9Du1F6GiGKJ6A9EVEFE5US03AvDJAI4QkTFAE4CeN8fgtYIDLQ7tZchonwAhcz8ayIyAIhg5g4/m6URwmii9iJEFAOgCMBM1t5oDR+hud/eZQaAZgD/R0TniOjX9jmvhobX0ETtXXQAFgL4BTMvANALYKd/TdIIdbQwUe9SC6CWmT+1//wH+EDUP3/2gAhgOmRPYSqAZPtjEoCJACIAGAAYARiXnvxBY2RfwyTIgSMS5BjyVgePGgCXAdSkV5RrU4oARBO1F2HmBiKqIaI5zHwRwEYAZWqO8fNnD0wHsNj+yAYwC8A0KPjbMgkSgCkKhx4on5teCfn3OQ+gFMCp9IryOoX9aKiMJmrv81cAXrOvfF8B8JS7Hb24dYsOwCIAa0BhGWGxz98H+c7rEXZRKyUMQKb98fDNJ8vnpl8BcBhAAYDD6RXlVzy1T0MZmqi9DDMXAXD7GOSLW7eMB/A5AA9CvtPLC208wCwNdJEQ5rmNIDXd6Jn2x5MAUD43vRayyP8MYG96RXmHimNpjIAm6gDkxa1bEgF8HsAXAazHyH8nkqxXK0VD+iKPB3TvTu0qUwA8Yn9YyuemH/rfzcKr+xcIH5Q+UdrmxXHHLJqoA4QXt26JBfAYZFd2JVzYmbCZL/eIhnSPx2YSfLXgpWdg7fF0WgIgwpRv+hDAawDeLX2itN9HNoQ8mqj9zItbtywF8CyArQDClbSVrHWqJCPwoajRFoXivjBabP/xs/ZHjynf9FsAPyl9orTKV7aEKpqo/YB9weshAN+CvGrtHtw/i5klIvIo3oDhVff7Dj5eIAyO8PQ4yF9sf2nKN/0BQF7pE6VFvrIp1NBE7UNe3LolAsDzkNP+KN1CGokotjVcIl3ybE868dWdmgHzR4vINMolImSPZasp3/QxZHEf9IVtoYQmah9gvzN/BcA/Qg4CUQ2b+XKj4LGoVV39dsgw19sZ9wK415RvOgng3wC8XfpEqRbs4gJamKgXeXHrFnpx65aHIQdo/AIqCxoAJMtVz/+GPrpT7xvZ9XbGEgB/BFBmyjdtU9mkkEQTtZd4ceuWjZDPNu8BkOatcVhqn+pxH97d0pLHAMwfLRzV9XbGXACvm/JN+0z5Jq+9n6GA5n6rzItbt6QC+DmAz/hmRGkaS73NJEQmuNuDL+bUbeNQ3Bvusus9GpsAlJryTXkA/rX0iVJ37v4hjXanVokXt27Rvbh1y7chx0D7SNAyNkvVVU/aM7wv6k/cc70dYQTwPcji3qRivyGBJmoVqN1ZOD8hbOpbkBd0FO01q4FkrvQocMPb7jcDlo8WUaYXuk4DsM+Ub3rNlG9K9EL/QYkmag+o3Vko1u4s/HsAJ9ckfmk2gaz+sEOy1U/wpD17ts3tlPZxKO4Jp1gvDvEIgApTvulZU76J1OiQiMKI6CQRFRPRBSL6JzX69QUhL2oimkNERUMeXUT0TU/7rd1ZmAbgCIB/AWDQCYY5SxNyj3rar1vwYBqzzeJ2cy/PqT/JFga82b+dWMg7DH805ZuiVOhvEMAGZp4P+UjrZiJapkK/XifkRc3MF5k5m5mzIR9b7IOcF9ttancWPgjgHIA7/sjTIuctH29IrPSkbzcJk6x1l9xu7cV9agYsH+Z4xfV2xBcAnDTlm+Z60gnL9Nh/1NsfQbFPHvKiHsZGAFXMfM2dxrU7C6l2Z+H3IGcwuSvXGBEZ1idvGySQzUM7FSNZLrtdGM+bC2U+cL1HYi5kYX/ek06ISCSiIgBNAPYNyWAT0Iw1Uf8FgDfcaVi7szACwJsAdgFwOG/TC8aMnPjPFLplnQdIlmtGd9syiV4TtY9c75GIAvCWKd/0z6Z8k1ufc2a22T28KQCWEPnU43CbMSNqe+aRzwH4vdK2tTsLp0KeP3/JletnjMtaFq2P92ibSSksdU53u62X3G8/uN7DIQDfBfC+Kd803t1O7HnaDwLYrJJdXmXMiBpy4bizzNyopFHtzsLlAE4BWOBqGyIK25j8aDd8WgGSkyVbp1v5wby1UOYn13skNgM4bco3ZbnagIgSiGTbiSgcwD0AKrxjnrqMJVFvg0LXu3Zn4ROQv6EV74EaxLCshRM2HVHazhMkS+V1d9oxiWqbAgDYn02BlPhgJoBjCoJVkgEcJKISyF/q+5h5r9esU5ExIWp7Av17ALzlapvanYV/B+AVyNFLbjEremFOlG68W0JzB5ul0q1tLW+43wxY/7xICLQ5aCSA90z5plxnFzJzCTMvYOYsZs5k5u/7wD5VGBOiZuZeZo5j5k5XrrcHlOR5Oi4RRWyY9FgrfLQVwtYmN+O/1f8YdESiqCeC3J7HepEwAG+b8k0P+tsQbzEmRK2E2p2F/w9yQIkqhIkRC7LGr/WRG25JY7b0KW3ljTDRAHO9h6MHsCdUj3Jqoh5C7c7C7wJQ3c2aG7M0O1IX44sk9zrJcv2y0kZqh4kyYP1zjpChaqfqowPwm1C8Y2uitlO7s/AFAP/sjb6JKGpj8mP13uh7ODbL5Q6lbdQWdUckirsjyKN4dB8hAnjDlG+639+GqIkmagC1OwufBPBTb44RrhuXkxG70utuuGSpiVDaRu0trQPzA9r1Ho4Bcrz4Bn8bohZjXtS1Owu/AODXGCVKTC0yYleawsWoBq8Owt2pipuoeKdmwPrBYmGeah36hjAA75ryTdn+NkQNxrSoa3cWmgD8FrIb5nWIKGbjpMe8vcU1QbK1KIpmUzP2O4hc7+FEQg4rDUbb72DMirp2Z2EM5H1rxe6qJ0TqopfMiVni1SOakrnyhqIGKt6pD8wnxavvAcQMyHnQgloXQW28u9TuLCTId+hZ/hh//vh16WFiZLO3+rdZqhRtUak1p7a73oG+6u2Mz8ALOyC+ZEyKGsA/ANjir8GJaMKG5Ee9du6abS2KUhGrNafuDF7Xezh/b8o3PeBvI9xlzIm6dmfhZsjHJ/1KlH788llRC094p3dbKksDLkXPAQCTOmuEQe56D4UAvGrKN3lUJMFfjClR1+4snAG5ymJA/N4L4zamGoRwb5RzJcla7bInoIb7zYD1g5ygW/UejWjI4aTj/G2IUgLiw+0LancWhkGu9BAw7iGRkLAh+RGvHOezmS/3OL9KhiF6fKvujEBJVyTFedpPgDEPwP/52wiljBlRA/gRFJyJ9hUxhvgVM8aZTqrdr2Stcz35ngru98H51OtxJ4HJl0z5puf9bYQSxoSoa3cWrgSw3d92OCInfvM0vWB0eQ7sEtyXxswurYJ7ulDGgO2DxUK6R50ENj805ZuS/G2EqwS8qInoW/a8y+eJ6A0iClPSvnZnoRFyxFjA/q4CCUnrk7aVqtxtFNsaXJpXezqn7opASWckxXvSR4A
|
2021-08-06 23:18:21 +02:00
|
|
|
"text/plain": [
|
|
|
|
"<Figure size 432x288 with 1 Axes>"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
"metadata": {},
|
|
|
|
"output_type": "display_data"
|
|
|
|
}
|
|
|
|
],
|
2021-07-27 16:00:03 +02:00
|
|
|
"source": [
|
2021-08-06 20:20:52 +02:00
|
|
|
"Xy_train = list(zip(X_train, y_train))\n",
|
|
|
|
"Xy_test = list(zip(X_test, y_test))\n",
|
|
|
|
"train_dict = {\"1\":[], \"2\":[],\"3\":[], \"4\":[], \"5\":[],\"6\":[], \"7\":[], \"8\":[],\"9\":[], \"10\":[], \"11\":[],\"12\":[], \"13\":[], \"14\":[], \"15\": [], \"16\": []}\n",
|
|
|
|
"\n",
|
|
|
|
"[train_dict[str(e[1])].append(e[0]) for e in Xy_train]\n",
|
|
|
|
"[print(f'Key: {k}: {len(v)}') for k, v in train_dict.items()]\n",
|
|
|
|
"pd.DataFrame.from_dict({k: len(v) for k, v in train_dict.items()}, orient='index').plot.pie(subplots=True, legend=False)"
|
2021-07-27 16:00:03 +02:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 23:18:21 +02:00
|
|
|
"execution_count": 28,
|
2021-08-06 23:49:37 +02:00
|
|
|
"id": "92991de2",
|
2021-08-06 20:20:52 +02:00
|
|
|
"metadata": {
|
|
|
|
"tags": []
|
|
|
|
},
|
2021-08-06 23:18:21 +02:00
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"name": "stdout",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
2021-08-06 23:49:37 +02:00
|
|
|
"Key: 1: 790\n",
|
|
|
|
"Key: 2: 59\n",
|
|
|
|
"Key: 3: 4330\n",
|
2021-08-06 23:18:21 +02:00
|
|
|
"Key: 4: 0\n",
|
2021-08-06 23:49:37 +02:00
|
|
|
"Key: 5: 545\n",
|
|
|
|
"Key: 6: 348\n",
|
|
|
|
"Key: 7: 5245\n",
|
|
|
|
"Key: 8: 3558\n",
|
|
|
|
"Key: 9: 2565\n",
|
|
|
|
"Key: 10: 4163\n",
|
|
|
|
"Key: 11: 3654\n",
|
|
|
|
"Key: 12: 2868\n",
|
|
|
|
"Key: 13: 2130\n",
|
|
|
|
"Key: 14: 2360\n",
|
|
|
|
"Key: 15: 2390\n",
|
|
|
|
"Key: 16: 2101\n"
|
2021-08-06 23:18:21 +02:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"data": {
|
|
|
|
"text/plain": [
|
|
|
|
"array([<AxesSubplot:ylabel='0'>], dtype=object)"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
"execution_count": 28,
|
|
|
|
"metadata": {},
|
|
|
|
"output_type": "execute_result"
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"data": {
|
2021-08-06 23:49:37 +02:00
|
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAADnCAYAAADGrxD1AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAA4j0lEQVR4nO2dd3wU95n/P8/M7K56o0iABBIChEArid6rHduxfJeLncRxcomTOJdGEvvO54tzKd7cxbFsx5fkfucLrjH2OTa4xQV3ihAgEBaqIIEoAgRqqNdt8/z+2MUWQtK22Z3d1bxfL70kzc58vw9iPzvzLc/nIWaGhoZG+CCoHYCGhoayaKLW0AgzNFFraIQZmqg1NMIMTdQaGmGGJmoNjTBDE7WGRpihiVpDI8zQRK2hEWZootbQCDM0UWtohBmaqDU0wgxN1BoaYYYmag2NMEMTtYZGmKGJWsMjiCiLiCqGffUQ0T1qx6XxGaSZJGh4CxGJAC4CWMHM59SOR8OBdqfW8IXrAJzWBB1caKLW8IWvAnhJ7SA0rkZ7/NbwCiLSA7gEYCEzt6gdj8ZnaHdqDW/5PICjmqCDD03UGt5yB7RH76BEe/zW8BgiigZwHsBsZu5WOx6Nq9FEraERZmiP3xoaYYYmag2NMEMTtYZGmCGpHYCGf2i8vzgOQPqIr1kAkgCIw76kEb+LcHzYdwNoA3DZ+b0VQCOAc3BMkl1MLVxnD8y/RsMTtImyEKfx/uJkAOsArASQic8EnODnru0AzgI4AqAUwGEA5amF64b83K+GCzRRhxiN9xdnwiHiK19z1Y3oKqwAqvCZyEsB1KUWrlPlTUZECQCeBpADgAF8h5lL1IglkGiiDnIa7y/OAHAzgPUA1gKYrm5EHtMFYCeAVwC8n1q4zhyojoloG4BiZn7aua01ipm7AtW/WmiiDkIa7y9OBfAVALcDWK5yOErSC+BtfCZwvz2qE1E8gAo4NshMqDe5JuogwWQyxQG4PcM+deV1VuO3AZDaMfmZXgDvwCHw95QWOBHlA3gSwHEAeQDKANzNzP1K9hOMaKJWGZPJtAzA3QBuBRBJjJbvmDdPIdBEWm7sA/AcgMdSC9c1KNEgES0FcAjAGmY+TER/AtDDzL9Sov1gRhP1CIjonwF8F46JlWoA32ZmRe8iJpNJAPD3AO6FY5x8FZstOUdny8mLlewzRLDDced+JLVwXbkvDRFRCoBDzJzu/H0dgPuZuYCIGuB4UrADsDHzUp+iDjI0UQ+DiGYA2A9gATMPEtEOAO8y83NKtG8ymaIAfBuOO/OYs9aJcsyB2ywr1ijRZwjzMRzi/sjbBoioGMB3mfkEEZkARDPzfU5RL2Xmy8qEGlxom0+uRQIQSURWAFFwGAH4hMlk0gP4EYBfAJjs6vxO6suzwj6ggxjla98hzPUArm+8v7gCwKMAtnux2eUnAF50znyfgeMDNezR7tQjIKK7ATwIYBDAh8z8dW/bMplMBOBrAP4TQIYn1y63zjmYa5+12tu+w5AzAH6WWrjuVV8bIqKzADrhGGI9wcxP+tpmMDGRJmNcQkSJAL4AhwCnA4gmon/0pi2TyXQDgKMA/g8eChoAaqQLem/6DWNmA3il8f7iPY33F+f52NZaZl4Mh3vLFiJa73t4wYMm6qu5HsBZZm5jZiuA1wF4dLc0mUwzTSbTTgAfAMj3NpABmPMHYG7z9vowZiOAsvp/fffhx26/JcmbBpj5ovN7K4A3EF57ATRRj+A8gJVEFEVEBIcFbq07F5pMJsFkMv0UwDE4doD5BkGqlBrc6nsCIlZ07N0M4MRjt99ypycXElE0EcVe+RnADQBq/BCjamiiHgYzHwbwKhyPzdVw/H1cjrdMJtMCAAcA/AlAjFLxnBSbXE6qTUT6rF2Hz/cfXwrHpONzj91+yweP3X7LTDcvTwawn4gq4dibvpOZ3/dXrGqgTZT5gMlkkgD8EsDPAfhlDHyrecXZJI7xeEwerjDzwM7GJzr7bd0zRrx0wZCw5Zc/fuLzz6sSWBCh3am9xGQypQMoBvAA/CRoACiTzpz3V9uhyLm+Y0dGETSkiLXniQzbHv/B7tce/8HuRDViCxY0UXuByWS6FUA5HDnMfuW8cHm2v/sIFexsO3vk8nurRh4nIf6QFLn8ymadWwFUPP6D3X7/vwlWtMdvD3A+bj8M4F8C2e+NlvzqNHmSMZB9BiMHW98su9Bft+Tqo9RmiP++QELUpBGnWwHct2Xr5j8FKr5gQbtTu4nJZJoKYDcCLGgAKJfOdAW6z2Cj19p56FpBA1LU9WdHETQA6AD88fEf7H7h8R/sjvB/hMGDJmo3MJlM8+DI+FmnRv+t1JMjQ7aq0XcwwMx9e5tfnjXyOIlT9ksG47hrzAz+WtHsl58xbjMm+y/C4EITtQtMJtNqAAfhxa4wxSAknhSbjqrWv8qc7asqG7D1TLv6KDXqY7/icmfZiSmHi2uTS74G4JBxm3GBn0IMKjRRj4PJZLoNwC4Aoz3eBZRK8dyEnPywy7bTn1z+cOSuPtZF/91lIkPseNe2R13av3fOSxucv6YDOGjcZrzOH3EGE5qox8BkMt0NYAeAoBiP9dLgIjOsE65uVUnbWz0MWTf8mCCl7RP1c/LHu84sDla/Znxs2YjD8QB2GrcZP69wmEGFJupRMJlMvwDwRwTT34dgqJbOV6sdRiDpsbQfvDhQv+jqo+JpXcwXV4x3nQz7xZfzH0yRBZthlJcNAN4IZ2EHz5s2SDCZTPcB+K3acYxGrXhx3MfNcIKZe/c2bx+5Rm/Tx9xmJpLGfHpicP/fcv7UN6jvnTJO82EtbE3Uw3A+cj+idhxjYYY1t5cGfTZtCAVO91YcHbT3pgw/JujmHRB0qWNOdjGY983eXtMaey7LjS7CVtiaqJ2YTKYfwfHIHbwQ6Kh09qTaYfgbm2ytP9r+0Qg7J91xXfTN41o8nZxypKg2uWTcR/MRhKWwNVEDMJlM3wHwP2rH4Q5nhJY0tWPwNwdb3xxg8HCrrSF97O0GImFM+62OyKYDe+a8uNGL7q4Ie7MX1wYlE17UJpPpegBPIER8tu0kZ7ZQ1wm14/AX3Za2A02Dp69afxYNeYcFaWrmWNeYxcGaV3N/74sjqAHAq8ZtxmAqYeQ1E1rUJpMpCw5L2pAyYCyTzjSrHYM/YObuvc3b5111kCIqpMjNY9oNybBf2p7/u+QxZro9IRHA28ZtxgQf21GdCStqk8mUBEeFiASVQ/GYJqFzPoNlteNQmlM9RyuH7P3DZ6179LFfm+x0obkGBve/mfPfvQP6nvFmuj0hC8AO4zajqFB7qjAhRW0ymXRw+I/NUTsWb2BC8lmhtULtOJTEJltPlHfsumoiTIxYWSWICamjnc9gLs7YUd0S2+DOTLcnfA7AHxRuM6BMSFEDeBzABpdnBTHl0tkBtWNQCmbmA61vmBn82R1SiC3VRa6+pnrJFeonlxUdTznor5zpnxi3Gb/np7b9zoQTtclk+iaAf1I7Dl/ppP58K+xhIewuS+uB5sGzucMOXTbEfm3MBJrOyOYDu+e+sNHPYf2PcZtRlaw8X5lQonamUD6udhyKQIipFRsr1A7DV5i5u6h5x/zhx6So606RED3qONkiDh17NffRa/Kq/YAOwAvGbcaQ28U3YUS9a3emfl7W/l8DcqTasShFOBj+n+w5UmmWBz51TSVh0gHJkDfqY7UMuWl73u8m2wVboJJsZsHhEOsxRBRBRKVEVElEx4joNwrHNiYTRtQATMnJZ7++ctUrNRERvRfVDkYJQt3w3ypbais69gwbN1OTPvb2nNHOZfDAWwv/u6vf0B1os4NvG7cZ/96L68wANjNzHhxFHW4iooD4pk0IUe/anbkawL8BgE5nyVu67G8x06fXlagclu+EsOE/M/P+ltdkfPYeZF10QRMJEfHXnAvm/emvVTbHnc0ObJSf8pRxm9GjZTN20Of8Vef8CkhOfNiLetfuzBgAzwP4dGaVCPGZc46sWrT4nf2CYO1XLzrfCVXD/05Ly4HWofMLr/xO0vRiUT9v1JrcpyeVFx2bVnyNi2gAmQp
|
2021-08-06 23:18:21 +02:00
|
|
|
"text/plain": [
|
|
|
|
"<Figure size 432x288 with 1 Axes>"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
"metadata": {},
|
|
|
|
"output_type": "display_data"
|
|
|
|
}
|
|
|
|
],
|
2021-07-19 01:21:53 +02:00
|
|
|
"source": [
|
2021-08-06 20:20:52 +02:00
|
|
|
"Xy_test = list(zip(X_test, y_test))\n",
|
|
|
|
"test_dict = {\"1\":[], \"2\":[],\"3\":[], \"4\":[], \"5\":[],\"6\":[], \"7\":[], \"8\":[],\"9\":[], \"10\":[], \"11\":[],\"12\":[], \"13\":[], \"14\":[], \"15\": [], \"16\": []}\n",
|
|
|
|
"\n",
|
|
|
|
"[test_dict[str(e[1])].append(e[0]) for e in Xy_test]\n",
|
|
|
|
"[print(f'Key: {k}: {len(v)}') for k, v in test_dict.items()]\n",
|
|
|
|
"pd.DataFrame.from_dict({k: len(v) for k, v in test_dict.items()}, orient='index').plot.pie(subplots=True, legend=False)"
|
2021-07-19 01:21:53 +02:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 23:18:21 +02:00
|
|
|
"execution_count": 29,
|
2021-08-06 23:49:37 +02:00
|
|
|
"id": "419d603a",
|
2021-07-17 17:40:05 +02:00
|
|
|
"metadata": {},
|
2021-08-06 23:18:21 +02:00
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"name": "stdout",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
2021-08-06 23:49:37 +02:00
|
|
|
"CPU times: user 355 ms, sys: 13 ms, total: 368 ms\n",
|
|
|
|
"Wall time: 367 ms\n"
|
2021-08-06 23:18:21 +02:00
|
|
|
]
|
|
|
|
}
|
|
|
|
],
|
2021-07-27 16:00:03 +02:00
|
|
|
"source": [
|
2021-08-06 20:20:52 +02:00
|
|
|
"%%time\n",
|
|
|
|
"\n",
|
|
|
|
"from sklearn.preprocessing import LabelBinarizer\n",
|
|
|
|
"\n",
|
|
|
|
"\n",
|
|
|
|
"lb = LabelBinarizer()\n",
|
|
|
|
"yy_train = lb.fit_transform(y_train)\n",
|
|
|
|
"yy_test = lb.transform(y_test)"
|
2021-07-27 16:00:03 +02:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 23:18:21 +02:00
|
|
|
"execution_count": 30,
|
2021-08-06 23:49:37 +02:00
|
|
|
"id": "da224750",
|
2021-07-27 16:00:03 +02:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
2021-07-17 17:40:05 +02:00
|
|
|
"source": [
|
|
|
|
"for e in test:\n",
|
|
|
|
" e['label'] = lb.transform([e['label']])\n",
|
|
|
|
" e['data'] = np.array(e['data'])\n",
|
2021-08-06 20:20:52 +02:00
|
|
|
"\n",
|
|
|
|
" \n",
|
|
|
|
"for e in train:\n",
|
|
|
|
" e['label'] = lb.transform([e['label']])\n",
|
|
|
|
" e['data'] = np.array(e['data'])"
|
2021-07-17 17:40:05 +02:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 23:18:21 +02:00
|
|
|
"execution_count": 31,
|
2021-08-06 23:49:37 +02:00
|
|
|
"id": "073c2c51",
|
2021-07-27 16:00:03 +02:00
|
|
|
"metadata": {},
|
2021-08-06 23:18:21 +02:00
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"name": "stdout",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
2021-08-06 23:49:37 +02:00
|
|
|
"(57800, 5, 336)\n",
|
|
|
|
"(57800, 16)\n",
|
|
|
|
"(37106, 5, 336)\n",
|
|
|
|
"(37106, 16)\n"
|
2021-08-06 23:18:21 +02:00
|
|
|
]
|
|
|
|
}
|
|
|
|
],
|
2021-07-27 16:00:03 +02:00
|
|
|
"source": [
|
2021-08-06 20:20:52 +02:00
|
|
|
"print(X_train.shape)\n",
|
|
|
|
"print(yy_train.shape)\n",
|
|
|
|
"print(X_test.shape)\n",
|
|
|
|
"print(yy_test.shape)"
|
2021-07-27 16:00:03 +02:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
2021-08-06 23:49:37 +02:00
|
|
|
"id": "cee9b1c3",
|
2021-07-19 01:21:53 +02:00
|
|
|
"metadata": {},
|
2021-07-17 17:40:05 +02:00
|
|
|
"source": [
|
2021-07-27 16:00:03 +02:00
|
|
|
"# Building Model"
|
2021-07-17 17:40:05 +02:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 23:18:21 +02:00
|
|
|
"execution_count": 32,
|
2021-08-06 23:49:37 +02:00
|
|
|
"id": "75c9ba6d",
|
2021-07-05 15:01:40 +02:00
|
|
|
"metadata": {},
|
2021-07-27 16:00:03 +02:00
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"import tensorflow as tf\n",
|
|
|
|
"from tensorflow.keras.regularizers import l2\n",
|
|
|
|
"from tensorflow.keras.models import Sequential\n",
|
|
|
|
"from tensorflow.keras.layers import Dense, Flatten, BatchNormalization, Dropout\n",
|
|
|
|
"from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau\n",
|
|
|
|
"from tensorflow.keras.optimizers import Adam\n",
|
|
|
|
"\n",
|
|
|
|
"def build_model(shape, classes):\n",
|
|
|
|
" model = Sequential()\n",
|
|
|
|
" \n",
|
|
|
|
" ncount = shape[0]*shape[1]\n",
|
|
|
|
" \n",
|
2021-08-06 20:20:52 +02:00
|
|
|
" model.add(Flatten(input_shape=shape, name='flatten'))\n",
|
2021-07-27 16:00:03 +02:00
|
|
|
" \n",
|
2021-08-06 20:20:52 +02:00
|
|
|
" model.add(Dropout(drop_count, name=f'dropout_{drop_count*100}'))\n",
|
|
|
|
" model.add(BatchNormalization(name='batchNorm'))\n",
|
2021-07-27 16:00:03 +02:00
|
|
|
" \n",
|
2021-08-06 20:20:52 +02:00
|
|
|
" for i in range(2,layer_count+2):\n",
|
2021-07-27 16:00:03 +02:00
|
|
|
" neurons = int(ncount/pow(dense_steps,i))\n",
|
2021-08-06 20:20:52 +02:00
|
|
|
" if neurons <= classes:\n",
|
2021-07-27 16:00:03 +02:00
|
|
|
" break\n",
|
2021-08-06 20:20:52 +02:00
|
|
|
" model.add(Dropout(drop_count*i, name=f'HiddenDropout_{drop_count*i*100:.0f}'))\n",
|
2021-07-27 16:00:03 +02:00
|
|
|
" model.add(Dense(neurons, activation='relu', \n",
|
2021-08-06 20:20:52 +02:00
|
|
|
" kernel_regularizer=l2(0.001), name=f'Hidden_{i}')\n",
|
2021-07-27 16:00:03 +02:00
|
|
|
" )\n",
|
|
|
|
" \n",
|
2021-08-06 20:20:52 +02:00
|
|
|
" model.add(Dense(classes, activation='softmax', name='Output'))\n",
|
2021-07-27 16:00:03 +02:00
|
|
|
" \n",
|
|
|
|
" model.compile(\n",
|
|
|
|
" optimizer=Adam(),\n",
|
|
|
|
" loss=\"categorical_crossentropy\", \n",
|
|
|
|
" metrics=[\"acc\"],\n",
|
|
|
|
" )\n",
|
|
|
|
" \n",
|
2021-08-06 22:16:00 +02:00
|
|
|
" model.summary()\n",
|
2021-08-06 20:20:52 +02:00
|
|
|
" return model"
|
2021-07-27 16:00:03 +02:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 23:18:21 +02:00
|
|
|
"execution_count": 33,
|
2021-08-06 23:49:37 +02:00
|
|
|
"id": "8f71c4bf",
|
2021-07-27 16:00:03 +02:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"checkpoint_file = './goat.weights'\n",
|
|
|
|
"\n",
|
|
|
|
"def train_model(X_train, y_train, X_test, y_test):\n",
|
|
|
|
" model = build_model(X_train[0].shape, 16)\n",
|
|
|
|
" \n",
|
|
|
|
" # Create a callback that saves the model's weights\n",
|
|
|
|
" model_checkpoint = ModelCheckpoint(filepath=checkpoint_path, monitor='loss', \n",
|
|
|
|
"\t\t\tsave_best_only=True)\n",
|
|
|
|
" \n",
|
|
|
|
" reduce_lr = ReduceLROnPlateau(monitor='loss', factor=0.5, patience=5, min_lr=0.0001)\n",
|
|
|
|
"\n",
|
|
|
|
" callbacks = [model_checkpoint, reduce_lr]\n",
|
|
|
|
" \n",
|
|
|
|
" history = model.fit(X_train, \n",
|
|
|
|
" y_train,\n",
|
2021-08-06 20:20:52 +02:00
|
|
|
" epochs=epoch,\n",
|
2021-07-27 16:00:03 +02:00
|
|
|
" batch_size=32,\n",
|
|
|
|
" verbose=2,\n",
|
|
|
|
" validation_data=(X_test, y_test),\n",
|
|
|
|
" callbacks=callbacks\n",
|
|
|
|
" )\n",
|
|
|
|
" \n",
|
|
|
|
" model.load_weights(checkpoint_path)\n",
|
2021-08-06 22:16:00 +02:00
|
|
|
" return model, history"
|
2021-07-27 16:00:03 +02:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 23:18:21 +02:00
|
|
|
"execution_count": 34,
|
2021-08-06 23:49:37 +02:00
|
|
|
"id": "77e0fc90",
|
2021-07-27 16:00:03 +02:00
|
|
|
"metadata": {},
|
2021-08-06 23:18:21 +02:00
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"name": "stdout",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
2021-08-06 23:49:37 +02:00
|
|
|
"Loaded weights...\n",
|
2021-08-06 23:18:21 +02:00
|
|
|
"Model: \"sequential\"\n",
|
|
|
|
"_________________________________________________________________\n",
|
|
|
|
"Layer (type) Output Shape Param # \n",
|
|
|
|
"=================================================================\n",
|
2021-08-06 23:49:37 +02:00
|
|
|
"flatten (Flatten) (None, 1680) 0 \n",
|
2021-08-06 23:18:21 +02:00
|
|
|
"_________________________________________________________________\n",
|
2021-08-06 23:49:37 +02:00
|
|
|
"dropout_10.0 (Dropout) (None, 1680) 0 \n",
|
2021-08-06 23:18:21 +02:00
|
|
|
"_________________________________________________________________\n",
|
2021-08-06 23:49:37 +02:00
|
|
|
"batchNorm (BatchNormalizatio (None, 1680) 6720 \n",
|
2021-08-06 23:18:21 +02:00
|
|
|
"_________________________________________________________________\n",
|
2021-08-06 23:49:37 +02:00
|
|
|
"HiddenDropout_20 (Dropout) (None, 1680) 0 \n",
|
2021-08-06 23:18:21 +02:00
|
|
|
"_________________________________________________________________\n",
|
2021-08-06 23:49:37 +02:00
|
|
|
"Hidden_2 (Dense) (None, 186) 312666 \n",
|
2021-08-06 23:18:21 +02:00
|
|
|
"_________________________________________________________________\n",
|
2021-08-06 23:49:37 +02:00
|
|
|
"HiddenDropout_30 (Dropout) (None, 186) 0 \n",
|
2021-08-06 23:18:21 +02:00
|
|
|
"_________________________________________________________________\n",
|
2021-08-06 23:49:37 +02:00
|
|
|
"Hidden_3 (Dense) (None, 62) 11594 \n",
|
2021-08-06 23:18:21 +02:00
|
|
|
"_________________________________________________________________\n",
|
2021-08-06 23:49:37 +02:00
|
|
|
"HiddenDropout_40 (Dropout) (None, 62) 0 \n",
|
2021-08-06 23:18:21 +02:00
|
|
|
"_________________________________________________________________\n",
|
2021-08-06 23:49:37 +02:00
|
|
|
"Hidden_4 (Dense) (None, 20) 1260 \n",
|
2021-08-06 23:18:21 +02:00
|
|
|
"_________________________________________________________________\n",
|
2021-08-06 23:49:37 +02:00
|
|
|
"Output (Dense) (None, 16) 336 \n",
|
2021-08-06 23:18:21 +02:00
|
|
|
"=================================================================\n",
|
2021-08-06 23:49:37 +02:00
|
|
|
"Total params: 332,576\n",
|
|
|
|
"Trainable params: 329,216\n",
|
|
|
|
"Non-trainable params: 3,360\n",
|
2021-08-06 23:18:21 +02:00
|
|
|
"_________________________________________________________________\n",
|
2021-08-06 23:49:37 +02:00
|
|
|
"CPU times: user 80.5 ms, sys: 3.3 ms, total: 83.8 ms\n",
|
|
|
|
"Wall time: 79.5 ms\n"
|
2021-08-06 23:18:21 +02:00
|
|
|
]
|
|
|
|
}
|
|
|
|
],
|
2021-07-17 17:40:05 +02:00
|
|
|
"source": [
|
|
|
|
"%%time\n",
|
2021-08-06 22:16:00 +02:00
|
|
|
"\n",
|
|
|
|
"if not os.path.isdir(checkpoint_dir) or create_new:\n",
|
2021-08-06 20:20:52 +02:00
|
|
|
" tf.keras.backend.clear_session()\n",
|
|
|
|
" model, history = train_model(np.array(X_train), np.array(yy_train), np.array(X_test), np.array(yy_test))\n",
|
|
|
|
"else:\n",
|
|
|
|
" print(\"Loaded weights...\")\n",
|
2021-08-06 22:16:00 +02:00
|
|
|
" model = build_model(X_train[0].shape, 16)\n",
|
2021-08-06 20:20:52 +02:00
|
|
|
" model.load_weights(checkpoint_path)"
|
2021-07-19 01:21:53 +02:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
2021-08-06 23:49:37 +02:00
|
|
|
"id": "f2e6f8ad",
|
2021-07-19 01:21:53 +02:00
|
|
|
"metadata": {},
|
|
|
|
"source": [
|
|
|
|
"# Eval"
|
2021-07-17 17:40:05 +02:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 23:18:21 +02:00
|
|
|
"execution_count": 35,
|
2021-08-06 23:49:37 +02:00
|
|
|
"id": "b7ede2b1",
|
2021-07-17 17:40:05 +02:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"def predict(model, entry):\n",
|
|
|
|
" p_dict = dict()\n",
|
2021-08-06 20:20:52 +02:00
|
|
|
" predictions = np.argmax(model.predict(entry), axis=-1)\n",
|
2021-07-17 17:40:05 +02:00
|
|
|
" for p in predictions:\n",
|
|
|
|
" if p in p_dict:\n",
|
|
|
|
" p_dict[p] += 1\n",
|
|
|
|
" else:\n",
|
|
|
|
" p_dict[p] = 1\n",
|
|
|
|
" prediction = max(p_dict, key=p_dict.get)\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
" return prediction+1"
|
2021-07-17 17:40:05 +02:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 23:18:21 +02:00
|
|
|
"execution_count": 36,
|
2021-08-06 23:49:37 +02:00
|
|
|
"id": "a71bb247",
|
2021-07-17 17:40:05 +02:00
|
|
|
"metadata": {},
|
2021-08-06 23:18:21 +02:00
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"name": "stdout",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
2021-08-06 23:49:37 +02:00
|
|
|
"CPU times: user 3.36 s, sys: 529 ms, total: 3.89 s\n",
|
|
|
|
"Wall time: 2.95 s\n"
|
2021-08-06 23:18:21 +02:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"data": {
|
|
|
|
"text/plain": [
|
2021-08-06 23:49:37 +02:00
|
|
|
"(43, 43)"
|
2021-08-06 23:18:21 +02:00
|
|
|
]
|
|
|
|
},
|
|
|
|
"execution_count": 36,
|
|
|
|
"metadata": {},
|
|
|
|
"output_type": "execute_result"
|
|
|
|
}
|
|
|
|
],
|
2021-07-05 15:01:40 +02:00
|
|
|
"source": [
|
2021-07-19 01:21:53 +02:00
|
|
|
"%%time\n",
|
|
|
|
"\n",
|
2021-07-17 17:40:05 +02:00
|
|
|
"ltest = [lb.inverse_transform(e['label'])[0] for e in test]\n",
|
2021-08-06 20:20:52 +02:00
|
|
|
"ptest = [predict(model, e['data']) for e in test]\n",
|
|
|
|
"\n",
|
|
|
|
"len(ltest), len(ptest)"
|
2021-07-17 17:40:05 +02:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 23:18:21 +02:00
|
|
|
"execution_count": 37,
|
2021-08-06 23:49:37 +02:00
|
|
|
"id": "ab3ecfc9",
|
2021-07-17 17:40:05 +02:00
|
|
|
"metadata": {},
|
2021-08-06 23:18:21 +02:00
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"name": "stdout",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
2021-08-06 23:49:37 +02:00
|
|
|
"CPU times: user 3.85 s, sys: 448 ms, total: 4.3 s\n",
|
|
|
|
"Wall time: 2.99 s\n"
|
2021-08-06 23:18:21 +02:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"data": {
|
|
|
|
"text/plain": [
|
2021-08-06 23:49:37 +02:00
|
|
|
"(47, 47)"
|
2021-08-06 23:18:21 +02:00
|
|
|
]
|
|
|
|
},
|
|
|
|
"execution_count": 37,
|
|
|
|
"metadata": {},
|
|
|
|
"output_type": "execute_result"
|
|
|
|
}
|
|
|
|
],
|
2021-07-17 17:40:05 +02:00
|
|
|
"source": [
|
2021-07-19 01:21:53 +02:00
|
|
|
"%%time\n",
|
|
|
|
"\n",
|
2021-07-17 17:40:05 +02:00
|
|
|
"ltrain = [lb.inverse_transform(e['label'])[0] for e in train]\n",
|
2021-08-06 20:20:52 +02:00
|
|
|
"ptrain = [predict(model, e['data']) for e in train]\n",
|
|
|
|
"\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
"\n",
|
2021-08-06 20:20:52 +02:00
|
|
|
"len(ltrain), len(ptrain)"
|
2021-07-17 17:40:05 +02:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 23:18:21 +02:00
|
|
|
"execution_count": 38,
|
2021-08-06 23:49:37 +02:00
|
|
|
"id": "ac226caa",
|
2021-07-17 17:40:05 +02:00
|
|
|
"metadata": {},
|
2021-08-06 23:18:21 +02:00
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"data": {
|
|
|
|
"text/plain": [
|
|
|
|
"({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},\n",
|
|
|
|
" {1, 2, 3, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16})"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
"execution_count": 38,
|
|
|
|
"metadata": {},
|
|
|
|
"output_type": "execute_result"
|
|
|
|
}
|
|
|
|
],
|
2021-08-06 20:20:52 +02:00
|
|
|
"source": [
|
|
|
|
"set(ltrain), set(ltest)"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 23:18:21 +02:00
|
|
|
"execution_count": 39,
|
2021-08-06 23:49:37 +02:00
|
|
|
"id": "3c3bac5d",
|
2021-08-06 20:20:52 +02:00
|
|
|
"metadata": {},
|
2021-08-06 23:18:21 +02:00
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"data": {
|
2021-08-06 23:49:37 +02:00
|
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjEAAAGtCAYAAADnIyVRAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAABZtklEQVR4nO3df3xU933n+9dHo/Az2IzAMyoKwSqGveaWGIEcoo1hXYFrEuTEvYDrBBzTBtNlW6d7e8HXdW3xo9f0h6l7l21LQyJI7Rb7ui3u2ihtqEVtvDZxUA0aAmmNoQKZwFiCUdZrughmvvePGSaSQICkOaM5Z97Px2MezJlz5ry/3yNp5sP3/DLnHCIiIiJ+UzLUDRAREREZCBUxIiIi4ksqYkRERMSXVMSIiIiIL6mIEREREV9SESMiIiK+pCJGREREPGdm28zsQzP7YR/zzcw2m9n7ZhYzs5nXW6eKGBEREcmH7wALrjH/C8CUzGMlsOV6K1QRIyIiIp5zzu0Fzl1jkS8Dz7m07wNjzexnrrXO0lw2MJdWrVqV10sJb9ly3YJPRER62bdvX17zampq8po3BCyvYWa5/K79VdIjKJdtdc5t7cf7K4C2btMfZF473dcbCraIEREREf/IFCz9KVoGTUWMiIhIkTLL68DP9ZwCJnab/lTmtT7pmBgREZEiZWY5e+TAK8DXMmcpfQ74iXOuz11JoJEYERERyQMzewG4GxhvZh8Aa4FPADjn/gz4LvBF4H3gPPDL11unihgREZEilc/dSc65r1xnvgN+rT/rVBEjIiJSpEpK/H1Uib9bLyIiIkXLVyMx06ZN44EHHsDMeOutt9i9e/cVy8ycOZO6ujqcc5w6dYpt27YxdepUFi9enF2mvLychoYGWlpaBtWevXv38vTTT5NKpViyZAkrV668/pt8kKU85SmvePJynRWLxdixYwepVIq5c+dSV1fXY/6ePXvYs2cPZsaIESNYvnw5FRUVHD9+nO3bt2eXu//++5k1a9ag2gLB/tnlQoGdndRvvilizIwHH3yQzZs3k0gkePzxx4nFYpw5cya7zC233MKCBQvYtGkT58+fZ8yYMQC89957bNy4EYBRo0axYcMGjhw5Mqj2JJNJNmzYwPbt24lGoyxevJja2lpuu+22Qa13qLOUpzzlFU9errNSqRTPP/88a9asoaysjPXr11NVVUVFRUV2mZqaGmprawE4cOAAL7zwAqtXr6aiooJ169YRCoXo7OzkqaeeYsaMGYRCoYLpX6Hl5YLfixjf7E669dZbaW9vp6Ojg2QySXNzM3fccUePZe666y7eeOMNzp8/D8BHH310xXpmzpzJ4cOHuXjx4qDaE4vFmDRpEhMnTmTYsGEsXLiQpqamQa2zELKUpzzlFU9errOOHz9ONBolEolQWlrK7NmzOXDgQI9lRo4cmX1+4cKF7Jfo8OHDswXLxYsXc/LlGuSfnaTlvYgxs+ueMnU1Y8eOJZFIZKcTiQRjx47tsUwkEiESibB69Woee+wxpk2bdsV6qqur2b9//0Ca0EM8Hqe8vDw7HY1Gicfjg17vUGcpT3nKK568XGclEgnKysqy0+FwuMfn9mWvvfYaa9as4aWXXmLp0qXZ148dO8YTTzzBk08+ycMPPzyoURgI9s8uVwrsOjH9NhQjMev7mmFmK82s2cyaB7K7JxQKEYlEePbZZ2loaGDp0qU9qv6bbrqJCRMmDHpXkoiIDNz8+fN55plnWLJkCa+++mr29cmTJ7Nx40bWrl3Lrl276OrqGsJWFgcVMVdhZrE+HoeAaF/vc85tdc5VO+eqe4+idHZ2Eg6Hs9PhcJjOzs4eyyQSCWKxGKlUirNnz/Lhhx8SiUSy82fNmsXBgwdJpVKD7mM0Gu1xPE48Hica7bNrvslSnvKUVzx5uc4Kh8OcO/fTmxQnEoken9u9zZ49m3ffffeK1ydMmMCIESM4deqaV5y/riD/7CTNq5GYKPA14L6rPM4OZIUnTpwgEokwbtw4QqEQ1dXVxGKxHsu0tLQwdepUAEaPHk0kEqGjoyM7/84776S5uXkg8VeYPn06ra2ttLW10dXVRWNjY/ZgtVzLZ5bylKe84snLdVZlZSXxeJz29nYuXbrEO++8Q1VVVY9lun/Jt7S0ZL/k29vbSSaTAHR0dHD69GnGjx8/4LZAsH92ueL3kRivzk7aBXzSOXew9wwze30gK0ylUrz44os8+uijlJSU8Pbbb3P69Gnq6uo4efIksViMI0eOcPvtt1NfX08qleLll1/m448/BqCsrIxwOMzRo0cH06+s0tJS6uvrWbFiBclkkkWLFjFlypScrHsos5SnPOUVT16us0KhEMuWLWPTpk2kUinmzJlDRUUFO3fupLKykqqqKpqamjh8+DChUIjRo0fzyCOPAOmzSBsbGwmFQpSUlPDQQw9lzzAtlP4VWl4u+P1id5a+ym/hWbVqVV4btmXLlnzGiYgEwr59+/KaV1NTk9e8IZDXIY0xY8bk7Lv2o48+yvtwjG+uEyMiIiK55ffrxKiIERERKVJ+L2L8vTNMREREipZGYkRERIqU30diVMSIiIgUKb8XMdqdJCIiIr6kkRgREZEi5feRmIK9TgxQsA3LhXz/4hTwz1lERH4qr18O48ePz9mXQ0dHR94rIu1OEhEREV/S7iQREZEi5ffdSSpiREREipTfixjtThIRERFf0kiMiIhIkfL7SIyKGBERkSKlIkZERER8ye9FTKCOidm7dy/33nsv99xzD1u3bg1UXkNDA/F4nEOHDnmac1mQt6XylKe8oclSnuScc65QH/1y6dIlN2/ePHfy5El34cIFd99997mjR4/2dzV5yyN9Mb8bfsyZM8dVVVW5Q4cO9fu96R9z/vrWX8pTnvKGJi/IffNxXl6/aydMmOBy9ch3251z3o3EmNn/ZmbzzOyTvV5f4EVeLBZj0qRJTJw4kWHDhrFw4UKampq8iBqSvDfffJNz5855tv7ugr4tlac85eU/S3mFycxy9hgKnhQxZvYN4L8BjwI/NLMvd5u90YvMeDxOeXl5djoajRKPx72IGpK8fAr6tlSe8pSX/yzliRe8Gol5BJjlnLsfuBt4ysx+IzOvz3LNzFaaWbOZNWtfooiIiLf8PhLj1dlJJc65/wngnGs1s7uBvzazSVyjiHHObQUuVy/9uilVNBrlzJkz2el4PE40Gu1nsws3L5+Cvi2Vpzzl5T9LeYVJZyddXdzMZlyeyBQ0dcB4YLoXgdOnT6e1tZW2tja6urpobGyktrbWi6ghycunoG9L5SlPefnPUp54wauRmK8Bl7q/4Jy7BHzNzL7pRWBpaSn19fWsWLGCZDLJokWLmDJlihdRQ5K3Y8cO7r77bsaPH09bWxtr165l27ZtnmQFfVsqT3nKy3+W8gqT30dizLl+7bXJp4JtWC7k+xengH/OIiLyU3n9crj11ltz9uXQ2tqa94ooUBe7ExERkeKh2w6IiIgUqZISf49lqIgREREpUn4/JsbfJZiIiIgULY3EiIiIFCm/j8SoiBERESlSfi9itDtJREREfEkjMSIiIkXK7yMxKmJERESKlN9PsfZ360VERKRoaSRGRESkSGl3koiIiPiS34sY7U4SERERX9JIjIiISJHy+4G9KmJERESKlHYnFZC9e/dy7733cs8997B169ZA5TU0NBCPxzl06JCnOZcFeVsqT3nKG5os5UnOOecK9dEvly5dcvPmzXMnT550Fy5ccPfdd587evRof1eTtzygX485c+a4qqoqd+jQoX6/N/1jzl/f+kt5ylPe0OQFuW8+zsvrd+1nPvMZl6tHvtvunPNuJMbMPmtmd2aeTzOz3zSzL3qVF4vFmDRpEhMnTmTYsGEsXLiQpqYmr+Lynvfmm29y7tw5z9bfXdC3pfKUp7z8ZymvMJlZzh5DwZMixszWApuBLWb2u8AfA6OBx83st73IjMfjlJeXZ6ej0SjxeNyLqCHJy6egb0vlKU95+c9SnnjBq5GYxcDngbnArwH3O+d+B7gX+KW+3mRmK82s2cyatS9RRETEW34fifHq7KRLzrkkcN7
|
2021-08-06 23:18:21 +02:00
|
|
|
"text/plain": [
|
|
|
|
"<Figure size 720x504 with 2 Axes>"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
"metadata": {
|
|
|
|
"needs_background": "light"
|
|
|
|
},
|
|
|
|
"output_type": "display_data"
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"name": "stdout",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
|
|
|
" precision recall f1-score support\n",
|
|
|
|
"\n",
|
2021-08-06 23:49:37 +02:00
|
|
|
" 1 1.00 0.67 0.80 3\n",
|
|
|
|
" 2 0.00 0.00 0.00 1\n",
|
|
|
|
" 3 0.43 1.00 0.60 3\n",
|
2021-08-06 23:18:21 +02:00
|
|
|
" 4 0.00 0.00 0.00 0\n",
|
2021-08-06 23:49:37 +02:00
|
|
|
" 5 0.00 0.00 0.00 3\n",
|
|
|
|
" 6 0.50 0.33 0.40 3\n",
|
|
|
|
" 7 1.00 1.00 1.00 3\n",
|
|
|
|
" 8 1.00 0.33 0.50 3\n",
|
|
|
|
" 9 0.00 0.00 0.00 3\n",
|
|
|
|
" 10 0.40 0.67 0.50 3\n",
|
2021-08-06 23:18:21 +02:00
|
|
|
" 11 0.00 0.00 0.00 3\n",
|
2021-08-06 23:49:37 +02:00
|
|
|
" 12 0.60 1.00 0.75 3\n",
|
|
|
|
" 13 0.75 1.00 0.86 3\n",
|
2021-08-06 23:18:21 +02:00
|
|
|
" 14 0.00 0.00 0.00 3\n",
|
2021-08-06 23:49:37 +02:00
|
|
|
" 15 0.00 0.00 0.00 3\n",
|
|
|
|
" 16 0.50 1.00 0.67 3\n",
|
2021-08-06 23:18:21 +02:00
|
|
|
"\n",
|
2021-08-06 23:49:37 +02:00
|
|
|
" accuracy 0.49 43\n",
|
|
|
|
" macro avg 0.39 0.44 0.38 43\n",
|
|
|
|
"weighted avg 0.43 0.49 0.42 43\n",
|
2021-08-06 23:18:21 +02:00
|
|
|
"\n",
|
2021-08-06 23:49:37 +02:00
|
|
|
"CPU times: user 646 ms, sys: 195 ms, total: 840 ms\n",
|
|
|
|
"Wall time: 610 ms\n"
|
2021-08-06 23:18:21 +02:00
|
|
|
]
|
|
|
|
}
|
|
|
|
],
|
2021-07-17 17:40:05 +02:00
|
|
|
"source": [
|
2021-07-19 01:21:53 +02:00
|
|
|
"%%time\n",
|
|
|
|
"\n",
|
2021-07-17 17:40:05 +02:00
|
|
|
"from sklearn.metrics import confusion_matrix\n",
|
|
|
|
"import seaborn as sn\n",
|
|
|
|
"\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
"from sklearn.metrics import classification_report\n",
|
|
|
|
"\n",
|
2021-08-06 20:20:52 +02:00
|
|
|
"set_digits = set(ltrain)\n",
|
2021-07-17 17:40:05 +02:00
|
|
|
"\n",
|
2021-08-06 20:20:52 +02:00
|
|
|
"train_cm = confusion_matrix(ltrain, ptrain, labels=list(set_digits), normalize='true')\n",
|
|
|
|
"test_cm = confusion_matrix(ltest, ptest, labels=list(set_digits), normalize='true')\n",
|
2021-07-17 17:40:05 +02:00
|
|
|
"\n",
|
|
|
|
"df_cm = pd.DataFrame(test_cm, index=set_digits, columns=set_digits)\n",
|
|
|
|
"plt.figure(figsize = (10,7))\n",
|
|
|
|
"sn_plot = sn.heatmap(df_cm, annot=True, cmap=\"Greys\")\n",
|
|
|
|
"plt.ylabel(\"True Label\")\n",
|
|
|
|
"plt.xlabel(\"Predicted Label\")\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
"plt.show()\n",
|
|
|
|
"\n",
|
|
|
|
"print(classification_report(ltest, ptest, zero_division=0))"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 23:18:21 +02:00
|
|
|
"execution_count": 40,
|
2021-08-06 23:49:37 +02:00
|
|
|
"id": "43acba77",
|
2021-07-19 01:21:53 +02:00
|
|
|
"metadata": {},
|
2021-08-06 23:49:37 +02:00
|
|
|
"outputs": [],
|
2021-07-19 01:21:53 +02:00
|
|
|
"source": [
|
2021-08-06 20:20:52 +02:00
|
|
|
"def plot_keras_history(history, name='', acc='acc'):\n",
|
|
|
|
" \"\"\"Plots keras history.\"\"\"\n",
|
|
|
|
" import matplotlib.pyplot as plt\n",
|
|
|
|
"\n",
|
|
|
|
" training_acc = history.history[acc]\n",
|
|
|
|
" validation_acc = history.history['val_' + acc]\n",
|
|
|
|
" loss = history.history['loss']\n",
|
|
|
|
" val_loss = history.history['val_loss']\n",
|
|
|
|
"\n",
|
|
|
|
" epochs = range(len(training_acc))\n",
|
|
|
|
"\n",
|
|
|
|
" plt.ylim(0, 1)\n",
|
|
|
|
" plt.plot(epochs, training_acc, 'tab:blue', label='Training acc')\n",
|
|
|
|
" plt.plot(epochs, validation_acc, 'tab:orange', label='Validation acc')\n",
|
|
|
|
" plt.title('Training and validation accuracy ' + name)\n",
|
|
|
|
" plt.legend()\n",
|
|
|
|
"\n",
|
|
|
|
" plt.figure()\n",
|
|
|
|
"\n",
|
|
|
|
" plt.plot(epochs, loss, 'tab:green', label='Training loss')\n",
|
|
|
|
" plt.plot(epochs, val_loss, 'tab:red', label='Validation loss')\n",
|
|
|
|
" plt.title('Training and validation loss ' + name)\n",
|
|
|
|
" plt.legend()\n",
|
|
|
|
" plt.show()\n",
|
|
|
|
" plt.close()\n",
|
2021-08-06 22:16:00 +02:00
|
|
|
"if 'history' in locals():\n",
|
|
|
|
" plot_keras_history(history)"
|
2021-08-06 20:20:52 +02:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 23:18:21 +02:00
|
|
|
"execution_count": 41,
|
2021-08-06 23:49:37 +02:00
|
|
|
"id": "af999e08",
|
2021-08-06 20:20:52 +02:00
|
|
|
"metadata": {},
|
2021-08-06 23:18:21 +02:00
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"name": "stdout",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
2021-08-06 23:49:37 +02:00
|
|
|
"Scenario: SYN\n",
|
2021-08-06 23:18:21 +02:00
|
|
|
"Window Size: 5\n",
|
|
|
|
"Strides: 1\n",
|
|
|
|
"Epochs: 50\n",
|
|
|
|
"HiddenL Count: 3\n",
|
|
|
|
"Neuron Factor: 3\n",
|
|
|
|
"Drop Factor: 0.1\n"
|
|
|
|
]
|
|
|
|
}
|
|
|
|
],
|
2021-08-06 20:20:52 +02:00
|
|
|
"source": [
|
|
|
|
"print(f'Scenario: {cenario}')\n",
|
|
|
|
"print(f'Window Size: {win_sz}')\n",
|
|
|
|
"print(f'Strides: {stride_sz}')\n",
|
|
|
|
"print(f'Epochs: {epoch}')\n",
|
|
|
|
"print(f'HiddenL Count: {layer_count}')\n",
|
|
|
|
"print(f'Neuron Factor: {dense_steps}')\n",
|
|
|
|
"print(f'Drop Factor: {drop_count}')"
|
2021-07-05 15:01:40 +02:00
|
|
|
]
|
2021-07-17 17:40:05 +02:00
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 23:49:37 +02:00
|
|
|
"execution_count": 42,
|
|
|
|
"id": "b16af0c6",
|
2021-07-17 17:40:05 +02:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
2021-08-06 23:49:37 +02:00
|
|
|
"source": [
|
|
|
|
"exit()"
|
|
|
|
]
|
2021-07-05 15:01:40 +02:00
|
|
|
}
|
|
|
|
],
|
|
|
|
"metadata": {
|
|
|
|
"kernelspec": {
|
|
|
|
"display_name": "Python 3",
|
|
|
|
"language": "python",
|
|
|
|
"name": "python3"
|
|
|
|
},
|
|
|
|
"language_info": {
|
|
|
|
"codemirror_mode": {
|
|
|
|
"name": "ipython",
|
|
|
|
"version": 3
|
|
|
|
},
|
|
|
|
"file_extension": ".py",
|
|
|
|
"mimetype": "text/x-python",
|
|
|
|
"name": "python",
|
|
|
|
"nbconvert_exporter": "python",
|
|
|
|
"pygments_lexer": "ipython3",
|
2021-07-14 10:15:52 +02:00
|
|
|
"version": "3.8.10"
|
2021-08-06 20:20:52 +02:00
|
|
|
},
|
|
|
|
"toc-showtags": false
|
2021-07-05 15:01:40 +02:00
|
|
|
},
|
|
|
|
"nbformat": 4,
|
|
|
|
"nbformat_minor": 5
|
|
|
|
}
|