iui-group-l-name-zensiert/1-first-project/tdt/NeuralNetwork.ipynb

675 lines
125 KiB
Plaintext
Raw Normal View History

2021-06-07 19:58:49 +02:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
2021-07-14 10:15:52 +02:00
"id": "cd4df4d6",
2021-06-07 19:58:49 +02:00
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np"
2021-06-07 19:58:49 +02:00
]
},
{
"cell_type": "code",
"execution_count": 2,
2021-07-14 10:15:52 +02:00
"id": "e74682bc",
2021-06-07 19:58:49 +02:00
"metadata": {},
"outputs": [],
"source": [
"delim = ';'\n",
"\n",
"base_path = '/opt/iui-datarelease1-sose2021/'\n",
"\n",
"Xpickle_file = '../X2.pickle'\n",
2021-06-07 19:58:49 +02:00
"\n",
"ypickle_file = '../y2.pickle'"
2021-06-07 19:58:49 +02:00
]
},
{
"cell_type": "code",
"execution_count": 3,
2021-07-14 10:15:52 +02:00
"id": "2cba70e6",
"metadata": {},
"outputs": [],
"source": [
"THRESH = 70\n",
"LEEWAY = 0\n",
"EPOCH = 50\n",
"\n",
"DENSE_COUNT = 2\n",
"DENSE_NEURONS = 2400\n",
"\n",
"DENSE2_COUNT = 3\n",
"DENSE2_NEURONS = 600\n",
"\n",
"AVG_FROM = 1"
]
},
{
"cell_type": "code",
"execution_count": 4,
2021-07-14 10:15:52 +02:00
"id": "25708ad0",
2021-06-07 19:58:49 +02:00
"metadata": {},
"outputs": [],
"source": [
"def shorten(npList):\n",
" temp = npList['Force']\n",
" thresh = THRESH\n",
" leeway = LEEWAY\n",
2021-06-07 19:58:49 +02:00
" \n",
" temps_over_T = np.where(temp > thresh)[0]\n",
" if len(temps_over_T) > 0:\n",
" return npList[max(temps_over_T[0]-leeway,0):min(len(npList)-1,temps_over_T[-1]+leeway)]\n",
" else:\n",
" return npList"
2021-06-07 19:58:49 +02:00
]
},
{
"cell_type": "code",
"execution_count": 5,
2021-07-14 10:15:52 +02:00
"id": "028b40fd",
2021-06-07 19:58:49 +02:00
"metadata": {},
"outputs": [],
"source": [
"import pickle\n",
"\n",
2021-06-07 19:58:49 +02:00
"def load_pickles():\n",
" _p = open(Xpickle_file, 'rb')\n",
" X = pickle.load(_p)\n",
" _p.close()\n",
" \n",
" _p = open(ypickle_file, 'rb')\n",
" y = pickle.load(_p)\n",
" _p.close()\n",
" \n",
" return (np.asarray(X, dtype=pd.DataFrame), np.asarray(y, dtype=str))"
]
},
{
"cell_type": "code",
"execution_count": 6,
2021-07-14 10:15:52 +02:00
"id": "909c1ae1",
2021-06-07 19:58:49 +02:00
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
2021-06-07 19:58:49 +02:00
"def load_data():\n",
" if os.path.isfile(Xpickle_file) and os.path.isfile(ypickle_file):\n",
" return load_pickles()\n",
" data = []\n",
" label = []\n",
" for user in range(0, user_count):\n",
" user_path = base_path + str(user) + '/split_letters_csv/'\n",
" for file in os.listdir(user_path):\n",
" file_name = user_path + file\n",
" letter = ''.join(filter(lambda x: x.isalpha(), file))[0]\n",
" data.append(pd.read_csv(file_name, delim))\n",
" label.append(letter)\n",
" return (np.asarray(data, dtype=pd.DataFrame), np.asarray(label, dtype=str), np.asarray(file_name))"
]
},
{
"cell_type": "code",
"execution_count": 7,
2021-07-14 10:15:52 +02:00
"id": "a1422936",
2021-06-07 19:58:49 +02:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 5.12 s, sys: 415 ms, total: 5.53 s\n",
"Wall time: 5.54 s\n"
2021-06-07 19:58:49 +02:00
]
},
{
"data": {
"text/plain": [
"(26179,)"
2021-06-07 19:58:49 +02:00
]
},
"execution_count": 7,
2021-06-07 19:58:49 +02:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"X, y = load_data()\n",
"\n",
"X.shape"
2021-06-07 19:58:49 +02:00
]
},
{
"cell_type": "code",
"execution_count": 8,
2021-07-14 10:15:52 +02:00
"id": "b50696c3",
2021-06-07 19:58:49 +02:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 6.22 s, sys: 23.5 ms, total: 6.24 s\n",
"Wall time: 6.24 s\n"
2021-06-07 19:58:49 +02:00
]
}
],
"source": [
"%%time\n",
"XX = np.array(list(map(shorten, X)), dtype=object)"
2021-06-07 19:58:49 +02:00
]
},
{
"cell_type": "code",
"execution_count": 9,
2021-07-14 10:15:52 +02:00
"id": "e9f71bad",
2021-06-07 19:58:49 +02:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"count 26179.000000\n",
"mean 47.987394\n",
"std 35.114351\n",
"min 2.000000\n",
"50% 43.000000\n",
"95% 88.000000\n",
"96% 94.000000\n",
"97% 101.000000\n",
"98% 116.000000\n",
"99% 150.000000\n",
"max 1512.000000\n",
2021-06-07 19:58:49 +02:00
"dtype: float64"
]
},
"execution_count": 9,
2021-06-07 19:58:49 +02:00
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAeyElEQVR4nO3deXxU9b3/8dcn+x4IhH0VEAwIFVNA7WJbtWhbvbb2Ci5Vi9Lbam8Xf+11+7XVe3u7eNvb2lorrnWrtrZWbLHaqtXWCgLiwk7YJBEhEAgkIdvkc/+YAYcYzACTnFnez8djHsxZMnlzYN45+Z5z5pi7IyIiyS8j6AAiIhIfKnQRkRShQhcRSREqdBGRFKFCFxFJEVlBfeP+/fv7qFGjgvr2IiJJaenSpTvcvbyrZYEV+qhRo1iyZElQ315EJCmZ2eZDLdOQi4hIilChi4ikCBW6iEiK6LbQzexuM9tuZssPsdzM7BYzqzKz181savxjiohId2LZQ78XmPkey88ExkUec4Hbjj6WiIgcrm4L3d1fAOreY5VzgPs8bCHQx8wGxyugiIjEJh5j6EOBLVHT1ZF572Jmc81siZktqa2tjcO3FhGR/Xr1PHR3nwfMA6isrNTn9opI0nB32juc1vYOWto7aI08WtpD4elQBy1t4T+jl7V2WtbS3sHHJgxgyvA+cc8Yj0KvAYZHTQ+LzBMROSrxLNED87t4rf1ff/D8/dOhA68Rr9tHDCjOTdhCnw9cZWYPA9OBenffGofXFZEE5u7sawtRv6+N+n1t7NnXHvkzMt3cxr62xCnRnKwMcjMzyM3OICczIzydlUlO1v7nGRTmZr1rWW7kkZMV/rp3vv6d5dGvkZuVQU5m5kHfJ6fTa5hZfP5SnXRb6Gb2a+BUoL+ZVQPfBrIB3P2XwALgLKAKaAIu65GkIhJ3oQ5nT6R831XMzW0HFXR4Xnt4/ch0e8d7t+3BJZdBbnZmVGGG/yzMzYo8P3hZbhdFGP31OZ3WOVDOnUq7p0s0kXRb6O4+u5vlDlwZt0QiclT2NrexeWcTb9Y1sXlnEzsaWt5VzHubw6Xd0NL+nq+VnWmU5mdTkpdNSX42pfnZjCgroCQvi9LI9P75JXnZUfOyKM7LJjMj9Us0kQT24VwicmTcnbrGVjbXNbF5ZyObdzZFHuHnOxtbD1q/MCfzQPGW5GczrG/BQcV7UBkXvPO8JD+L/OzMtNizTRUqdJEE1dDSzvrtDVRtb6CqtuGg8o7eszaDIaX5jCgr4IyJAxlRVsiofgWM6FfAyH6FFOXqbZ4u9C8tErBdja1U1Tawblu4vNdt38v67Q28Vd98YJ2sDGN4WQEj+xVQObIvI/sVMjJS2MP65pOXnRng30AShQpdpJc0t4VYtXUPy9/aw+qte6ja3sD62gZ2NLwzRJKfncmYAYVMG13GuIHFjCkvYtzAIkaUFZCdqc/Sk/emQhfpAXub21j5Vri8V9TUs/ytetbXNhKKnBVSkpfFuIHFfGzCQMYOKGLswCLGlhcxtE8+GTqQKEdIhS4SB1vqmnh65TZeeXMXK2rq2bSz6cCygSW5TBxSysyJg5g4tJSJQ0oY2idfBxsl7lToIkdofW0Df17+Nn9e/jZv1NQDMLwsn0lDSvls5XAqhpQwcUgJA4rzAk4q6UKFLhIjd2f123t5cvnb/Hn5VtZuawDgfcP7cO2ZE5g5aRAj+xUGnFLSmQpd5D24O69X1x8o8U07mzCD948q49ufquDjEwcxpE9+0DFFABW6yLuEOpylm3fx5PKtPL1iGzW795GVYZw0ph9zPzSG0ysGUl6cG3RMkXdRoYsAbaEOFm2o48nlW3lqxTZ2NLSQk5XBh8aV87XTj+W04wbQpyAn6Jgi70mFLmnt9erd3P/SZv6yahu7m9ooyMnkI+MHMHPSID4yYYCuspSkov+tkpaWbKrjZ89W8fzaWopzszitYiAzJw3iw8eW66pLSVoqdEkb7s5L63dyy7PrWLihjn6FOfzHzAlcNGMExXnZQccTOWoqdEl57s7f1tTys2fX8cqbuxlQnMv//2QFs6cNpyBHbwFJHfrfLClt8aY6bnxiBctr9jC0Tz7/9S+TOO/EYRpWkZSkQpeU9fzaWubet4Ty4lx+eN5kzj1hqD7gSlKaCl1STqjD+d0r1dzw2HLGDijigcunU1aoUw4l9anQJWW0hTp4bFkNt/1tPRt3NDJ1RB/uuXQapQU64CnpQYUuKeGJ197i+0+upmb3PiYOKeGXF03ljIpB+ihaSSsqdEl6v/rnJr49fwWTh5XyX/8yiVPHl+ujaSUtqdAlqd3xwga+u2AVZ1QM5OcXTCUnSwc9JX2p0CVp/fzZdfzP02v5xOTB/OT89+kMFkl7KnRJOu7O//5lLbc8W8W5Jwzl5vMmk6UyF1GhS3Jxd77/5Gpuf2ED51cO578/fTyZOvApAqjQJYm4Ozc+sZJ7/7mJi2aM4KazJ+ksFpEoKnRJCh0dzg2PL+ehRW8y5wOjueETx+lMFpFOVOiS8EIdzn/87nUeXVrNF08dwzc/Pl5lLtIFFboktPZQB1f/9jUef/UtvnraOL7ysXEqc5FDUKFLwmpt7+CrjyxjwRtv882Z4/nSqWODjiSS0FTokpBa2kNc+eAr/HXVdm74xHFc/sFjgo4kkvBiOnnXzGaa2RozqzKza7pYPsLMnjOzZWb2upmdFf+oki6a20Jccd9S/rpqO/95zkSVuUiMui10M8sEbgXOBCqA2WZW0Wm1G4DfuPsJwCzgF/EOKumhqbWdy+5ZzN/X1fL9Tx/PxSeNCjqSSNKIZQ99GlDl7hvcvRV4GDin0zoOlESelwJvxS+ipIuGlnYuvXsxizbu5EefncKsaSOCjiSSVGIZQx8KbImargamd1rnO8DTZvZloBA4rasXMrO5wFyAESP0ZpV31O9r49J7Xub16np+OusEPjVlSNCRRJJOvD4AYzZwr7sPA84C7jezd722u89z90p3rywvL4/Tt5Zkt7uplYvuXMTymnp+ceFUlbnIEYplD70GGB41PSwyL9ocYCaAu79kZnlAf2B7PEJK6trZ0MKFdy5iw45Gbr/4RD46YWDQkUSSVix76IuBcWY22sxyCB/0nN9pnTeBjwGY2XFAHlAbz6CSerbvaWbWvIVs2tnIXZdUqsxFjlK3e+ju3m5mVwFPAZnA3e6+wsxuApa4+3zgauAOM/sa4QOkl7q792RwSW5b6/dxwR2L2LanmXsvm8aMY/oFHUkk6cV0YZG7LwAWdJr3rajnK4FT4htNUtWWuiYuuHMhuxvbuH/ONE4cWRZ0JJGUoCtFpVdt2tHIhXcuYm9zGw9cPp0pw/sEHUkkZajQpde8tXsf5897idb2Dn49dwYTh5QGHUkkpajQpdc8t2Y72/a08NiXTlaZi/QA3YhRek17KHycfGjf/ICTiKQmFbr0itVv7+Fnz65jYEkupfnZQccRSUkqdOlxy2vqmT1vIZkZxkNXzCA3KzPoSCIpSYUuPerVLbu54I6F5Gdn8sjckxhTXhR0JJGUpYOi0mOWbq7jkrsXU1aYw0NXTGdY34KgI4mkNO2hS494af1OLr7rZQYU5/LIF2aozEV6gfbQJe6Wbq7jsntfZnjfAh68YjoDivOCjiSSFlToEndPvLYVgIfnzqBfUW7AaUTSh4ZcJO7aQh3kZWeqzEV6mQpd4uqZVdv47ZJqJgwqDjqKSNpRoUvc/H1dLf/2wFImDC7mlxedGHQckbSjMXSJm8eW1VCUm8UDl0+nJE9Xg4r0Nu2hS9y0h5yivCyVuUhAVOgSF4+/WsOf3tjKcYNKgo4ikrZU6HLUnnxjK1995FUqR/blx+e/L+g4ImlLY+hy1B5bVsPgkjzuvWwa+Tn64C2RoGgPXY5aW6iD0oIclblIwFToclQeWvQmz62ppWKwxs5FgqZClyP22yVbuO6xN/jI+HK+e+6koOOIpD2NocsR+90r1YwdUMTtF1eSk6V9A5Gg6V0oR6w95JQV5qjMRRKE3olyRH75/HqWbN7FxCEaOxdJFBpykcP2s2fW8aO/rOVTU4Zw3Vn
2021-06-07 19:58:49 +02:00
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
2021-06-07 19:58:49 +02:00
"\n",
"X_len = np.asarray(list(map(len, XX)))\n",
"l = []\n",
"sq_xlen = pd.Series(X_len)\n",
"ptiles = [x*0.01 for x in range(100)]\n",
"for i in ptiles:\n",
" l.append(sq_xlen.quantile(i))\n",
"plt.plot(l, ptiles)\n",
"sq_xlen.describe(percentiles=[x*0.01 for x in range(95,100)])"
]
},
{
"cell_type": "code",
"execution_count": 10,
2021-07-14 10:15:52 +02:00
"id": "e38a87d6",
2021-06-07 19:58:49 +02:00
"metadata": {},
"outputs": [],
"source": [
"def plot_data(data):\n",
" fig, axs = plt.subplots(4, 3, figsize=(3*3, 3*4))\n",
" t = data['Millis']\n",
" axs[0][0].plot(t, data['Acc1 X'])\n",
" axs[0][1].plot(t, data['Acc1 Y'])\n",
" axs[0][2].plot(t, data['Acc1 Z'])\n",
" axs[1][0].plot(t, data['Acc2 X'])\n",
" axs[1][1].plot(t, data['Acc2 Y'])\n",
" axs[1][2].plot(t, data['Acc2 Z'])\n",
" axs[2][0].plot(t, data['Gyro X'])\n",
" axs[2][1].plot(t, data['Gyro Y'])\n",
" axs[2][2].plot(t, data['Gyro Z'])\n",
" axs[3][0].plot(t, data['Mag X'])\n",
" axs[3][1].plot(t, data['Mag Y'])\n",
" axs[3][2].plot(t, data['Mag Z'])\n",
"\n",
" for a in axs:\n",
" for b in a:\n",
" b.plot(t, data['Force'])\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
2021-07-14 10:15:52 +02:00
"id": "61d0460c",
2021-06-07 19:58:49 +02:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"((25918,), (52, 15))"
2021-06-07 19:58:49 +02:00
]
},
"execution_count": 11,
2021-06-07 19:58:49 +02:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"threshold_p = 0.99\n",
"threshold = int(sq_xlen.quantile(threshold_p))\n",
"len_mask = np.where(X_len <= threshold)\n",
"\n",
"X_filter = XX[len_mask]\n",
"y_filter = y[len_mask]\n",
"\n",
"X_filter.shape, X_filter[0].shape"
]
},
{
"cell_type": "code",
"execution_count": 12,
2021-07-14 10:15:52 +02:00
"id": "f616324d",
2021-06-07 19:58:49 +02:00
"metadata": {},
"outputs": [],
"source": [
"from tensorflow.keras.preprocessing.sequence import pad_sequences\n",
"a = [x.drop(labels='Millis', axis=1) for x in X_filter]"
]
},
{
"cell_type": "code",
"execution_count": 13,
2021-07-14 10:15:52 +02:00
"id": "5d59fccb",
2021-06-07 19:58:49 +02:00
"metadata": {},
"outputs": [],
"source": [
"X_filter = pad_sequences(X_filter, dtype=float, padding='post')"
]
},
{
"cell_type": "code",
"execution_count": 14,
2021-07-14 10:15:52 +02:00
"id": "972fc363",
2021-06-07 19:58:49 +02:00
"metadata": {},
"outputs": [],
"source": [
"def plot_data(data):\n",
" fig, axs = plt.subplots(5, 3, figsize=(3*3, 3*5))\n",
" axs[0][0].plot(data[0])\n",
" axs[0][1].plot(data[1])\n",
" axs[0][2].plot(data[2])\n",
" axs[1][0].plot(data[3])\n",
" axs[1][1].plot(data[4])\n",
" axs[1][2].plot(data[5])\n",
" axs[2][0].plot(data[6])\n",
" axs[2][1].plot(data[7])\n",
" axs[2][2].plot(data[8])\n",
" axs[3][0].plot(data[9])\n",
" axs[3][1].plot(data[10])\n",
" axs[3][2].plot(data[11])\n",
" axs[4][0].plot(data[12])\n",
" axs[4][1].plot(data[13])\n",
"\n",
"# for a in axs:\n",
"# for b in a:\n",
"# b.plot(t, data['Force'])\n"
]
},
{
"cell_type": "code",
"execution_count": 15,
2021-07-14 10:15:52 +02:00
"id": "f9949967",
2021-06-07 19:58:49 +02:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(20734, 150, 15)\n",
"(5184, 150, 15)\n",
"(20734, 52)\n",
"(5184, 52)\n"
2021-06-07 19:58:49 +02:00
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAj4AAANZCAYAAAAPtDT6AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOy9eZhcZZn3/7mrqvd0Z0/ISgKEJYkQIAZURJBVfMeA28A4goKD48AsOq8jDK8Dg4ODvq/j8nMZozIiM4K4MGSUxYAiigQIEEIIBEIIJCF7d5Jeq2u5f3+c51Sfqq6tu6ururruz3X11VXPWerp5KlzvudeRVUxDMMwDMOoBUKVnoBhGIZhGEa5MOFjGIZhGEbNYMLHMAzDMIyawYSPYRiGYRg1gwkfwzAMwzBqBhM+hmEYhmHUDONa+IjIbSKyV0Q2FrHvmSLyjIjEReSDgfFlIvK4iLwgIhtE5E9Hd9aGYRiGYYwW41r4AD8ELixy3zeAjwE/zhjvAS5X1SXuXF8TkUklmp9hGIZhGGUkUukJjCaq+qiILAiOicjRwLeA6Xii5i9U9SVV3ea2JzPO8XLg9Zsistcde3BUJ28YhmEYRskZ18InB6uAv1TVV0TkNODbwLuLOVBEVgD1wKujOD/DMAzDMEaJmhI+IjIBeDvwUxHxhxuKPHYWcAdwhaomC+1vGIZhGMbYo6aED15M00FVXTaUg0SkDfgVcIOqrh2NiRmGYRiGMfqM9+DmNFT1MPCaiHwIQDxOyneMiNQD9wA/UtWflWGahmEYhmGMEjKeu7OLyJ3AWcA0YA9wI/Ab4DvALKAOuEtVbxaRt+IJnMlAH7BbVZeIyJ8D/wG8EDj1x1R1fbn+DsMwDMMwSsO4Fj6GUQpEZBvQCSSAuKouF5EpwE+ABcA24MOq2iFe8NjXgYvwsgY/pqrPuPNcAfwfd9p/UdXby/l3GIZhGDXm6jKMEXC2qi5T1eXu/XXAw6q6CHjYvQd4D7DI/VyNZ13ECaUbgdOAFcCNIjK5jPM3xjkiMk9Efisim1zB1b914zeJyE4RWe9+Lgocc72IbBGRzSJyQWD8Qje2RUSuC4wvFJEn3PhPXCiAYVQV49biM23aNF2wYEGlp2GMA55//nn6+/sPqOo0f0xENgNnqeoul/H3iKoeJyLfda/vDO7n/6jqJ9142n7ZsDVsDIVYLEYsFqO5uZlEIsGLL77I0UcfTUdHB6FQiJ07d+5X1en+/iKyGLgTT4jPBh4CjnWbXwbOA3YATwGXqeomEbkb+IWq3iUi/w48p6rfyTcvW8dGqXj66afT1vBwGbdZXQsWLGDdunWVnoYxDli4cCHbtm1rEpGnge+q6ipgpqrucrvsBma613OA7YHDd7ixXONpiMjVeJYi5s+fb2vYGDYrV67k2muv5bHHHmPChAl89rOffT1zF7wYxyhe0scWPBEEsEVVtwKIyF3AShF5Ea/m2Z+5fW4HbsJZNXNh12KjVIhI5hoeFubqMowC/OEPfwB4Ec+NdY2InBncrp7ZtCSmU1VdparLVXX59OkjfrAxapRt27bx7LPPctpppwHwzW9+E2Cx61/ou1iHKtKn4pUDiWeMG0ZVYcLHMAowZ453bVfVvXiZfyuAPc7F5Re33Ot23wnMCxw+143lGjeMktLV1cUHPvABvva1r9HW1sanPvUpXn31VYBNwC7gK6M9BxG5WkTWici6ffv2jfbHGcaQMOFjGHno7u6ms7MTABFpAc4HNgKrgSvcblcA97rXq4HLXY2o04FDziX2IHC+iEx2T9znuzHDKBmxWIwPfOADfOQjH+H9738/ADNnziQcDvu7fI8Bd9ZQRfoBYJKIRDLGB2GWS2MsY8LHMPKwZ88ezjjjDIDFwJPAr1T1AeBW4DwReQU4170HuA/YCmzBu8n8FYCqtgNfwAsUfQq42Y0ZRklQVa666ipOOOEEPvOZz6TGd+3aFdztEjzhDp5Iv1REGkRkIV4m4pN463ORy+CqBy4FVjuX7m+BD7rjg4LfMKqGcRvcnItlN/+avlii0tMYMmERvvzBk3jvibMqPZWa4qijjuK5555DRDYFUtlR1QPAOZn7u5vDNdnOpaq3AbeNxjyffaODS779R/543buZPalpND7CGOM89thj3HHHHbzlLW9h2bJlAHzxi1/kzjvvZP369eCJ97OBTwKo6gsuS2sTEAeuUdUEgIhci2eRDAO3qapfwPVzwF0i8i/As8APyvTnDZvHXz3ArQ+8xE8/+TbqI/asb9Sg8PnIafOJJ6ovhf+7j25l855O3osJH2Mwd6z1kh0e27KfDy2fV2BvYzxyxhlnkK08yUUXeWV7nHh/X3Cbqt4C3JJ5jKreh2e9zBzfyoCrrCp44c1DPLf9IJ19MaZOKKontTHOKZvwEZHbgP8F7FXVpVm2n4VnNn3NDf1CVW922y7Eq4YbBr6vqrdmHl8sn73g+OEeWlFW/X5r1ouaYQAkk97aCIekwjMxjLFFwn034km7fhoe5bT7/RC4sMA+v3fVcZcFRE8Y+BZeKvFi4DJXeKumCIlgusfIhW/ENOFjGOkk3IUzlkhWeCbGWKFswkdVHwWGE8y5AldMS1X7gbvwCm/VFCGBpCkfIwe+xSckJnwMI0jCPRUkzOJjOMZapNfbROQ5EblfRJa4saIq3o53RAT73hq5SJiryzCy4ru4YlUY22mMDmMpuPkZ4EhV7XJN9P4bL72yaDLL/Y8nBCzGx8iJbw00i49hpON/N+JJc3UZHmPG4qOqh1W1y72+D6gTkWkMoeLteC6aFRIxV5eRkwHhU+GJGMYYw7f4VGM2rzE6jBnhIyJHiHiPqyKyAm9uB8hRTKtyM60MXoxPpWdhjFXM1WUY2UlaVpeRQTnT2e8EzgKmicgO4EagDkBV/x2vGuinRCQO9AKXumJw8TzFtGoGy+oy8uE/zIZM+BhGGgMWH3N1GR5lEz6qelmB7d8EvpljW9ZiWrWEWFaXkYdUHR+L8TGMNKyOj5HJmHF1GfkREQtuNnLii2LTPYaRTsJifIwMTPhUCRbjY+TDv7ibNjaMdFLp7JbVZThM+FQJltVl5MNfG7ZGDCMd3w2cMIuP4TDhUyWICPa1NXKRsvhUeB6GMdZIBTebxcdwmPCpEkJiBQyN3PgPs7ZGDCOdhBM8VrnZ8DHhUyWIgD2wGLnwzfm2RgwjHV/vmMXH8DHhUyVYjI+RD4vxMYzs+BYfy+oyfEz4VAkha1Jq5MGP8bE1YhjpWB0fIxMTPlWCCKiFrho5SKQu6rZGDCNIwio3GxmY8KkSrGWFkY8BV1eFJ2JUjO3bt3P22WezePFilixZwte//nUA2tvbOe+88wCWisgaEZkMIB7fEJEtIrJBRE7xzyUiV4jIK+7nisD4qSLyvDvmG35/xbFM3Cw+RgYmfKqEkLWsMPLgX9NtjdQukUiEr3zlK2zatIm1a9fyrW99i02bNnHrrbdyzjnnAGwEHgauc4e8B1jkfq4GvgMgIlPweimeBqwAbvTFktvnLwLHXViev274WOVmIxMTPlWCWIyPkQez+BizZs3ilFM8o01raysnnHACO3fu5N577+WKK1JGm9uBi93rlcCP1GMtMElEZgEXAGtUtV1VO4A1wIVuW5uqrnUNpH8UONeYJWGVm40MyiZ8ROQ2EdkrIhtzbP+IM7c+LyJ/FJGTAtu2ufH1IrKuXHMeS1iTUiMfAy0rbI0YsG3bNp599llOO+009uzZw6xZs/xNu4GZ7vUcYHvgsB1uLN/4jizjgxCRq0VknYis27dv34j/npGQsMrNRgbltPj8kPxm0deAd6nqW4AvAKsytp+tqstUdfkozW9ME/Kimw0jK9ary/Dp6uriAx/4AF/72tdoa2tL2+YsNaO+SlR1laouV9Xl06dPH+2Py8uAxce+HIZH2YSPqj4KtOfZ/kdnVgVYC8wty8SqBIvxMXLRH08G0tltjdQysViMD3zgA3zkIx/h/e9/PwAzZ85k165dADh31V63+05gXuDwuW4s3/jcLONjmrhldRkZjNUYn6uA+wPvFfi1iDwtIlfnOmgsmVdLjRU
2021-06-07 19:58:49 +02:00
"text/plain": [
"<Figure size 648x1080 with 15 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"from sklearn.model_selection import train_test_split\n",
"from sklearn.preprocessing import LabelEncoder, LabelBinarizer\n",
"import tensorflow as tf\n",
"\n",
"lb = LabelBinarizer()\n",
"\n",
"yt_filter = lb.fit_transform(y_filter)\n",
"\n",
2021-06-07 19:58:49 +02:00
"X_train, X_test, y_train, y_test = train_test_split(X_filter, yt_filter, test_size=0.2, random_state=177013)\n",
"\n",
"print(X_train.shape)\n",
"print(X_test.shape)\n",
"print(y_train.shape)\n",
"print(y_test.shape)\n",
"\n",
"plot_data(X_filter[0].T)"
]
},
{
"cell_type": "markdown",
2021-07-14 10:15:52 +02:00
"id": "bb6724f3",
"metadata": {},
"source": [
"fig, axs = plt.subplots(13,2,figsize=(20, 60), sharey=True)\n",
"data_count = int(len(X_train)/10)\n",
"for i,j in zip(X_train[:data_count], lb.inverse_transform(y_train)[:data_count]):\n",
" num = ord(j) - 64\n",
" f = i.T[13]\n",
" r = int((num-1)/2)%13\n",
" c = (num-1)%2\n",
" axs[r][c].title.set_text(f'{j}')\n",
" axs[r][c].plot(f)\n",
"plt.savefig('./all_forces.png')"
]
},
2021-06-07 19:58:49 +02:00
{
"cell_type": "code",
"execution_count": 16,
2021-07-14 10:15:52 +02:00
"id": "5493e919",
2021-06-07 19:58:49 +02:00
"metadata": {},
"outputs": [],
"source": [
"# FIRST CELL: set these variables to limit GPU usage.\n",
"os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' # this is required\n",
"os.environ['CUDA_VISIBLE_DEVICES'] = '1' # set to '0' for GPU0, '1' for GPU1 or '2' for GPU2. Check \"gpustat\" in a terminal."
]
},
{
"cell_type": "code",
"execution_count": 17,
2021-07-14 10:15:52 +02:00
"id": "fad18a1d",
"metadata": {},
"outputs": [],
"source": [
"accs = []"
]
},
{
"cell_type": "code",
"execution_count": 18,
2021-07-14 10:15:52 +02:00
"id": "488c40fc",
"metadata": {},
"outputs": [],
2021-06-07 19:58:49 +02:00
"source": [
"import tensorflow as tf\n",
"from tensorflow.keras.models import Sequential\n",
"from tensorflow.keras.layers import Dense, Flatten, BatchNormalization, Dropout\n",
"from tqdm import tqdm\n",
2021-06-07 19:58:49 +02:00
"\n",
"\n",
"def build_model():\n",
" model = Sequential()\n",
2021-06-07 19:58:49 +02:00
"\n",
" model.add(BatchNormalization(input_shape=X_filter[0].shape))\n",
" \n",
" model.add(Flatten())\n",
2021-06-07 19:58:49 +02:00
"\n",
" for i in range(DENSE_COUNT):\n",
" model.add(Dense(DENSE_NEURONS, activation='relu'))\n",
" \n",
" Dropout(0.2)\n",
" \n",
" for i in range(DENSE2_COUNT):\n",
" model.add(Dense(DENSE2_NEURONS, activation='relu'))\n",
" \n",
" Dropout(0.2)\n",
" \n",
" model.add(Dense(52, activation='softmax'))\n",
2021-06-07 19:58:49 +02:00
"\n",
" model.compile(\n",
" optimizer=tf.keras.optimizers.Adam(0.0001),\n",
" loss=\"categorical_crossentropy\", \n",
" metrics=[\"acc\"],\n",
" )\n",
2021-06-07 19:58:49 +02:00
"\n",
" return model\n",
"# model.summary()\n"
2021-06-07 19:58:49 +02:00
]
},
{
"cell_type": "code",
"execution_count": 19,
2021-07-14 10:15:52 +02:00
"id": "feeafbd0",
2021-06-07 19:58:49 +02:00
"metadata": {},
"outputs": [
{
"name": "stderr",
2021-06-07 19:58:49 +02:00
"output_type": "stream",
"text": [
" 0%| | 0/1 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/50\n",
"162/162 [==============================] - 2s 5ms/step - loss: 3.1562 - acc: 0.1829 - val_loss: 2.5782 - val_acc: 0.3098\n",
"Epoch 2/50\n",
"162/162 [==============================] - 1s 4ms/step - loss: 2.0657 - acc: 0.4226 - val_loss: 1.9569 - val_acc: 0.4579\n",
"Epoch 3/50\n",
"162/162 [==============================] - 1s 4ms/step - loss: 1.5548 - acc: 0.5398 - val_loss: 1.7179 - val_acc: 0.5149\n",
"Epoch 4/50\n",
"162/162 [==============================] - 1s 4ms/step - loss: 1.2320 - acc: 0.6281 - val_loss: 1.5439 - val_acc: 0.5484\n",
"Epoch 5/50\n",
"162/162 [==============================] - 1s 4ms/step - loss: 0.9888 - acc: 0.6937 - val_loss: 1.4862 - val_acc: 0.5700\n",
"Epoch 6/50\n",
"162/162 [==============================] - 1s 4ms/step - loss: 0.7951 - acc: 0.7524 - val_loss: 1.4870 - val_acc: 0.5689\n",
"Epoch 7/50\n",
"162/162 [==============================] - 1s 4ms/step - loss: 0.6452 - acc: 0.7965 - val_loss: 1.3981 - val_acc: 0.6030\n",
"Epoch 8/50\n",
"162/162 [==============================] - 1s 4ms/step - loss: 0.5287 - acc: 0.8313 - val_loss: 1.3804 - val_acc: 0.6209\n",
"Epoch 9/50\n",
"162/162 [==============================] - 1s 4ms/step - loss: 0.4338 - acc: 0.8627 - val_loss: 1.4227 - val_acc: 0.6073\n",
"Epoch 10/50\n",
"162/162 [==============================] - 1s 4ms/step - loss: 0.3522 - acc: 0.8852 - val_loss: 1.4271 - val_acc: 0.6277\n",
"Epoch 11/50\n",
"162/162 [==============================] - 1s 4ms/step - loss: 0.3016 - acc: 0.9029 - val_loss: 1.4799 - val_acc: 0.6223\n",
"Epoch 12/50\n",
"162/162 [==============================] - 1s 4ms/step - loss: 0.2579 - acc: 0.9169 - val_loss: 1.5139 - val_acc: 0.6275\n",
"Epoch 13/50\n",
"162/162 [==============================] - 1s 4ms/step - loss: 0.2041 - acc: 0.9367 - val_loss: 1.5174 - val_acc: 0.6348\n",
"Epoch 14/50\n",
"162/162 [==============================] - 1s 4ms/step - loss: 0.1761 - acc: 0.9452 - val_loss: 1.5331 - val_acc: 0.6437\n",
"Epoch 15/50\n",
"162/162 [==============================] - 1s 4ms/step - loss: 0.1571 - acc: 0.9501 - val_loss: 1.6265 - val_acc: 0.6404\n",
"Epoch 16/50\n",
"162/162 [==============================] - 1s 4ms/step - loss: 0.1334 - acc: 0.9580 - val_loss: 1.6154 - val_acc: 0.6476\n",
"Epoch 17/50\n",
"162/162 [==============================] - 1s 4ms/step - loss: 0.1110 - acc: 0.9676 - val_loss: 1.7053 - val_acc: 0.6427\n",
"Epoch 18/50\n",
"162/162 [==============================] - 1s 4ms/step - loss: 0.1231 - acc: 0.9628 - val_loss: 1.8015 - val_acc: 0.6117\n",
"Epoch 19/50\n",
"162/162 [==============================] - 1s 4ms/step - loss: 0.1297 - acc: 0.9598 - val_loss: 1.7488 - val_acc: 0.6360\n",
"Epoch 20/50\n",
"162/162 [==============================] - 1s 4ms/step - loss: 0.1451 - acc: 0.9543 - val_loss: 1.7453 - val_acc: 0.6285\n",
"Epoch 21/50\n",
"162/162 [==============================] - 1s 4ms/step - loss: 0.1098 - acc: 0.9681 - val_loss: 1.7955 - val_acc: 0.6418\n",
"Epoch 22/50\n",
"162/162 [==============================] - 1s 4ms/step - loss: 0.0817 - acc: 0.9765 - val_loss: 1.7614 - val_acc: 0.6518\n",
"Epoch 23/50\n",
"162/162 [==============================] - 1s 4ms/step - loss: 0.0729 - acc: 0.9790 - val_loss: 1.8825 - val_acc: 0.6337\n",
"Epoch 24/50\n",
"162/162 [==============================] - 1s 4ms/step - loss: 0.0927 - acc: 0.9695 - val_loss: 1.8203 - val_acc: 0.6402\n",
"Epoch 25/50\n",
"162/162 [==============================] - 1s 4ms/step - loss: 0.0796 - acc: 0.9764 - val_loss: 1.9148 - val_acc: 0.6360\n",
"Epoch 26/50\n",
"162/162 [==============================] - 1s 4ms/step - loss: 0.0609 - acc: 0.9815 - val_loss: 1.9998 - val_acc: 0.6337\n",
"Epoch 27/50\n",
"162/162 [==============================] - 1s 4ms/step - loss: 0.0589 - acc: 0.9830 - val_loss: 1.9563 - val_acc: 0.6439\n",
"Epoch 28/50\n",
"162/162 [==============================] - 1s 4ms/step - loss: 0.0562 - acc: 0.9840 - val_loss: 1.9897 - val_acc: 0.6410\n",
"Epoch 29/50\n",
"162/162 [==============================] - 1s 4ms/step - loss: 0.0861 - acc: 0.9730 - val_loss: 2.0437 - val_acc: 0.6329\n",
"Epoch 30/50\n",
"162/162 [==============================] - 1s 4ms/step - loss: 0.0672 - acc: 0.9783 - val_loss: 2.0066 - val_acc: 0.6503\n",
"Epoch 31/50\n",
"162/162 [==============================] - 1s 4ms/step - loss: 0.0713 - acc: 0.9789 - val_loss: 1.9843 - val_acc: 0.6453\n",
"Epoch 32/50\n",
"162/162 [==============================] - 1s 4ms/step - loss: 0.0718 - acc: 0.9777 - val_loss: 1.9756 - val_acc: 0.6424\n",
"Epoch 33/50\n",
"162/162 [==============================] - 1s 4ms/step - loss: 0.0501 - acc: 0.9847 - val_loss: 2.0316 - val_acc: 0.6472\n",
"Epoch 34/50\n",
"162/162 [==============================] - 1s 4ms/step - loss: 0.0588 - acc: 0.9822 - val_loss: 1.9967 - val_acc: 0.6368\n",
"Epoch 35/50\n",
"162/162 [==============================] - 1s 4ms/step - loss: 0.0422 - acc: 0.9885 - val_loss: 2.1204 - val_acc: 0.6522\n",
"Epoch 36/50\n",
"162/162 [==============================] - 1s 4ms/step - loss: 0.0392 - acc: 0.9887 - val_loss: 2.1788 - val_acc: 0.6356\n",
"Epoch 37/50\n",
"162/162 [==============================] - 1s 4ms/step - loss: 0.0835 - acc: 0.9755 - val_loss: 2.1036 - val_acc: 0.6321\n",
"Epoch 38/50\n",
"162/162 [==============================] - 1s 4ms/step - loss: 0.0733 - acc: 0.9771 - val_loss: 2.1187 - val_acc: 0.6443\n",
"Epoch 39/50\n",
"162/162 [==============================] - 1s 4ms/step - loss: 0.0362 - acc: 0.9899 - val_loss: 2.1870 - val_acc: 0.6451\n",
"Epoch 40/50\n",
"162/162 [==============================] - 1s 4ms/step - loss: 0.0526 - acc: 0.9834 - val_loss: 2.0849 - val_acc: 0.6512\n",
"Epoch 41/50\n",
"162/162 [==============================] - 1s 4ms/step - loss: 0.0462 - acc: 0.9865 - val_loss: 2.2498 - val_acc: 0.6366\n",
"Epoch 42/50\n",
"162/162 [==============================] - 1s 4ms/step - loss: 0.0399 - acc: 0.9876 - val_loss: 2.1870 - val_acc: 0.6501\n",
"Epoch 43/50\n",
"162/162 [==============================] - 1s 4ms/step - loss: 0.0492 - acc: 0.9850 - val_loss: 2.0921 - val_acc: 0.6620\n",
"Epoch 44/50\n",
"162/162 [==============================] - 1s 4ms/step - loss: 0.0262 - acc: 0.9932 - val_loss: 2.1486 - val_acc: 0.6586\n",
"Epoch 45/50\n",
"162/162 [==============================] - 1s 4ms/step - loss: 0.0337 - acc: 0.9908 - val_loss: 2.1384 - val_acc: 0.6570\n",
"Epoch 46/50\n",
"162/162 [==============================] - 1s 4ms/step - loss: 0.0399 - acc: 0.9878 - val_loss: 2.1853 - val_acc: 0.6393\n",
"Epoch 47/50\n",
"162/162 [==============================] - 1s 4ms/step - loss: 0.0487 - acc: 0.9860 - val_loss: 2.1394 - val_acc: 0.6607\n",
"Epoch 48/50\n",
"162/162 [==============================] - 1s 4ms/step - loss: 0.0483 - acc: 0.9855 - val_loss: 2.1069 - val_acc: 0.6429\n",
"Epoch 49/50\n",
"162/162 [==============================] - 1s 4ms/step - loss: 0.0494 - acc: 0.9855 - val_loss: 2.2427 - val_acc: 0.6435\n",
"Epoch 50/50\n",
"162/162 [==============================] - 1s 4ms/step - loss: 0.0567 - acc: 0.9832 - val_loss: 2.1715 - val_acc: 0.6375\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1/1 [00:36<00:00, 36.07s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"test loss, test acc: [2.1715357303619385, 0.6375385522842407]\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
2021-06-07 19:58:49 +02:00
]
}
],
"source": [
"for i in tqdm(range(AVG_FROM)):\n",
" model = build_model()\n",
" \n",
" model.fit(X_train, y_train, \n",
" epochs=EPOCH,\n",
" batch_size=128,\n",
" shuffle=True,\n",
" validation_data=(X_test, y_test),\n",
" verbose=1,\n",
" )\n",
" # Evaluate the model on the test data using `evaluate`\n",
"# print(\"Evaluate on test data\")\n",
" results = model.evaluate(X_test, y_test, batch_size=128, verbose=0)\n",
" print(\"test loss, test acc:\", results)\n",
" accs.append((model,results[1]))"
2021-06-07 19:58:49 +02:00
]
},
{
"cell_type": "code",
"execution_count": 20,
2021-07-14 10:15:52 +02:00
"id": "e81143f4",
2021-06-07 19:58:49 +02:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.6375385522842407"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
2021-06-07 19:58:49 +02:00
"source": [
"np.mean(np.delete(accs,0,1).astype('float64'))"
2021-06-07 19:58:49 +02:00
]
},
{
"cell_type": "code",
"execution_count": 21,
2021-07-14 10:15:52 +02:00
"id": "2451e675",
2021-06-07 19:58:49 +02:00
"metadata": {},
"outputs": [],
"source": [
"exit()"
]
},
{
"cell_type": "code",
"execution_count": null,
2021-07-14 10:15:52 +02:00
"id": "d452d294",
"metadata": {},
"outputs": [],
"source": []
2021-06-07 19:58:49 +02:00
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2021-07-14 10:15:52 +02:00
"version": "3.8.10"
2021-06-07 19:58:49 +02:00
}
},
"nbformat": 4,
"nbformat_minor": 5
}