1579 lines
199 KiB
Plaintext
1579 lines
199 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "b91d212e",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Constants"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "1bf63c18",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import os\n",
|
|
"\n",
|
|
"os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' # this is required\n",
|
|
"os.environ['CUDA_VISIBLE_DEVICES'] = '0' # set to '0' for GPU0, '1' for GPU1 or '2' for GPU2. Check \"gpustat\" in a terminal."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "61e91687",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import pandas as pd\n",
|
|
"\n",
|
|
"glob_path = '/opt/iui-datarelease3-sose2021/*.csv'\n",
|
|
"\n",
|
|
"pickle_file = '../data.pickle'\n",
|
|
"\n",
|
|
"checkpoint_path = \"training_1/cp.ckpt\"\n",
|
|
"checkpoint_dir = os.path.dirname(checkpoint_path)\n",
|
|
"\n",
|
|
"pd.set_option('display.float_format', lambda x: '%.2f' % x)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "85e6ab7c",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Config"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "3f97f28e",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Possibilities: 'SYY', 'SYN', 'SNY', 'SNN', \n",
|
|
"# 'JYY', 'JYN', 'JNY', 'JNN'\n",
|
|
"cenario = 'SYN'\n",
|
|
"\n",
|
|
"win_sz = 5\n",
|
|
"stride_sz = 1\n",
|
|
"\n",
|
|
"epoch = 50\n",
|
|
"\n",
|
|
"# divisor for neuron count step downs (hard to describe), e.g. dense_step = 3: layer1=900, layer2 = 300, layer3 = 100, layer4 = 33...\n",
|
|
"dense_steps = 3\n",
|
|
"# amount of dense/dropout layers\n",
|
|
"layer_count = 3\n",
|
|
"# how much to drop\n",
|
|
"drop_count = 0.1"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "fddb1e58",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Helper Functions"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "8d1865b9",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from matplotlib import pyplot as plt\n",
|
|
"\n",
|
|
"def pplot(dd):\n",
|
|
" x = dd.shape[0]\n",
|
|
" fix = int(x/3)+1\n",
|
|
" fiy = 3\n",
|
|
" fig, axs = plt.subplots(fix, fiy, figsize=(3*fiy, 9*fix))\n",
|
|
" \n",
|
|
" for i in range(x):\n",
|
|
" axs[int(i/3)][i%3].plot(dd[i])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "09693b32",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Loading Data"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "ccbd870e",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from glob import glob\n",
|
|
"from tqdm import tqdm\n",
|
|
"\n",
|
|
"def dl_from_blob(filename, user_filter=None):\n",
|
|
" \n",
|
|
" dic_data = []\n",
|
|
" \n",
|
|
" for p in tqdm(glob(glob_path)):\n",
|
|
" path = p\n",
|
|
" filename = path.split('/')[-1].split('.')[0]\n",
|
|
" splitname = filename.split('_')\n",
|
|
" user = int(splitname[0][1:])\n",
|
|
" if (user_filter):\n",
|
|
" if (user != user_filter):\n",
|
|
" continue\n",
|
|
" scenario = splitname[1][len('Scenario'):]\n",
|
|
" heightnorm = splitname[2][len('HeightNormalization'):] == 'True'\n",
|
|
" armnorm = splitname[3][len('ArmNormalization'):] == 'True'\n",
|
|
" rep = int(splitname[4][len('Repetition'):])\n",
|
|
" session = int(splitname[5][len('Session'):])\n",
|
|
" data = pd.read_csv(path)\n",
|
|
" dic_data.append(\n",
|
|
" {\n",
|
|
" 'filename': path,\n",
|
|
" 'user': user,\n",
|
|
" 'scenario': scenario,\n",
|
|
" 'heightnorm': heightnorm,\n",
|
|
" 'armnorm': armnorm,\n",
|
|
" 'rep': rep,\n",
|
|
" 'session': session,\n",
|
|
" 'data': data \n",
|
|
" }\n",
|
|
" )\n",
|
|
" return dic_data"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "f0f4201c",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import pickle\n",
|
|
"\n",
|
|
"def save_pickle(f, structure):\n",
|
|
" _p = open(f, 'wb')\n",
|
|
" pickle.dump(structure, _p)\n",
|
|
" _p.close()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "83695ce1",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def load_pickles(f) -> list:\n",
|
|
" _p = open(pickle_file, 'rb')\n",
|
|
" _d = pickle.load(_p)\n",
|
|
" _p.close()\n",
|
|
" \n",
|
|
" return _d"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "92216c47",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Loading data...\n",
|
|
"../data.pickle found...\n",
|
|
"768\n",
|
|
"CPU times: user 614 ms, sys: 2.53 s, total: 3.14 s\n",
|
|
"Wall time: 3.14 s\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"%%time\n",
|
|
"\n",
|
|
"def load_data() -> list:\n",
|
|
" if os.path.isfile(pickle_file):\n",
|
|
" print(f'{pickle_file} found...')\n",
|
|
" return load_pickles(pickle_file)\n",
|
|
" print(f'Didn\\'t find {pickle_file}...')\n",
|
|
" all_data = dl_from_blob(glob_path)\n",
|
|
" print(f'Creating {pickle_file}...')\n",
|
|
" save_pickle(pickle_file, all_data)\n",
|
|
" return all_data\n",
|
|
"\n",
|
|
"print(\"Loading data...\")\n",
|
|
"dic_data = load_data()\n",
|
|
"print(len(dic_data))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"id": "6f337a51",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"CPU times: user 97 µs, sys: 302 µs, total: 399 µs\n",
|
|
"Wall time: 402 µs\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"%%time\n",
|
|
"\n",
|
|
"# Categorized Data\n",
|
|
"cdata = dict() \n",
|
|
"# Sorting, HeightNorm, ArmNorm\n",
|
|
"cdata['SYY'] = list() \n",
|
|
"cdata['SYN'] = list() \n",
|
|
"cdata['SNY'] = list() \n",
|
|
"cdata['SNN'] = list() \n",
|
|
"\n",
|
|
"# Jenga, HeightNorm, ArmNorm\n",
|
|
"cdata['JYY'] = list() \n",
|
|
"cdata['JYN'] = list() \n",
|
|
"cdata['JNY'] = list() \n",
|
|
"cdata['JNN'] = list() \n",
|
|
"\n",
|
|
"for d in dic_data:\n",
|
|
" if d['scenario'] == 'Sorting':\n",
|
|
" if d['heightnorm']:\n",
|
|
" if d['armnorm']:\n",
|
|
" cdata['SYY'].append(d)\n",
|
|
" else:\n",
|
|
" cdata['SYN'].append(d)\n",
|
|
" else:\n",
|
|
" if d['armnorm']:\n",
|
|
" cdata['SNY'].append(d)\n",
|
|
" else:\n",
|
|
" cdata['SNN'].append(d)\n",
|
|
" elif d['scenario'] == 'Jenga':\n",
|
|
" if d['heightnorm']:\n",
|
|
" if d['armnorm']:\n",
|
|
" cdata['JYY'].append(d)\n",
|
|
" else:\n",
|
|
" cdata['JYN'].append(d)\n",
|
|
" else:\n",
|
|
" if d['armnorm']:\n",
|
|
" cdata['JNY'].append(d)\n",
|
|
" else:\n",
|
|
" cdata['JNN'].append(d)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "b9500a64",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Preprocessing"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"id": "c7c9d655",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def drop(entry, data=True) -> pd.DataFrame:\n",
|
|
" droptable = ['participantID', 'FrameID', 'Scenario', 'HeightNormalization', 'ArmNormalization', 'Repetition', 'Session', 'Unnamed: 0']\n",
|
|
" if data:\n",
|
|
" centry = pickle.loads(pickle.dumps(entry['data']))\n",
|
|
" else:\n",
|
|
" centry = pickle.loads(pickle.dumps(entry))\n",
|
|
"\n",
|
|
" return centry.drop(droptable, axis=1)\n",
|
|
" \n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"id": "f5b437a2",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import numpy as np\n",
|
|
"left_Hand_ident='left'\n",
|
|
"right_Hand_ident='right'\n",
|
|
"\n",
|
|
"def rem_low_acc(entry, data=True) -> pd.DataFrame:\n",
|
|
" if data:\n",
|
|
" centry = pickle.loads(pickle.dumps(entry['data']))\n",
|
|
" else:\n",
|
|
" centry = pickle.loads(pickle.dumps(entry))\n",
|
|
" \n",
|
|
" centry['LeftHandTrackingAccuracy'] = (centry['LeftHandTrackingAccuracy'] == 'High') * 1.0\n",
|
|
" centry['RightHandTrackingAccuracy'] = (centry['RightHandTrackingAccuracy'] == 'High') * 1.0\n",
|
|
" \n",
|
|
" left_Hand_cols = [c for c in centry if left_Hand_ident in c.lower() and c != 'LeftHandTrackingAccuracy']\n",
|
|
" right_Hand_cols = [c for c in centry if right_Hand_ident in c.lower() and c != 'RightHandTrackingAccuracy']\n",
|
|
" \n",
|
|
" centry.loc[centry['LeftHandTrackingAccuracy'] == 0.0, left_Hand_cols] = np.nan\n",
|
|
" centry.loc[centry['RightHandTrackingAccuracy'] == 0.0, right_Hand_cols] = np.nan\n",
|
|
"\n",
|
|
" return centry"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"id": "f126154b",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from tensorflow.keras.preprocessing.sequence import pad_sequences\n",
|
|
"\n",
|
|
"def pad(entry, data=True) -> pd.DataFrame:\n",
|
|
" if data:\n",
|
|
" centry = pickle.loads(pickle.dumps(entry['data']))\n",
|
|
" else:\n",
|
|
" centry = pickle.loads(pickle.dumps(entry))\n",
|
|
" \n",
|
|
" cols = centry.columns\n",
|
|
" pentry = pad_sequences(centry.T.to_numpy(),\n",
|
|
" maxlen=(int(centry.shape[0]/stride_sz)+1)*stride_sz,\n",
|
|
" dtype='float64',\n",
|
|
" padding='pre', \n",
|
|
" truncating='post',\n",
|
|
" value=np.nan\n",
|
|
" ) \n",
|
|
" pdentry = pd.DataFrame(pentry.T, columns=cols)\n",
|
|
" pdentry.loc[0] = [0 for _ in cols]\n",
|
|
" return pdentry"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"id": "d50c3391",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def interpol(entry, data=True) -> pd.DataFrame:\n",
|
|
" if data:\n",
|
|
" centry = pickle.loads(pickle.dumps(entry['data']))\n",
|
|
" else:\n",
|
|
" centry = pickle.loads(pickle.dumps(entry))\n",
|
|
" \n",
|
|
" return centry.interpolate(limit_direction='both')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"id": "60629469",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from tensorflow.keras.preprocessing import timeseries_dataset_from_array\n",
|
|
"\n",
|
|
"def slicing(entry, label, data=True):\n",
|
|
" if data:\n",
|
|
" centry = pickle.loads(pickle.dumps(entry['data']))\n",
|
|
" else:\n",
|
|
" centry = pickle.loads(pickle.dumps(entry))\n",
|
|
" \n",
|
|
" return timeseries_dataset_from_array(\n",
|
|
" data=centry, \n",
|
|
" targets=[label for _ in range(centry.shape[0])], \n",
|
|
" sequence_length=win_sz,\n",
|
|
" sequence_stride=stride_sz, \n",
|
|
" batch_size=8, \n",
|
|
" seed=177013\n",
|
|
" )"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"id": "4de2497d",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|██████████| 96/96 [00:05<00:00, 16.44it/s] \n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"acc_data = pd.DataFrame()\n",
|
|
"\n",
|
|
"for e in tqdm(cdata[cenario]):\n",
|
|
" acc_data = acc_data.append(e['data'], ignore_index=True)\n",
|
|
"\n",
|
|
"ddacc_data = rem_low_acc(drop(acc_data, False),False)\n",
|
|
"\n",
|
|
"eula = ddacc_data[[c for c in ddacc_data if 'euler' in c.lower()]]\n",
|
|
"posi = ddacc_data[[c for c in ddacc_data if 'pos' in c.lower()]]\n",
|
|
"eulamin = eula.min()\n",
|
|
"eulamax = eula.max()\n",
|
|
"eulamean = eula.mean()\n",
|
|
"eulastd = eula.std()\n",
|
|
"posimin = posi.min()\n",
|
|
"posimax = posi.max()\n",
|
|
"posimean = posi.mean()\n",
|
|
"posistd = posi.std()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 16,
|
|
"id": "7b167a3c",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def minmaxscaler(entry, minimum, maximum):\n",
|
|
" return (entry-minimum)/(maximum-minimum)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 17,
|
|
"id": "044a73e4",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"euler_ident = 'euler'\n",
|
|
"pos_ident = 'pos'\n",
|
|
"\n",
|
|
"def norm(entry, data=True) -> pd.DataFrame:\n",
|
|
" if data:\n",
|
|
" centry = pickle.loads(pickle.dumps(entry['data']))\n",
|
|
" else:\n",
|
|
" centry = pickle.loads(pickle.dumps(entry))\n",
|
|
" \n",
|
|
" euler_cols = [c for c in centry if euler_ident in c.lower()]\n",
|
|
" pos_cols = [c for c in centry if pos_ident in c.lower()]\n",
|
|
" \n",
|
|
" centry[euler_cols] = minmaxscaler(centry[euler_cols], eulamin, eulamax)\n",
|
|
" centry[pos_cols] = minmaxscaler(centry[pos_cols], posimin, posimax)\n",
|
|
" return centry"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 18,
|
|
"id": "7e7526e9",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def drop_acc(entry, data=True) -> pd.DataFrame:\n",
|
|
" droptable = ['LeftHandTrackingAccuracy', 'RightHandTrackingAccuracy']\n",
|
|
" if data:\n",
|
|
" centry = pickle.loads(pickle.dumps(entry['data']))\n",
|
|
" else:\n",
|
|
" centry = pickle.loads(pickle.dumps(entry))\n",
|
|
"\n",
|
|
" return centry.drop(droptable, axis=1)\n",
|
|
" \n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 19,
|
|
"id": "e9ace1d0",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|██████████| 96/96 [00:14<00:00, 6.79it/s]"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"CPU times: user 12.8 s, sys: 1.54 s, total: 14.3 s\n",
|
|
"Wall time: 14.1 s\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"%%time\n",
|
|
"\n",
|
|
"classes = 16 # dynamic\n",
|
|
"\n",
|
|
"def preproc(data):\n",
|
|
" res_list = list()\n",
|
|
" \n",
|
|
" for e in tqdm(data):\n",
|
|
" res_list.append(preproc_entry(e))\n",
|
|
" \n",
|
|
" return res_list\n",
|
|
" \n",
|
|
"def preproc_entry(entry, data = True):\n",
|
|
" entry2 = pickle.loads(pickle.dumps(entry))\n",
|
|
" entry2['data'] = drop(entry2, data)\n",
|
|
" \n",
|
|
" entry3 = pickle.loads(pickle.dumps(entry2))\n",
|
|
" entry3['data'] = rem_low_acc(entry3, data)\n",
|
|
" \n",
|
|
" entry1 = pickle.loads(pickle.dumps(entry3))\n",
|
|
" entry1['data'] = norm(entry1, data)\n",
|
|
" \n",
|
|
" entry8 = pickle.loads(pickle.dumps(entry1))\n",
|
|
" entry8['data'] = drop_acc(entry8, data)\n",
|
|
" \n",
|
|
"# entry5 = pickle.loads(pickle.dumps(entry4))\n",
|
|
"# entry5['data'] = pad(entry5, data)\n",
|
|
" \n",
|
|
"# entry6 = pickle.loads(pickle.dumps(entry8))\n",
|
|
"# entry6['data'] = interpol(entry6, data)\n",
|
|
" \n",
|
|
" entry7 = pickle.loads(pickle.dumps(entry8))\n",
|
|
" entry7['data'] = slicing(entry7, entry7['user'], data)\n",
|
|
" \n",
|
|
" return entry7\n",
|
|
"\n",
|
|
"pdata = preproc(cdata[cenario])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 20,
|
|
"id": "a067392e",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"[<matplotlib.lines.Line2D at 0x7f37f07b3fa0>]"
|
|
]
|
|
},
|
|
"execution_count": 20,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "\n",
|
|
"text/plain": [
|
|
"<Figure size 432x288 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"needs_background": "light"
|
|
},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"a = drop(cdata[cenario][0]['data'], False)\n",
|
|
"a['left_OVRHandPrefab_pos_X'].plot()\n",
|
|
"plt.plot((a['LeftHandTrackingAccuracy'] == 'High')*1.0)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 21,
|
|
"id": "b1567410",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"[<matplotlib.lines.Line2D at 0x7f37f0654970>]"
|
|
]
|
|
},
|
|
"execution_count": 21,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "\n",
|
|
"text/plain": [
|
|
"<Figure size 432x288 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"needs_background": "light"
|
|
},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"b = rem_low_acc(a, False)\n",
|
|
"b['left_OVRHandPrefab_pos_X'].plot()\n",
|
|
"plt.plot(b['LeftHandTrackingAccuracy'])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 22,
|
|
"id": "b04ce014",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"[<matplotlib.lines.Line2D at 0x7f37f05d9310>]"
|
|
]
|
|
},
|
|
"execution_count": 22,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "\n",
|
|
"text/plain": [
|
|
"<Figure size 432x288 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"needs_background": "light"
|
|
},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"c = norm(b, False)\n",
|
|
"c['left_OVRHandPrefab_pos_X'].plot()\n",
|
|
"plt.plot(c['LeftHandTrackingAccuracy'])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 23,
|
|
"id": "62e4eb23",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"[<matplotlib.lines.Line2D at 0x7f37f05524c0>]"
|
|
]
|
|
},
|
|
"execution_count": 23,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "\n",
|
|
"text/plain": [
|
|
"<Figure size 432x288 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"needs_background": "light"
|
|
},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"d = interpol(c, False)\n",
|
|
"d['left_OVRHandPrefab_pos_X'].plot()\n",
|
|
"plt.plot(d['LeftHandTrackingAccuracy'])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 24,
|
|
"id": "cc70c742",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"CPU times: user 123 µs, sys: 105 µs, total: 228 µs\n",
|
|
"Wall time: 238 µs\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(48, 48)"
|
|
]
|
|
},
|
|
"execution_count": 24,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"%%time\n",
|
|
"train = np.array([x['data'] for x in pdata if x['session'] == 1])\n",
|
|
"test = np.array([x['data'] for x in pdata if x['session'] == 2])\n",
|
|
"\n",
|
|
"len(train), len(test)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 25,
|
|
"id": "7b44ef39",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|██████████| 96/96 [00:36<00:00, 2.61it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"(57800, 5, 336) (57800,) (37106, 5, 336) (37106,)\n",
|
|
"CPU times: user 1min 48s, sys: 15.1 s, total: 2min 3s\n",
|
|
"Wall time: 37.1 s\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"%%time\n",
|
|
"\n",
|
|
"X_train = list()\n",
|
|
"y_train = list()\n",
|
|
"\n",
|
|
"X_test = list()\n",
|
|
"y_test = list()\n",
|
|
"\n",
|
|
"train = list()\n",
|
|
"test = list()\n",
|
|
"\n",
|
|
"for x in tqdm(pdata):\n",
|
|
" if x['session'] == 1:\n",
|
|
" train.append(\n",
|
|
" {\n",
|
|
" 'label': x['user'],\n",
|
|
" 'data': list()\n",
|
|
" })\n",
|
|
" for y in x['data'].unbatch().as_numpy_iterator():\n",
|
|
" if not np.isnan(y[0]).any():\n",
|
|
" X_train.append(y[0])\n",
|
|
" y_train.append(y[1])\n",
|
|
" \n",
|
|
" train[-1]['data'].append(y[0])\n",
|
|
" if len(train[-1]['data']) == 0:\n",
|
|
" del train[-1]\n",
|
|
" if x['session'] == 2:\n",
|
|
" test.append(\n",
|
|
" {\n",
|
|
" 'label': x['user'],\n",
|
|
" 'data': list()\n",
|
|
" })\n",
|
|
" for y in x['data'].unbatch().as_numpy_iterator():\n",
|
|
" if not np.isnan(y[0]).any():\n",
|
|
" X_test.append(y[0])\n",
|
|
" y_test.append(y[1])\n",
|
|
" \n",
|
|
" test[-1]['data'].append(y[0])\n",
|
|
" \n",
|
|
" if len(test[-1]['data']) == 0:\n",
|
|
" del test[-1]\n",
|
|
" \n",
|
|
"X_train = np.array(X_train)\n",
|
|
"y_train = np.array(y_train)\n",
|
|
"X_test = np.array(X_test)\n",
|
|
"y_test = np.array(y_test)\n",
|
|
"\n",
|
|
"print(X_train.shape, y_train.shape, X_test.shape, y_test.shape)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 26,
|
|
"id": "27f1c824",
|
|
"metadata": {
|
|
"jupyter": {
|
|
"source_hidden": true
|
|
},
|
|
"tags": []
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Key: 1: 1347\n",
|
|
"Key: 2: 1583\n",
|
|
"Key: 3: 8568\n",
|
|
"Key: 4: 3034\n",
|
|
"Key: 5: 1960\n",
|
|
"Key: 6: 3311\n",
|
|
"Key: 7: 3971\n",
|
|
"Key: 8: 1407\n",
|
|
"Key: 9: 1135\n",
|
|
"Key: 10: 7466\n",
|
|
"Key: 11: 6494\n",
|
|
"Key: 12: 1813\n",
|
|
"Key: 13: 3596\n",
|
|
"Key: 14: 3260\n",
|
|
"Key: 15: 2825\n",
|
|
"Key: 16: 6030\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"array([<AxesSubplot:ylabel='0'>], dtype=object)"
|
|
]
|
|
},
|
|
"execution_count": 26,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "\n",
|
|
"text/plain": [
|
|
"<Figure size 432x288 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"Xy_train = list(zip(X_train, y_train))\n",
|
|
"Xy_test = list(zip(X_test, y_test))\n",
|
|
"train_dict = {\"1\":[], \"2\":[],\"3\":[], \"4\":[], \"5\":[],\"6\":[], \"7\":[], \"8\":[],\"9\":[], \"10\":[], \"11\":[],\"12\":[], \"13\":[], \"14\":[], \"15\": [], \"16\": []}\n",
|
|
"\n",
|
|
"[train_dict[str(e[1])].append(e[0]) for e in Xy_train]\n",
|
|
"[print(f'Key: {k}: {len(v)}') for k, v in train_dict.items()]\n",
|
|
"pd.DataFrame.from_dict({k: len(v) for k, v in train_dict.items()}, orient='index').plot.pie(subplots=True, legend=False)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 27,
|
|
"id": "34e0af8d",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Key: 1: 790\n",
|
|
"Key: 2: 59\n",
|
|
"Key: 3: 4330\n",
|
|
"Key: 4: 0\n",
|
|
"Key: 5: 545\n",
|
|
"Key: 6: 348\n",
|
|
"Key: 7: 5245\n",
|
|
"Key: 8: 3558\n",
|
|
"Key: 9: 2565\n",
|
|
"Key: 10: 4163\n",
|
|
"Key: 11: 3654\n",
|
|
"Key: 12: 2868\n",
|
|
"Key: 13: 2130\n",
|
|
"Key: 14: 2360\n",
|
|
"Key: 15: 2390\n",
|
|
"Key: 16: 2101\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"array([<AxesSubplot:ylabel='0'>], dtype=object)"
|
|
]
|
|
},
|
|
"execution_count": 27,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "\n",
|
|
"text/plain": [
|
|
"<Figure size 432x288 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"Xy_test = list(zip(X_test, y_test))\n",
|
|
"test_dict = {\"1\":[], \"2\":[],\"3\":[], \"4\":[], \"5\":[],\"6\":[], \"7\":[], \"8\":[],\"9\":[], \"10\":[], \"11\":[],\"12\":[], \"13\":[], \"14\":[], \"15\": [], \"16\": []}\n",
|
|
"\n",
|
|
"[test_dict[str(e[1])].append(e[0]) for e in Xy_test]\n",
|
|
"[print(f'Key: {k}: {len(v)}') for k, v in test_dict.items()]\n",
|
|
"pd.DataFrame.from_dict({k: len(v) for k, v in test_dict.items()}, orient='index').plot.pie(subplots=True, legend=False)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 28,
|
|
"id": "4b6d7b97",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"CPU times: user 348 ms, sys: 22.5 ms, total: 371 ms\n",
|
|
"Wall time: 370 ms\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"%%time\n",
|
|
"\n",
|
|
"from sklearn.preprocessing import LabelBinarizer\n",
|
|
"\n",
|
|
"\n",
|
|
"lb = LabelBinarizer()\n",
|
|
"yy_train = lb.fit_transform(y_train)\n",
|
|
"yy_test = lb.transform(y_test)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 29,
|
|
"id": "90634662",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"for e in test:\n",
|
|
" e['label'] = lb.transform([e['label']])\n",
|
|
" e['data'] = np.array(e['data'])\n",
|
|
"\n",
|
|
" \n",
|
|
"for e in train:\n",
|
|
" e['label'] = lb.transform([e['label']])\n",
|
|
" e['data'] = np.array(e['data'])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 30,
|
|
"id": "02c58b6d",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"(57800, 5, 336)\n",
|
|
"(57800, 16)\n",
|
|
"(37106, 5, 336)\n",
|
|
"(37106, 16)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(X_train.shape)\n",
|
|
"print(yy_train.shape)\n",
|
|
"print(X_test.shape)\n",
|
|
"print(yy_test.shape)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "ff2da104",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Building Model"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 31,
|
|
"id": "a5ef1b0f",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import tensorflow as tf\n",
|
|
"from tensorflow.keras.regularizers import l2\n",
|
|
"from tensorflow.keras.models import Sequential\n",
|
|
"from tensorflow.keras.layers import Dense, Flatten, BatchNormalization, Dropout\n",
|
|
"from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau\n",
|
|
"from tensorflow.keras.optimizers import Adam\n",
|
|
"\n",
|
|
"def build_model(shape, classes):\n",
|
|
" model = Sequential()\n",
|
|
" \n",
|
|
" ncount = shape[0]*shape[1]\n",
|
|
" \n",
|
|
" model.add(Flatten(input_shape=shape, name='flatten'))\n",
|
|
" \n",
|
|
" model.add(Dropout(drop_count, name=f'dropout_{drop_count*100}'))\n",
|
|
" model.add(BatchNormalization(name='batchNorm'))\n",
|
|
" \n",
|
|
" for i in range(2,layer_count+2):\n",
|
|
" neurons = int(ncount/pow(dense_steps,i))\n",
|
|
" if neurons <= classes:\n",
|
|
" break\n",
|
|
" model.add(Dropout(drop_count*i, name=f'HiddenDropout_{drop_count*i*100:.0f}'))\n",
|
|
" model.add(Dense(neurons, activation='relu', \n",
|
|
" kernel_regularizer=l2(0.001), name=f'Hidden_{i}')\n",
|
|
" )\n",
|
|
" \n",
|
|
" model.add(Dense(classes, activation='softmax', name='Output'))\n",
|
|
" \n",
|
|
" model.compile(\n",
|
|
" optimizer=Adam(),\n",
|
|
" loss=\"categorical_crossentropy\", \n",
|
|
" metrics=[\"acc\"],\n",
|
|
" )\n",
|
|
" \n",
|
|
" return model"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 32,
|
|
"id": "ccbb5d69",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"checkpoint_file = './goat.weights'\n",
|
|
"\n",
|
|
"def train_model(X_train, y_train, X_test, y_test):\n",
|
|
" model = build_model(X_train[0].shape, 16)\n",
|
|
" \n",
|
|
" model.summary()\n",
|
|
" \n",
|
|
" # Create a callback that saves the model's weights\n",
|
|
" model_checkpoint = ModelCheckpoint(filepath=checkpoint_path, monitor='loss', \n",
|
|
"\t\t\tsave_best_only=True)\n",
|
|
" \n",
|
|
" reduce_lr = ReduceLROnPlateau(monitor='loss', factor=0.5, patience=5, min_lr=0.0001)\n",
|
|
"\n",
|
|
" callbacks = [model_checkpoint, reduce_lr]\n",
|
|
" \n",
|
|
" history = model.fit(X_train, \n",
|
|
" y_train,\n",
|
|
" epochs=epoch,\n",
|
|
" batch_size=32,\n",
|
|
" verbose=2,\n",
|
|
" validation_data=(X_test, y_test),\n",
|
|
" callbacks=callbacks\n",
|
|
" )\n",
|
|
" \n",
|
|
" model.load_weights(checkpoint_path)\n",
|
|
" return model, history\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 33,
|
|
"id": "2032334e",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Model: \"sequential\"\n",
|
|
"_________________________________________________________________\n",
|
|
"Layer (type) Output Shape Param # \n",
|
|
"=================================================================\n",
|
|
"flatten (Flatten) (None, 1680) 0 \n",
|
|
"_________________________________________________________________\n",
|
|
"dropout_10.0 (Dropout) (None, 1680) 0 \n",
|
|
"_________________________________________________________________\n",
|
|
"batchNorm (BatchNormalizatio (None, 1680) 6720 \n",
|
|
"_________________________________________________________________\n",
|
|
"HiddenDropout_20 (Dropout) (None, 1680) 0 \n",
|
|
"_________________________________________________________________\n",
|
|
"Hidden_2 (Dense) (None, 186) 312666 \n",
|
|
"_________________________________________________________________\n",
|
|
"HiddenDropout_30 (Dropout) (None, 186) 0 \n",
|
|
"_________________________________________________________________\n",
|
|
"Hidden_3 (Dense) (None, 62) 11594 \n",
|
|
"_________________________________________________________________\n",
|
|
"Output (Dense) (None, 16) 1008 \n",
|
|
"=================================================================\n",
|
|
"Total params: 331,988\n",
|
|
"Trainable params: 328,628\n",
|
|
"Non-trainable params: 3,360\n",
|
|
"_________________________________________________________________\n",
|
|
"Epoch 1/50\n",
|
|
"1807/1807 - 7s - loss: 0.9538 - acc: 0.8336 - val_loss: 3.6780 - val_acc: 0.3465\n",
|
|
"INFO:tensorflow:Assets written to: training_1/cp.ckpt/assets\n",
|
|
"Epoch 2/50\n",
|
|
"1807/1807 - 6s - loss: 0.5806 - acc: 0.9273 - val_loss: 3.6350 - val_acc: 0.3550\n",
|
|
"INFO:tensorflow:Assets written to: training_1/cp.ckpt/assets\n",
|
|
"Epoch 3/50\n",
|
|
"1807/1807 - 6s - loss: 0.5180 - acc: 0.9391 - val_loss: 3.9988 - val_acc: 0.3526\n",
|
|
"INFO:tensorflow:Assets written to: training_1/cp.ckpt/assets\n",
|
|
"Epoch 4/50\n",
|
|
"1807/1807 - 6s - loss: 0.4804 - acc: 0.9477 - val_loss: 3.9750 - val_acc: 0.3467\n",
|
|
"INFO:tensorflow:Assets written to: training_1/cp.ckpt/assets\n",
|
|
"Epoch 5/50\n",
|
|
"1807/1807 - 6s - loss: 0.4737 - acc: 0.9476 - val_loss: 4.1912 - val_acc: 0.3614\n",
|
|
"INFO:tensorflow:Assets written to: training_1/cp.ckpt/assets\n",
|
|
"Epoch 6/50\n",
|
|
"1807/1807 - 6s - loss: 0.4542 - acc: 0.9515 - val_loss: 3.9345 - val_acc: 0.3631\n",
|
|
"INFO:tensorflow:Assets written to: training_1/cp.ckpt/assets\n",
|
|
"Epoch 7/50\n",
|
|
"1807/1807 - 6s - loss: 0.4476 - acc: 0.9516 - val_loss: 3.8092 - val_acc: 0.3848\n",
|
|
"INFO:tensorflow:Assets written to: training_1/cp.ckpt/assets\n",
|
|
"Epoch 8/50\n",
|
|
"1807/1807 - 7s - loss: 0.4408 - acc: 0.9528 - val_loss: 3.8813 - val_acc: 0.3970\n",
|
|
"INFO:tensorflow:Assets written to: training_1/cp.ckpt/assets\n",
|
|
"Epoch 9/50\n",
|
|
"1807/1807 - 7s - loss: 0.4283 - acc: 0.9537 - val_loss: 4.0705 - val_acc: 0.3612\n",
|
|
"INFO:tensorflow:Assets written to: training_1/cp.ckpt/assets\n",
|
|
"Epoch 10/50\n",
|
|
"1807/1807 - 6s - loss: 0.4218 - acc: 0.9543 - val_loss: 4.2684 - val_acc: 0.3690\n",
|
|
"INFO:tensorflow:Assets written to: training_1/cp.ckpt/assets\n",
|
|
"Epoch 11/50\n",
|
|
"1807/1807 - 6s - loss: 0.4198 - acc: 0.9563 - val_loss: 4.1950 - val_acc: 0.3870\n",
|
|
"INFO:tensorflow:Assets written to: training_1/cp.ckpt/assets\n",
|
|
"Epoch 12/50\n",
|
|
"1807/1807 - 6s - loss: 0.4149 - acc: 0.9553 - val_loss: 4.5124 - val_acc: 0.3501\n",
|
|
"INFO:tensorflow:Assets written to: training_1/cp.ckpt/assets\n",
|
|
"Epoch 13/50\n",
|
|
"1807/1807 - 6s - loss: 0.4078 - acc: 0.9571 - val_loss: 4.2129 - val_acc: 0.3756\n",
|
|
"INFO:tensorflow:Assets written to: training_1/cp.ckpt/assets\n",
|
|
"Epoch 14/50\n",
|
|
"1807/1807 - 6s - loss: 0.4025 - acc: 0.9578 - val_loss: 4.4079 - val_acc: 0.3655\n",
|
|
"INFO:tensorflow:Assets written to: training_1/cp.ckpt/assets\n",
|
|
"Epoch 15/50\n",
|
|
"1807/1807 - 6s - loss: 0.4055 - acc: 0.9575 - val_loss: 4.1757 - val_acc: 0.3840\n",
|
|
"Epoch 16/50\n",
|
|
"1807/1807 - 6s - loss: 0.4043 - acc: 0.9566 - val_loss: 4.3999 - val_acc: 0.3444\n",
|
|
"Epoch 17/50\n",
|
|
"1807/1807 - 6s - loss: 0.4010 - acc: 0.9566 - val_loss: 4.3559 - val_acc: 0.3768\n",
|
|
"INFO:tensorflow:Assets written to: training_1/cp.ckpt/assets\n",
|
|
"Epoch 18/50\n",
|
|
"1807/1807 - 6s - loss: 0.3933 - acc: 0.9594 - val_loss: 4.2061 - val_acc: 0.3845\n",
|
|
"INFO:tensorflow:Assets written to: training_1/cp.ckpt/assets\n",
|
|
"Epoch 19/50\n",
|
|
"1807/1807 - 6s - loss: 0.3968 - acc: 0.9579 - val_loss: 4.3063 - val_acc: 0.3800\n",
|
|
"Epoch 20/50\n",
|
|
"1807/1807 - 6s - loss: 0.3919 - acc: 0.9581 - val_loss: 4.3823 - val_acc: 0.3755\n",
|
|
"INFO:tensorflow:Assets written to: training_1/cp.ckpt/assets\n",
|
|
"Epoch 21/50\n",
|
|
"1807/1807 - 6s - loss: 0.3901 - acc: 0.9586 - val_loss: 4.3927 - val_acc: 0.3830\n",
|
|
"INFO:tensorflow:Assets written to: training_1/cp.ckpt/assets\n",
|
|
"Epoch 22/50\n",
|
|
"1807/1807 - 6s - loss: 0.3915 - acc: 0.9584 - val_loss: 4.2102 - val_acc: 0.3625\n",
|
|
"Epoch 23/50\n",
|
|
"1807/1807 - 6s - loss: 0.3854 - acc: 0.9585 - val_loss: 4.0813 - val_acc: 0.3962\n",
|
|
"INFO:tensorflow:Assets written to: training_1/cp.ckpt/assets\n",
|
|
"Epoch 24/50\n",
|
|
"1807/1807 - 6s - loss: 0.3871 - acc: 0.9587 - val_loss: 4.3345 - val_acc: 0.3638\n",
|
|
"Epoch 25/50\n",
|
|
"1807/1807 - 6s - loss: 0.3815 - acc: 0.9594 - val_loss: 4.2709 - val_acc: 0.3893\n",
|
|
"INFO:tensorflow:Assets written to: training_1/cp.ckpt/assets\n",
|
|
"Epoch 26/50\n",
|
|
"1807/1807 - 6s - loss: 0.3767 - acc: 0.9606 - val_loss: 4.4886 - val_acc: 0.3754\n",
|
|
"INFO:tensorflow:Assets written to: training_1/cp.ckpt/assets\n",
|
|
"Epoch 27/50\n",
|
|
"1807/1807 - 6s - loss: 0.3803 - acc: 0.9604 - val_loss: 4.2969 - val_acc: 0.3838\n",
|
|
"Epoch 28/50\n",
|
|
"1807/1807 - 6s - loss: 0.3804 - acc: 0.9595 - val_loss: 4.5919 - val_acc: 0.3615\n",
|
|
"Epoch 29/50\n",
|
|
"1807/1807 - 6s - loss: 0.3783 - acc: 0.9598 - val_loss: 4.5624 - val_acc: 0.3457\n",
|
|
"Epoch 30/50\n",
|
|
"1807/1807 - 6s - loss: 0.3762 - acc: 0.9595 - val_loss: 4.2288 - val_acc: 0.3515\n",
|
|
"INFO:tensorflow:Assets written to: training_1/cp.ckpt/assets\n",
|
|
"Epoch 31/50\n",
|
|
"1807/1807 - 6s - loss: 0.3771 - acc: 0.9609 - val_loss: 4.1085 - val_acc: 0.3797\n",
|
|
"Epoch 32/50\n",
|
|
"1807/1807 - 6s - loss: 0.3698 - acc: 0.9619 - val_loss: 4.1579 - val_acc: 0.3847\n",
|
|
"INFO:tensorflow:Assets written to: training_1/cp.ckpt/assets\n",
|
|
"Epoch 33/50\n",
|
|
"1807/1807 - 6s - loss: 0.3763 - acc: 0.9601 - val_loss: 4.4391 - val_acc: 0.3858\n",
|
|
"Epoch 34/50\n",
|
|
"1807/1807 - 6s - loss: 0.3725 - acc: 0.9607 - val_loss: 4.1958 - val_acc: 0.3683\n",
|
|
"Epoch 35/50\n",
|
|
"1807/1807 - 6s - loss: 0.3709 - acc: 0.9604 - val_loss: 4.1139 - val_acc: 0.3646\n",
|
|
"Epoch 36/50\n",
|
|
"1807/1807 - 6s - loss: 0.3673 - acc: 0.9618 - val_loss: 4.2390 - val_acc: 0.3969\n",
|
|
"INFO:tensorflow:Assets written to: training_1/cp.ckpt/assets\n",
|
|
"Epoch 37/50\n",
|
|
"1807/1807 - 6s - loss: 0.3686 - acc: 0.9616 - val_loss: 4.3510 - val_acc: 0.3647\n",
|
|
"Epoch 38/50\n",
|
|
"1807/1807 - 6s - loss: 0.3633 - acc: 0.9623 - val_loss: 4.4228 - val_acc: 0.3852\n",
|
|
"INFO:tensorflow:Assets written to: training_1/cp.ckpt/assets\n",
|
|
"Epoch 39/50\n",
|
|
"1807/1807 - 6s - loss: 0.3719 - acc: 0.9609 - val_loss: 4.2608 - val_acc: 0.3747\n",
|
|
"Epoch 40/50\n",
|
|
"1807/1807 - 6s - loss: 0.3668 - acc: 0.9608 - val_loss: 4.3543 - val_acc: 0.3529\n",
|
|
"Epoch 41/50\n",
|
|
"1807/1807 - 6s - loss: 0.3612 - acc: 0.9620 - val_loss: 4.0086 - val_acc: 0.3695\n",
|
|
"INFO:tensorflow:Assets written to: training_1/cp.ckpt/assets\n",
|
|
"Epoch 42/50\n",
|
|
"1807/1807 - 6s - loss: 0.3680 - acc: 0.9613 - val_loss: 4.2026 - val_acc: 0.3929\n",
|
|
"Epoch 43/50\n",
|
|
"1807/1807 - 6s - loss: 0.3576 - acc: 0.9633 - val_loss: 4.3106 - val_acc: 0.3515\n",
|
|
"INFO:tensorflow:Assets written to: training_1/cp.ckpt/assets\n",
|
|
"Epoch 44/50\n",
|
|
"1807/1807 - 6s - loss: 0.3651 - acc: 0.9605 - val_loss: 4.2105 - val_acc: 0.3962\n",
|
|
"Epoch 45/50\n",
|
|
"1807/1807 - 6s - loss: 0.3610 - acc: 0.9618 - val_loss: 4.4887 - val_acc: 0.3635\n",
|
|
"Epoch 46/50\n",
|
|
"1807/1807 - 6s - loss: 0.3681 - acc: 0.9606 - val_loss: 4.1679 - val_acc: 0.3470\n",
|
|
"Epoch 47/50\n",
|
|
"1807/1807 - 6s - loss: 0.3635 - acc: 0.9610 - val_loss: 4.1243 - val_acc: 0.4065\n",
|
|
"Epoch 48/50\n",
|
|
"1807/1807 - 6s - loss: 0.3568 - acc: 0.9624 - val_loss: 4.1263 - val_acc: 0.3691\n",
|
|
"INFO:tensorflow:Assets written to: training_1/cp.ckpt/assets\n",
|
|
"Epoch 49/50\n",
|
|
"1807/1807 - 6s - loss: 0.3585 - acc: 0.9622 - val_loss: 3.9652 - val_acc: 0.4077\n",
|
|
"Epoch 50/50\n",
|
|
"1807/1807 - 6s - loss: 0.3595 - acc: 0.9624 - val_loss: 4.0055 - val_acc: 0.3855\n",
|
|
"CPU times: user 8min 47s, sys: 46.3 s, total: 9min 34s\n",
|
|
"Wall time: 5min 33s\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"%%time\n",
|
|
"if 'model' not in locals():\n",
|
|
" tf.keras.backend.clear_session()\n",
|
|
" model, history = train_model(np.array(X_train), np.array(yy_train), np.array(X_test), np.array(yy_test))\n",
|
|
"else:\n",
|
|
" print(\"Loaded weights...\")\n",
|
|
" model.load_weights(checkpoint_path)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "3125ff7f",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Eval"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 34,
|
|
"id": "03c2ed28",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def predict(model, entry):\n",
|
|
" p_dict = dict()\n",
|
|
" predictions = np.argmax(model.predict(entry), axis=-1)\n",
|
|
" for p in predictions:\n",
|
|
" if p in p_dict:\n",
|
|
" p_dict[p] += 1\n",
|
|
" else:\n",
|
|
" p_dict[p] = 1\n",
|
|
" prediction = max(p_dict, key=p_dict.get)\n",
|
|
" return prediction+1"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 35,
|
|
"id": "3a3df1c6",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"CPU times: user 2.86 s, sys: 310 ms, total: 3.17 s\n",
|
|
"Wall time: 2.36 s\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(43, 43)"
|
|
]
|
|
},
|
|
"execution_count": 35,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"%%time\n",
|
|
"\n",
|
|
"ltest = [lb.inverse_transform(e['label'])[0] for e in test]\n",
|
|
"ptest = [predict(model, e['data']) for e in test]\n",
|
|
"\n",
|
|
"len(ltest), len(ptest)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 36,
|
|
"id": "ac38f7f4",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"CPU times: user 3.85 s, sys: 350 ms, total: 4.2 s\n",
|
|
"Wall time: 2.93 s\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(47, 47)"
|
|
]
|
|
},
|
|
"execution_count": 36,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"%%time\n",
|
|
"\n",
|
|
"ltrain = [lb.inverse_transform(e['label'])[0] for e in train]\n",
|
|
"ptrain = [predict(model, e['data']) for e in train]\n",
|
|
"\n",
|
|
"\n",
|
|
"len(ltrain), len(ptrain)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 37,
|
|
"id": "626fa67e",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},\n",
|
|
" {1, 2, 3, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16})"
|
|
]
|
|
},
|
|
"execution_count": 37,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"set(ltrain), set(ltest)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 38,
|
|
"id": "3b643815",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/png": "\n",
|
|
"text/plain": [
|
|
"<Figure size 720x504 with 2 Axes>"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"needs_background": "light"
|
|
},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
" precision recall f1-score support\n",
|
|
"\n",
|
|
" 1 1.00 0.67 0.80 3\n",
|
|
" 2 0.00 0.00 0.00 1\n",
|
|
" 3 0.29 0.67 0.40 3\n",
|
|
" 4 0.00 0.00 0.00 0\n",
|
|
" 5 0.00 0.00 0.00 3\n",
|
|
" 6 1.00 0.33 0.50 3\n",
|
|
" 7 0.75 1.00 0.86 3\n",
|
|
" 8 0.00 0.00 0.00 3\n",
|
|
" 9 1.00 0.33 0.50 3\n",
|
|
" 10 0.00 0.00 0.00 3\n",
|
|
" 11 0.20 0.33 0.25 3\n",
|
|
" 12 0.60 1.00 0.75 3\n",
|
|
" 13 0.38 1.00 0.55 3\n",
|
|
" 14 0.00 0.00 0.00 3\n",
|
|
" 15 0.00 0.00 0.00 3\n",
|
|
" 16 0.75 1.00 0.86 3\n",
|
|
"\n",
|
|
" accuracy 0.44 43\n",
|
|
" macro avg 0.37 0.40 0.34 43\n",
|
|
"weighted avg 0.42 0.44 0.38 43\n",
|
|
"\n",
|
|
"CPU times: user 676 ms, sys: 176 ms, total: 852 ms\n",
|
|
"Wall time: 622 ms\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"%%time\n",
|
|
"\n",
|
|
"from sklearn.metrics import confusion_matrix\n",
|
|
"import seaborn as sn\n",
|
|
"\n",
|
|
"from sklearn.metrics import classification_report\n",
|
|
"\n",
|
|
"set_digits = set(ltrain)\n",
|
|
"\n",
|
|
"train_cm = confusion_matrix(ltrain, ptrain, labels=list(set_digits), normalize='true')\n",
|
|
"test_cm = confusion_matrix(ltest, ptest, labels=list(set_digits), normalize='true')\n",
|
|
"\n",
|
|
"df_cm = pd.DataFrame(test_cm, index=set_digits, columns=set_digits)\n",
|
|
"plt.figure(figsize = (10,7))\n",
|
|
"sn_plot = sn.heatmap(df_cm, annot=True, cmap=\"Greys\")\n",
|
|
"plt.ylabel(\"True Label\")\n",
|
|
"plt.xlabel(\"Predicted Label\")\n",
|
|
"plt.show()\n",
|
|
"\n",
|
|
"print(classification_report(ltest, ptest, zero_division=0))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 39,
|
|
"id": "645ca873",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/png": "\n",
|
|
"text/plain": [
|
|
"<Figure size 432x288 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"needs_background": "light"
|
|
},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "\n",
|
|
"text/plain": [
|
|
"<Figure size 432x288 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"needs_background": "light"
|
|
},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"def plot_keras_history(history, name='', acc='acc'):\n",
|
|
" \"\"\"Plots keras history.\"\"\"\n",
|
|
" import matplotlib.pyplot as plt\n",
|
|
"\n",
|
|
" training_acc = history.history[acc]\n",
|
|
" validation_acc = history.history['val_' + acc]\n",
|
|
" loss = history.history['loss']\n",
|
|
" val_loss = history.history['val_loss']\n",
|
|
"\n",
|
|
" epochs = range(len(training_acc))\n",
|
|
"\n",
|
|
" plt.ylim(0, 1)\n",
|
|
" plt.plot(epochs, training_acc, 'tab:blue', label='Training acc')\n",
|
|
" plt.plot(epochs, validation_acc, 'tab:orange', label='Validation acc')\n",
|
|
" plt.title('Training and validation accuracy ' + name)\n",
|
|
" plt.legend()\n",
|
|
"\n",
|
|
" plt.figure()\n",
|
|
"\n",
|
|
" plt.plot(epochs, loss, 'tab:green', label='Training loss')\n",
|
|
" plt.plot(epochs, val_loss, 'tab:red', label='Validation loss')\n",
|
|
" plt.title('Training and validation loss ' + name)\n",
|
|
" plt.legend()\n",
|
|
" plt.show()\n",
|
|
" plt.close()\n",
|
|
"plot_keras_history(history)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 40,
|
|
"id": "39e60eaa",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Scenario: SYN\n",
|
|
"Window Size: 5\n",
|
|
"Strides: 1\n",
|
|
"Epochs: 50\n",
|
|
"HiddenL Count: 3\n",
|
|
"Neuron Factor: 3\n",
|
|
"Drop Factor: 0.1\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(f'Scenario: {cenario}')\n",
|
|
"print(f'Window Size: {win_sz}')\n",
|
|
"print(f'Strides: {stride_sz}')\n",
|
|
"print(f'Epochs: {epoch}')\n",
|
|
"print(f'HiddenL Count: {layer_count}')\n",
|
|
"print(f'Neuron Factor: {dense_steps}')\n",
|
|
"print(f'Drop Factor: {drop_count}')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "7676ea0c",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.8.10"
|
|
},
|
|
"toc-showtags": false
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|