2021-06-23 10:12:11 +02:00
|
|
|
{
|
|
|
|
"cells": [
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 23:18:21 +02:00
|
|
|
"execution_count": 1,
|
2021-08-06 23:50:15 +02:00
|
|
|
"id": "1a0d0dda",
|
2021-06-23 10:12:11 +02:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
2021-08-06 20:20:52 +02:00
|
|
|
"import os\n",
|
|
|
|
"\n",
|
2021-06-23 10:12:11 +02:00
|
|
|
"glob_path = '/opt/iui-datarelease2-sose2021/*/split_letters_csv/*'\n",
|
|
|
|
"\n",
|
2021-08-06 20:20:52 +02:00
|
|
|
"pickle_file = 'data.pickle'\n",
|
|
|
|
"\n",
|
2021-08-06 23:18:21 +02:00
|
|
|
"create_new = False\n",
|
2021-08-06 20:20:52 +02:00
|
|
|
"checkpoint_path = \"training_1/cp.ckpt\"\n",
|
|
|
|
"checkpoint_dir = os.path.dirname(checkpoint_path)\n",
|
|
|
|
"\n",
|
|
|
|
"# divisor for neuron count step downs (hard to describe), e.g. dense_step = 3: layer1=900, layer2 = 300, layer3 = 100, layer4 = 33...\n",
|
|
|
|
"dense_steps = 2\n",
|
|
|
|
"# amount of dense/dropout layers\n",
|
|
|
|
"layer_count = 3\n",
|
|
|
|
"# how much to drop\n",
|
|
|
|
"drop_count = 0.1"
|
2021-06-23 10:12:11 +02:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 23:18:21 +02:00
|
|
|
"execution_count": 2,
|
2021-08-06 23:50:15 +02:00
|
|
|
"id": "592dd9b6",
|
2021-08-06 23:18:21 +02:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' # this is required\n",
|
|
|
|
"os.environ['CUDA_VISIBLE_DEVICES'] = '2' # set to '0' for GPU0, '1' for GPU1 or '2' for GPU2. Check \"gpustat\" in a terminal."
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 3,
|
2021-08-06 23:50:15 +02:00
|
|
|
"id": "f02f3401",
|
2021-06-23 10:12:11 +02:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"from glob import glob\n",
|
|
|
|
"import pandas as pd\n",
|
|
|
|
"from tqdm import tqdm\n",
|
|
|
|
"\n",
|
|
|
|
"def dl_from_blob(filename) -> list:\n",
|
|
|
|
" all_data = []\n",
|
|
|
|
" \n",
|
|
|
|
" for path in tqdm(glob(filename)):\n",
|
|
|
|
" path = path\n",
|
|
|
|
" df = pd.read_csv(path, ';')\n",
|
|
|
|
" u = path.split('/')[3]\n",
|
|
|
|
" l = ''.join(filter(lambda x: x.isalpha(), path.split('/')[5]))[0] \n",
|
|
|
|
" d = {\n",
|
|
|
|
" 'file': path,\n",
|
|
|
|
" 'data': df,\n",
|
|
|
|
" 'user': u,\n",
|
|
|
|
" 'label': l\n",
|
|
|
|
" }\n",
|
|
|
|
" all_data.append(d)\n",
|
|
|
|
" return all_data"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 23:18:21 +02:00
|
|
|
"execution_count": 4,
|
2021-08-06 23:50:15 +02:00
|
|
|
"id": "591292e0",
|
2021-06-23 10:12:11 +02:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"def save_pickle(f, structure):\n",
|
|
|
|
" _p = open(f, 'wb')\n",
|
|
|
|
" pickle.dump(structure, _p)\n",
|
|
|
|
" _p.close()"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 23:18:21 +02:00
|
|
|
"execution_count": 5,
|
2021-08-06 23:50:15 +02:00
|
|
|
"id": "efbe6b1d",
|
2021-06-23 10:12:11 +02:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"import pickle\n",
|
|
|
|
"\n",
|
|
|
|
"def load_pickles(f) -> list:\n",
|
|
|
|
" _p = open(pickle_file, 'rb')\n",
|
|
|
|
" _d = pickle.load(_p)\n",
|
|
|
|
" _p.close()\n",
|
|
|
|
" \n",
|
|
|
|
" return _d"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 23:18:21 +02:00
|
|
|
"execution_count": 6,
|
2021-08-06 23:50:15 +02:00
|
|
|
"id": "a0b68deb",
|
2021-06-23 10:12:11 +02:00
|
|
|
"metadata": {},
|
2021-08-06 23:18:21 +02:00
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"name": "stdout",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
|
|
|
"Loading data...\n",
|
|
|
|
"data.pickle found...\n"
|
|
|
|
]
|
|
|
|
}
|
|
|
|
],
|
2021-06-23 10:12:11 +02:00
|
|
|
"source": [
|
|
|
|
"import os\n",
|
|
|
|
"def load_data() -> list:\n",
|
|
|
|
" if os.path.isfile(pickle_file):\n",
|
|
|
|
" print(f'{pickle_file} found...')\n",
|
|
|
|
" return load_pickles(pickle_file)\n",
|
|
|
|
" print(f'Didn\\'t find {pickle_file}...')\n",
|
|
|
|
" all_data = dl_from_blob(glob_path)\n",
|
|
|
|
" print(f'Creating {pickle_file}...')\n",
|
|
|
|
" save_pickle(pickle_file, all_data)\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
" return all_data\n",
|
|
|
|
"\n",
|
|
|
|
"print(\"Loading data...\")\n",
|
|
|
|
"data = load_data()\n",
|
|
|
|
"# plot_pd(data[0]['data'], False)"
|
2021-06-23 10:12:11 +02:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 23:18:21 +02:00
|
|
|
"execution_count": 7,
|
2021-08-06 23:50:15 +02:00
|
|
|
"id": "238b73fb",
|
2021-06-23 10:12:11 +02:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"import matplotlib.pyplot as plt\n",
|
|
|
|
"\n",
|
|
|
|
"def plot_pd(data, force=True):\n",
|
|
|
|
" fig, axs = plt.subplots(5, 3, figsize=(3*3, 3*5))\n",
|
|
|
|
" axs[0][0].plot(data['Acc1 X'])\n",
|
|
|
|
" axs[0][1].plot(data['Acc1 Y'])\n",
|
|
|
|
" axs[0][2].plot(data['Acc1 Z'])\n",
|
|
|
|
" axs[1][0].plot(data['Acc2 X'])\n",
|
|
|
|
" axs[1][1].plot(data['Acc2 Y'])\n",
|
|
|
|
" axs[1][2].plot(data['Acc2 Z'])\n",
|
|
|
|
" axs[2][0].plot(data['Gyro X'])\n",
|
|
|
|
" axs[2][1].plot(data['Gyro Y'])\n",
|
|
|
|
" axs[2][2].plot(data['Gyro Z'])\n",
|
|
|
|
" axs[3][0].plot(data['Mag X'])\n",
|
|
|
|
" axs[3][1].plot(data['Mag Y'])\n",
|
|
|
|
" axs[3][2].plot(data['Mag Z'])\n",
|
|
|
|
" axs[4][0].plot(data['Time'])\n",
|
|
|
|
"\n",
|
|
|
|
" if force:\n",
|
|
|
|
" for a in axs:\n",
|
|
|
|
" for b in a:\n",
|
|
|
|
" b.plot(data['Force'])\n",
|
|
|
|
" else:\n",
|
|
|
|
" axs[4][1].plot(data['Force'])\n",
|
|
|
|
"\n",
|
|
|
|
"def plot_np(data, force=True):\n",
|
|
|
|
" fig, axs = plt.subplots(5, 3, figsize=(3*3, 3*5))\n",
|
|
|
|
" axs[0][0].plot(data[0])\n",
|
|
|
|
" axs[0][1].plot(data[1])\n",
|
|
|
|
" axs[0][2].plot(data[2])\n",
|
|
|
|
" axs[1][0].plot(data[3])\n",
|
|
|
|
" axs[1][1].plot(data[4])\n",
|
|
|
|
" axs[1][2].plot(data[5])\n",
|
|
|
|
" axs[2][0].plot(data[6])\n",
|
|
|
|
" axs[2][1].plot(data[7])\n",
|
|
|
|
" axs[2][2].plot(data[8])\n",
|
|
|
|
" axs[3][0].plot(data[9])\n",
|
|
|
|
" axs[3][1].plot(data[10])\n",
|
|
|
|
" axs[3][2].plot(data[11])\n",
|
|
|
|
" axs[4][0].plot(data[13])\n",
|
|
|
|
"\n",
|
|
|
|
" if force:\n",
|
|
|
|
" for a in axs:\n",
|
|
|
|
" for b in a:\n",
|
|
|
|
" b.plot(data[12])\n",
|
|
|
|
" else:\n",
|
|
|
|
" axs[4][1].plot(data[12])\n"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 23:18:21 +02:00
|
|
|
"execution_count": 8,
|
2021-08-06 23:50:15 +02:00
|
|
|
"id": "8b2bba94",
|
2021-06-23 10:12:11 +02:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"def mill_drop(entry):\n",
|
|
|
|
" #drop millis on single\n",
|
|
|
|
" data_wo_mill = entry['data'].drop(labels='Millis', axis=1, inplace=False)\n",
|
|
|
|
" drop_entry = entry\n",
|
|
|
|
" drop_entry['data'] = data_wo_mill.reset_index(drop=True)\n",
|
|
|
|
" \n",
|
|
|
|
" return drop_entry"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 23:18:21 +02:00
|
|
|
"execution_count": 9,
|
2021-08-06 23:50:15 +02:00
|
|
|
"id": "836bab5a",
|
2021-06-23 10:12:11 +02:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"import numpy as np\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
"\n",
|
2021-06-23 10:12:11 +02:00
|
|
|
"def cut_force(drop_entry):\n",
|
|
|
|
" # force trans\n",
|
|
|
|
" shorten_entry = drop_entry\n",
|
|
|
|
" shorten_data = shorten_entry['data']\n",
|
|
|
|
" sf_entry = shorten_data['Force']\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
" leeway = 10\n",
|
2021-06-23 10:12:11 +02:00
|
|
|
" \n",
|
2021-07-05 15:01:40 +02:00
|
|
|
" try:\n",
|
|
|
|
" thresh = 70\n",
|
|
|
|
" temps_over_T = np.where(sf_entry > thresh)[0]\n",
|
|
|
|
" shorten_data = shorten_data[max(temps_over_T.min()-leeway,0):min(len(sf_entry)-1,temps_over_T.max()+leeway)]\n",
|
|
|
|
" except:\n",
|
|
|
|
" thresold = 0.05\n",
|
|
|
|
" thresh = sf_entry.max()*thresold\n",
|
|
|
|
" temps_over_T = np.where(sf_entry > thresh)[0]\n",
|
|
|
|
" shorten_data = shorten_data[max(temps_over_T.min()-leeway,0):min(len(sf_entry)-1,temps_over_T.max()+leeway)]\n",
|
2021-06-23 10:12:11 +02:00
|
|
|
" \n",
|
|
|
|
" shorten_entry['data'] = shorten_data.reset_index(drop=True)\n",
|
|
|
|
" return shorten_entry"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 23:18:21 +02:00
|
|
|
"execution_count": 10,
|
2021-08-06 23:50:15 +02:00
|
|
|
"id": "aa85eade",
|
2021-06-23 10:12:11 +02:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"def norm_force(shorten_entry, flist):\n",
|
|
|
|
" fnorm_entry = shorten_entry\n",
|
|
|
|
" u = fnorm_entry['user']\n",
|
|
|
|
" d = fnorm_entry['data']\n",
|
|
|
|
" \n",
|
|
|
|
" \n",
|
|
|
|
" d['Force'] = ((d['Force'] - flist[u].mean())/flist[u].std())\n",
|
|
|
|
" \n",
|
|
|
|
" fnorm_entry['data'] = fnorm_entry['data'].reset_index(drop=True)\n",
|
|
|
|
" return fnorm_entry"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 23:18:21 +02:00
|
|
|
"execution_count": 11,
|
2021-08-06 23:50:15 +02:00
|
|
|
"id": "42579003",
|
2021-06-23 10:12:11 +02:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"def time_trans(fnorm_entry):\n",
|
|
|
|
" #timetrans\n",
|
|
|
|
" time_entry = fnorm_entry\n",
|
|
|
|
" \n",
|
|
|
|
" time_entry['data']['Time'] = fnorm_entry['data']['Time']-fnorm_entry['data']['Time'][0]\n",
|
|
|
|
" \n",
|
|
|
|
" time_entry['data'] = time_entry['data'].reset_index(drop=True)\n",
|
|
|
|
"\n",
|
|
|
|
" return time_entry"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 23:18:21 +02:00
|
|
|
"execution_count": 12,
|
2021-08-06 23:50:15 +02:00
|
|
|
"id": "4fc8f3a2",
|
2021-06-23 10:12:11 +02:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"def norm(time_entry):\n",
|
|
|
|
" # normalize\n",
|
|
|
|
" norm_entry = time_entry\n",
|
|
|
|
" \n",
|
|
|
|
" norm_entry['data']['Acc1 X'] = norm_entry['data']['Acc1 X'] / 32768\n",
|
|
|
|
" norm_entry['data']['Acc1 Y'] = norm_entry['data']['Acc1 Y'] / 32768\n",
|
|
|
|
" norm_entry['data']['Acc1 Z'] = norm_entry['data']['Acc1 Z'] / 32768\n",
|
|
|
|
" norm_entry['data']['Acc2 X'] = norm_entry['data']['Acc2 X'] / 8192\n",
|
|
|
|
" norm_entry['data']['Acc2 Y'] = norm_entry['data']['Acc2 Y'] / 8192\n",
|
|
|
|
" norm_entry['data']['Acc2 Z'] = norm_entry['data']['Acc2 Z'] / 8192\n",
|
|
|
|
" norm_entry['data']['Gyro X'] = norm_entry['data']['Gyro X'] / 32768\n",
|
|
|
|
" norm_entry['data']['Gyro Y'] = norm_entry['data']['Gyro Y'] / 32768\n",
|
|
|
|
" norm_entry['data']['Gyro Z'] = norm_entry['data']['Gyro Z'] / 32768\n",
|
|
|
|
" norm_entry['data']['Mag X'] = norm_entry['data']['Mag X'] / 8192\n",
|
|
|
|
" norm_entry['data']['Mag Y'] = norm_entry['data']['Mag Y'] / 8192\n",
|
|
|
|
" norm_entry['data']['Mag Z'] = norm_entry['data']['Mag Z'] / 8192\n",
|
|
|
|
" \n",
|
|
|
|
" norm_entry['data'] = norm_entry['data'].reset_index(drop=True)\n",
|
|
|
|
" \n",
|
|
|
|
" return norm_entry"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 23:18:21 +02:00
|
|
|
"execution_count": 13,
|
2021-08-06 23:50:15 +02:00
|
|
|
"id": "2a8cf8f1",
|
2021-06-23 10:12:11 +02:00
|
|
|
"metadata": {},
|
2021-08-06 23:18:21 +02:00
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"name": "stdout",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
|
|
|
"Preprocessing...\n"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"name": "stderr",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
2021-08-06 23:49:37 +02:00
|
|
|
"100%|██████████| 26179/26179 [01:28<00:00, 294.18it/s]\n"
|
2021-08-06 23:18:21 +02:00
|
|
|
]
|
|
|
|
}
|
|
|
|
],
|
2021-06-23 10:12:11 +02:00
|
|
|
"source": [
|
|
|
|
"def preproc(d):\n",
|
|
|
|
" flist = {} \n",
|
|
|
|
" d_res = []\n",
|
|
|
|
" for e in data:\n",
|
|
|
|
" if e['user'] not in flist:\n",
|
|
|
|
" flist[e['user']] = e['data']['Force']\n",
|
|
|
|
" else:\n",
|
|
|
|
" flist[e['user']] = flist[e['user']].append(e['data']['Force'])\n",
|
|
|
|
" \n",
|
|
|
|
" for e in tqdm(data):\n",
|
|
|
|
" d_res.append(preproc_entry(e, flist))\n",
|
|
|
|
" return d_res\n",
|
|
|
|
" \n",
|
|
|
|
"def preproc_entry(entry, flist):\n",
|
|
|
|
" drop_entry = mill_drop(entry)\n",
|
|
|
|
"# plot_pd(drop_entry['data'])\n",
|
|
|
|
"# \n",
|
|
|
|
" shorten_entry = cut_force(drop_entry)\n",
|
|
|
|
"# plot_pd(shorten_entry['data'])\n",
|
|
|
|
"# \n",
|
|
|
|
" fnorm_entry = norm_force(shorten_entry, flist)\n",
|
|
|
|
"# plot_pd(fnorm_entry['data'])\n",
|
|
|
|
"# \n",
|
2021-07-05 15:01:40 +02:00
|
|
|
" time_entry = time_trans(shorten_entry)\n",
|
2021-06-23 10:12:11 +02:00
|
|
|
"# plot_pd(time_entry['data'])\n",
|
|
|
|
"# \n",
|
|
|
|
" norm_entry = norm(time_entry)\n",
|
|
|
|
"# plot_pd(norm_entry['data'], False)\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
" return norm_entry\n",
|
|
|
|
"\n",
|
|
|
|
"print(\"Preprocessing...\")\n",
|
|
|
|
"pdata = preproc(data)\n",
|
|
|
|
"# plot_pd(pdata[0]['data'], False)"
|
2021-06-23 10:12:11 +02:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 23:18:21 +02:00
|
|
|
"execution_count": 14,
|
2021-08-06 23:50:15 +02:00
|
|
|
"id": "daca6878",
|
2021-06-23 10:12:11 +02:00
|
|
|
"metadata": {},
|
2021-08-06 23:18:21 +02:00
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"name": "stdout",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
|
|
|
"Truncating...\n"
|
|
|
|
]
|
|
|
|
}
|
|
|
|
],
|
2021-06-23 10:12:11 +02:00
|
|
|
"source": [
|
|
|
|
"def throw(pdata):\n",
|
|
|
|
" llist = pd.Series([len(x['data']) for x in pdata])\n",
|
|
|
|
" threshold = int(llist.quantile(threshold_p))\n",
|
|
|
|
" longdex = np.where(llist <= threshold)[0]\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
" return np.array(pdata)[longdex]\n",
|
|
|
|
"\n",
|
|
|
|
"llist = pd.Series([len(x['data']) for x in pdata])\n",
|
|
|
|
"threshold_p = 0.75\n",
|
|
|
|
"threshold = int(llist.quantile(threshold_p))\n",
|
|
|
|
"\n",
|
|
|
|
"print(\"Truncating...\")\n",
|
|
|
|
"tpdata = throw(pdata)\n",
|
|
|
|
"# plot_pd(tpdata[0]['data'], False)"
|
2021-06-23 10:12:11 +02:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 23:18:21 +02:00
|
|
|
"execution_count": 15,
|
2021-08-06 23:50:15 +02:00
|
|
|
"id": "7321c532",
|
2021-06-23 10:12:11 +02:00
|
|
|
"metadata": {},
|
2021-08-06 23:18:21 +02:00
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"name": "stderr",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
2021-08-06 23:49:37 +02:00
|
|
|
" 18%|█▊ | 3624/19640 [00:00<00:00, 18199.99it/s]"
|
2021-08-06 23:18:21 +02:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"name": "stdout",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
|
|
|
"Padding...\n"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"name": "stderr",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
2021-08-06 23:49:37 +02:00
|
|
|
"100%|██████████| 19640/19640 [00:01<00:00, 19054.38it/s]\n"
|
2021-08-06 23:18:21 +02:00
|
|
|
]
|
|
|
|
}
|
|
|
|
],
|
2021-06-23 10:12:11 +02:00
|
|
|
"source": [
|
|
|
|
"from tensorflow.keras.preprocessing.sequence import pad_sequences\n",
|
|
|
|
"# ltpdata = []\n",
|
|
|
|
"def elong(tpdata):\n",
|
|
|
|
" for x in tqdm(tpdata):\n",
|
|
|
|
" y = x['data'].to_numpy().T\n",
|
|
|
|
" x['data'] = pad_sequences(y, dtype=float, padding='post', maxlen=threshold)\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
" return tpdata\n",
|
|
|
|
"\n",
|
|
|
|
"print(\"Padding...\")\n",
|
|
|
|
"ltpdata = elong(tpdata)\n",
|
|
|
|
"# plot_np(ltpdata[0]['data'], False)"
|
2021-06-23 10:12:11 +02:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 23:18:21 +02:00
|
|
|
"execution_count": 16,
|
2021-08-06 23:50:15 +02:00
|
|
|
"id": "863d612a",
|
2021-06-23 10:12:11 +02:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"import tensorflow as tf\n",
|
2021-08-06 20:20:52 +02:00
|
|
|
"from tensorflow.keras.regularizers import l2\n",
|
2021-06-23 10:12:11 +02:00
|
|
|
"from tensorflow.keras.models import Sequential\n",
|
2021-08-06 20:20:52 +02:00
|
|
|
"from tensorflow.keras.layers import Dense, Flatten, BatchNormalization, Dropout\n",
|
|
|
|
"from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau\n",
|
|
|
|
"from tensorflow.keras.optimizers import Adam\n",
|
2021-06-23 10:12:11 +02:00
|
|
|
"\n",
|
2021-08-06 20:20:52 +02:00
|
|
|
"def build_model(shape, classes):\n",
|
2021-06-23 10:12:11 +02:00
|
|
|
" model = Sequential()\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
" \n",
|
2021-08-06 20:20:52 +02:00
|
|
|
" ncount = shape[0]*shape[1]\n",
|
2021-06-23 10:12:11 +02:00
|
|
|
" \n",
|
2021-08-06 20:20:52 +02:00
|
|
|
" model.add(Flatten(input_shape=shape, name='flatten'))\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
" \n",
|
2021-08-06 20:20:52 +02:00
|
|
|
" model.add(Dropout(drop_count, name=f'dropout_{drop_count*100}'))\n",
|
|
|
|
" model.add(BatchNormalization(name='batchNorm'))\n",
|
2021-06-23 10:12:11 +02:00
|
|
|
" \n",
|
2021-08-06 20:20:52 +02:00
|
|
|
" for i in range(1,layer_count+1):\n",
|
|
|
|
" neurons = int(ncount/pow(dense_steps,i))\n",
|
|
|
|
" if neurons <= classes:\n",
|
|
|
|
" break\n",
|
|
|
|
" model.add(Dropout(drop_count*i, name=f'HiddenDropout_{drop_count*i*100:.0f}'))\n",
|
|
|
|
" model.add(Dense(neurons, activation='relu', \n",
|
|
|
|
" kernel_regularizer=l2(0.001), name=f'Hidden_{i}')\n",
|
|
|
|
" )\n",
|
|
|
|
" \n",
|
|
|
|
" model.add(Dense(classes, activation='softmax', name='Output'))\n",
|
2021-06-23 10:12:11 +02:00
|
|
|
" \n",
|
|
|
|
" model.compile(\n",
|
2021-08-06 20:20:52 +02:00
|
|
|
" optimizer=Adam(),\n",
|
2021-06-23 10:12:11 +02:00
|
|
|
" loss=\"categorical_crossentropy\", \n",
|
|
|
|
" metrics=[\"acc\"],\n",
|
|
|
|
" )\n",
|
2021-08-06 20:20:52 +02:00
|
|
|
" \n",
|
2021-08-06 23:49:37 +02:00
|
|
|
" model.summary()\n",
|
|
|
|
" \n",
|
2021-06-23 10:12:11 +02:00
|
|
|
" return model"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 23:18:21 +02:00
|
|
|
"execution_count": 17,
|
2021-08-06 23:50:15 +02:00
|
|
|
"id": "a046d6e5",
|
2021-06-23 10:12:11 +02:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
2021-07-05 15:01:40 +02:00
|
|
|
"checkpoint_file = './goat.weights'\n",
|
|
|
|
"\n",
|
2021-06-23 10:12:11 +02:00
|
|
|
"def train(X_train, y_train, X_test, y_test):\n",
|
2021-08-06 20:20:52 +02:00
|
|
|
" model = build_model(X_train[0].shape, 52)\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
" \n",
|
2021-08-06 20:20:52 +02:00
|
|
|
" # Create a callback that saves the model's weights\n",
|
|
|
|
" model_checkpoint = ModelCheckpoint(filepath=checkpoint_path, monitor='loss', \n",
|
|
|
|
"\t\t\tsave_best_only=True)\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
" \n",
|
2021-07-19 01:21:53 +02:00
|
|
|
" history = model.fit(X_train, y_train, \n",
|
2021-08-06 22:16:00 +02:00
|
|
|
" epochs=100,\n",
|
|
|
|
" batch_size=32,\n",
|
2021-08-06 20:20:52 +02:00
|
|
|
" shuffle=True,\n",
|
|
|
|
" validation_data=(X_test, y_test),\n",
|
|
|
|
" verbose=2,\n",
|
|
|
|
" callbacks=[model_checkpoint]\n",
|
2021-06-23 10:12:11 +02:00
|
|
|
" )\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
" \n",
|
2021-08-06 20:20:52 +02:00
|
|
|
" \n",
|
|
|
|
" model.load_weights(checkpoint_path)\n",
|
2021-06-23 10:12:11 +02:00
|
|
|
" print(\"Evaluate on test data\")\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
" return model, history"
|
2021-06-23 10:12:11 +02:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 23:18:21 +02:00
|
|
|
"execution_count": 18,
|
2021-08-06 23:50:15 +02:00
|
|
|
"id": "93428481",
|
2021-06-23 10:12:11 +02:00
|
|
|
"metadata": {},
|
2021-07-05 15:01:40 +02:00
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"from sklearn.model_selection import train_test_split\n",
|
|
|
|
"from sklearn.preprocessing import LabelEncoder, LabelBinarizer\n",
|
|
|
|
"\n",
|
|
|
|
"X = np.array([x['data'] for x in ltpdata])\n",
|
|
|
|
"y = np.array([x['label'] for x in ltpdata])\n",
|
|
|
|
"\n",
|
|
|
|
"lb = LabelBinarizer()\n",
|
|
|
|
"y_tran = lb.fit_transform(y)\n",
|
|
|
|
"\n",
|
|
|
|
"X_train, X_test, y_train, y_test = train_test_split(X, y_tran, test_size=0.2, random_state=177013)\n",
|
|
|
|
"\n",
|
2021-08-06 20:20:52 +02:00
|
|
|
"X_train=X_train.reshape(X_train.shape[0],X_train.shape[1],X_train.shape[2])\n",
|
|
|
|
"X_test=X_test.reshape(X_test.shape[0],X_test.shape[1],X_test.shape[2])\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
"\n",
|
|
|
|
"train_shape = X_train[0].shape\n",
|
|
|
|
"classes = y_train[0].shape[0]"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 23:18:21 +02:00
|
|
|
"execution_count": 19,
|
2021-08-06 23:50:15 +02:00
|
|
|
"id": "046ff9ca",
|
2021-07-05 15:01:40 +02:00
|
|
|
"metadata": {
|
|
|
|
"tags": []
|
|
|
|
},
|
2021-08-06 23:18:21 +02:00
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"name": "stdout",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
|
|
|
"Loaded weights...\n",
|
2021-08-06 23:49:37 +02:00
|
|
|
"Model: \"sequential\"\n",
|
|
|
|
"_________________________________________________________________\n",
|
|
|
|
"Layer (type) Output Shape Param # \n",
|
|
|
|
"=================================================================\n",
|
|
|
|
"flatten (Flatten) (None, 1050) 0 \n",
|
|
|
|
"_________________________________________________________________\n",
|
|
|
|
"dropout_10.0 (Dropout) (None, 1050) 0 \n",
|
|
|
|
"_________________________________________________________________\n",
|
|
|
|
"batchNorm (BatchNormalizatio (None, 1050) 4200 \n",
|
|
|
|
"_________________________________________________________________\n",
|
|
|
|
"HiddenDropout_10 (Dropout) (None, 1050) 0 \n",
|
|
|
|
"_________________________________________________________________\n",
|
|
|
|
"Hidden_1 (Dense) (None, 525) 551775 \n",
|
|
|
|
"_________________________________________________________________\n",
|
|
|
|
"HiddenDropout_20 (Dropout) (None, 525) 0 \n",
|
|
|
|
"_________________________________________________________________\n",
|
|
|
|
"Hidden_2 (Dense) (None, 262) 137812 \n",
|
|
|
|
"_________________________________________________________________\n",
|
|
|
|
"HiddenDropout_30 (Dropout) (None, 262) 0 \n",
|
|
|
|
"_________________________________________________________________\n",
|
|
|
|
"Hidden_3 (Dense) (None, 131) 34453 \n",
|
|
|
|
"_________________________________________________________________\n",
|
|
|
|
"Output (Dense) (None, 52) 6864 \n",
|
|
|
|
"=================================================================\n",
|
|
|
|
"Total params: 735,104\n",
|
|
|
|
"Trainable params: 733,004\n",
|
|
|
|
"Non-trainable params: 2,100\n",
|
|
|
|
"_________________________________________________________________\n",
|
|
|
|
"CPU times: user 338 ms, sys: 217 ms, total: 554 ms\n",
|
|
|
|
"Wall time: 574 ms\n"
|
2021-08-06 23:18:21 +02:00
|
|
|
]
|
|
|
|
}
|
|
|
|
],
|
2021-07-05 15:01:40 +02:00
|
|
|
"source": [
|
2021-08-06 20:20:52 +02:00
|
|
|
"%%time\n",
|
2021-08-06 22:16:00 +02:00
|
|
|
"if not os.path.isdir(checkpoint_dir) or create_new:\n",
|
2021-08-06 20:20:52 +02:00
|
|
|
" tf.keras.backend.clear_session()\n",
|
2021-08-06 22:16:00 +02:00
|
|
|
" model, history = train(np.array(X), np.array(y_tran), np.array(X_test), np.array(y_test))\n",
|
2021-08-06 20:20:52 +02:00
|
|
|
"else:\n",
|
|
|
|
" print(\"Loaded weights...\")\n",
|
2021-08-06 22:16:00 +02:00
|
|
|
" model = build_model(X_train[0].shape, 52)\n",
|
2021-08-06 20:20:52 +02:00
|
|
|
" model.load_weights(checkpoint_path)"
|
2021-07-05 15:01:40 +02:00
|
|
|
]
|
|
|
|
},
|
2021-08-06 22:16:00 +02:00
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
2021-08-06 23:50:15 +02:00
|
|
|
"id": "e95e5144",
|
2021-08-06 22:16:00 +02:00
|
|
|
"metadata": {},
|
|
|
|
"source": [
|
|
|
|
"# Evaluation"
|
|
|
|
]
|
|
|
|
},
|
2021-07-05 15:01:40 +02:00
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 23:18:21 +02:00
|
|
|
"execution_count": 20,
|
2021-08-06 23:50:15 +02:00
|
|
|
"id": "e6c40138",
|
2021-07-05 15:01:40 +02:00
|
|
|
"metadata": {},
|
2021-08-06 22:16:00 +02:00
|
|
|
"outputs": [],
|
2021-07-05 15:01:40 +02:00
|
|
|
"source": [
|
2021-08-06 22:16:00 +02:00
|
|
|
"ptest = [lb.classes_[e] for e in np.argmax(model.predict(X_test), axis=-1)]\n",
|
|
|
|
"ltest = lb.inverse_transform(y_test)"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 23:18:21 +02:00
|
|
|
"execution_count": 21,
|
2021-08-06 23:50:15 +02:00
|
|
|
"id": "4dacd4bd",
|
2021-08-06 22:16:00 +02:00
|
|
|
"metadata": {},
|
2021-08-06 23:18:21 +02:00
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"data": {
|
|
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAABB4AAAMqCAYAAAAl6oIGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAABfj0lEQVR4nO3deZikWVUn/u+p6qUa6VZAEAcaEEEUUBCaTcAFREFRgYERFPnhVi6gII6jzjCAOM6476CdIAqKyzAC9gAKCDjsSwLNvrVgs4gi0EID3XRX1fn9kZFkUl1LZkbczDezPp/niafijYj33JtvRLxRceLce6u7AwAAADDCvp3uAAAAALB3STwAAAAAw0g8AAAAAMNIPAAAAADDSDwAAAAAw0g8AAAAAMNIPAAAAABJkqp6SlV9pKreepz7q6p+t6ouqqo3V9VtThZT4gEAAABY9SdJ7nmC+++V5Kazy8Ekf3CygBIPAAAAQJKku1+a5OMneMh3JXlar3h1ki+qqi89UUyJBwAAAGCjrpfkA+u2Pzi77bhOG9qdOZx77rk9b4yLL754EV3Jvn17Jz9z5MiRuWMs4ngsoh/J3npupsJzA+wlzmkAk1Y73YGRqmru77SD/EhWhkisWurupZENTjbxAAAAACzWLMkwT6LhQ0nOXbd9/dltxyV9DwAAAGzUBUkeMlvd4o5JPtHdHz7RDioeAAAAYMGqdudIkqr6iyTfmOSLq+qDSR6b5PQk6e4/TPK8JN+W5KIkn0ny/SeLuW2Jh6q6S5IHdffDtqtNAAAAYOO6+0Enub+TbOp7/dDEQ1V9bZLvSfKAJO9L8syR7QEAAADTsvDEQ1V9RZIHzS4fTfJXSaq7v2nRbQEAAMAU7dahFiOMmFzynUnuluTe3X2X7v69JIc3smNVHayq5apa/tSnPjWgawAAAMB2GpF4uF+SDyd5SVU9qarung2uz9rdS919Xnefd/WrX31A1wAAAIDttPChFt397CTPrqovSPJdSR6Z5DpV9QdJntXdL1h0mwAAADAlhlqsGVHxkCTp7k93959393ckuX6SNyb52VHtAQAAANMzLPGwXndfMhtGcfftaA8AAACYhqHLaQIAAMCpaN++bfmdf1dwJAAAAIBhqrt3ug/HdOWVV87dsXvd616L6Er+/u//fiFxWLwrr7xy7hinn376AnoyDY4HAKcan32Lt4hjmjiubMienn3xjDPOmOSX7SuuuGLbj7uhFgAAALBgVrVYY6gFAAAAMIzEAwAAADCMoRYAAACwYIZarFHxAAAAAAyzbYmHqvrikvIBAACAU8qQxENV3bGq/qGqnllVX1tVb03y1iT/WlX3HNEmAAAATEVVTfKyE0ZVPPx+kv+Z5C+SvDjJD3X3dZN8fZL/dbydqupgVS1X1fKTn/zkQV0DAAAAtsuoySVP6+4XJElVPb67X50k3f3OE2VYunspyVKSXHnllT2obwAAAMA2GZV4OLLu+mVH3SehAAAAwJ5misM1oxIPt6qqTyapJGfNrme2fWBQmwAAAMDEDEk8dPf+EXEBAACA3WVUxQMAAACcsgy1WDNqVQsAAACAVPdk53qcu2OL+tt+/Md/fO4Yf/AHf7CAnnC0RTzHU8lEHjly5OQPOol9+6aTSzx06NDcMU47TVEW7ISPfexjc8e41rWuNXeMvXSOT/be3zOvyy47ev7xrTnrrLMWEofF85pnA/b0E3z22WdP8sv2pZdeuu3HfTrfUgAAAIA9R+IBAAAAGEYdMwAAACyYoUJrVDwAAAAAwwxJPFTVTarqzse4/c5V9eUj2gQAAACmZ1TFw28n+eQxbv/k7D4AAADYs6pqkpedMCrx8CXd/Zajb5zddqPj7VRVB6tquaqWl5aWBnUNAAAA2C6jJpf8ohPcd9zFlrt7KclqxmGSa54CAAAAGzcq8bBcVT/c3U9af2NV/VCS1w9qEwAAACbBqhZrRiUeHpnkWVX1vVlLNJyX5Iwk9x3UJgAAADAxQxIP3f2vSb6uqr4pyS1nNz+3u188oj0AAABgmkZVPCRJuvslSV4ysg0AAACYGkMt1oxa1QIAAABA4gEAAAAYp7onu2rlZDu2Fde//vXnjvHBD35wAT2B3eOzn/3s3DHOPPPMuWMs4jyp1A6AU43PTzZgTz/B17zmNSf5nfbjH//4th93FQ8AAADAMBIPAAAAwDBDV7UAAACAU9G+fX7nX+VIAAAAAMMMTzxU1bWr6tqj2wEAAACmZ8hQi1qZfvaxSR6eleRGVdWhJL/X3Y8f0SYAAABMhVVZ1oyqePipJHdOcrvuvmZ3XyPJHZLcuap+6ng7VdXBqlququWlpaVBXQMAAAC2Sy1ifd2rBK16Y5J7dPdHj7r92kle0N1fu4Ewk1zzdKuuf/3rzx3jgx/84AJ6ArvHZz/72bljnHnmmXPHsA45AGyez082YE8/wde+9rUn+Z323/7t37b9uI9a1eL0o5MOSdLd/1ZVpw9qEwAAACZB4mzNqKEWV2zxPgAAAGAPGVXxcKuq+uQxbq8kBwa1CQAAAEzMkMRDd+8fERcAAAB2A0Mt1owaagEAAAAwbKgFR1nEihTf933fN3eMJz3pSXPHOHBgb42Wufzyy+eOsdeOyVQsYkUKzy8A7IxF/Nrrcxz2BokHAAAAWDBDLdYYagEAAAAMI/EAAAAADGOoBQAAACyYoRZrVDwAAAAAwwxJPFTVf1l3/QFH3fc/R7QJAAAATM+oiocHrrv+80fdd89BbQIAAMAk7Nu3b5KXHTkWg+LWca4fa3vtjqqDVbVcVctLS0tjegYAAABsm1GTS/Zxrh9re+2O7qUkSyd7HAAAALA7jEo83KqqPpmV6oazZtcz2z4wqE0AAACYBKtarBmSeOju/SPiAgAAALuL5TQBAACAYUYNtQAAAIBTlqEWa1Q8AAAAAMNU9zQXjzh8+PDcHdu/31QTR7vhDW84d4yLL754AT2B7XH48OG5YziXwO51xRVXLCTOGWecsZA4AHyePV0ScO65507yy/YHPvCBbT/uhloAAADAghlqscZQCwAAAGAYiQcAAABgGEMtAAAAYMEMtVgzpOKhqm4wIi4AAACwu4waavHs1StV9deD2gAAAAAmblTiYX1NyY03vFPVwaparqrlJz3pSQO6BQAAAONV1SQvO2HUHA99nOsn3ql7KclSkhw+fHiSa54CAAAAGzcq8XCrqvpkViofzppdz2y7u/ucQe0CAAAAEzIk8dDd+0fEBQAAgN1g375RMxvsPo4EAAAAMIzEAwAAADDMqDkeAAAA4JS1UytITJGKBwAAAGCYyVY87N9vfsoRLr744rlj3OQmN5k7xkUXXTR3DK6qe/5VaBeVmb3yyivnjnH66afPHWMq55IpPTdM1+HDh+eOMZXX/FScccYZC4njPQwAWzfZxAMAAADsVhLOawy1AAAAAIaReAAAAACGMdQCAAAAFsxQizVDKh6q6ruq6mHrtl9TVe+dXe4/ok0AAABgekYNtfgvSS5Yt31mktsl+cYkPzaoTQAAAGBiRiUezujuD6zbfnl3f6y735/kC463U1UdrKrlqlpeWloa1DUAAAAYq6omedkJo+Z4uMb6je5++LrNax9vp+5eSrKacZh/wWwAAABgR42qeHhNVf3w0TdW1Y8kee2gNgEAAICJGVXx8FNJnl1V35PkDbPbbpuVuR7uM6hNAAAAmIR9+0b9zr/7DEk8dPdHknxdVd0tyS1mNz+3u188oj0AAABgmkZVPCRJZokGyQYAAAA4RQ1NPAAAAMCpaKdWkJgiiYdTzOWXXz53jIsuumjuGDe84Q3njpEkF1988ULi7BWLOLkdOnRoAT1JTj/99IXE2St88LAR+/fv3+kucBzewwCwdWa7AAAAAIZR8QAAAAALplpujYoHAAAAYBiJBwAAAGAYQy0AAABgwfbt8zv/qiGJh6r6vSR9vPu7+ydHtAsAAABMy6iKh+V1138hyWMHtQMAAABM2JDEQ3c/dfV6VT1y/faJVNXBJAeT5Pzzz8/BgwdHdA8AAACGsqrFmu2Y4+G4Qy6u8sDupSRLm90PAAA
|
|
|
|
"text/plain": [
|
|
|
|
"<Figure size 1440x1008 with 2 Axes>"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
"metadata": {
|
|
|
|
"needs_background": "light"
|
|
|
|
},
|
|
|
|
"output_type": "display_data"
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"name": "stdout",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
|
|
|
" precision recall f1-score support\n",
|
|
|
|
"\n",
|
|
|
|
" A 0.94 0.90 0.92 52\n",
|
|
|
|
" B 0.86 0.79 0.83 24\n",
|
|
|
|
" C 0.74 0.67 0.70 93\n",
|
|
|
|
" D 0.94 0.91 0.92 65\n",
|
|
|
|
" E 1.00 0.71 0.83 14\n",
|
|
|
|
" F 0.87 0.89 0.88 38\n",
|
|
|
|
" G 0.97 0.88 0.92 67\n",
|
|
|
|
" H 0.93 0.90 0.91 29\n",
|
|
|
|
" I 0.70 0.81 0.75 96\n",
|
|
|
|
" J 0.88 0.88 0.88 101\n",
|
|
|
|
" K 0.81 0.85 0.83 60\n",
|
|
|
|
" L 0.86 0.78 0.82 92\n",
|
|
|
|
" M 0.83 0.87 0.85 55\n",
|
|
|
|
" N 0.89 0.89 0.89 82\n",
|
|
|
|
" O 0.67 0.76 0.71 80\n",
|
|
|
|
" P 0.64 0.53 0.58 55\n",
|
|
|
|
" Q 0.90 1.00 0.95 36\n",
|
|
|
|
" R 0.95 0.93 0.94 42\n",
|
|
|
|
" S 0.61 0.66 0.63 86\n",
|
|
|
|
" T 0.80 0.79 0.80 89\n",
|
|
|
|
" U 0.82 0.42 0.56 97\n",
|
|
|
|
" V 0.57 0.82 0.67 76\n",
|
|
|
|
" W 0.90 0.57 0.70 67\n",
|
|
|
|
" X 0.79 0.56 0.66 80\n",
|
|
|
|
" Y 0.60 0.46 0.52 76\n",
|
|
|
|
" Z 0.72 0.85 0.78 59\n",
|
|
|
|
" a 0.76 0.77 0.77 96\n",
|
|
|
|
" b 0.82 0.89 0.86 93\n",
|
|
|
|
" c 0.69 0.77 0.72 77\n",
|
|
|
|
" d 0.86 0.88 0.87 82\n",
|
|
|
|
" e 0.81 0.93 0.86 95\n",
|
|
|
|
" f 0.88 0.96 0.92 76\n",
|
|
|
|
" g 0.87 0.85 0.86 71\n",
|
|
|
|
" h 0.85 0.87 0.86 94\n",
|
|
|
|
" i 0.81 0.90 0.86 82\n",
|
|
|
|
" j 0.96 0.86 0.91 59\n",
|
|
|
|
" k 0.91 0.77 0.83 78\n",
|
|
|
|
" l 0.68 0.75 0.71 100\n",
|
|
|
|
" m 0.87 0.93 0.90 58\n",
|
|
|
|
" n 0.77 0.86 0.81 98\n",
|
|
|
|
" o 0.66 0.71 0.68 82\n",
|
|
|
|
" p 0.75 0.90 0.82 96\n",
|
|
|
|
" q 0.85 0.68 0.75 65\n",
|
|
|
|
" r 0.76 0.83 0.79 118\n",
|
|
|
|
" s 0.67 0.68 0.68 109\n",
|
|
|
|
" t 0.84 0.78 0.81 82\n",
|
|
|
|
" u 0.62 0.43 0.51 104\n",
|
|
|
|
" v 0.53 0.55 0.54 92\n",
|
|
|
|
" w 0.62 0.89 0.73 62\n",
|
|
|
|
" x 0.60 0.75 0.67 85\n",
|
|
|
|
" y 0.57 0.58 0.57 83\n",
|
|
|
|
" z 0.87 0.59 0.70 80\n",
|
|
|
|
"\n",
|
|
|
|
" accuracy 0.77 3928\n",
|
|
|
|
" macro avg 0.79 0.78 0.78 3928\n",
|
|
|
|
"weighted avg 0.77 0.77 0.76 3928\n",
|
|
|
|
"\n",
|
2021-08-06 23:49:37 +02:00
|
|
|
"CPU times: user 998 ms, sys: 195 ms, total: 1.19 s\n",
|
|
|
|
"Wall time: 963 ms\n"
|
2021-08-06 23:18:21 +02:00
|
|
|
]
|
|
|
|
}
|
|
|
|
],
|
2021-08-06 22:16:00 +02:00
|
|
|
"source": [
|
|
|
|
"%%time\n",
|
|
|
|
"\n",
|
|
|
|
"from sklearn.metrics import confusion_matrix\n",
|
|
|
|
"import seaborn as sn\n",
|
2021-08-06 20:20:52 +02:00
|
|
|
"\n",
|
2021-08-06 22:16:00 +02:00
|
|
|
"from sklearn.metrics import classification_report\n",
|
2021-08-06 20:20:52 +02:00
|
|
|
"\n",
|
2021-08-06 22:16:00 +02:00
|
|
|
"set_digits = sorted(list(set(ltest)))\n",
|
2021-08-06 20:20:52 +02:00
|
|
|
"\n",
|
2021-08-06 22:16:00 +02:00
|
|
|
"test_cm = confusion_matrix(ltest, ptest, labels=list(set_digits), normalize='true')\n",
|
2021-08-06 20:20:52 +02:00
|
|
|
"\n",
|
2021-08-06 22:16:00 +02:00
|
|
|
"df_cm = pd.DataFrame(test_cm, index=set_digits, columns=set_digits)\n",
|
|
|
|
"plt.figure(figsize = (20,14))\n",
|
|
|
|
"sn_plot = sn.heatmap(df_cm, cmap=\"Greys\")\n",
|
|
|
|
"plt.ylabel(\"True Label\")\n",
|
|
|
|
"plt.xlabel(\"Predicted Label\")\n",
|
|
|
|
"plt.show()\n",
|
2021-08-06 20:20:52 +02:00
|
|
|
"\n",
|
2021-08-06 22:16:00 +02:00
|
|
|
"print(classification_report(ltest, ptest, zero_division=0))"
|
2021-07-05 15:01:40 +02:00
|
|
|
]
|
|
|
|
},
|
2021-08-06 23:18:21 +02:00
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 22,
|
2021-08-06 23:50:15 +02:00
|
|
|
"id": "772b43c9",
|
2021-08-06 23:18:21 +02:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"def plot_keras_history(history, name='', acc='acc'):\n",
|
|
|
|
" \"\"\"Plots keras history.\"\"\"\n",
|
|
|
|
" import matplotlib.pyplot as plt\n",
|
|
|
|
"\n",
|
|
|
|
" training_acc = history.history[acc]\n",
|
|
|
|
" validation_acc = history.history['val_' + acc]\n",
|
|
|
|
" loss = history.history['loss']\n",
|
|
|
|
" val_loss = history.history['val_loss']\n",
|
|
|
|
"\n",
|
|
|
|
" epochs = range(len(training_acc))\n",
|
|
|
|
"\n",
|
|
|
|
" plt.ylim(0, 1)\n",
|
|
|
|
" plt.plot(epochs, training_acc, 'tab:blue', label='Training acc')\n",
|
|
|
|
" plt.plot(epochs, validation_acc, 'tab:orange', label='Validation acc')\n",
|
|
|
|
" plt.title('Training and validation accuracy ' + name)\n",
|
|
|
|
" plt.legend()\n",
|
|
|
|
"\n",
|
|
|
|
" plt.figure()\n",
|
|
|
|
"\n",
|
|
|
|
" plt.plot(epochs, loss, 'tab:green', label='Training loss')\n",
|
|
|
|
" plt.plot(epochs, val_loss, 'tab:red', label='Validation loss')\n",
|
|
|
|
" plt.title('Training and validation loss ' + name)\n",
|
|
|
|
" plt.legend()\n",
|
|
|
|
" plt.show()\n",
|
|
|
|
" plt.close()\n",
|
|
|
|
"if 'history' in locals():\n",
|
|
|
|
" plot_keras_history(history)"
|
|
|
|
]
|
|
|
|
},
|
2021-07-05 15:01:40 +02:00
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-08-06 23:49:37 +02:00
|
|
|
"execution_count": 23,
|
2021-08-06 23:50:15 +02:00
|
|
|
"id": "07e20a8e",
|
2021-07-05 15:01:40 +02:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
2021-08-06 23:49:37 +02:00
|
|
|
"source": [
|
|
|
|
"exit()"
|
|
|
|
]
|
2021-06-23 10:12:11 +02:00
|
|
|
}
|
|
|
|
],
|
|
|
|
"metadata": {
|
|
|
|
"kernelspec": {
|
|
|
|
"display_name": "Python 3",
|
|
|
|
"language": "python",
|
|
|
|
"name": "python3"
|
|
|
|
},
|
|
|
|
"language_info": {
|
|
|
|
"codemirror_mode": {
|
|
|
|
"name": "ipython",
|
|
|
|
"version": 3
|
|
|
|
},
|
|
|
|
"file_extension": ".py",
|
|
|
|
"mimetype": "text/x-python",
|
|
|
|
"name": "python",
|
|
|
|
"nbconvert_exporter": "python",
|
|
|
|
"pygments_lexer": "ipython3",
|
2021-07-14 10:15:52 +02:00
|
|
|
"version": "3.8.10"
|
2021-06-23 10:12:11 +02:00
|
|
|
}
|
|
|
|
},
|
|
|
|
"nbformat": 4,
|
|
|
|
"nbformat_minor": 5
|
|
|
|
}
|