2021-07-05 15:01:40 +02:00
|
|
|
|
{
|
|
|
|
|
"cells": [
|
2021-07-17 17:40:05 +02:00
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"id": "2bc4ab88",
|
2021-07-17 17:40:05 +02:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"## Constants"
|
2021-07-17 17:40:05 +02:00
|
|
|
|
]
|
|
|
|
|
},
|
2021-07-05 15:01:40 +02:00
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2021-07-14 10:15:52 +02:00
|
|
|
|
"execution_count": 1,
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"id": "c767cb34",
|
2021-07-05 15:01:40 +02:00
|
|
|
|
"metadata": {},
|
2021-07-17 03:32:34 +02:00
|
|
|
|
"outputs": [],
|
2021-07-05 15:01:40 +02:00
|
|
|
|
"source": [
|
|
|
|
|
"import os\n",
|
2021-07-14 10:15:52 +02:00
|
|
|
|
"\n",
|
|
|
|
|
"os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' # this is required\n",
|
2021-07-21 03:01:19 +02:00
|
|
|
|
"os.environ['CUDA_VISIBLE_DEVICES'] = '0' # set to '0' for GPU0, '1' for GPU1 or '2' for GPU2. Check \"gpustat\" in a terminal."
|
2021-07-14 10:15:52 +02:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 2,
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"id": "f783fc7f",
|
2021-07-14 10:15:52 +02:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"glob_path = '/opt/iui-datarelease3-sose2021/*.csv'\n",
|
|
|
|
|
"\n",
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"pickle_file = '../data.pickle'\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"checkpoint_path = \"training_1/cp.ckpt\"\n",
|
|
|
|
|
"checkpoint_dir = os.path.dirname(checkpoint_path)"
|
2021-07-19 01:21:53 +02:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"id": "bb1c9c9b",
|
2021-07-19 01:21:53 +02:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
|
|
|
|
"# Config"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 3,
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"id": "3d812543",
|
2021-07-19 01:21:53 +02:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
2021-07-19 02:20:02 +02:00
|
|
|
|
"# Possibilities: 'SYY', 'SYN', 'SNY', 'SNN', \n",
|
|
|
|
|
"# 'JYY', 'JYN', 'JNY', 'JNN'\n",
|
2021-07-21 03:01:19 +02:00
|
|
|
|
"cenario = 'SYN'\n",
|
2021-07-17 17:40:05 +02:00
|
|
|
|
"\n",
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"win_sz = 30\n",
|
|
|
|
|
"stride_sz = 2\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
|
"\n",
|
|
|
|
|
"# divisor for neuron count step downs (hard to describe), e.g. dense_step = 3: layer1=900, layer2 = 300, layer3 = 100, layer4 = 33...\n",
|
|
|
|
|
"dense_steps = 3\n",
|
|
|
|
|
"# amount of dense/dropout layers\n",
|
2021-07-21 03:01:19 +02:00
|
|
|
|
"layer_count = 5\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
|
"# how much to drop\n",
|
2021-07-19 02:20:02 +02:00
|
|
|
|
"drop_count = 0.2"
|
2021-07-19 01:21:53 +02:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"id": "8cef4021",
|
2021-07-19 01:21:53 +02:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
|
|
|
|
"# Helper Functions"
|
2021-07-14 10:15:52 +02:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2021-07-19 01:21:53 +02:00
|
|
|
|
"execution_count": 4,
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"id": "cde65835",
|
2021-07-17 03:32:34 +02:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"from matplotlib import pyplot as plt\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"def pplot(dd):\n",
|
|
|
|
|
" x = dd.shape[0]\n",
|
|
|
|
|
" fix = int(x/3)+1\n",
|
|
|
|
|
" fiy = 3\n",
|
|
|
|
|
" fig, axs = plt.subplots(fix, fiy, figsize=(3*fiy, 9*fix))\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" for i in range(x):\n",
|
|
|
|
|
" axs[int(i/3)][i%3].plot(dd[i])"
|
|
|
|
|
]
|
|
|
|
|
},
|
2021-07-17 17:40:05 +02:00
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"id": "476851ec",
|
2021-07-17 17:40:05 +02:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
|
|
|
|
"# Loading Data"
|
|
|
|
|
]
|
|
|
|
|
},
|
2021-07-17 03:32:34 +02:00
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2021-07-19 01:21:53 +02:00
|
|
|
|
"execution_count": 5,
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"id": "199e4435",
|
2021-07-14 10:15:52 +02:00
|
|
|
|
"metadata": {
|
|
|
|
|
"tags": []
|
|
|
|
|
},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
2021-07-05 15:01:40 +02:00
|
|
|
|
"from glob import glob\n",
|
|
|
|
|
"import pandas as pd\n",
|
2021-07-14 10:15:52 +02:00
|
|
|
|
"from tqdm import tqdm\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
|
"\n",
|
2021-07-14 10:15:52 +02:00
|
|
|
|
"def dl_from_blob(filename, user_filter=None):\n",
|
|
|
|
|
" \n",
|
2021-07-05 15:01:40 +02:00
|
|
|
|
" dic_data = []\n",
|
|
|
|
|
" \n",
|
2021-07-14 10:15:52 +02:00
|
|
|
|
" for p in tqdm(glob(glob_path)):\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
|
" path = p\n",
|
2021-07-14 10:15:52 +02:00
|
|
|
|
" filename = path.split('/')[-1].split('.')[0]\n",
|
|
|
|
|
" splitname = filename.split('_')\n",
|
|
|
|
|
" user = int(splitname[0][1:])\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
|
" if (user_filter):\n",
|
|
|
|
|
" if (user != user_filter):\n",
|
|
|
|
|
" continue\n",
|
2021-07-14 10:15:52 +02:00
|
|
|
|
" scenario = splitname[1][len('Scenario'):]\n",
|
|
|
|
|
" heightnorm = splitname[2][len('HeightNormalization'):] == 'True'\n",
|
|
|
|
|
" armnorm = splitname[3][len('ArmNormalization'):] == 'True'\n",
|
|
|
|
|
" rep = int(splitname[4][len('Repetition'):])\n",
|
|
|
|
|
" session = int(splitname[5][len('Session'):])\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
|
" data = pd.read_csv(path)\n",
|
|
|
|
|
" dic_data.append(\n",
|
|
|
|
|
" {\n",
|
|
|
|
|
" 'filename': path,\n",
|
|
|
|
|
" 'user': user,\n",
|
|
|
|
|
" 'scenario': scenario,\n",
|
|
|
|
|
" 'heightnorm': heightnorm,\n",
|
|
|
|
|
" 'armnorm': armnorm,\n",
|
|
|
|
|
" 'rep': rep,\n",
|
2021-07-14 10:15:52 +02:00
|
|
|
|
" 'session': session,\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
|
" 'data': data \n",
|
|
|
|
|
" }\n",
|
|
|
|
|
" )\n",
|
2021-07-14 10:15:52 +02:00
|
|
|
|
" return dic_data"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2021-07-19 01:21:53 +02:00
|
|
|
|
"execution_count": 6,
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"id": "9e2817c1",
|
2021-07-14 10:15:52 +02:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"import pickle\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
|
"\n",
|
2021-07-14 10:15:52 +02:00
|
|
|
|
"def save_pickle(f, structure):\n",
|
|
|
|
|
" _p = open(f, 'wb')\n",
|
|
|
|
|
" pickle.dump(structure, _p)\n",
|
|
|
|
|
" _p.close()"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2021-07-19 01:21:53 +02:00
|
|
|
|
"execution_count": 7,
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"id": "12c5098e",
|
2021-07-14 10:15:52 +02:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"def load_pickles(f) -> list:\n",
|
|
|
|
|
" _p = open(pickle_file, 'rb')\n",
|
|
|
|
|
" _d = pickle.load(_p)\n",
|
|
|
|
|
" _p.close()\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" return _d"
|
2021-07-05 15:01:40 +02:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2021-07-19 01:21:53 +02:00
|
|
|
|
"execution_count": 8,
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"id": "00ee7490",
|
2021-07-05 15:01:40 +02:00
|
|
|
|
"metadata": {},
|
2021-07-14 10:15:52 +02:00
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"name": "stdout",
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
"text": [
|
|
|
|
|
"Loading data...\n",
|
|
|
|
|
"../data.pickle found...\n",
|
2021-07-17 17:40:05 +02:00
|
|
|
|
"768\n",
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"CPU times: user 548 ms, sys: 2.56 s, total: 3.11 s\n",
|
|
|
|
|
"Wall time: 3.11 s\n"
|
2021-07-14 10:15:52 +02:00
|
|
|
|
]
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
|
|
|
|
"%%time\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"def load_data() -> list:\n",
|
|
|
|
|
" if os.path.isfile(pickle_file):\n",
|
|
|
|
|
" print(f'{pickle_file} found...')\n",
|
|
|
|
|
" return load_pickles(pickle_file)\n",
|
|
|
|
|
" print(f'Didn\\'t find {pickle_file}...')\n",
|
|
|
|
|
" all_data = dl_from_blob(glob_path)\n",
|
|
|
|
|
" print(f'Creating {pickle_file}...')\n",
|
|
|
|
|
" save_pickle(pickle_file, all_data)\n",
|
|
|
|
|
" return all_data\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"print(\"Loading data...\")\n",
|
|
|
|
|
"dic_data = load_data()\n",
|
2021-07-17 17:40:05 +02:00
|
|
|
|
"print(len(dic_data))"
|
2021-07-14 10:15:52 +02:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2021-07-19 01:21:53 +02:00
|
|
|
|
"execution_count": 9,
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"id": "d1db1537",
|
2021-07-14 10:15:52 +02:00
|
|
|
|
"metadata": {
|
|
|
|
|
"tags": []
|
|
|
|
|
},
|
2021-07-19 01:21:53 +02:00
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"name": "stdout",
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
"text": [
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"CPU times: user 95 µs, sys: 297 µs, total: 392 µs\n",
|
|
|
|
|
"Wall time: 396 µs\n"
|
2021-07-19 01:21:53 +02:00
|
|
|
|
]
|
|
|
|
|
}
|
|
|
|
|
],
|
2021-07-14 10:15:52 +02:00
|
|
|
|
"source": [
|
2021-07-19 01:21:53 +02:00
|
|
|
|
"%%time\n",
|
|
|
|
|
"\n",
|
2021-07-17 03:32:34 +02:00
|
|
|
|
"# Categorized Data\n",
|
|
|
|
|
"cdata = dict() \n",
|
|
|
|
|
"# Sorting, HeightNorm, ArmNorm\n",
|
|
|
|
|
"cdata['SYY'] = list() \n",
|
|
|
|
|
"cdata['SYN'] = list() \n",
|
|
|
|
|
"cdata['SNY'] = list() \n",
|
|
|
|
|
"cdata['SNN'] = list() \n",
|
|
|
|
|
"\n",
|
|
|
|
|
"# Jenga, HeightNorm, ArmNorm\n",
|
|
|
|
|
"cdata['JYY'] = list() \n",
|
|
|
|
|
"cdata['JYN'] = list() \n",
|
|
|
|
|
"cdata['JNY'] = list() \n",
|
|
|
|
|
"cdata['JNN'] = list() \n",
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"\n",
|
2021-07-14 10:15:52 +02:00
|
|
|
|
"for d in dic_data:\n",
|
|
|
|
|
" if d['scenario'] == 'Sorting':\n",
|
2021-07-17 03:32:34 +02:00
|
|
|
|
" if d['heightnorm']:\n",
|
|
|
|
|
" if d['armnorm']:\n",
|
|
|
|
|
" cdata['SYY'].append(d)\n",
|
|
|
|
|
" else:\n",
|
|
|
|
|
" cdata['SYN'].append(d)\n",
|
|
|
|
|
" else:\n",
|
|
|
|
|
" if d['armnorm']:\n",
|
|
|
|
|
" cdata['SNY'].append(d)\n",
|
|
|
|
|
" else:\n",
|
|
|
|
|
" cdata['SNN'].append(d)\n",
|
|
|
|
|
" elif d['scenario'] == 'Jenga':\n",
|
|
|
|
|
" if d['heightnorm']:\n",
|
|
|
|
|
" if d['armnorm']:\n",
|
|
|
|
|
" cdata['JYY'].append(d)\n",
|
|
|
|
|
" else:\n",
|
|
|
|
|
" cdata['JYN'].append(d)\n",
|
|
|
|
|
" else:\n",
|
|
|
|
|
" if d['armnorm']:\n",
|
|
|
|
|
" cdata['JNY'].append(d)\n",
|
|
|
|
|
" else:\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
|
" cdata['JNN'].append(d)"
|
2021-07-17 17:40:05 +02:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"id": "46382aad",
|
2021-07-17 17:40:05 +02:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
|
|
|
|
"# Preprocessing"
|
2021-07-14 10:15:52 +02:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2021-07-19 01:21:53 +02:00
|
|
|
|
"execution_count": 10,
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"id": "f7842338",
|
2021-07-14 10:15:52 +02:00
|
|
|
|
"metadata": {
|
|
|
|
|
"tags": []
|
|
|
|
|
},
|
2021-07-17 17:40:05 +02:00
|
|
|
|
"outputs": [],
|
2021-07-14 10:15:52 +02:00
|
|
|
|
"source": [
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"def drop(entry, data=True) -> pd.DataFrame:\n",
|
2021-07-17 17:40:05 +02:00
|
|
|
|
" droptable = ['participantID', 'FrameID', 'Scenario', 'HeightNormalization', 'ArmNormalization', 'Repetition', 'Session', 'Unnamed: 0']\n",
|
2021-07-27 16:00:03 +02:00
|
|
|
|
" if data:\n",
|
|
|
|
|
" centry = pickle.loads(pickle.dumps(entry['data']))\n",
|
|
|
|
|
" else:\n",
|
|
|
|
|
" centry = pickle.loads(pickle.dumps(entry))\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" return centry.drop(droptable, axis=1)\n",
|
|
|
|
|
" \n"
|
2021-07-14 10:15:52 +02:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2021-07-19 01:21:53 +02:00
|
|
|
|
"execution_count": 11,
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"id": "b73d9485",
|
2021-07-14 10:15:52 +02:00
|
|
|
|
"metadata": {},
|
2021-07-17 17:40:05 +02:00
|
|
|
|
"outputs": [],
|
2021-07-14 10:15:52 +02:00
|
|
|
|
"source": [
|
2021-07-17 03:32:34 +02:00
|
|
|
|
"import numpy as np\n",
|
|
|
|
|
"right_Hand_ident='right_Hand'\n",
|
|
|
|
|
"left_Hand_ident='left_hand'\n",
|
2021-07-14 10:15:52 +02:00
|
|
|
|
"\n",
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"def rem_low_acc(entry, data=True) -> pd.DataFrame:\n",
|
|
|
|
|
" if data:\n",
|
|
|
|
|
" centry = pickle.loads(pickle.dumps(entry['data']))\n",
|
|
|
|
|
" else:\n",
|
|
|
|
|
" centry = pickle.loads(pickle.dumps(entry))\n",
|
2021-07-17 03:32:34 +02:00
|
|
|
|
" \n",
|
2021-07-27 16:00:03 +02:00
|
|
|
|
" centry['LeftHandTrackingAccuracy'] = (centry['LeftHandTrackingAccuracy'] == 'High') * 1.0\n",
|
|
|
|
|
" centry['RightHandTrackingAccuracy'] = (centry['RightHandTrackingAccuracy'] == 'High') * 1.0\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" right_Hand_cols = [c for c in centry if right_Hand_ident in c]\n",
|
|
|
|
|
" left_Hand_cols = [c for c in centry if left_Hand_ident in c]\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" centry.loc[centry['RightHandTrackingAccuracy'] == 0.0, right_Hand_cols] = np.nan\n",
|
|
|
|
|
" centry.loc[centry['LeftHandTrackingAccuracy'] == 0.0, left_Hand_cols] = np.nan\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" return centry\n",
|
|
|
|
|
"\n"
|
2021-07-17 17:40:05 +02:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"execution_count": 12,
|
|
|
|
|
"id": "1a298d6d",
|
2021-07-17 17:40:05 +02:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"from tensorflow.keras.preprocessing.sequence import pad_sequences\n",
|
|
|
|
|
"\n",
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"def pad(entry, data=True) -> pd.DataFrame:\n",
|
|
|
|
|
" if data:\n",
|
|
|
|
|
" centry = pickle.loads(pickle.dumps(entry['data']))\n",
|
|
|
|
|
" else:\n",
|
|
|
|
|
" centry = pickle.loads(pickle.dumps(entry))\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" cols = centry.columns\n",
|
|
|
|
|
" pentry = pad_sequences(centry.T.to_numpy(),\n",
|
|
|
|
|
" maxlen=(int(centry.shape[0]/stride_sz)+1)*stride_sz,\n",
|
2021-07-17 17:40:05 +02:00
|
|
|
|
" dtype='float64',\n",
|
|
|
|
|
" padding='pre', \n",
|
|
|
|
|
" truncating='post',\n",
|
|
|
|
|
" value=np.nan\n",
|
|
|
|
|
" ) \n",
|
|
|
|
|
" pdentry = pd.DataFrame(pentry.T, columns=cols)\n",
|
|
|
|
|
" pdentry.loc[0] = [0 for _ in cols]\n",
|
|
|
|
|
" return pdentry"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"execution_count": 13,
|
|
|
|
|
"id": "3be1bd3f",
|
2021-07-17 17:40:05 +02:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"def interpol(entry, data=True) -> pd.DataFrame:\n",
|
|
|
|
|
" if data:\n",
|
|
|
|
|
" centry = pickle.loads(pickle.dumps(entry['data']))\n",
|
|
|
|
|
" else:\n",
|
|
|
|
|
" centry = pickle.loads(pickle.dumps(entry))\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" return centry.interpolate(method='linear', axis=0)"
|
2021-07-17 17:40:05 +02:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"execution_count": 14,
|
|
|
|
|
"id": "2a7f4e26",
|
2021-07-17 17:40:05 +02:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"from tensorflow.keras.preprocessing import timeseries_dataset_from_array\n",
|
2021-07-14 10:15:52 +02:00
|
|
|
|
"\n",
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"def slicing(entry, label, data=True):\n",
|
|
|
|
|
" if data:\n",
|
|
|
|
|
" centry = pickle.loads(pickle.dumps(entry['data']))\n",
|
|
|
|
|
" else:\n",
|
|
|
|
|
" centry = pickle.loads(pickle.dumps(entry))\n",
|
|
|
|
|
" \n",
|
2021-07-17 17:40:05 +02:00
|
|
|
|
" return timeseries_dataset_from_array(\n",
|
2021-07-27 16:00:03 +02:00
|
|
|
|
" data=centry, \n",
|
|
|
|
|
" targets=[label for _ in range(centry.shape[0])], \n",
|
2021-07-17 17:40:05 +02:00
|
|
|
|
" sequence_length=win_sz,\n",
|
|
|
|
|
" sequence_stride=stride_sz, \n",
|
|
|
|
|
" batch_size=8, \n",
|
|
|
|
|
" seed=177013\n",
|
|
|
|
|
" )"
|
|
|
|
|
]
|
|
|
|
|
},
|
2021-07-27 16:00:03 +02:00
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 15,
|
|
|
|
|
"id": "b012b0f7",
|
|
|
|
|
"metadata": {
|
|
|
|
|
"tags": []
|
|
|
|
|
},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"# %%time \n",
|
|
|
|
|
"\n",
|
|
|
|
|
"# acc_data = pd.DataFrame()\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"# for d in tqdm(cdata['SYY']):\n",
|
|
|
|
|
"# acc_data = acc_data.append(d['data'])\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"# dacc_data = drop(acc_data, False)\n",
|
|
|
|
|
"# ddacc_data = rem_low_acc(dacc_data, False)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"# for c in ddacc_data:\n",
|
|
|
|
|
"# print(f\"{c}: {dacc_data[c].min()}, {dacc_data[c].max()}\")"
|
|
|
|
|
]
|
|
|
|
|
},
|
2021-07-17 17:40:05 +02:00
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2021-07-19 01:21:53 +02:00
|
|
|
|
"execution_count": 16,
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"id": "a2440d77",
|
2021-07-17 17:40:05 +02:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"name": "stderr",
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
"text": [
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"100%|██████████| 96/96 [00:09<00:00, 10.55it/s]"
|
2021-07-19 01:21:53 +02:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"name": "stdout",
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
"text": [
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"CPU times: user 8.07 s, sys: 1.19 s, total: 9.27 s\n",
|
|
|
|
|
"Wall time: 9.1 s\n"
|
2021-07-19 01:21:53 +02:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"name": "stderr",
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
"text": [
|
|
|
|
|
"\n"
|
2021-07-17 17:40:05 +02:00
|
|
|
|
]
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
2021-07-19 01:21:53 +02:00
|
|
|
|
"%%time\n",
|
|
|
|
|
"\n",
|
2021-07-17 17:40:05 +02:00
|
|
|
|
"classes = 16 # dynamic\n",
|
2021-07-17 03:32:34 +02:00
|
|
|
|
"\n",
|
2021-07-17 17:40:05 +02:00
|
|
|
|
"def preproc(data):\n",
|
|
|
|
|
" res_list = list()\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" for e in tqdm(data):\n",
|
|
|
|
|
" res_list.append(preproc_entry(e))\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" return res_list\n",
|
|
|
|
|
" \n",
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"def preproc_entry(entry, data = True):\n",
|
2021-07-17 17:40:05 +02:00
|
|
|
|
" entry2 = pickle.loads(pickle.dumps(entry))\n",
|
2021-07-27 16:00:03 +02:00
|
|
|
|
" entry2['data'] = drop(entry2, data)\n",
|
2021-07-17 17:40:05 +02:00
|
|
|
|
" \n",
|
2021-07-27 16:00:03 +02:00
|
|
|
|
" entry4 = pickle.loads(pickle.dumps(entry2))\n",
|
|
|
|
|
" entry4['data'] = rem_low_acc(entry4, data)\n",
|
2021-07-17 17:40:05 +02:00
|
|
|
|
" \n",
|
|
|
|
|
" entry5 = pickle.loads(pickle.dumps(entry4))\n",
|
2021-07-27 16:00:03 +02:00
|
|
|
|
" entry5['data'] = pad(entry5, data)\n",
|
2021-07-17 17:40:05 +02:00
|
|
|
|
" \n",
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"# entry6 = pickle.loads(pickle.dumps(entry5))\n",
|
|
|
|
|
"# entry6['data'] = interpol(entry6, data)\n",
|
2021-07-17 17:40:05 +02:00
|
|
|
|
" \n",
|
2021-07-27 16:00:03 +02:00
|
|
|
|
" entry7 = pickle.loads(pickle.dumps(entry5))\n",
|
|
|
|
|
" entry7['data'] = slicing(entry7, entry7['user'], data)\n",
|
2021-07-17 17:40:05 +02:00
|
|
|
|
" \n",
|
|
|
|
|
" return entry7\n",
|
2021-07-14 10:15:52 +02:00
|
|
|
|
"\n",
|
2021-07-17 17:40:05 +02:00
|
|
|
|
"pdata = preproc(cdata[cenario])"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2021-07-19 01:21:53 +02:00
|
|
|
|
"execution_count": 17,
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"id": "11e96fef",
|
2021-07-14 10:15:52 +02:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
2021-07-19 01:21:53 +02:00
|
|
|
|
{
|
|
|
|
|
"name": "stdout",
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
"text": [
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"CPU times: user 96 µs, sys: 107 µs, total: 203 µs\n",
|
|
|
|
|
"Wall time: 214 µs\n"
|
2021-07-19 01:21:53 +02:00
|
|
|
|
]
|
|
|
|
|
},
|
2021-07-05 15:01:40 +02:00
|
|
|
|
{
|
|
|
|
|
"data": {
|
|
|
|
|
"text/plain": [
|
2021-07-17 17:40:05 +02:00
|
|
|
|
"(48, 48)"
|
2021-07-05 15:01:40 +02:00
|
|
|
|
]
|
|
|
|
|
},
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"execution_count": 17,
|
2021-07-05 15:01:40 +02:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
2021-07-19 01:21:53 +02:00
|
|
|
|
"%%time\n",
|
2021-07-17 17:40:05 +02:00
|
|
|
|
"train = np.array([x['data'] for x in pdata if x['session'] == 1])\n",
|
|
|
|
|
"test = np.array([x['data'] for x in pdata if x['session'] == 2])\n",
|
2021-07-17 03:32:34 +02:00
|
|
|
|
"\n",
|
2021-07-17 17:40:05 +02:00
|
|
|
|
"len(train), len(test)"
|
2021-07-05 15:01:40 +02:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"execution_count": 18,
|
|
|
|
|
"id": "1807a2f7",
|
2021-07-05 15:01:40 +02:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
2021-07-17 17:40:05 +02:00
|
|
|
|
"name": "stdout",
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
"text": [
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"CPU times: user 1min, sys: 13.2 s, total: 1min 13s\n",
|
|
|
|
|
"Wall time: 21.1 s\n"
|
2021-07-17 17:40:05 +02:00
|
|
|
|
]
|
2021-07-05 15:01:40 +02:00
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
2021-07-17 17:40:05 +02:00
|
|
|
|
"%%time\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
|
"\n",
|
2021-07-17 17:40:05 +02:00
|
|
|
|
"X_train = list()\n",
|
|
|
|
|
"y_train = list()\n",
|
2021-07-14 10:15:52 +02:00
|
|
|
|
"\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
|
"X_test = list()\n",
|
|
|
|
|
"y_test = list()\n",
|
|
|
|
|
"\n",
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"# train = list()\n",
|
2021-07-17 17:40:05 +02:00
|
|
|
|
"test = list()\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
|
"\n",
|
2021-07-17 17:40:05 +02:00
|
|
|
|
"for x in pdata:\n",
|
|
|
|
|
" if x['session'] == 1:\n",
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"# train.append(\n",
|
|
|
|
|
"# {\n",
|
|
|
|
|
"# 'label': x['user'],\n",
|
|
|
|
|
"# 'data': list()\n",
|
|
|
|
|
"# })\n",
|
2021-07-17 17:40:05 +02:00
|
|
|
|
" for y in x['data'].unbatch().as_numpy_iterator():\n",
|
|
|
|
|
" X_train.append(y[0])\n",
|
|
|
|
|
" y_train.append(y[1])\n",
|
|
|
|
|
" \n",
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"# train[-1]['data'].append(y[0])\n",
|
2021-07-17 17:40:05 +02:00
|
|
|
|
" if x['session'] == 2:\n",
|
|
|
|
|
" test.append(\n",
|
|
|
|
|
" {\n",
|
|
|
|
|
" 'label': x['user'],\n",
|
|
|
|
|
" 'data': list()\n",
|
|
|
|
|
" })\n",
|
|
|
|
|
" for y in x['data'].unbatch().as_numpy_iterator():\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
|
" X_test.append(y[0])\n",
|
|
|
|
|
" y_test.append(y[1])\n",
|
|
|
|
|
" \n",
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"# test[-1]['data'].append(y[0])\n",
|
2021-07-17 03:32:34 +02:00
|
|
|
|
"\n",
|
2021-07-17 17:40:05 +02:00
|
|
|
|
"X_train = np.array(X_train)\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
|
"y_train = np.array(y_train)\n",
|
|
|
|
|
"X_test = np.array(X_test)\n",
|
|
|
|
|
"y_test = np.array(y_test)"
|
2021-07-14 10:15:52 +02:00
|
|
|
|
]
|
|
|
|
|
},
|
2021-07-27 16:00:03 +02:00
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 19,
|
|
|
|
|
"id": "44330b34",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"name": "stdout",
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
"text": [
|
|
|
|
|
"(37902, 30, 338) (73692, 30, 338) (37902,) (73692,)\n",
|
|
|
|
|
"(32745, 30, 338) (48773, 30, 338) (32745,) (48773,)\n"
|
|
|
|
|
]
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
|
|
|
|
"XX_train = list()\n",
|
|
|
|
|
"yy_train = list()\n",
|
|
|
|
|
"XX_test = list()\n",
|
|
|
|
|
"yy_test = list()\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"for X,y in zip(X_train, y_train):\n",
|
|
|
|
|
" if not np.isnan(X).any():\n",
|
|
|
|
|
" XX_train.append(X)\n",
|
|
|
|
|
" yy_train.append(y)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"for X,y in zip(X_test, y_test):\n",
|
|
|
|
|
" if not np.isnan(X).any():\n",
|
|
|
|
|
" XX_test.append(X)\n",
|
|
|
|
|
" yy_test.append(y)\n",
|
|
|
|
|
" \n",
|
|
|
|
|
"XX_train = np.array(XX_train)\n",
|
|
|
|
|
"yy_train = np.array(yy_train)\n",
|
|
|
|
|
"XX_test = np.array(XX_test)\n",
|
|
|
|
|
"yy_test = np.array(yy_test)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"print(np.array(XX_train).shape, X_train.shape, np.array(yy_train).shape, np.array(y_train).shape)\n",
|
|
|
|
|
"print(np.array(XX_test).shape, X_test.shape, np.array(yy_test).shape, np.array(y_test).shape)"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 20,
|
|
|
|
|
"id": "bd805e81",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"data": {
|
|
|
|
|
"text/html": [
|
|
|
|
|
"<div>\n",
|
|
|
|
|
"<style scoped>\n",
|
|
|
|
|
" .dataframe tbody tr th:only-of-type {\n",
|
|
|
|
|
" vertical-align: middle;\n",
|
|
|
|
|
" }\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" .dataframe tbody tr th {\n",
|
|
|
|
|
" vertical-align: top;\n",
|
|
|
|
|
" }\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" .dataframe thead th {\n",
|
|
|
|
|
" text-align: right;\n",
|
|
|
|
|
" }\n",
|
|
|
|
|
"</style>\n",
|
|
|
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|
|
|
|
" <thead>\n",
|
|
|
|
|
" <tr style=\"text-align: right;\">\n",
|
|
|
|
|
" <th></th>\n",
|
|
|
|
|
" <th>0</th>\n",
|
|
|
|
|
" </tr>\n",
|
|
|
|
|
" </thead>\n",
|
|
|
|
|
" <tbody>\n",
|
|
|
|
|
" <tr>\n",
|
|
|
|
|
" <th>0</th>\n",
|
|
|
|
|
" <td>11</td>\n",
|
|
|
|
|
" </tr>\n",
|
|
|
|
|
" <tr>\n",
|
|
|
|
|
" <th>1</th>\n",
|
|
|
|
|
" <td>11</td>\n",
|
|
|
|
|
" </tr>\n",
|
|
|
|
|
" <tr>\n",
|
|
|
|
|
" <th>2</th>\n",
|
|
|
|
|
" <td>11</td>\n",
|
|
|
|
|
" </tr>\n",
|
|
|
|
|
" <tr>\n",
|
|
|
|
|
" <th>3</th>\n",
|
|
|
|
|
" <td>11</td>\n",
|
|
|
|
|
" </tr>\n",
|
|
|
|
|
" <tr>\n",
|
|
|
|
|
" <th>4</th>\n",
|
|
|
|
|
" <td>11</td>\n",
|
|
|
|
|
" </tr>\n",
|
|
|
|
|
" <tr>\n",
|
|
|
|
|
" <th>...</th>\n",
|
|
|
|
|
" <td>...</td>\n",
|
|
|
|
|
" </tr>\n",
|
|
|
|
|
" <tr>\n",
|
|
|
|
|
" <th>37897</th>\n",
|
|
|
|
|
" <td>9</td>\n",
|
|
|
|
|
" </tr>\n",
|
|
|
|
|
" <tr>\n",
|
|
|
|
|
" <th>37898</th>\n",
|
|
|
|
|
" <td>9</td>\n",
|
|
|
|
|
" </tr>\n",
|
|
|
|
|
" <tr>\n",
|
|
|
|
|
" <th>37899</th>\n",
|
|
|
|
|
" <td>9</td>\n",
|
|
|
|
|
" </tr>\n",
|
|
|
|
|
" <tr>\n",
|
|
|
|
|
" <th>37900</th>\n",
|
|
|
|
|
" <td>9</td>\n",
|
|
|
|
|
" </tr>\n",
|
|
|
|
|
" <tr>\n",
|
|
|
|
|
" <th>37901</th>\n",
|
|
|
|
|
" <td>9</td>\n",
|
|
|
|
|
" </tr>\n",
|
|
|
|
|
" </tbody>\n",
|
|
|
|
|
"</table>\n",
|
|
|
|
|
"<p>37902 rows × 1 columns</p>\n",
|
|
|
|
|
"</div>"
|
|
|
|
|
],
|
|
|
|
|
"text/plain": [
|
|
|
|
|
" 0\n",
|
|
|
|
|
"0 11\n",
|
|
|
|
|
"1 11\n",
|
|
|
|
|
"2 11\n",
|
|
|
|
|
"3 11\n",
|
|
|
|
|
"4 11\n",
|
|
|
|
|
"... ..\n",
|
|
|
|
|
"37897 9\n",
|
|
|
|
|
"37898 9\n",
|
|
|
|
|
"37899 9\n",
|
|
|
|
|
"37900 9\n",
|
|
|
|
|
"37901 9\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"[37902 rows x 1 columns]"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
"execution_count": 20,
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
|
|
|
|
"a = pd.DataFrame(yy_train)\n",
|
|
|
|
|
"a"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 21,
|
|
|
|
|
"id": "f6416fc1",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"data": {
|
|
|
|
|
"text/plain": [
|
|
|
|
|
"16 0.113055\n",
|
|
|
|
|
"11 0.110680\n",
|
|
|
|
|
"10 0.103662\n",
|
|
|
|
|
"3 0.102739\n",
|
|
|
|
|
"5 0.095404\n",
|
|
|
|
|
"13 0.069469\n",
|
|
|
|
|
"15 0.061738\n",
|
|
|
|
|
"7 0.048625\n",
|
|
|
|
|
"9 0.044826\n",
|
|
|
|
|
"6 0.043955\n",
|
|
|
|
|
"14 0.043375\n",
|
|
|
|
|
"1 0.038890\n",
|
|
|
|
|
"4 0.036489\n",
|
|
|
|
|
"8 0.035433\n",
|
|
|
|
|
"12 0.032030\n",
|
|
|
|
|
"2 0.019630\n",
|
|
|
|
|
"dtype: float64"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
"execution_count": 21,
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
|
|
|
|
"b = a.value_counts(normalize=True)\n",
|
|
|
|
|
"b"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 22,
|
|
|
|
|
"id": "1885329b",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"data": {
|
|
|
|
|
"text/html": [
|
|
|
|
|
"<div>\n",
|
|
|
|
|
"<style scoped>\n",
|
|
|
|
|
" .dataframe tbody tr th:only-of-type {\n",
|
|
|
|
|
" vertical-align: middle;\n",
|
|
|
|
|
" }\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" .dataframe tbody tr th {\n",
|
|
|
|
|
" vertical-align: top;\n",
|
|
|
|
|
" }\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" .dataframe thead th {\n",
|
|
|
|
|
" text-align: right;\n",
|
|
|
|
|
" }\n",
|
|
|
|
|
"</style>\n",
|
|
|
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|
|
|
|
" <thead>\n",
|
|
|
|
|
" <tr style=\"text-align: right;\">\n",
|
|
|
|
|
" <th></th>\n",
|
|
|
|
|
" <th>0</th>\n",
|
|
|
|
|
" </tr>\n",
|
|
|
|
|
" <tr>\n",
|
|
|
|
|
" <th>0</th>\n",
|
|
|
|
|
" <th></th>\n",
|
|
|
|
|
" </tr>\n",
|
|
|
|
|
" </thead>\n",
|
|
|
|
|
" <tbody>\n",
|
|
|
|
|
" <tr>\n",
|
|
|
|
|
" <th>16</th>\n",
|
|
|
|
|
" <td>0.113055</td>\n",
|
|
|
|
|
" </tr>\n",
|
|
|
|
|
" <tr>\n",
|
|
|
|
|
" <th>11</th>\n",
|
|
|
|
|
" <td>0.110680</td>\n",
|
|
|
|
|
" </tr>\n",
|
|
|
|
|
" <tr>\n",
|
|
|
|
|
" <th>10</th>\n",
|
|
|
|
|
" <td>0.103662</td>\n",
|
|
|
|
|
" </tr>\n",
|
|
|
|
|
" <tr>\n",
|
|
|
|
|
" <th>3</th>\n",
|
|
|
|
|
" <td>0.102739</td>\n",
|
|
|
|
|
" </tr>\n",
|
|
|
|
|
" <tr>\n",
|
|
|
|
|
" <th>5</th>\n",
|
|
|
|
|
" <td>0.095404</td>\n",
|
|
|
|
|
" </tr>\n",
|
|
|
|
|
" <tr>\n",
|
|
|
|
|
" <th>13</th>\n",
|
|
|
|
|
" <td>0.069469</td>\n",
|
|
|
|
|
" </tr>\n",
|
|
|
|
|
" <tr>\n",
|
|
|
|
|
" <th>15</th>\n",
|
|
|
|
|
" <td>0.061738</td>\n",
|
|
|
|
|
" </tr>\n",
|
|
|
|
|
" <tr>\n",
|
|
|
|
|
" <th>7</th>\n",
|
|
|
|
|
" <td>0.048625</td>\n",
|
|
|
|
|
" </tr>\n",
|
|
|
|
|
" <tr>\n",
|
|
|
|
|
" <th>9</th>\n",
|
|
|
|
|
" <td>0.044826</td>\n",
|
|
|
|
|
" </tr>\n",
|
|
|
|
|
" <tr>\n",
|
|
|
|
|
" <th>6</th>\n",
|
|
|
|
|
" <td>0.043955</td>\n",
|
|
|
|
|
" </tr>\n",
|
|
|
|
|
" <tr>\n",
|
|
|
|
|
" <th>14</th>\n",
|
|
|
|
|
" <td>0.043375</td>\n",
|
|
|
|
|
" </tr>\n",
|
|
|
|
|
" <tr>\n",
|
|
|
|
|
" <th>1</th>\n",
|
|
|
|
|
" <td>0.038890</td>\n",
|
|
|
|
|
" </tr>\n",
|
|
|
|
|
" <tr>\n",
|
|
|
|
|
" <th>4</th>\n",
|
|
|
|
|
" <td>0.036489</td>\n",
|
|
|
|
|
" </tr>\n",
|
|
|
|
|
" <tr>\n",
|
|
|
|
|
" <th>8</th>\n",
|
|
|
|
|
" <td>0.035433</td>\n",
|
|
|
|
|
" </tr>\n",
|
|
|
|
|
" <tr>\n",
|
|
|
|
|
" <th>12</th>\n",
|
|
|
|
|
" <td>0.032030</td>\n",
|
|
|
|
|
" </tr>\n",
|
|
|
|
|
" <tr>\n",
|
|
|
|
|
" <th>2</th>\n",
|
|
|
|
|
" <td>0.019630</td>\n",
|
|
|
|
|
" </tr>\n",
|
|
|
|
|
" </tbody>\n",
|
|
|
|
|
"</table>\n",
|
|
|
|
|
"</div>"
|
|
|
|
|
],
|
|
|
|
|
"text/plain": [
|
|
|
|
|
" 0\n",
|
|
|
|
|
"0 \n",
|
|
|
|
|
"16 0.113055\n",
|
|
|
|
|
"11 0.110680\n",
|
|
|
|
|
"10 0.103662\n",
|
|
|
|
|
"3 0.102739\n",
|
|
|
|
|
"5 0.095404\n",
|
|
|
|
|
"13 0.069469\n",
|
|
|
|
|
"15 0.061738\n",
|
|
|
|
|
"7 0.048625\n",
|
|
|
|
|
"9 0.044826\n",
|
|
|
|
|
"6 0.043955\n",
|
|
|
|
|
"14 0.043375\n",
|
|
|
|
|
"1 0.038890\n",
|
|
|
|
|
"4 0.036489\n",
|
|
|
|
|
"8 0.035433\n",
|
|
|
|
|
"12 0.032030\n",
|
|
|
|
|
"2 0.019630"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
"execution_count": 22,
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
|
|
|
|
"c = pd.DataFrame(b)\n",
|
|
|
|
|
"c"
|
|
|
|
|
]
|
|
|
|
|
},
|
2021-07-05 15:01:40 +02:00
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2021-07-21 03:01:19 +02:00
|
|
|
|
"execution_count": 23,
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"id": "7dfe2339",
|
2021-07-19 01:21:53 +02:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"data": {
|
|
|
|
|
"text/plain": [
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"<AxesSubplot:ylabel='0'>"
|
2021-07-19 01:21:53 +02:00
|
|
|
|
]
|
|
|
|
|
},
|
2021-07-21 03:01:19 +02:00
|
|
|
|
"execution_count": 23,
|
2021-07-19 01:21:53 +02:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"output_type": "execute_result"
|
2021-07-27 16:00:03 +02:00
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"data": {
|
|
|
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAADnCAYAAAA+T+sCAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAABEvklEQVR4nO2deXxU5dXHf+fOkp0kZA8TiKwJENl3QgDXt2jV9lWrqLi1UnFDbEv1bTtvfa1pXVtLtbvaRa27gjtI2EFkCxBkE8gKCSEJWWfm3vP+cW8wJDOZe2fubMn9fj7zIbnz3OceYM69z5znnN8hZoaBgUH/Qgi1AQYGBsHHcHwDg36I4fgGBv0Qw/ENDPohhuMbGPRDDMc3MOiHGI5vYNAPMRzfwKAfYji+gUE/xHB8A4N+iOH4Bgb9EMPxDQz6IYbjGxj0QwzHNzDohxiOb2DQDzEc38CgH2I4fgggohgiKiEiExF9REQNRLSy25gLiGgrER0moteIyOpmniuI6JfBs9ygr2A4fmi4HcBbzCwCeALAzW7G/BrAM8w8HMAZAHe4GbMKwJVEFBswSw36JIbjh4aFAN4FAGZeDeBs1zeJiADMB/CGcuglAFd3n4Rl3bS1AK4InKkGfRHD8YOMsmQfyszHehmWAqCBmV3K7xUABnkYux1AoX4WGvQHDMcPPqkAGnSc7xSAbB3nM+gHGI4ffNoARHsZcxpAEhGZld9tACo9jI1W5jQwUI3h+EGGmc8AMBGRR+dXvrt/DuC/lUOLoMQEiOgaInq8y/CRAPYGyFyDPorh+KHhEwCzAYCI1gN4HcBFRFRBRJcpY34C4EEiOgz5O/9flePDADR1mWse5Oi+gYFqyGioEXyIaCKApczsbhvP27n/VM6tJaIMAP9m5ot0N9KgT2M4foggotsBvKTs5fs6xxQATmbepZthBv0Cw/ENDPohZu9DDCKNgpcKzACGQo4HZABIg7yN2PlKA5AM+f+/887PXX6WIGcL1kLeLuz650kARwEcL11UKgXhr6MaIooB8BHk5KdVAKYD2MDMV3QZcw+AByD/26Qxc52beQoALGPmW4NgdkgwnvgRTsFLBSMBjAdQoLzyITt9oG/q7QAOv1lRvWmk03kCwB4Au2FvPBHg63qEiJYAMDPzb4noIgCxAO7q5vgTIN/U1gKY7M7xlXGfAbidmUP29wkkhuNHEAUvFRCAMQCKlNccyE/00MDs3HGsHBbA0uVoJeStyDUAVgfzRkBEmwDc2JkVSURzATzU1fG7jD2G3h3/fgBRzPybQNkbSoylfphT8FJBBoDvALgYsqOnhtaib7AyTljkJXNXBgG4SXkB9sSjkG8CawB8CnujW0fzF5Wp0FrYDmA5AMPxDYJDwUsF6QC+C+A6yM4elvkWGaLrFHo6fneGKq87AbhgT/wMwL8BvA17Y7OO5hip0BowHD9MKHipIBnA9fjG2U2htcg7eQ6nQ+MpZgCXK69W2BPfh3wT+BD2Rqef5qhJhdZCn06FNhw/xBS8VDAWwH2QS3Ujqq5+YntHD3EQDcRCvtFdD6Ae9sRXADwLe+NhXyZj5jOKsEk0M7drPZ+IpgK4h5lvUQ716VRow/FDQMFLBQKAqwDcCznlNiKZ2N6eptNUAwEsAfBD2BPfBfAE7I2bfZinMxX6MyUVOg9APBFVALiDmT8movsA/BhAJoA9RPQBM98JYDDOf8L36VRoI6ofRApeKogFsBiyw+eG1ho/YXbtOFbO3SL6erIZwJMA3oG9UVW+gJ+p0E8A+Acz7yGiKAAlAGZ30UToUxiOHwQKXiqIBvBDyFHi9BCbowtWiY98ebzcW2BPDw4DKAbwIuyNXtObdUqFHgFgEDOv9XWOcMdw/ABSlpdvAnD7O9Ppyn/PM10Zanv0JMfp3PxBRfWMIF5yP4DlsDe+H8Rr9lnCcpuoL1CWl38F5Gy2P317K0+PdrCeW1chx4eIvr+MBvAe7IlrYU8cH+Rr9zkMx9eZsrx8W1le/rsA3of8YYXASPvBh9L20FqmL35G9P2hCMCXsCe+AHti2CQzRRqG4+tEWV6+UJaXvwTykvTb3d+ftZ8nJrRyffAtCwwT2jv0iuj7ggDgLgAHYU+8xdtgg54Yjq8DZXn5YwBsAPB7AAnuxhAw4L53pdKgGhYomF0jHI7BoTYDcoXhS7AnvgV7Yp8ImgYLw/H9oCwv31yWl28HsAOA10DXhcd4WmojVwfcsABjZRy3AqFa6rvjGgB7YU+8JtSGRAqG4/tIWV7+IMhVaL+ASicgIHrZW6JPmWnhRIboqg21DW5IA/AW7Ikvw56YGGpjwp2IdPwuveeGENEOItpFRPuIaLGH8a8qe7O6UJaXfymAnVAEM7UwtAYzbbX8tV62hIJRDqfmlNggcjPkp7/m/5v+REQ6PpTecwCqAcxg5vEApgFYTkTuKqqeh5ym6RdKAO9RAB9CfsJohgDTQ2+KNf7aEkomtXdEhdoGL9gArIY98c5QGxKuRKrjLwTwLjM7mLlDORYFz3+f9QAu7tKgQjNlefkZAD4D8D+9XEcVWWcwfVQ5l/kzRygJcURfLVYAf4Y98TnYE42alG5EnON3F1wgohwi2gOgHMCvmbmq+znMLEFO/RznyzXL8vLzAGyFTgU1BNCDb4steswVdMInoq+WewB8BHviwFAbEk5EnOOjm+ACM5cz84UAhgNYpGjNu8MnYYWyvPyZADYCGKLdVM8kt2DypEPSLj3nDAZW4ESYRfTVcBGAbbAnjgm1IeFCJDq+W8EF5Um/F547x2oWVijLy78K8vI+IE+Le96Xwl5sozvpLtepUNvgI8MAbIY9sSjUhoQDEef4XXvPEZFNkVQGESVDjrJ/pfz+siKu0IkmYYWyvPzFAN4EEKOb8d2I60DBvN3StkDNHwjywjui740EAB/AntjvOw9FnOMrdAou5APYSkS7IddPP8nMndlxFwKoAgBl+d/GzKqi6UpSzvMIgvzV7Z9IySTHICKCieEf0fdGLICVsCde5nVkHyZSHX8FgEXM/CkzX8jM45Q//wQARDQAwCFmrlDG3wjgj2omVpz+F4Ew2h1RLoy4ajP7ojYTEia2d/SFwphoAO/CnthDdru/EJGOz8w7AHxORG6fyMzcxMzXdjnUAOAlb/OW5eU/giA6fSfXrZcGm0T2V2wy8MgRfV2DnCEkCnKmX79M841IxwcAZv6bWpUVZv67Nwmlsrz8BwD8nx62acUsIWfh51LYP/UjNKLfGxYA/4E98TuhNiTYRKzj60lZXv4tAJ4OpQ3f2s55UQ4O6739CI7o94YZwL9hT/S0G9Qn6feOryjl/BUAhdIOgZH+/Y+kL0JpgzfyHM4O76MikijI3/nzQ21IsOjXjq/U0b+CMJEZL9zHE+LbuCHUdnhiYntHoBR1w4HkvVLub3KXrwpdL8Ig0m8dvywvPxnAuwDiQ21LJwQk3vuetDvUdnhiYmTk6PvEanFCyRWOxxYAeCd3+apI37L0Sr90fEX99jV47/sWdMYf5akpTeryDYJK34ron4MZrudcV6+/w/mjIoAIwHQAfwm1XYGmXzo+5A6ol4TaCHcQELP0bfFgqO3oTh+M6IMZjUuc9+95ynVd98DeTbnLVy0LiVFBot85flle/s0AHgy1Hb0xogozB9Xx8VDb0ZW+FtF3sVBxheOx2g+kaRM9DPlV7vJVnt6LePqV45fl5Y+Cygy+UEKA+aG3xMpQ29GVvhTRb+bo/TM7novaxxcM72WYFcC/c5eviqhGpmrpN46vfK9/GQEsutGT7NOYMbySvwq1HZ30lYj+CSlty5SO53NPIVlNoHIUgN8G2qZQ0G8cH8DDAKZ6HRUmEEDL3hKbQm1HJ30hor9eHFtS5HhmahuitDzF78xdvqrPZfb1C8cvy8ufCOBnobZDKynNmDL+iLQn1HZEekSfGeIfXQvW3ex8uIgh+PKZ/3Pu8lU23Q0LIX3e8cvy8qMgL/Ejcql633tSSDMKgciO6DPj7FLn3Tsfdy2c48c0AyF/hvoMfd7xATwKIGIll+LbUVBUGtpU3kiN6LtYqL7a8cvqd6T
|
|
|
|
|
"text/plain": [
|
|
|
|
|
"<Figure size 432x288 with 1 Axes>"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"output_type": "display_data"
|
2021-07-19 01:21:53 +02:00
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"c.plot.pie(y=0, legend=False)"
|
2021-07-19 01:21:53 +02:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2021-07-21 03:01:19 +02:00
|
|
|
|
"execution_count": 24,
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"id": "fd3f3f1e",
|
2021-07-17 17:40:05 +02:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"ename": "SyntaxError",
|
|
|
|
|
"evalue": "unmatched ']' (<ipython-input-24-481b8b7ed782>, line 1)",
|
|
|
|
|
"output_type": "error",
|
|
|
|
|
"traceback": [
|
|
|
|
|
"\u001b[0;36m File \u001b[0;32m\"<ipython-input-24-481b8b7ed782>\"\u001b[0;36m, line \u001b[0;32m1\u001b[0m\n\u001b[0;31m ]]]\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m unmatched ']'\n"
|
2021-07-17 17:40:05 +02:00
|
|
|
|
]
|
|
|
|
|
}
|
|
|
|
|
],
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"source": [
|
|
|
|
|
"]"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": null,
|
|
|
|
|
"id": "615cccc9",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
2021-07-17 17:40:05 +02:00
|
|
|
|
"source": [
|
2021-07-19 01:21:53 +02:00
|
|
|
|
"%%time\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"from sklearn.preprocessing import LabelBinarizer\n",
|
|
|
|
|
"\n",
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"\n",
|
2021-07-17 17:40:05 +02:00
|
|
|
|
"lb = LabelBinarizer()\n",
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"yyy_train = lb.fit_transform(yy_train)\n",
|
|
|
|
|
"yyy_test = lb.fit_transform(yy_test)\n",
|
2021-07-17 17:40:05 +02:00
|
|
|
|
"\n",
|
|
|
|
|
"for e in test:\n",
|
|
|
|
|
" e['label'] = lb.transform([e['label']])\n",
|
|
|
|
|
" e['data'] = np.array(e['data'])\n",
|
|
|
|
|
" \n",
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"# for e in train:\n",
|
|
|
|
|
"# e['label'] = lb.transform([e['label']])\n",
|
|
|
|
|
"# e['data'] = np.array(e['data'])"
|
2021-07-17 17:40:05 +02:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"execution_count": null,
|
|
|
|
|
"id": "f56e5055",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"print(XX_train.shape)\n",
|
|
|
|
|
"print(yyy_train.shape)\n",
|
|
|
|
|
"print(XX_test.shape)\n",
|
|
|
|
|
"print(yyy_test.shape)"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"id": "f046e211",
|
2021-07-19 01:21:53 +02:00
|
|
|
|
"metadata": {},
|
2021-07-17 17:40:05 +02:00
|
|
|
|
"source": [
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"# Building Model"
|
2021-07-17 17:40:05 +02:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"execution_count": null,
|
|
|
|
|
"id": "53832797",
|
2021-07-05 15:01:40 +02:00
|
|
|
|
"metadata": {},
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"import tensorflow as tf\n",
|
|
|
|
|
"from tensorflow.keras.regularizers import l2\n",
|
|
|
|
|
"from tensorflow.keras.models import Sequential\n",
|
|
|
|
|
"from tensorflow.keras.layers import Dense, Flatten, BatchNormalization, Dropout\n",
|
|
|
|
|
"from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau\n",
|
|
|
|
|
"from tensorflow.keras.optimizers import Adam\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"def build_model(shape, classes):\n",
|
|
|
|
|
" model = Sequential()\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" ncount = shape[0]*shape[1]\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" model.add(Flatten(input_shape=shape))\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" model.add(Dropout(drop_count))\n",
|
|
|
|
|
" model.add(BatchNormalization())\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" for i in range(1,layer_count):\n",
|
|
|
|
|
" neurons = int(ncount/pow(dense_steps,i))\n",
|
|
|
|
|
" if neurons <= classes*dense_steps:\n",
|
|
|
|
|
" break\n",
|
|
|
|
|
" model.add(Dropout(drop_count*i))\n",
|
|
|
|
|
" model.add(Dense(neurons, activation='relu', \n",
|
|
|
|
|
" kernel_regularizer=l2(0.001))\n",
|
|
|
|
|
" )\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" model.add(Dense(classes, activation='softmax'))\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" model.compile(\n",
|
|
|
|
|
" optimizer=Adam(),\n",
|
|
|
|
|
" loss=\"categorical_crossentropy\", \n",
|
|
|
|
|
" metrics=[\"acc\"],\n",
|
|
|
|
|
" )\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" return model\n",
|
|
|
|
|
"\n"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": null,
|
|
|
|
|
"id": "64f2aaa7",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"checkpoint_file = './goat.weights'\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"def train_model(X_train, y_train, X_test, y_test):\n",
|
|
|
|
|
" model = build_model(X_train[0].shape, 16)\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" model.summary()\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" # Create a callback that saves the model's weights\n",
|
|
|
|
|
" model_checkpoint = ModelCheckpoint(filepath=checkpoint_path, monitor='loss', \n",
|
|
|
|
|
"\t\t\tsave_best_only=True)\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" reduce_lr = ReduceLROnPlateau(monitor='loss', factor=0.5, patience=5, min_lr=0.0001)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" callbacks = [model_checkpoint, reduce_lr]\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" history = model.fit(X_train, \n",
|
|
|
|
|
" y_train,\n",
|
|
|
|
|
" epochs=50,\n",
|
|
|
|
|
" batch_size=32,\n",
|
|
|
|
|
" verbose=2,\n",
|
|
|
|
|
" validation_data=(X_test, y_test),\n",
|
|
|
|
|
" callbacks=callbacks\n",
|
|
|
|
|
" )\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" model.load_weights(checkpoint_path)\n",
|
|
|
|
|
" return model, history\n"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": null,
|
|
|
|
|
"id": "858670a8",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
2021-07-17 17:40:05 +02:00
|
|
|
|
"source": [
|
|
|
|
|
"%%time\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
|
"\n",
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"model, history = train_model(np.array(XX_train), np.array(yyy_train), np.array(XX_test), np.array(yyy_test))"
|
2021-07-19 01:21:53 +02:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"id": "6e905067",
|
2021-07-19 01:21:53 +02:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
|
|
|
|
"# Eval"
|
2021-07-17 17:40:05 +02:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"execution_count": null,
|
|
|
|
|
"id": "196dae1d",
|
2021-07-17 17:40:05 +02:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"def predict(model, entry):\n",
|
|
|
|
|
" p_dict = dict()\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
|
" predictions = np.argmax(model.predict(entry['data']), axis=-1)\n",
|
2021-07-17 17:40:05 +02:00
|
|
|
|
" for p in predictions:\n",
|
|
|
|
|
" if p in p_dict:\n",
|
|
|
|
|
" p_dict[p] += 1\n",
|
|
|
|
|
" else:\n",
|
|
|
|
|
" p_dict[p] = 1\n",
|
|
|
|
|
" prediction = max(p_dict, key=p_dict.get)\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
|
" return prediction+1"
|
2021-07-17 17:40:05 +02:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"execution_count": null,
|
|
|
|
|
"id": "49aebfaa",
|
2021-07-17 17:40:05 +02:00
|
|
|
|
"metadata": {},
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"outputs": [],
|
2021-07-05 15:01:40 +02:00
|
|
|
|
"source": [
|
2021-07-19 01:21:53 +02:00
|
|
|
|
"%%time\n",
|
|
|
|
|
"\n",
|
2021-07-17 17:40:05 +02:00
|
|
|
|
"ltest = [lb.inverse_transform(e['label'])[0] for e in test]\n",
|
|
|
|
|
"ptest = [predict(model, e) for e in test]\n",
|
|
|
|
|
"# for e in test:\n",
|
|
|
|
|
"# print(f\"Label: {lb.inverse_transform(e['label'])[0]:2d}\")\n",
|
|
|
|
|
"# print(f\"Prediction: {predict(model, e):2d}\\n_______________\")"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"execution_count": null,
|
|
|
|
|
"id": "6577cf5f",
|
2021-07-17 17:40:05 +02:00
|
|
|
|
"metadata": {},
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"outputs": [],
|
2021-07-17 17:40:05 +02:00
|
|
|
|
"source": [
|
2021-07-19 01:21:53 +02:00
|
|
|
|
"%%time\n",
|
|
|
|
|
"\n",
|
2021-07-17 17:40:05 +02:00
|
|
|
|
"ltrain = [lb.inverse_transform(e['label'])[0] for e in train]\n",
|
|
|
|
|
"ptrain = [predict(model, e) for e in train]\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
|
"\n",
|
2021-07-17 17:40:05 +02:00
|
|
|
|
"# for e in train:\n",
|
|
|
|
|
"# print(f\"Label: {lb.inverse_transform(e['label'])[0]:2d}\")\n",
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"# print(f\"Prediction: {predict(model, e):2d}\\n_______________\")\n"
|
2021-07-17 17:40:05 +02:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"execution_count": null,
|
|
|
|
|
"id": "d62e2063",
|
2021-07-17 17:40:05 +02:00
|
|
|
|
"metadata": {},
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"outputs": [],
|
2021-07-17 17:40:05 +02:00
|
|
|
|
"source": [
|
2021-07-19 01:21:53 +02:00
|
|
|
|
"%%time\n",
|
|
|
|
|
"\n",
|
2021-07-17 17:40:05 +02:00
|
|
|
|
"from sklearn.metrics import confusion_matrix\n",
|
|
|
|
|
"import seaborn as sn\n",
|
|
|
|
|
"\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
|
"from sklearn.metrics import classification_report\n",
|
|
|
|
|
"\n",
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"set_digits = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 }\n",
|
2021-07-17 17:40:05 +02:00
|
|
|
|
"\n",
|
|
|
|
|
"train_cm = confusion_matrix(ltrain, ptrain, normalize='true')\n",
|
|
|
|
|
"test_cm = confusion_matrix(ltest, ptest, normalize='true')\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"df_cm = pd.DataFrame(test_cm, index=set_digits, columns=set_digits)\n",
|
|
|
|
|
"plt.figure(figsize = (10,7))\n",
|
|
|
|
|
"sn_plot = sn.heatmap(df_cm, annot=True, cmap=\"Greys\")\n",
|
|
|
|
|
"plt.ylabel(\"True Label\")\n",
|
|
|
|
|
"plt.xlabel(\"Predicted Label\")\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
|
"plt.show()\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"print(classification_report(ltest, ptest, zero_division=0))"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"execution_count": null,
|
|
|
|
|
"id": "5041115f",
|
2021-07-19 01:21:53 +02:00
|
|
|
|
"metadata": {},
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"outputs": [],
|
2021-07-19 01:21:53 +02:00
|
|
|
|
"source": [
|
|
|
|
|
"print(f'cenario: {cenario}')\n",
|
|
|
|
|
"print(f'win_sz: {win_sz}')\n",
|
|
|
|
|
"print(f'stride_sz: {stride_sz}')\n",
|
|
|
|
|
"print(f'dense_steps: {dense_steps}')\n",
|
|
|
|
|
"print(f'layer_count: {layer_count}')\n",
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"print(f'drop_count: {drop_count}')"
|
2021-07-05 15:01:40 +02:00
|
|
|
|
]
|
2021-07-17 17:40:05 +02:00
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": null,
|
2021-07-27 16:00:03 +02:00
|
|
|
|
"id": "d25d662c",
|
2021-07-17 17:40:05 +02:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": []
|
2021-07-05 15:01:40 +02:00
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"metadata": {
|
|
|
|
|
"kernelspec": {
|
|
|
|
|
"display_name": "Python 3",
|
|
|
|
|
"language": "python",
|
|
|
|
|
"name": "python3"
|
|
|
|
|
},
|
|
|
|
|
"language_info": {
|
|
|
|
|
"codemirror_mode": {
|
|
|
|
|
"name": "ipython",
|
|
|
|
|
"version": 3
|
|
|
|
|
},
|
|
|
|
|
"file_extension": ".py",
|
|
|
|
|
"mimetype": "text/x-python",
|
|
|
|
|
"name": "python",
|
|
|
|
|
"nbconvert_exporter": "python",
|
|
|
|
|
"pygments_lexer": "ipython3",
|
2021-07-14 10:15:52 +02:00
|
|
|
|
"version": "3.8.10"
|
2021-07-05 15:01:40 +02:00
|
|
|
|
}
|
|
|
|
|
},
|
|
|
|
|
"nbformat": 4,
|
|
|
|
|
"nbformat_minor": 5
|
|
|
|
|
}
|