1130 lines
61 KiB
Plaintext
1130 lines
61 KiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "b6131d61",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Constants"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 1,
|
||
|
"id": "6144a350",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"import os\n",
|
||
|
"\n",
|
||
|
"os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' # this is required\n",
|
||
|
"os.environ['CUDA_VISIBLE_DEVICES'] = '2' # set to '0' for GPU0, '1' for GPU1 or '2' for GPU2. Check \"gpustat\" in a terminal."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 2,
|
||
|
"id": "7aa3948f",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"glob_path = '/opt/iui-datarelease3-sose2021/*.csv'\n",
|
||
|
"\n",
|
||
|
"pickle_file = '../data.pickle'"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "89eb31ab",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Config"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 3,
|
||
|
"id": "e2be13e5",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Possibilities: 'SYY', 'SYN', 'SNY', 'SNN', \n",
|
||
|
"# 'JYY', 'JYN', 'JNY', 'JNN'\n",
|
||
|
"cenario = 'SYN' \n",
|
||
|
"\n",
|
||
|
"win_sz = 10\n",
|
||
|
"stride_sz = 5\n",
|
||
|
"\n",
|
||
|
"# divisor for neuron count step downs (hard to describe), e.g. dense_step = 3: layer1=900, layer2 = 300, layer3 = 100, layer4 = 33...\n",
|
||
|
"dense_steps = 3\n",
|
||
|
"# amount of dense/dropout layers\n",
|
||
|
"layer_count = 3\n",
|
||
|
"# how much to drop\n",
|
||
|
"drop_count = 0.1"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "0eca097f",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Helper Functions"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 4,
|
||
|
"id": "82014801",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"from matplotlib import pyplot as plt\n",
|
||
|
"\n",
|
||
|
"def pplot(dd):\n",
|
||
|
" x = dd.shape[0]\n",
|
||
|
" fix = int(x/3)+1\n",
|
||
|
" fiy = 3\n",
|
||
|
" fig, axs = plt.subplots(fix, fiy, figsize=(3*fiy, 9*fix))\n",
|
||
|
" \n",
|
||
|
" for i in range(x):\n",
|
||
|
" axs[int(i/3)][i%3].plot(dd[i])"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "e8d9944b",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Loading Data"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 5,
|
||
|
"id": "0ae277a0",
|
||
|
"metadata": {
|
||
|
"tags": []
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"from glob import glob\n",
|
||
|
"import pandas as pd\n",
|
||
|
"from tqdm import tqdm\n",
|
||
|
"\n",
|
||
|
"def dl_from_blob(filename, user_filter=None):\n",
|
||
|
" \n",
|
||
|
" dic_data = []\n",
|
||
|
" \n",
|
||
|
" for p in tqdm(glob(glob_path)):\n",
|
||
|
" path = p\n",
|
||
|
" filename = path.split('/')[-1].split('.')[0]\n",
|
||
|
" splitname = filename.split('_')\n",
|
||
|
" user = int(splitname[0][1:])\n",
|
||
|
" if (user_filter):\n",
|
||
|
" if (user != user_filter):\n",
|
||
|
" continue\n",
|
||
|
" scenario = splitname[1][len('Scenario'):]\n",
|
||
|
" heightnorm = splitname[2][len('HeightNormalization'):] == 'True'\n",
|
||
|
" armnorm = splitname[3][len('ArmNormalization'):] == 'True'\n",
|
||
|
" rep = int(splitname[4][len('Repetition'):])\n",
|
||
|
" session = int(splitname[5][len('Session'):])\n",
|
||
|
" data = pd.read_csv(path)\n",
|
||
|
" dic_data.append(\n",
|
||
|
" {\n",
|
||
|
" 'filename': path,\n",
|
||
|
" 'user': user,\n",
|
||
|
" 'scenario': scenario,\n",
|
||
|
" 'heightnorm': heightnorm,\n",
|
||
|
" 'armnorm': armnorm,\n",
|
||
|
" 'rep': rep,\n",
|
||
|
" 'session': session,\n",
|
||
|
" 'data': data \n",
|
||
|
" }\n",
|
||
|
" )\n",
|
||
|
" return dic_data"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 6,
|
||
|
"id": "a7b4f994",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"import pickle\n",
|
||
|
"\n",
|
||
|
"def save_pickle(f, structure):\n",
|
||
|
" _p = open(f, 'wb')\n",
|
||
|
" pickle.dump(structure, _p)\n",
|
||
|
" _p.close()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 7,
|
||
|
"id": "b6b6fa69",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def load_pickles(f) -> list:\n",
|
||
|
" _p = open(pickle_file, 'rb')\n",
|
||
|
" _d = pickle.load(_p)\n",
|
||
|
" _p.close()\n",
|
||
|
" \n",
|
||
|
" return _d"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 8,
|
||
|
"id": "38d131f0",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Loading data...\n",
|
||
|
"../data.pickle found...\n",
|
||
|
"768\n",
|
||
|
"CPU times: user 596 ms, sys: 2.15 s, total: 2.75 s\n",
|
||
|
"Wall time: 2.75 s\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"%%time\n",
|
||
|
"\n",
|
||
|
"def load_data() -> list:\n",
|
||
|
" if os.path.isfile(pickle_file):\n",
|
||
|
" print(f'{pickle_file} found...')\n",
|
||
|
" return load_pickles(pickle_file)\n",
|
||
|
" print(f'Didn\\'t find {pickle_file}...')\n",
|
||
|
" all_data = dl_from_blob(glob_path)\n",
|
||
|
" print(f'Creating {pickle_file}...')\n",
|
||
|
" save_pickle(pickle_file, all_data)\n",
|
||
|
" return all_data\n",
|
||
|
"\n",
|
||
|
"print(\"Loading data...\")\n",
|
||
|
"dic_data = load_data()\n",
|
||
|
"print(len(dic_data))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 9,
|
||
|
"id": "967f81ef",
|
||
|
"metadata": {
|
||
|
"tags": []
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"CPU times: user 398 µs, sys: 0 ns, total: 398 µs\n",
|
||
|
"Wall time: 402 µs\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"%%time\n",
|
||
|
"\n",
|
||
|
"# Categorized Data\n",
|
||
|
"cdata = dict() \n",
|
||
|
"# Sorting, HeightNorm, ArmNorm\n",
|
||
|
"cdata['SYY'] = list() \n",
|
||
|
"cdata['SYN'] = list() \n",
|
||
|
"cdata['SNY'] = list() \n",
|
||
|
"cdata['SNN'] = list() \n",
|
||
|
"\n",
|
||
|
"# Jenga, HeightNorm, ArmNorm\n",
|
||
|
"cdata['JYY'] = list() \n",
|
||
|
"cdata['JYN'] = list() \n",
|
||
|
"cdata['JNY'] = list() \n",
|
||
|
"cdata['JNN'] = list() \n",
|
||
|
"for d in dic_data:\n",
|
||
|
" if d['scenario'] == 'Sorting':\n",
|
||
|
" if d['heightnorm']:\n",
|
||
|
" if d['armnorm']:\n",
|
||
|
" cdata['SYY'].append(d)\n",
|
||
|
" else:\n",
|
||
|
" cdata['SYN'].append(d)\n",
|
||
|
" else:\n",
|
||
|
" if d['armnorm']:\n",
|
||
|
" cdata['SNY'].append(d)\n",
|
||
|
" else:\n",
|
||
|
" cdata['SNN'].append(d)\n",
|
||
|
" elif d['scenario'] == 'Jenga':\n",
|
||
|
" if d['heightnorm']:\n",
|
||
|
" if d['armnorm']:\n",
|
||
|
" cdata['JYY'].append(d)\n",
|
||
|
" else:\n",
|
||
|
" cdata['JYN'].append(d)\n",
|
||
|
" else:\n",
|
||
|
" if d['armnorm']:\n",
|
||
|
" cdata['JNY'].append(d)\n",
|
||
|
" else:\n",
|
||
|
" cdata['JNN'].append(d)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "588af385",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Preprocessing"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 10,
|
||
|
"id": "375fee1d",
|
||
|
"metadata": {
|
||
|
"tags": []
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def drop(entry) -> pd.DataFrame:\n",
|
||
|
" droptable = ['participantID', 'FrameID', 'Scenario', 'HeightNormalization', 'ArmNormalization', 'Repetition', 'Session', 'Unnamed: 0']\n",
|
||
|
" centry = pickle.loads(pickle.dumps(entry))\n",
|
||
|
" return centry['data'].drop(droptable, axis=1)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 11,
|
||
|
"id": "e2b0b2fc",
|
||
|
"metadata": {
|
||
|
"tags": []
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def floatize(entry) -> pd.DataFrame:\n",
|
||
|
" centry = pickle.loads(pickle.dumps(entry))\n",
|
||
|
" centry['data']['LeftHandTrackingAccuracy'] = (entry['data']['LeftHandTrackingAccuracy'] == 'High') * 1.0\n",
|
||
|
" centry['data']['RightHandTrackingAccuracy'] = (entry['data']['RightHandTrackingAccuracy'] == 'High') * 1.0\n",
|
||
|
" return centry['data']"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 12,
|
||
|
"id": "9785e9f0",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"import numpy as np\n",
|
||
|
"right_Hand_ident='right_Hand'\n",
|
||
|
"left_Hand_ident='left_hand'\n",
|
||
|
"\n",
|
||
|
"def rem_low_acc(entry) -> pd.DataFrame:\n",
|
||
|
" centry = pickle.loads(pickle.dumps(entry))\n",
|
||
|
" right_Hand_cols = [c for c in centry['data'] if right_Hand_ident in c]\n",
|
||
|
" left_Hand_cols = [c for c in centry['data'] if left_Hand_ident in c]\n",
|
||
|
" \n",
|
||
|
" centry['data'].loc[centry['data']['RightHandTrackingAccuracy'] == 0.0, right_Hand_cols] = np.nan\n",
|
||
|
" centry['data'].loc[centry['data']['LeftHandTrackingAccuracy'] == 0.0, left_Hand_cols] = np.nan\n",
|
||
|
" return centry['data']"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 13,
|
||
|
"id": "3ec4cc1d",
|
||
|
"metadata": {
|
||
|
"tags": []
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"from tensorflow.keras.preprocessing.sequence import pad_sequences\n",
|
||
|
"\n",
|
||
|
"stride = 150\n",
|
||
|
"def pad(entry) -> pd.DataFrame:\n",
|
||
|
" centry = pickle.loads(pickle.dumps(entry))\n",
|
||
|
" cols = centry['data'].columns\n",
|
||
|
" pentry = pad_sequences(centry['data'].T.to_numpy(),\n",
|
||
|
" maxlen=(int(centry['data'].shape[0]/stride)+1)*stride,\n",
|
||
|
" dtype='float64',\n",
|
||
|
" padding='pre', \n",
|
||
|
" truncating='post',\n",
|
||
|
" value=np.nan\n",
|
||
|
" ) \n",
|
||
|
" pdentry = pd.DataFrame(pentry.T, columns=cols)\n",
|
||
|
" pdentry.loc[0] = [0 for _ in cols]\n",
|
||
|
" return pdentry"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 14,
|
||
|
"id": "0361a89c",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def interpol(entry) -> pd.DataFrame:\n",
|
||
|
" centry = pickle.loads(pickle.dumps(entry))\n",
|
||
|
" return centry['data'].interpolate(method='linear', axis=0)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 15,
|
||
|
"id": "52b62534",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"from tensorflow.keras.preprocessing import timeseries_dataset_from_array\n",
|
||
|
"\n",
|
||
|
"def slicing(entry):\n",
|
||
|
" centry = pickle.loads(pickle.dumps(entry))\n",
|
||
|
" return timeseries_dataset_from_array(\n",
|
||
|
" data=centry['data'], \n",
|
||
|
" targets=[centry['user'] for _ in range(centry['data'].shape[0])], \n",
|
||
|
" sequence_length=win_sz,\n",
|
||
|
" sequence_stride=stride_sz, \n",
|
||
|
" batch_size=8, \n",
|
||
|
" seed=177013\n",
|
||
|
" )"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 16,
|
||
|
"id": "383a8b9f",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"from sklearn.preprocessing import (StandardScaler, \n",
|
||
|
" MinMaxScaler, \n",
|
||
|
" MaxAbsScaler,\n",
|
||
|
" PowerTransformer,\n",
|
||
|
" Binarizer)\n",
|
||
|
"def scaling(entry,scale):\n",
|
||
|
" \n",
|
||
|
" standard = StandardScaler()\n",
|
||
|
" max_Abs = MaxAbsScaler()\n",
|
||
|
" binarizer = Binarizer()\n",
|
||
|
" entry = entry.to_numpy(dtype=np.float64)\n",
|
||
|
" \n",
|
||
|
" if (scale == 0 ):\n",
|
||
|
" entry = min_Max.fit_transform(entry)\n",
|
||
|
" \n",
|
||
|
" if (scale == 1 ):\n",
|
||
|
" for i in entry:\n",
|
||
|
" entry = standard.fit_transform(entry)\n",
|
||
|
" \n",
|
||
|
" if (scale == 2 ):\n",
|
||
|
" for i in entry:\n",
|
||
|
" entry = max_Abs.fit_transform(entry)\n",
|
||
|
" \n",
|
||
|
" if (scale == 3 ):\n",
|
||
|
" for i in entry:\n",
|
||
|
" entry = binarizer.fit_transform(entry)\n",
|
||
|
" return pd.DataFrame(entry)\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"def minScale(entry):\n",
|
||
|
" entry = entry.to_numpy(dtype=np.float64)\n",
|
||
|
" min_Max = MinMaxScaler()\n",
|
||
|
" entry = min_Max.fit_transform(entry)\n",
|
||
|
" return pd.DataFrame(entry)\n",
|
||
|
" \n",
|
||
|
"\n",
|
||
|
"def stanScale(entry):\n",
|
||
|
" entry = entry.to_numpy(dtype=np.float64)\n",
|
||
|
" standard = StandardScaler()\n",
|
||
|
" entry = standard.fit_transform(entry)\n",
|
||
|
" return pd.DataFrame(entry)\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
" \n",
|
||
|
"def maxScale(entry):\n",
|
||
|
" entry = entry.to_numpy(dtype=np.float64)\n",
|
||
|
" binarizer = Binarizer()\n",
|
||
|
" entry = binarizer.fit_transform(entry)\n",
|
||
|
" return pd.DataFrame(entry)\n",
|
||
|
" \n",
|
||
|
" \n",
|
||
|
"\n",
|
||
|
"def binScale(entry):\n",
|
||
|
" entry = entry.to_numpy(dtype=np.float64)\n",
|
||
|
" min_Max = MinMaxScaler()\n",
|
||
|
" entry = min_Max.fit_transform(entry)\n",
|
||
|
" return pd.DataFrame(entry)\n",
|
||
|
" "
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 17,
|
||
|
"id": "29134efc",
|
||
|
"metadata": {
|
||
|
"tags": []
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"100%|██████████| 96/96 [00:16<00:00, 5.95it/s]\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"from sklearn.preprocessing import (StandardScaler, \n",
|
||
|
" MinMaxScaler, \n",
|
||
|
" MaxAbsScaler,\n",
|
||
|
" PowerTransformer,\n",
|
||
|
" Binarizer)\n",
|
||
|
"\n",
|
||
|
"#%%time\n",
|
||
|
"\n",
|
||
|
"classes = 16 # dynamic\n",
|
||
|
"\n",
|
||
|
"def preproc(data):\n",
|
||
|
" res_list = list()\n",
|
||
|
" temp_list= list()\n",
|
||
|
" for e in tqdm(data):\n",
|
||
|
" res_list.append(preproc_entry(e))\n",
|
||
|
"# for a in tqdm(temp_list):\n",
|
||
|
"# res_list.append(preproc_entry(a))\n",
|
||
|
"# \n",
|
||
|
" return res_list\n",
|
||
|
" \n",
|
||
|
"def preproc_entry(entry):\n",
|
||
|
" entry2 = pickle.loads(pickle.dumps(entry))\n",
|
||
|
" entry2['data'] = drop(entry2)\n",
|
||
|
" \n",
|
||
|
" entry3 = pickle.loads(pickle.dumps(entry2))\n",
|
||
|
" entry3['data'] = floatize(entry3)\n",
|
||
|
" \n",
|
||
|
" entry4 = pickle.loads(pickle.dumps(entry3))\n",
|
||
|
" entry4['data'] = rem_low_acc(entry4)\n",
|
||
|
" \n",
|
||
|
" \n",
|
||
|
" \n",
|
||
|
" entry5 = pickle.loads(pickle.dumps(entry4))\n",
|
||
|
" entry5['data'] = pad(entry5)\n",
|
||
|
" \n",
|
||
|
" entry6 = pickle.loads(pickle.dumps(entry5))\n",
|
||
|
" entry6['data'] = interpol(entry6)\n",
|
||
|
" \n",
|
||
|
" entry8 = pickle.loads(pickle.dumps(entry6))\n",
|
||
|
" entry8['data'] = minScale(entry8['data']) # 0 = minmax, 1 = standard, 2 = maxabs, 3 = binarizer\n",
|
||
|
" \n",
|
||
|
" entry7 = pickle.loads(pickle.dumps(entry8))\n",
|
||
|
" entry7['data'] = slicing(entry7)\n",
|
||
|
" \n",
|
||
|
" \n",
|
||
|
" \n",
|
||
|
" \n",
|
||
|
" return entry7\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"pdata = preproc(cdata[cenario])\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "f8157b26",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Building Model"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 18,
|
||
|
"id": "2eb9c242",
|
||
|
"metadata": {
|
||
|
"tags": []
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"import tensorflow as tf\n",
|
||
|
"from tensorflow.keras.models import Sequential\n",
|
||
|
"from tensorflow.keras.layers import Dense, Flatten, BatchNormalization, Dropout, LSTM\n",
|
||
|
"import tensorflow.keras as keras\n",
|
||
|
"\n",
|
||
|
"def build_model(shape, classes):\n",
|
||
|
" \n",
|
||
|
" model = Sequential()\n",
|
||
|
" ncount = shape[0]*shape[1]\n",
|
||
|
" \n",
|
||
|
" model.add(Flatten(input_shape=shape))\n",
|
||
|
" \n",
|
||
|
" model.add(Dropout(drop_count))\n",
|
||
|
" model.add(BatchNormalization())\n",
|
||
|
" \n",
|
||
|
" for i in range(1,layer_count):\n",
|
||
|
" neurons = int(ncount/pow(dense_steps,i))\n",
|
||
|
" if neurons <= classes*dense_steps:\n",
|
||
|
" break\n",
|
||
|
" model.add(Dropout(drop_count*i))\n",
|
||
|
" model.add(Dense(neurons, activation='relu'))\n",
|
||
|
" \n",
|
||
|
" model.add(Dense(classes, activation='softmax'))\n",
|
||
|
"\n",
|
||
|
" model.compile(\n",
|
||
|
" optimizer=tf.keras.optimizers.Adam(0.001),\n",
|
||
|
" loss=\"categorical_crossentropy\", \n",
|
||
|
" metrics=[\"acc\"],\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
" return model\n",
|
||
|
"\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 19,
|
||
|
"id": "eb3212ae",
|
||
|
"metadata": {
|
||
|
"tags": []
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"checkpoint_file = './goat.weights'\n",
|
||
|
"\n",
|
||
|
"def train_model(X_train, y_train, X_test, y_test):\n",
|
||
|
" model = build_model(X_train[0].shape, 16)\n",
|
||
|
" \n",
|
||
|
" model.summary()\n",
|
||
|
"\n",
|
||
|
" model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(\n",
|
||
|
" filepath = checkpoint_file,\n",
|
||
|
" save_weights_only=True,\n",
|
||
|
" monitor='val_acc',\n",
|
||
|
" mode='max',\n",
|
||
|
" save_best_only=True\n",
|
||
|
" )\n",
|
||
|
" \n",
|
||
|
" history = model.fit(X_train, \n",
|
||
|
" y_train,\n",
|
||
|
" epochs=30,\n",
|
||
|
" batch_size=128,\n",
|
||
|
" shuffle=True,\n",
|
||
|
" verbose=2,\n",
|
||
|
" validation_data=(X_test, y_test),\n",
|
||
|
" callbacks=[model_checkpoint_callback]\n",
|
||
|
" \n",
|
||
|
" )\n",
|
||
|
" return model, history\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
" "
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 20,
|
||
|
"id": "cb296665",
|
||
|
"metadata": {
|
||
|
"tags": []
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"CPU times: user 375 µs, sys: 0 ns, total: 375 µs\n",
|
||
|
"Wall time: 396 µs\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"(48, 48)"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 20,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"%%time\n",
|
||
|
"train = np.array([x['data'] for x in pdata if x['session'] == 1])\n",
|
||
|
"test = np.array([x['data'] for x in pdata if x['session'] == 2])\n",
|
||
|
"\n",
|
||
|
"len(train), len(test)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 21,
|
||
|
"id": "bf378c00",
|
||
|
"metadata": {
|
||
|
"tags": []
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"CPU times: user 25.7 s, sys: 6.87 s, total: 32.6 s\n",
|
||
|
"Wall time: 9.2 s\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"%%time\n",
|
||
|
"\n",
|
||
|
"X_train = list()\n",
|
||
|
"y_train = list()\n",
|
||
|
"\n",
|
||
|
"X_test = list()\n",
|
||
|
"y_test = list()\n",
|
||
|
"\n",
|
||
|
"train = list()\n",
|
||
|
"test = list()\n",
|
||
|
"\n",
|
||
|
"for x in pdata:\n",
|
||
|
" if x['session'] == 1:\n",
|
||
|
" train.append(\n",
|
||
|
" {\n",
|
||
|
" 'label': x['user'],\n",
|
||
|
" 'data': list()\n",
|
||
|
" })\n",
|
||
|
" for y in x['data'].unbatch().as_numpy_iterator():\n",
|
||
|
" X_train.append(y[0])\n",
|
||
|
" y_train.append(y[1])\n",
|
||
|
" \n",
|
||
|
" train[-1]['data'].append(y[0])\n",
|
||
|
" if x['session'] == 2:\n",
|
||
|
" test.append(\n",
|
||
|
" {\n",
|
||
|
" 'label': x['user'],\n",
|
||
|
" 'data': list()\n",
|
||
|
" })\n",
|
||
|
" for y in x['data'].unbatch().as_numpy_iterator():\n",
|
||
|
" X_test.append(y[0])\n",
|
||
|
" y_test.append(y[1])\n",
|
||
|
" \n",
|
||
|
" test[-1]['data'].append(y[0])\n",
|
||
|
"\n",
|
||
|
"X_train = np.array(X_train)\n",
|
||
|
"y_train = np.array(y_train)\n",
|
||
|
"X_test = np.array(X_test)\n",
|
||
|
"y_test = np.array(y_test)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 22,
|
||
|
"id": "fdb1b754",
|
||
|
"metadata": {
|
||
|
"tags": []
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"((30432, 10, 338), (30432,), (20502, 10, 338), (20502,))"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 22,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"X_train.shape, y_train.shape, X_test.shape, y_test.shape"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 23,
|
||
|
"id": "4b29f6dd",
|
||
|
"metadata": {
|
||
|
"tags": []
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"CPU times: user 241 ms, sys: 116 ms, total: 358 ms\n",
|
||
|
"Wall time: 357 ms\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"%%time\n",
|
||
|
"\n",
|
||
|
"from sklearn.preprocessing import LabelBinarizer\n",
|
||
|
"\n",
|
||
|
"lb = LabelBinarizer()\n",
|
||
|
"yy_train = lb.fit_transform(y_train)\n",
|
||
|
"yy_test = lb.fit_transform(y_test)\n",
|
||
|
"\n",
|
||
|
"for e in test:\n",
|
||
|
" e['label'] = lb.transform([e['label']])\n",
|
||
|
" e['data'] = np.array(e['data'])\n",
|
||
|
" \n",
|
||
|
"for e in train:\n",
|
||
|
" e['label'] = lb.transform([e['label']])\n",
|
||
|
" e['data'] = np.array(e['data'])\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 24,
|
||
|
"id": "e50d9d82",
|
||
|
"metadata": {
|
||
|
"tags": []
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"(30432, 10, 338)\n",
|
||
|
"(30432, 16)\n",
|
||
|
"(20502, 10, 338)\n",
|
||
|
"(20502, 16)\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"print(X_train.shape)\n",
|
||
|
"print(yy_train.shape)\n",
|
||
|
"print(X_test.shape)\n",
|
||
|
"print(yy_test.shape)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 25,
|
||
|
"id": "29cab8e3",
|
||
|
"metadata": {
|
||
|
"tags": []
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Model: \"sequential\"\n",
|
||
|
"_________________________________________________________________\n",
|
||
|
"Layer (type) Output Shape Param # \n",
|
||
|
"=================================================================\n",
|
||
|
"flatten (Flatten) (None, 3380) 0 \n",
|
||
|
"_________________________________________________________________\n",
|
||
|
"dropout (Dropout) (None, 3380) 0 \n",
|
||
|
"_________________________________________________________________\n",
|
||
|
"batch_normalization (BatchNo (None, 3380) 13520 \n",
|
||
|
"_________________________________________________________________\n",
|
||
|
"dropout_1 (Dropout) (None, 3380) 0 \n",
|
||
|
"_________________________________________________________________\n",
|
||
|
"dense (Dense) (None, 1126) 3807006 \n",
|
||
|
"_________________________________________________________________\n",
|
||
|
"dropout_2 (Dropout) (None, 1126) 0 \n",
|
||
|
"_________________________________________________________________\n",
|
||
|
"dense_1 (Dense) (None, 375) 422625 \n",
|
||
|
"_________________________________________________________________\n",
|
||
|
"dense_2 (Dense) (None, 16) 6016 \n",
|
||
|
"=================================================================\n",
|
||
|
"Total params: 4,249,167\n",
|
||
|
"Trainable params: 4,242,407\n",
|
||
|
"Non-trainable params: 6,760\n",
|
||
|
"_________________________________________________________________\n",
|
||
|
"Epoch 1/30\n",
|
||
|
"238/238 - 2s - loss: 0.2785 - acc: 0.9139 - val_loss: 7.5618 - val_acc: 0.1295\n",
|
||
|
"Epoch 2/30\n",
|
||
|
"238/238 - 1s - loss: 0.0757 - acc: 0.9780 - val_loss: 9.8776 - val_acc: 0.1766\n",
|
||
|
"Epoch 3/30\n",
|
||
|
"238/238 - 1s - loss: 0.0565 - acc: 0.9830 - val_loss: 12.0728 - val_acc: 0.1515\n",
|
||
|
"Epoch 4/30\n",
|
||
|
"238/238 - 1s - loss: 0.0534 - acc: 0.9857 - val_loss: 14.3411 - val_acc: 0.1648\n",
|
||
|
"Epoch 5/30\n",
|
||
|
"238/238 - 1s - loss: 0.0376 - acc: 0.9897 - val_loss: 15.7724 - val_acc: 0.1598\n",
|
||
|
"Epoch 6/30\n",
|
||
|
"238/238 - 1s - loss: 0.0464 - acc: 0.9881 - val_loss: 17.0488 - val_acc: 0.1536\n",
|
||
|
"Epoch 7/30\n",
|
||
|
"238/238 - 1s - loss: 0.0417 - acc: 0.9889 - val_loss: 19.5126 - val_acc: 0.1550\n",
|
||
|
"Epoch 8/30\n",
|
||
|
"238/238 - 1s - loss: 0.0387 - acc: 0.9901 - val_loss: 19.9876 - val_acc: 0.1788\n",
|
||
|
"Epoch 9/30\n",
|
||
|
"238/238 - 1s - loss: 0.0339 - acc: 0.9908 - val_loss: 19.5807 - val_acc: 0.1572\n",
|
||
|
"Epoch 10/30\n",
|
||
|
"238/238 - 1s - loss: 0.0291 - acc: 0.9930 - val_loss: 20.1623 - val_acc: 0.1779\n",
|
||
|
"Epoch 11/30\n",
|
||
|
"238/238 - 1s - loss: 0.0433 - acc: 0.9914 - val_loss: 23.2585 - val_acc: 0.1521\n",
|
||
|
"Epoch 12/30\n",
|
||
|
"238/238 - 1s - loss: 0.0393 - acc: 0.9913 - val_loss: 25.2286 - val_acc: 0.1594\n",
|
||
|
"Epoch 13/30\n",
|
||
|
"238/238 - 1s - loss: 0.0262 - acc: 0.9946 - val_loss: 24.5537 - val_acc: 0.1794\n",
|
||
|
"Epoch 14/30\n",
|
||
|
"238/238 - 1s - loss: 0.0264 - acc: 0.9947 - val_loss: 26.0528 - val_acc: 0.1804\n",
|
||
|
"Epoch 15/30\n",
|
||
|
"238/238 - 1s - loss: 0.0484 - acc: 0.9910 - val_loss: 25.6410 - val_acc: 0.1610\n",
|
||
|
"Epoch 16/30\n",
|
||
|
"238/238 - 1s - loss: 0.0211 - acc: 0.9948 - val_loss: 28.7820 - val_acc: 0.1696\n",
|
||
|
"Epoch 17/30\n",
|
||
|
"238/238 - 1s - loss: 0.0177 - acc: 0.9957 - val_loss: 25.7378 - val_acc: 0.1955\n",
|
||
|
"Epoch 18/30\n",
|
||
|
"238/238 - 1s - loss: 0.0233 - acc: 0.9956 - val_loss: 27.1410 - val_acc: 0.1924\n",
|
||
|
"Epoch 19/30\n",
|
||
|
"238/238 - 1s - loss: 0.0380 - acc: 0.9934 - val_loss: 30.4740 - val_acc: 0.1707\n",
|
||
|
"Epoch 20/30\n",
|
||
|
"238/238 - 1s - loss: 0.0286 - acc: 0.9938 - val_loss: 27.3403 - val_acc: 0.1771\n",
|
||
|
"Epoch 21/30\n",
|
||
|
"238/238 - 1s - loss: 0.0205 - acc: 0.9954 - val_loss: 30.5033 - val_acc: 0.1706\n",
|
||
|
"Epoch 22/30\n",
|
||
|
"238/238 - 1s - loss: 0.0288 - acc: 0.9949 - val_loss: 31.7822 - val_acc: 0.1682\n",
|
||
|
"Epoch 23/30\n",
|
||
|
"238/238 - 1s - loss: 0.0309 - acc: 0.9950 - val_loss: 28.9407 - val_acc: 0.1791\n",
|
||
|
"Epoch 24/30\n",
|
||
|
"238/238 - 1s - loss: 0.0173 - acc: 0.9961 - val_loss: 32.9953 - val_acc: 0.1817\n",
|
||
|
"Epoch 25/30\n",
|
||
|
"238/238 - 1s - loss: 0.0189 - acc: 0.9965 - val_loss: 33.6316 - val_acc: 0.1817\n",
|
||
|
"Epoch 26/30\n",
|
||
|
"238/238 - 1s - loss: 0.0276 - acc: 0.9953 - val_loss: 33.3303 - val_acc: 0.1635\n",
|
||
|
"Epoch 27/30\n",
|
||
|
"238/238 - 1s - loss: 0.0243 - acc: 0.9961 - val_loss: 35.7127 - val_acc: 0.1422\n",
|
||
|
"Epoch 28/30\n",
|
||
|
"238/238 - 1s - loss: 0.0308 - acc: 0.9949 - val_loss: 33.3842 - val_acc: 0.1697\n",
|
||
|
"Epoch 29/30\n",
|
||
|
"238/238 - 1s - loss: 0.0342 - acc: 0.9952 - val_loss: 39.3381 - val_acc: 0.1698\n",
|
||
|
"Epoch 30/30\n",
|
||
|
"238/238 - 1s - loss: 0.0231 - acc: 0.9959 - val_loss: 38.9394 - val_acc: 0.1641\n",
|
||
|
"CPU times: user 1min 7s, sys: 26.2 s, total: 1min 33s\n",
|
||
|
"Wall time: 29.6 s\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"%%time\n",
|
||
|
"\n",
|
||
|
"model, history = train_model(np.array(X_train), np.array(yy_train), np.array(X_test), np.array(yy_test))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "941c82f8",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Eval"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 26,
|
||
|
"id": "bdf45d51",
|
||
|
"metadata": {
|
||
|
"tags": []
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def predict(model, entry):\n",
|
||
|
" p_dict = dict()\n",
|
||
|
" predictions = np.argmax(model.predict(entry['data']), axis=-1)\n",
|
||
|
" for p in predictions:\n",
|
||
|
" if p in p_dict:\n",
|
||
|
" p_dict[p] += 1\n",
|
||
|
" else:\n",
|
||
|
" p_dict[p] = 1\n",
|
||
|
" prediction = max(p_dict, key=p_dict.get)\n",
|
||
|
" return prediction+1"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 27,
|
||
|
"id": "5dbc1e1e",
|
||
|
"metadata": {
|
||
|
"tags": []
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"CPU times: user 2.59 s, sys: 335 ms, total: 2.92 s\n",
|
||
|
"Wall time: 2.17 s\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"%%time\n",
|
||
|
"\n",
|
||
|
"ltest = [lb.inverse_transform(e['label'])[0] for e in test]\n",
|
||
|
"ptest = [predict(model, e) for e in test]\n",
|
||
|
"\n",
|
||
|
"# for e in test:\n",
|
||
|
"# print(f\"Label: {lb.inverse_transform(e['label'])[0]:2d}\")\n",
|
||
|
"# print(f\"Prediction: {predict(model, e):2d}\\n_______________\")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 28,
|
||
|
"id": "10056f7d",
|
||
|
"metadata": {
|
||
|
"tags": []
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"CPU times: user 3.2 s, sys: 264 ms, total: 3.47 s\n",
|
||
|
"Wall time: 2.44 s\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"%%time\n",
|
||
|
"\n",
|
||
|
"ltrain = [lb.inverse_transform(e['label'])[0] for e in train]\n",
|
||
|
"ptrain = [predict(model, e) for e in train]\n",
|
||
|
"\n",
|
||
|
"# for e in train:\n",
|
||
|
"# print(f\"Label: {lb.inverse_transform(e['label'])[0]:2d}\")\n",
|
||
|
"# print(f\"Prediction: {predict(model, e):2d}\\n_______________\")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 29,
|
||
|
"id": "48aad447",
|
||
|
"metadata": {
|
||
|
"tags": []
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjEAAAGtCAYAAADnIyVRAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAABVCklEQVR4nO3dfXhU533n//dXo/BgQmLJRENRCFEx3g1basuQKNoE6ghcOwEntECaxNDQDWjjbZPu7g+oN43B0IVta+p23aZuSEBp3OJs0pKsQWntIpJAY0JMwRoMacGoBIXAWIDIZkMWwcz394eGiSQQIGnO0Zwzn9d1zcWchzmf+76Zh1v3eTJ3R0RERCRqyoa7ACIiIiKDoU6MiIiIRJI6MSIiIhJJ6sSIiIhIJKkTIyIiIpGkToyIiIhEkjoxIiIiEjgz22Jmr5nZK/0sNzN7ysxeNbOUmd17s22qEyMiIiJh+ALw4A2WvxeYkns0Ak/fbIPqxIiIiEjg3H03cP4Gq3wA+KJ3+w5wu5n93I22WV7IAhbS3r17Q72UcH19fZhxIgOyd+/e0LLC/iyEWTfQZ12KnoUaZlbI39r/SPcIylWb3H3TAF5fDbT3mP5Bbt7p/l5QtJ0YERERiY5ch2UgnZYhUydGRESkRJmFOvBzM6eAiT2m35yb1y8dEyMiIlKizKxgjwJ4Dvj13FlK7wR+5O797koCjcSIiIhICMzsWeA+YJyZ/QBYA7wOwN3/Avg68D7gVeAi8Bs326Y6MSIiIiUqzN1J7v7hmyx34DcHsk11YkREREpUWVm0jyqJdulFRESkZEVqJCaVSrF161ay2SyzZs1i3rx5vZbv2rWLXbt2YWaMGjWKpUuXUl1dTVtbG01NTfn15s+fz/Tp04dcnt27d7N+/Xqy2SyLFi2isbHx5i+KQJbyij8v7p+FuNevmPLiXLdSyBuqIjs7aeDcvSgfL774ovd8/OM//qO/613v8q997Wu+e/dunz17tn/lK1/ptU5LS0v++Wc+8xn/1V/9VX/xxRf9G9/4hu/Zs8dffPFF//rXv+7Tp0/PT199DNSVK1d89uzZfvLkSb906ZI/9NBDfuzYsQFvp9iylFeceXH+LOizPnx5ca5bhPNC/a0dOXKkF+oRdtndPTq7k9ra2kgmk1RVVVFeXk5dXR0HDx7stc7o0aPzzy9dupTvYY4cOZJEIgHA5cuXC9LzTKVSTJo0iYkTJzJixAjmzp1LS0vLkLc73FnKK/68uH8W4l6/YsqLc91KIU+GYXeSmf2GuzfdfM3eOjs7qayszE9XVFTQ1tZ2zXo7d+7k+eefJ5PJsGrVqvz848ePs3nzZs6dO0djY2P+i26w0uk048ePz08nk0lSqdSQtlkMWcor/ry4fxbiXr9iyotz3UohrxCivjtpOEZi1va3wMwazWy/me3/2te+NqiNz5kzhyeeeIJFixaxffv2/PzJkyezYcMG1qxZw44dO+jq6hrU9kWiIu6fhbjXTyQMRXaxuwELpBNjZql+HoeAZH+vc/dN7j7D3WfMnz+/17KKigrOn//ZzS87OzupqKjotwx1dXUcOHDgmvkTJkxg1KhRnDp1wysZ31QymeTMmTP56XQ6TTLZb9Uik6W84s+L+2ch7vUrprw4160U8iS4kZgk8OvAQ9d5nBvMBmtqakin03R0dHDlyhX27dtHbW1tr3V6vnlaW1vzb56Ojg4ymQwAZ8+e5fTp04wbN24wxcibNm0aJ06coL29na6uLpqbm2loaBjSNoshS3nFnxf3z0Lc61dMeXGuWynkFULUR2KCOiZmB/B6d3+57wIz++ZgNphIJFi8eDEbN24km80yc+ZMqqur2bZtGzU1NdTW1tLS0sLhw4dJJBKMGTOG5cuXA3D06FGam5tJJBKUlZWxZMkSxo4dO5T6UV5ezurVq1m2bBmZTIYFCxYwZcqUIW2zGLKUV/x5cf8sxL1+xZQX57qVQl4hRP1id+buw12G69q7d2+oBauvrw8zTmRA9u7dG1pW2J+FMOsG+qxL0Qt1SGPs2LEF+6398Y9/HPpwTKQudiciIiKFE/Wzk9SJERERKVFR78REe2eYiIiIlCyNxIiIiJSoqI/EqBMjIiJSoqLeidHuJBEREYkkjcSIiIiUqKiPxBRtJ0bXcoi2Rx55JNS8p59+OtQ8KRx91kWGT9Qvdhft0ouIiEjJKtqRGBEREQmWdieJiIhIJEW9E6PdSSIiIhJJGokREREpUVEfiVEnRkREpESpEyMiIiKRpE5MEdm9ezfr168nm82yaNEiGhsbY5MX9bpNnTqVD37wg5gZ3/72t3nhhReuWefee+9l3rx5uDunTp1iy5Yt3HXXXSxcuDC/zvjx49m8eTOtra1DKk/U2zOVSrF161ay2SyzZs1i3rx5vZbv2rWLXbt2YWaMGjWKpUuXUl1dTVtbG01NTfn15s+fz/Tp04dUFoh+e5ZyXpzrVgp5pS42nZhMJsO6detoamoimUyycOFCGhoauPPOOyOfF/W6mRkf+tCHeOqpp+js7OTRRx8llUpx5syZ/DpvetObePDBB9m4cSMXL15k7NixABw9epQNGzYAcNttt7Fu3TqOHDlSVPULOy+bzfLMM8+wcuVKKisrWbt2LbW1tVRXV+fXqa+vp6GhAYCDBw/y7LPPsmLFCqqrq3n88cdJJBJcuHCBxx57jHvuuYdEIlE09VOevluUFx5d7K4fZvZvzWy2mb2+z/wHg8hLpVJMmjSJiRMnMmLECObOnUtLS0sQUaHnRb1ub33rW+no6ODs2bNkMhn279/P3Xff3Wudd7/73XzrW9/i4sWLAPz4xz++Zjv33nsvhw8f5vLly4MuC0S/Pdva2kgmk1RVVVFeXk5dXR0HDx7stc7o0aPzzy9dupQfMh45cmS+w3L58uWCDCVHvT1LOS/OdSuFvEIws4I9hkMgnRgz+yTwv4FPAK+Y2Qd6LN4QRGY6nWb8+PH56WQySTqdDiIq9Lyo1+3222+ns7MzP93Z2cntt9/ea52qqiqqqqpYsWIFq1atYurUqddsZ8aMGbz00kuDLsdVUW/Pzs5OKisr89MVFRW92veqnTt3snLlSr785S/z8MMP5+cfP36cT33qU3z605/mox/96JBGYSD67VnKeXGuWynkSXAjMcuB6e4+H7gPeMzMfju3rN/umpk1mtl+M9u/adOmgIomxSiRSFBVVcWTTz7J5s2befjhh3uNJrzhDW9gwoQJQ96VVErmzJnDE088waJFi9i+fXt+/uTJk9mwYQNr1qxhx44ddHV1DWMpRWQ4RX0kJqhjYsrc/f8CuPsJM7sP+Bszm8QNOjHuvgm42nvxgQQmk8lex1ik02mSyeQAi12ceVGv24ULF6ioqMhPV1RUcOHChV7rdHZ2cuLECbLZLOfOneO1116jqqqK73//+wBMnz6dl19+mWw2O+hyXBX19qyoqOD8+fP56c7Ozl7t21ddXR1f/OIXr5k/YcIERo0axalTp6ipqRl0eaLenqWcF+e6lUJeIUT97KSgRmLSZnbP1Ylch2YeMA6YFkTgtGnTOHHiBO3t7XR1ddHc3Jw/sDHqeVGv2/e//32qqqq44447SCQSzJgxg1Qq1Wud1tZW7rrrLgDGjBlDVVUVZ8+ezS9/+9vfzv79+wddhp6i3p41NTWk02k6Ojq4cuUK+/bto7a2ttc6Pb9IW1tb81+kHR0dZDIZAM6ePcvp06cZN27coMsC0W/PUs6Lc91KIU+CG4n5deBKzxnufgX4dTP7bBCB5eXlrF69mmXLlpHJZFiwYAFTpkwJIir0vKjXLZvN8qUvfYlPfOITlJWV8eKLL3L69GnmzZvHyZMnSaVSHDlyhLe97W2sXr2abDbLV7/6VX7yk58AUFlZSUVFBceOHSvK+oWdl0gkWLx4MRs3biSbzTJz5kyqq6vZtm0bNTU11NbW0tLSwuHDh0kkEowZM4bly5cD3Wd7NTc3k0gkKCsrY8mSJfkzwYqlfsrTd4vywhP1kRhzH9BemzAVbcHk5h555JFQ855++ulQ88K2d+/e0LLq6+tDyxKRa4T
|
||
|
"text/plain": [
|
||
|
"<Figure size 720x504 with 2 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {
|
||
|
"needs_background": "light"
|
||
|
},
|
||
|
"output_type": "display_data"
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
" precision recall f1-score support\n",
|
||
|
"\n",
|
||
|
" 1 0.33 0.33 0.33 3\n",
|
||
|
" 2 0.00 0.00 0.00 3\n",
|
||
|
" 3 0.00 0.00 0.00 3\n",
|
||
|
" 4 0.00 0.00 0.00 3\n",
|
||
|
" 5 0.14 0.33 0.20 3\n",
|
||
|
" 6 0.00 0.00 0.00 3\n",
|
||
|
" 7 0.00 0.00 0.00 3\n",
|
||
|
" 8 0.00 0.00 0.00 3\n",
|
||
|
" 9 0.75 1.00 0.86 3\n",
|
||
|
" 10 0.00 0.00 0.00 3\n",
|
||
|
" 11 0.00 0.00 0.00 3\n",
|
||
|
" 12 0.38 1.00 0.55 3\n",
|
||
|
" 13 0.50 0.33 0.40 3\n",
|
||
|
" 14 0.00 0.00 0.00 3\n",
|
||
|
" 15 0.17 0.33 0.22 3\n",
|
||
|
" 16 0.00 0.00 0.00 3\n",
|
||
|
"\n",
|
||
|
" accuracy 0.21 48\n",
|
||
|
" macro avg 0.14 0.21 0.16 48\n",
|
||
|
"weighted avg 0.14 0.21 0.16 48\n",
|
||
|
"\n",
|
||
|
"CPU times: user 649 ms, sys: 204 ms, total: 853 ms\n",
|
||
|
"Wall time: 623 ms\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"%%time\n",
|
||
|
"\n",
|
||
|
"from sklearn.metrics import confusion_matrix\n",
|
||
|
"import seaborn as sn\n",
|
||
|
"\n",
|
||
|
"from sklearn.metrics import classification_report\n",
|
||
|
"\n",
|
||
|
"set_digits = { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 }\n",
|
||
|
"\n",
|
||
|
"train_cm = confusion_matrix(ltrain, ptrain, normalize='true')\n",
|
||
|
"test_cm = confusion_matrix(ltest, ptest, normalize='true')\n",
|
||
|
"\n",
|
||
|
"df_cm = pd.DataFrame(test_cm, index=set_digits, columns=set_digits)\n",
|
||
|
"plt.figure(figsize = (10,7))\n",
|
||
|
"sn_plot = sn.heatmap(df_cm, annot=True, cmap=\"Greys\")\n",
|
||
|
"plt.ylabel(\"True Label\")\n",
|
||
|
"plt.xlabel(\"Predicted Label\")\n",
|
||
|
"plt.show()\n",
|
||
|
"\n",
|
||
|
"print(classification_report(ltest, ptest, zero_division=0))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 30,
|
||
|
"id": "9c334bde",
|
||
|
"metadata": {
|
||
|
"tags": []
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"cenario: SYN\n",
|
||
|
"win_sz: 10\n",
|
||
|
"stride_sz: 5\n",
|
||
|
"dense_steps: 3\n",
|
||
|
"layer_count: 3\n",
|
||
|
"drop_count: 0.1\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"print(f'cenario: {cenario}')\n",
|
||
|
"print(f'win_sz: {win_sz}')\n",
|
||
|
"print(f'stride_sz: {stride_sz}')\n",
|
||
|
"print(f'dense_steps: {dense_steps}')\n",
|
||
|
"print(f'layer_count: {layer_count}')\n",
|
||
|
"print(f'drop_count: {drop_count}')\n",
|
||
|
"\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 31,
|
||
|
"id": "15fa9b96",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"exit()"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.8.10"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 5
|
||
|
}
|