1462 lines
158 KiB
Plaintext
1462 lines
158 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "9c890798",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Change Scenario here.\n",
|
|
"\n",
|
|
"| | GameType | HeightNorm | ArmNorm |\n",
|
|
"|:---:|:--------:|:----------:|:-------:|\n",
|
|
"| SYY | Sorting | ✅ | ✅ |\n",
|
|
"| SYN | Sorting | ✅ | ❌ |\n",
|
|
"| SNY | Sorting | ❌ | ✅ |\n",
|
|
"| SNN | Sorting | ❌ | ❌ |\n",
|
|
"| JYY | Jenga | ✅ | ✅ |\n",
|
|
"| JYN | Jenga | ✅ | ❌ |\n",
|
|
"| JNY | Jenga | ❌ | ✅ |\n",
|
|
"| JNN | Jenga | ❌ | ❌ |\n",
|
|
"\n",
|
|
"Weights for the corresponding scenario are loaded automatically."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "1c9e114c",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Possibilities: 'SYY', 'SYN', 'SNY', 'SNN', \n",
|
|
"# 'JYY', 'JYN', 'JNY', 'JNN'\n",
|
|
"cenario = 'SYN'"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "a3c8b624",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Constants"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "5f120a31",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import os\n",
|
|
"\n",
|
|
"os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' # this is required\n",
|
|
"os.environ['CUDA_VISIBLE_DEVICES'] = '2' # set to '0' for GPU0, '1' for GPU1 or '2' for GPU2. Check \"gpustat\" in a terminal."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "3be386b5",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import pandas as pd\n",
|
|
"\n",
|
|
"glob_path = '/opt/iui-datarelease3-sose2021/*.csv'\n",
|
|
"\n",
|
|
"pickle_file = '../data.pickle'\n",
|
|
"\n",
|
|
"pd.set_option('display.float_format', lambda x: '%.2f' % x)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "375756bc",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Config"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "fe73e572",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"create_new = False\n",
|
|
"checkpoint_path = f\"training_{cenario}/cp.ckpt\"\n",
|
|
"checkpoint_dir = os.path.dirname(checkpoint_path)\n",
|
|
"\n",
|
|
"win_sz = 5\n",
|
|
"stride_sz = 1\n",
|
|
"\n",
|
|
"epoch = 50\n",
|
|
"\n",
|
|
"# divisor for neuron count step downs (hard to describe), e.g. dense_step = 3: layer1=900, layer2 = 300, layer3 = 100, layer4 = 33...\n",
|
|
"dense_steps = 3\n",
|
|
"# amount of dense/dropout layers\n",
|
|
"layer_count = 3\n",
|
|
"# how much to drop\n",
|
|
"drop_count = 0.1"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "0173497c",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Helper Functions"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "ef82a419",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from matplotlib import pyplot as plt\n",
|
|
"\n",
|
|
"def pplot(dd):\n",
|
|
" x = dd.shape[0]\n",
|
|
" fix = int(x/3)+1\n",
|
|
" fiy = 3\n",
|
|
" fig, axs = plt.subplots(fix, fiy, figsize=(3*fiy, 9*fix))\n",
|
|
" \n",
|
|
" for i in range(x):\n",
|
|
" axs[int(i/3)][i%3].plot(dd[i])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "556c7dde",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Loading Data"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "51195751",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from glob import glob\n",
|
|
"from tqdm import tqdm\n",
|
|
"\n",
|
|
"def dl_from_blob(filename, user_filter=None):\n",
|
|
" \n",
|
|
" dic_data = []\n",
|
|
" \n",
|
|
" for p in tqdm(glob(glob_path)):\n",
|
|
" path = p\n",
|
|
" filename = path.split('/')[-1].split('.')[0]\n",
|
|
" splitname = filename.split('_')\n",
|
|
" user = int(splitname[0][1:])\n",
|
|
" if (user_filter):\n",
|
|
" if (user != user_filter):\n",
|
|
" continue\n",
|
|
" scenario = splitname[1][len('Scenario'):]\n",
|
|
" heightnorm = splitname[2][len('HeightNormalization'):] == 'True'\n",
|
|
" armnorm = splitname[3][len('ArmNormalization'):] == 'True'\n",
|
|
" rep = int(splitname[4][len('Repetition'):])\n",
|
|
" session = int(splitname[5][len('Session'):])\n",
|
|
" data = pd.read_csv(path)\n",
|
|
" dic_data.append(\n",
|
|
" {\n",
|
|
" 'filename': path,\n",
|
|
" 'user': user,\n",
|
|
" 'scenario': scenario,\n",
|
|
" 'heightnorm': heightnorm,\n",
|
|
" 'armnorm': armnorm,\n",
|
|
" 'rep': rep,\n",
|
|
" 'session': session,\n",
|
|
" 'data': data \n",
|
|
" }\n",
|
|
" )\n",
|
|
" return dic_data"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "457bc16f",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import pickle\n",
|
|
"\n",
|
|
"def save_pickle(f, structure):\n",
|
|
" _p = open(f, 'wb')\n",
|
|
" pickle.dump(structure, _p)\n",
|
|
" _p.close()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "9482bc78",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def load_pickles(f) -> list:\n",
|
|
" _p = open(pickle_file, 'rb')\n",
|
|
" _d = pickle.load(_p)\n",
|
|
" _p.close()\n",
|
|
" \n",
|
|
" return _d"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"id": "230fb3b8",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Loading data...\n",
|
|
"../data.pickle found...\n",
|
|
"768\n",
|
|
"CPU times: user 572 ms, sys: 2.57 s, total: 3.14 s\n",
|
|
"Wall time: 3.14 s\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"%%time\n",
|
|
"\n",
|
|
"def load_data() -> list:\n",
|
|
" if os.path.isfile(pickle_file):\n",
|
|
" print(f'{pickle_file} found...')\n",
|
|
" return load_pickles(pickle_file)\n",
|
|
" print(f'Didn\\'t find {pickle_file}...')\n",
|
|
" all_data = dl_from_blob(glob_path)\n",
|
|
" print(f'Creating {pickle_file}...')\n",
|
|
" save_pickle(pickle_file, all_data)\n",
|
|
" return all_data\n",
|
|
"\n",
|
|
"print(\"Loading data...\")\n",
|
|
"dic_data = load_data()\n",
|
|
"print(len(dic_data))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"id": "effa570d",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"CPU times: user 393 µs, sys: 0 ns, total: 393 µs\n",
|
|
"Wall time: 397 µs\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"%%time\n",
|
|
"\n",
|
|
"# Categorized Data\n",
|
|
"cdata = dict() \n",
|
|
"# Sorting, HeightNorm, ArmNorm\n",
|
|
"cdata['SYY'] = list() \n",
|
|
"cdata['SYN'] = list() \n",
|
|
"cdata['SNY'] = list() \n",
|
|
"cdata['SNN'] = list() \n",
|
|
"\n",
|
|
"# Jenga, HeightNorm, ArmNorm\n",
|
|
"cdata['JYY'] = list() \n",
|
|
"cdata['JYN'] = list() \n",
|
|
"cdata['JNY'] = list() \n",
|
|
"cdata['JNN'] = list() \n",
|
|
"\n",
|
|
"for d in dic_data:\n",
|
|
" if d['scenario'] == 'Sorting':\n",
|
|
" if d['heightnorm']:\n",
|
|
" if d['armnorm']:\n",
|
|
" cdata['SYY'].append(d)\n",
|
|
" else:\n",
|
|
" cdata['SYN'].append(d)\n",
|
|
" else:\n",
|
|
" if d['armnorm']:\n",
|
|
" cdata['SNY'].append(d)\n",
|
|
" else:\n",
|
|
" cdata['SNN'].append(d)\n",
|
|
" elif d['scenario'] == 'Jenga':\n",
|
|
" if d['heightnorm']:\n",
|
|
" if d['armnorm']:\n",
|
|
" cdata['JYY'].append(d)\n",
|
|
" else:\n",
|
|
" cdata['JYN'].append(d)\n",
|
|
" else:\n",
|
|
" if d['armnorm']:\n",
|
|
" cdata['JNY'].append(d)\n",
|
|
" else:\n",
|
|
" cdata['JNN'].append(d)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "2ad62c63",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Preprocessing"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"id": "55619c6e",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def drop(entry, data=True) -> pd.DataFrame:\n",
|
|
" droptable = ['participantID', 'FrameID', 'Scenario', 'HeightNormalization', 'ArmNormalization', 'Repetition', 'Session', 'Unnamed: 0']\n",
|
|
" if data:\n",
|
|
" centry = pickle.loads(pickle.dumps(entry['data']))\n",
|
|
" else:\n",
|
|
" centry = pickle.loads(pickle.dumps(entry))\n",
|
|
"\n",
|
|
" return centry.drop(droptable, axis=1)\n",
|
|
" \n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"id": "d7be5822",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import numpy as np\n",
|
|
"left_Hand_ident='left'\n",
|
|
"right_Hand_ident='right'\n",
|
|
"\n",
|
|
"def rem_low_acc(entry, data=True) -> pd.DataFrame:\n",
|
|
" if data:\n",
|
|
" centry = pickle.loads(pickle.dumps(entry['data']))\n",
|
|
" else:\n",
|
|
" centry = pickle.loads(pickle.dumps(entry))\n",
|
|
" \n",
|
|
" centry['LeftHandTrackingAccuracy'] = (centry['LeftHandTrackingAccuracy'] == 'High') * 1.0\n",
|
|
" centry['RightHandTrackingAccuracy'] = (centry['RightHandTrackingAccuracy'] == 'High') * 1.0\n",
|
|
" \n",
|
|
" left_Hand_cols = [c for c in centry if left_Hand_ident in c.lower() and c != 'LeftHandTrackingAccuracy']\n",
|
|
" right_Hand_cols = [c for c in centry if right_Hand_ident in c.lower() and c != 'RightHandTrackingAccuracy']\n",
|
|
" \n",
|
|
" centry.loc[centry['LeftHandTrackingAccuracy'] == 0.0, left_Hand_cols] = np.nan\n",
|
|
" centry.loc[centry['RightHandTrackingAccuracy'] == 0.0, right_Hand_cols] = np.nan\n",
|
|
"\n",
|
|
" return centry"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"id": "da77d0a9",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from tensorflow.keras.preprocessing.sequence import pad_sequences\n",
|
|
"\n",
|
|
"def pad(entry, data=True) -> pd.DataFrame:\n",
|
|
" if data:\n",
|
|
" centry = pickle.loads(pickle.dumps(entry['data']))\n",
|
|
" else:\n",
|
|
" centry = pickle.loads(pickle.dumps(entry))\n",
|
|
" \n",
|
|
" cols = centry.columns\n",
|
|
" pentry = pad_sequences(centry.T.to_numpy(),\n",
|
|
" maxlen=(int(centry.shape[0]/stride_sz)+1)*stride_sz,\n",
|
|
" dtype='float64',\n",
|
|
" padding='pre', \n",
|
|
" truncating='post',\n",
|
|
" value=np.nan\n",
|
|
" ) \n",
|
|
" pdentry = pd.DataFrame(pentry.T, columns=cols)\n",
|
|
" pdentry.loc[0] = [0 for _ in cols]\n",
|
|
" return pdentry"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"id": "ac13ea7d",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def interpol(entry, data=True) -> pd.DataFrame:\n",
|
|
" if data:\n",
|
|
" centry = pickle.loads(pickle.dumps(entry['data']))\n",
|
|
" else:\n",
|
|
" centry = pickle.loads(pickle.dumps(entry))\n",
|
|
" \n",
|
|
" return centry.interpolate(limit_direction='both')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"id": "2f6b0535",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from tensorflow.keras.preprocessing import timeseries_dataset_from_array\n",
|
|
"\n",
|
|
"def slicing(entry, label, data=True):\n",
|
|
" if data:\n",
|
|
" centry = pickle.loads(pickle.dumps(entry['data']))\n",
|
|
" else:\n",
|
|
" centry = pickle.loads(pickle.dumps(entry))\n",
|
|
" \n",
|
|
" return timeseries_dataset_from_array(\n",
|
|
" data=centry, \n",
|
|
" targets=[label for _ in range(centry.shape[0])], \n",
|
|
" sequence_length=win_sz,\n",
|
|
" sequence_stride=stride_sz, \n",
|
|
" batch_size=8, \n",
|
|
" seed=177013\n",
|
|
" )"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 16,
|
|
"id": "be9a3bee",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|██████████| 96/96 [00:05<00:00, 16.33it/s]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"acc_data = pd.DataFrame()\n",
|
|
"\n",
|
|
"for e in tqdm(cdata[cenario]):\n",
|
|
" acc_data = acc_data.append(e['data'], ignore_index=True)\n",
|
|
"\n",
|
|
"ddacc_data = rem_low_acc(drop(acc_data, False),False)\n",
|
|
"\n",
|
|
"eula = ddacc_data[[c for c in ddacc_data if 'euler' in c.lower()]]\n",
|
|
"posi = ddacc_data[[c for c in ddacc_data if 'pos' in c.lower()]]\n",
|
|
"eulamin = eula.min()\n",
|
|
"eulamax = eula.max()\n",
|
|
"eulamean = eula.mean()\n",
|
|
"eulastd = eula.std()\n",
|
|
"posimin = posi.min()\n",
|
|
"posimax = posi.max()\n",
|
|
"posimean = posi.mean()\n",
|
|
"posistd = posi.std()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 17,
|
|
"id": "bf571416",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def minmaxscaler(entry, minimum, maximum):\n",
|
|
" return (entry-minimum)/(maximum-minimum)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 18,
|
|
"id": "dc70c74b",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"euler_ident = 'euler'\n",
|
|
"pos_ident = 'pos'\n",
|
|
"\n",
|
|
"def norm(entry, data=True) -> pd.DataFrame:\n",
|
|
" if data:\n",
|
|
" centry = pickle.loads(pickle.dumps(entry['data']))\n",
|
|
" else:\n",
|
|
" centry = pickle.loads(pickle.dumps(entry))\n",
|
|
" \n",
|
|
" euler_cols = [c for c in centry if euler_ident in c.lower()]\n",
|
|
" pos_cols = [c for c in centry if pos_ident in c.lower()]\n",
|
|
" \n",
|
|
" centry[euler_cols] = minmaxscaler(centry[euler_cols], eulamin, eulamax)\n",
|
|
" centry[pos_cols] = minmaxscaler(centry[pos_cols], posimin, posimax)\n",
|
|
" return centry"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 19,
|
|
"id": "45877405",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def drop_acc(entry, data=True) -> pd.DataFrame:\n",
|
|
" droptable = ['LeftHandTrackingAccuracy', 'RightHandTrackingAccuracy']\n",
|
|
" if data:\n",
|
|
" centry = pickle.loads(pickle.dumps(entry['data']))\n",
|
|
" else:\n",
|
|
" centry = pickle.loads(pickle.dumps(entry))\n",
|
|
"\n",
|
|
" return centry.drop(droptable, axis=1)\n",
|
|
" \n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 20,
|
|
"id": "d7a30d7b",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|██████████| 96/96 [00:14<00:00, 6.67it/s]"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"CPU times: user 13 s, sys: 1.93 s, total: 14.9 s\n",
|
|
"Wall time: 14.4 s\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"%%time\n",
|
|
"\n",
|
|
"classes = 16 # dynamic\n",
|
|
"\n",
|
|
"def preproc(data):\n",
|
|
" res_list = list()\n",
|
|
" \n",
|
|
" for e in tqdm(data):\n",
|
|
" res_list.append(preproc_entry(e))\n",
|
|
" \n",
|
|
" return res_list\n",
|
|
" \n",
|
|
"def preproc_entry(entry, data = True):\n",
|
|
" entry2 = pickle.loads(pickle.dumps(entry))\n",
|
|
" entry2['data'] = drop(entry2, data)\n",
|
|
" \n",
|
|
" entry3 = pickle.loads(pickle.dumps(entry2))\n",
|
|
" entry3['data'] = rem_low_acc(entry3, data)\n",
|
|
" \n",
|
|
" entry1 = pickle.loads(pickle.dumps(entry3))\n",
|
|
" entry1['data'] = norm(entry1, data)\n",
|
|
" \n",
|
|
" entry8 = pickle.loads(pickle.dumps(entry1))\n",
|
|
" entry8['data'] = drop_acc(entry8, data)\n",
|
|
" \n",
|
|
"# entry5 = pickle.loads(pickle.dumps(entry4))\n",
|
|
"# entry5['data'] = pad(entry5, data)\n",
|
|
" \n",
|
|
"# entry6 = pickle.loads(pickle.dumps(entry8))\n",
|
|
"# entry6['data'] = interpol(entry6, data)\n",
|
|
" \n",
|
|
" entry7 = pickle.loads(pickle.dumps(entry8))\n",
|
|
" entry7['data'] = slicing(entry7, entry7['user'], data)\n",
|
|
" \n",
|
|
" return entry7\n",
|
|
"\n",
|
|
"pdata = preproc(cdata[cenario])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 21,
|
|
"id": "c88f53a4",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"[<matplotlib.lines.Line2D at 0x7fb1087bc370>]"
|
|
]
|
|
},
|
|
"execution_count": 21,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "\n",
|
|
"text/plain": [
|
|
"<Figure size 432x288 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"needs_background": "light"
|
|
},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"a = drop(cdata[cenario][0]['data'], False)\n",
|
|
"a['left_OVRHandPrefab_pos_X'].plot()\n",
|
|
"plt.plot((a['LeftHandTrackingAccuracy'] == 'High')*1.0)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 22,
|
|
"id": "b9518087",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"[<matplotlib.lines.Line2D at 0x7fb1086d6820>]"
|
|
]
|
|
},
|
|
"execution_count": 22,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "\n",
|
|
"text/plain": [
|
|
"<Figure size 432x288 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"needs_background": "light"
|
|
},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"b = rem_low_acc(a, False)\n",
|
|
"b['left_OVRHandPrefab_pos_X'].plot()\n",
|
|
"plt.plot(b['LeftHandTrackingAccuracy'])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 23,
|
|
"id": "09687aab",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"[<matplotlib.lines.Line2D at 0x7fb108669be0>]"
|
|
]
|
|
},
|
|
"execution_count": 23,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "\n",
|
|
"text/plain": [
|
|
"<Figure size 432x288 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"needs_background": "light"
|
|
},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"c = norm(b, False)\n",
|
|
"c['left_OVRHandPrefab_pos_X'].plot()\n",
|
|
"plt.plot(c['LeftHandTrackingAccuracy'])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 24,
|
|
"id": "9ae9b71e",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"[<matplotlib.lines.Line2D at 0x7fb1085d5700>]"
|
|
]
|
|
},
|
|
"execution_count": 24,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "\n",
|
|
"text/plain": [
|
|
"<Figure size 432x288 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"needs_background": "light"
|
|
},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"d = interpol(c, False)\n",
|
|
"d['left_OVRHandPrefab_pos_X'].plot()\n",
|
|
"plt.plot(d['LeftHandTrackingAccuracy'])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 25,
|
|
"id": "29e9063e",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"CPU times: user 234 µs, sys: 0 ns, total: 234 µs\n",
|
|
"Wall time: 252 µs\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(48, 48)"
|
|
]
|
|
},
|
|
"execution_count": 25,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"%%time\n",
|
|
"train = np.array([x['data'] for x in pdata if x['session'] == 1])\n",
|
|
"test = np.array([x['data'] for x in pdata if x['session'] == 2])\n",
|
|
"\n",
|
|
"len(train), len(test)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 26,
|
|
"id": "a52352aa",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|██████████| 96/96 [00:36<00:00, 2.62it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"(57800, 5, 336) (57800,) (37106, 5, 336) (37106,)\n",
|
|
"CPU times: user 1min 48s, sys: 14.9 s, total: 2min 3s\n",
|
|
"Wall time: 37 s\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"%%time\n",
|
|
"\n",
|
|
"X_train = list()\n",
|
|
"y_train = list()\n",
|
|
"\n",
|
|
"X_test = list()\n",
|
|
"y_test = list()\n",
|
|
"\n",
|
|
"train = list()\n",
|
|
"test = list()\n",
|
|
"\n",
|
|
"for x in tqdm(pdata):\n",
|
|
" if x['session'] == 1:\n",
|
|
" train.append(\n",
|
|
" {\n",
|
|
" 'label': x['user'],\n",
|
|
" 'data': list()\n",
|
|
" })\n",
|
|
" for y in x['data'].unbatch().as_numpy_iterator():\n",
|
|
" if not np.isnan(y[0]).any():\n",
|
|
" X_train.append(y[0])\n",
|
|
" y_train.append(y[1])\n",
|
|
" \n",
|
|
" train[-1]['data'].append(y[0])\n",
|
|
" if len(train[-1]['data']) == 0:\n",
|
|
" del train[-1]\n",
|
|
" if x['session'] == 2:\n",
|
|
" test.append(\n",
|
|
" {\n",
|
|
" 'label': x['user'],\n",
|
|
" 'data': list()\n",
|
|
" })\n",
|
|
" for y in x['data'].unbatch().as_numpy_iterator():\n",
|
|
" if not np.isnan(y[0]).any():\n",
|
|
" X_test.append(y[0])\n",
|
|
" y_test.append(y[1])\n",
|
|
" \n",
|
|
" test[-1]['data'].append(y[0])\n",
|
|
" \n",
|
|
" if len(test[-1]['data']) == 0:\n",
|
|
" del test[-1]\n",
|
|
" \n",
|
|
"X_train = np.array(X_train)\n",
|
|
"y_train = np.array(y_train)\n",
|
|
"X_test = np.array(X_test)\n",
|
|
"y_test = np.array(y_test)\n",
|
|
"\n",
|
|
"print(X_train.shape, y_train.shape, X_test.shape, y_test.shape)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 27,
|
|
"id": "8c85c181",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Key: 1: 1347\n",
|
|
"Key: 2: 1583\n",
|
|
"Key: 3: 8568\n",
|
|
"Key: 4: 3034\n",
|
|
"Key: 5: 1960\n",
|
|
"Key: 6: 3311\n",
|
|
"Key: 7: 3971\n",
|
|
"Key: 8: 1407\n",
|
|
"Key: 9: 1135\n",
|
|
"Key: 10: 7466\n",
|
|
"Key: 11: 6494\n",
|
|
"Key: 12: 1813\n",
|
|
"Key: 13: 3596\n",
|
|
"Key: 14: 3260\n",
|
|
"Key: 15: 2825\n",
|
|
"Key: 16: 6030\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"array([<AxesSubplot:ylabel='0'>], dtype=object)"
|
|
]
|
|
},
|
|
"execution_count": 27,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "\n",
|
|
"text/plain": [
|
|
"<Figure size 432x288 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"Xy_train = list(zip(X_train, y_train))\n",
|
|
"Xy_test = list(zip(X_test, y_test))\n",
|
|
"train_dict = {\"1\":[], \"2\":[],\"3\":[], \"4\":[], \"5\":[],\"6\":[], \"7\":[], \"8\":[],\"9\":[], \"10\":[], \"11\":[],\"12\":[], \"13\":[], \"14\":[], \"15\": [], \"16\": []}\n",
|
|
"\n",
|
|
"[train_dict[str(e[1])].append(e[0]) for e in Xy_train]\n",
|
|
"[print(f'Key: {k}: {len(v)}') for k, v in train_dict.items()]\n",
|
|
"pd.DataFrame.from_dict({k: len(v) for k, v in train_dict.items()}, orient='index').plot.pie(subplots=True, legend=False)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 28,
|
|
"id": "92991de2",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Key: 1: 790\n",
|
|
"Key: 2: 59\n",
|
|
"Key: 3: 4330\n",
|
|
"Key: 4: 0\n",
|
|
"Key: 5: 545\n",
|
|
"Key: 6: 348\n",
|
|
"Key: 7: 5245\n",
|
|
"Key: 8: 3558\n",
|
|
"Key: 9: 2565\n",
|
|
"Key: 10: 4163\n",
|
|
"Key: 11: 3654\n",
|
|
"Key: 12: 2868\n",
|
|
"Key: 13: 2130\n",
|
|
"Key: 14: 2360\n",
|
|
"Key: 15: 2390\n",
|
|
"Key: 16: 2101\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"array([<AxesSubplot:ylabel='0'>], dtype=object)"
|
|
]
|
|
},
|
|
"execution_count": 28,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "\n",
|
|
"text/plain": [
|
|
"<Figure size 432x288 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"Xy_test = list(zip(X_test, y_test))\n",
|
|
"test_dict = {\"1\":[], \"2\":[],\"3\":[], \"4\":[], \"5\":[],\"6\":[], \"7\":[], \"8\":[],\"9\":[], \"10\":[], \"11\":[],\"12\":[], \"13\":[], \"14\":[], \"15\": [], \"16\": []}\n",
|
|
"\n",
|
|
"[test_dict[str(e[1])].append(e[0]) for e in Xy_test]\n",
|
|
"[print(f'Key: {k}: {len(v)}') for k, v in test_dict.items()]\n",
|
|
"pd.DataFrame.from_dict({k: len(v) for k, v in test_dict.items()}, orient='index').plot.pie(subplots=True, legend=False)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 29,
|
|
"id": "419d603a",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"CPU times: user 355 ms, sys: 13 ms, total: 368 ms\n",
|
|
"Wall time: 367 ms\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"%%time\n",
|
|
"\n",
|
|
"from sklearn.preprocessing import LabelBinarizer\n",
|
|
"\n",
|
|
"\n",
|
|
"lb = LabelBinarizer()\n",
|
|
"yy_train = lb.fit_transform(y_train)\n",
|
|
"yy_test = lb.transform(y_test)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 30,
|
|
"id": "da224750",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"for e in test:\n",
|
|
" e['label'] = lb.transform([e['label']])\n",
|
|
" e['data'] = np.array(e['data'])\n",
|
|
"\n",
|
|
" \n",
|
|
"for e in train:\n",
|
|
" e['label'] = lb.transform([e['label']])\n",
|
|
" e['data'] = np.array(e['data'])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 31,
|
|
"id": "073c2c51",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"(57800, 5, 336)\n",
|
|
"(57800, 16)\n",
|
|
"(37106, 5, 336)\n",
|
|
"(37106, 16)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(X_train.shape)\n",
|
|
"print(yy_train.shape)\n",
|
|
"print(X_test.shape)\n",
|
|
"print(yy_test.shape)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "cee9b1c3",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Building Model"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 32,
|
|
"id": "75c9ba6d",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import tensorflow as tf\n",
|
|
"from tensorflow.keras.regularizers import l2\n",
|
|
"from tensorflow.keras.models import Sequential\n",
|
|
"from tensorflow.keras.layers import Dense, Flatten, BatchNormalization, Dropout\n",
|
|
"from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau\n",
|
|
"from tensorflow.keras.optimizers import Adam\n",
|
|
"\n",
|
|
"def build_model(shape, classes):\n",
|
|
" model = Sequential()\n",
|
|
" \n",
|
|
" ncount = shape[0]*shape[1]\n",
|
|
" \n",
|
|
" model.add(Flatten(input_shape=shape, name='flatten'))\n",
|
|
" \n",
|
|
" model.add(Dropout(drop_count, name=f'dropout_{drop_count*100}'))\n",
|
|
" model.add(BatchNormalization(name='batchNorm'))\n",
|
|
" \n",
|
|
" for i in range(2,layer_count+2):\n",
|
|
" neurons = int(ncount/pow(dense_steps,i))\n",
|
|
" if neurons <= classes:\n",
|
|
" break\n",
|
|
" model.add(Dropout(drop_count*i, name=f'HiddenDropout_{drop_count*i*100:.0f}'))\n",
|
|
" model.add(Dense(neurons, activation='relu', \n",
|
|
" kernel_regularizer=l2(0.001), name=f'Hidden_{i}')\n",
|
|
" )\n",
|
|
" \n",
|
|
" model.add(Dense(classes, activation='softmax', name='Output'))\n",
|
|
" \n",
|
|
" model.compile(\n",
|
|
" optimizer=Adam(),\n",
|
|
" loss=\"categorical_crossentropy\", \n",
|
|
" metrics=[\"acc\"],\n",
|
|
" )\n",
|
|
" \n",
|
|
" model.summary()\n",
|
|
" return model"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 33,
|
|
"id": "8f71c4bf",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"checkpoint_file = './goat.weights'\n",
|
|
"\n",
|
|
"def train_model(X_train, y_train, X_test, y_test):\n",
|
|
" model = build_model(X_train[0].shape, 16)\n",
|
|
" \n",
|
|
" # Create a callback that saves the model's weights\n",
|
|
" model_checkpoint = ModelCheckpoint(filepath=checkpoint_path, monitor='loss', \n",
|
|
"\t\t\tsave_best_only=True)\n",
|
|
" \n",
|
|
" reduce_lr = ReduceLROnPlateau(monitor='loss', factor=0.5, patience=5, min_lr=0.0001)\n",
|
|
"\n",
|
|
" callbacks = [model_checkpoint, reduce_lr]\n",
|
|
" \n",
|
|
" history = model.fit(X_train, \n",
|
|
" y_train,\n",
|
|
" epochs=epoch,\n",
|
|
" batch_size=32,\n",
|
|
" verbose=2,\n",
|
|
" validation_data=(X_test, y_test),\n",
|
|
" callbacks=callbacks\n",
|
|
" )\n",
|
|
" \n",
|
|
" model.load_weights(checkpoint_path)\n",
|
|
" return model, history"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 34,
|
|
"id": "77e0fc90",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Loaded weights...\n",
|
|
"Model: \"sequential\"\n",
|
|
"_________________________________________________________________\n",
|
|
"Layer (type) Output Shape Param # \n",
|
|
"=================================================================\n",
|
|
"flatten (Flatten) (None, 1680) 0 \n",
|
|
"_________________________________________________________________\n",
|
|
"dropout_10.0 (Dropout) (None, 1680) 0 \n",
|
|
"_________________________________________________________________\n",
|
|
"batchNorm (BatchNormalizatio (None, 1680) 6720 \n",
|
|
"_________________________________________________________________\n",
|
|
"HiddenDropout_20 (Dropout) (None, 1680) 0 \n",
|
|
"_________________________________________________________________\n",
|
|
"Hidden_2 (Dense) (None, 186) 312666 \n",
|
|
"_________________________________________________________________\n",
|
|
"HiddenDropout_30 (Dropout) (None, 186) 0 \n",
|
|
"_________________________________________________________________\n",
|
|
"Hidden_3 (Dense) (None, 62) 11594 \n",
|
|
"_________________________________________________________________\n",
|
|
"HiddenDropout_40 (Dropout) (None, 62) 0 \n",
|
|
"_________________________________________________________________\n",
|
|
"Hidden_4 (Dense) (None, 20) 1260 \n",
|
|
"_________________________________________________________________\n",
|
|
"Output (Dense) (None, 16) 336 \n",
|
|
"=================================================================\n",
|
|
"Total params: 332,576\n",
|
|
"Trainable params: 329,216\n",
|
|
"Non-trainable params: 3,360\n",
|
|
"_________________________________________________________________\n",
|
|
"CPU times: user 80.5 ms, sys: 3.3 ms, total: 83.8 ms\n",
|
|
"Wall time: 79.5 ms\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"%%time\n",
|
|
"\n",
|
|
"if not os.path.isdir(checkpoint_dir) or create_new:\n",
|
|
" tf.keras.backend.clear_session()\n",
|
|
" model, history = train_model(np.array(X_train), np.array(yy_train), np.array(X_test), np.array(yy_test))\n",
|
|
"else:\n",
|
|
" print(\"Loaded weights...\")\n",
|
|
" model = build_model(X_train[0].shape, 16)\n",
|
|
" model.load_weights(checkpoint_path)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "f2e6f8ad",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Eval"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 35,
|
|
"id": "b7ede2b1",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def predict(model, entry):\n",
|
|
" p_dict = dict()\n",
|
|
" predictions = np.argmax(model.predict(entry), axis=-1)\n",
|
|
" for p in predictions:\n",
|
|
" if p in p_dict:\n",
|
|
" p_dict[p] += 1\n",
|
|
" else:\n",
|
|
" p_dict[p] = 1\n",
|
|
" prediction = max(p_dict, key=p_dict.get)\n",
|
|
" return prediction+1"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 36,
|
|
"id": "a71bb247",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"CPU times: user 3.36 s, sys: 529 ms, total: 3.89 s\n",
|
|
"Wall time: 2.95 s\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(43, 43)"
|
|
]
|
|
},
|
|
"execution_count": 36,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"%%time\n",
|
|
"\n",
|
|
"ltest = [lb.inverse_transform(e['label'])[0] for e in test]\n",
|
|
"ptest = [predict(model, e['data']) for e in test]\n",
|
|
"\n",
|
|
"len(ltest), len(ptest)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 37,
|
|
"id": "ab3ecfc9",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"CPU times: user 3.85 s, sys: 448 ms, total: 4.3 s\n",
|
|
"Wall time: 2.99 s\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(47, 47)"
|
|
]
|
|
},
|
|
"execution_count": 37,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"%%time\n",
|
|
"\n",
|
|
"ltrain = [lb.inverse_transform(e['label'])[0] for e in train]\n",
|
|
"ptrain = [predict(model, e['data']) for e in train]\n",
|
|
"\n",
|
|
"\n",
|
|
"len(ltrain), len(ptrain)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 38,
|
|
"id": "ac226caa",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},\n",
|
|
" {1, 2, 3, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16})"
|
|
]
|
|
},
|
|
"execution_count": 38,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"set(ltrain), set(ltest)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 39,
|
|
"id": "3c3bac5d",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/png": "\n",
|
|
"text/plain": [
|
|
"<Figure size 720x504 with 2 Axes>"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"needs_background": "light"
|
|
},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
" precision recall f1-score support\n",
|
|
"\n",
|
|
" 1 1.00 0.67 0.80 3\n",
|
|
" 2 0.00 0.00 0.00 1\n",
|
|
" 3 0.43 1.00 0.60 3\n",
|
|
" 4 0.00 0.00 0.00 0\n",
|
|
" 5 0.00 0.00 0.00 3\n",
|
|
" 6 0.50 0.33 0.40 3\n",
|
|
" 7 1.00 1.00 1.00 3\n",
|
|
" 8 1.00 0.33 0.50 3\n",
|
|
" 9 0.00 0.00 0.00 3\n",
|
|
" 10 0.40 0.67 0.50 3\n",
|
|
" 11 0.00 0.00 0.00 3\n",
|
|
" 12 0.60 1.00 0.75 3\n",
|
|
" 13 0.75 1.00 0.86 3\n",
|
|
" 14 0.00 0.00 0.00 3\n",
|
|
" 15 0.00 0.00 0.00 3\n",
|
|
" 16 0.50 1.00 0.67 3\n",
|
|
"\n",
|
|
" accuracy 0.49 43\n",
|
|
" macro avg 0.39 0.44 0.38 43\n",
|
|
"weighted avg 0.43 0.49 0.42 43\n",
|
|
"\n",
|
|
"CPU times: user 646 ms, sys: 195 ms, total: 840 ms\n",
|
|
"Wall time: 610 ms\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"%%time\n",
|
|
"\n",
|
|
"from sklearn.metrics import confusion_matrix\n",
|
|
"import seaborn as sn\n",
|
|
"\n",
|
|
"from sklearn.metrics import classification_report\n",
|
|
"\n",
|
|
"set_digits = set(ltrain)\n",
|
|
"\n",
|
|
"train_cm = confusion_matrix(ltrain, ptrain, labels=list(set_digits), normalize='true')\n",
|
|
"test_cm = confusion_matrix(ltest, ptest, labels=list(set_digits), normalize='true')\n",
|
|
"\n",
|
|
"df_cm = pd.DataFrame(test_cm, index=set_digits, columns=set_digits)\n",
|
|
"plt.figure(figsize = (10,7))\n",
|
|
"sn_plot = sn.heatmap(df_cm, annot=True, cmap=\"Greys\")\n",
|
|
"plt.ylabel(\"True Label\")\n",
|
|
"plt.xlabel(\"Predicted Label\")\n",
|
|
"plt.show()\n",
|
|
"\n",
|
|
"print(classification_report(ltest, ptest, zero_division=0))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 40,
|
|
"id": "43acba77",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def plot_keras_history(history, name='', acc='acc'):\n",
|
|
" \"\"\"Plots keras history.\"\"\"\n",
|
|
" import matplotlib.pyplot as plt\n",
|
|
"\n",
|
|
" training_acc = history.history[acc]\n",
|
|
" validation_acc = history.history['val_' + acc]\n",
|
|
" loss = history.history['loss']\n",
|
|
" val_loss = history.history['val_loss']\n",
|
|
"\n",
|
|
" epochs = range(len(training_acc))\n",
|
|
"\n",
|
|
" plt.ylim(0, 1)\n",
|
|
" plt.plot(epochs, training_acc, 'tab:blue', label='Training acc')\n",
|
|
" plt.plot(epochs, validation_acc, 'tab:orange', label='Validation acc')\n",
|
|
" plt.title('Training and validation accuracy ' + name)\n",
|
|
" plt.legend()\n",
|
|
"\n",
|
|
" plt.figure()\n",
|
|
"\n",
|
|
" plt.plot(epochs, loss, 'tab:green', label='Training loss')\n",
|
|
" plt.plot(epochs, val_loss, 'tab:red', label='Validation loss')\n",
|
|
" plt.title('Training and validation loss ' + name)\n",
|
|
" plt.legend()\n",
|
|
" plt.show()\n",
|
|
" plt.close()\n",
|
|
"if 'history' in locals():\n",
|
|
" plot_keras_history(history)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 41,
|
|
"id": "af999e08",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Scenario: SYN\n",
|
|
"Window Size: 5\n",
|
|
"Strides: 1\n",
|
|
"Epochs: 50\n",
|
|
"HiddenL Count: 3\n",
|
|
"Neuron Factor: 3\n",
|
|
"Drop Factor: 0.1\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(f'Scenario: {cenario}')\n",
|
|
"print(f'Window Size: {win_sz}')\n",
|
|
"print(f'Strides: {stride_sz}')\n",
|
|
"print(f'Epochs: {epoch}')\n",
|
|
"print(f'HiddenL Count: {layer_count}')\n",
|
|
"print(f'Neuron Factor: {dense_steps}')\n",
|
|
"print(f'Drop Factor: {drop_count}')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 42,
|
|
"id": "b16af0c6",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"exit()"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.8.10"
|
|
},
|
|
"toc-showtags": false
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|