iui-group-l-name-zensiert/1-first-project/ies/TestNetrwork.ipynb

548 lines
27 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "8c784b5a",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import pickle\n",
"import pandas as pd\n",
"import numpy as np\n",
"import tensorflow as tf\n",
"import matplotlib.pyplot as plt\n",
"from tensorflow.keras.preprocessing.sequence import pad_sequences\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.preprocessing import LabelEncoder, LabelBinarizer\n",
"from tensorflow.keras.models import Sequential\n",
"from tensorflow.keras.layers import Dense, Flatten, BatchNormalization\n",
"\n",
"os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'\n",
"os.environ['CUDA_VISIBLE_DEVICES'] = '2'\n",
"\n",
"\n",
"\n",
"delim = ';'\n",
"user_count = 100\n",
"base_path = '/opt/iui-datarelease1-sose2021/'\n",
"\n",
"Xpickle_file = './X.pickle'\n",
"\n",
"ypickle_file = './y.pickle'"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "7b486d61",
"metadata": {},
"outputs": [],
"source": [
"def load_pickles():\n",
" _p = open(Xpickle_file, 'rb')\n",
" X = pickle.load(_p)\n",
" _p.close()\n",
" \n",
" _p = open(ypickle_file, 'rb')\n",
" y = pickle.load(_p)\n",
" _p.close()\n",
" \n",
" return (np.asarray(X, dtype=pd.DataFrame), np.asarray(y, dtype=str))"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "5ea384ea",
"metadata": {},
"outputs": [],
"source": [
"def shorten(npList):\n",
" temp = npList['Force']\n",
" thresh = 100\n",
" leeway = 5\n",
" \n",
" temps_over_T = np.where(temp > thresh)[0]\n",
" return npList[max(temps_over_T[0]-leeway,0):temps_over_T[-1]+leeway]"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "09aad3f2",
"metadata": {},
"outputs": [],
"source": [
"def load_data():\n",
" if os.path.isfile(Xpickle_file) and os.path.isfile(ypickle_file):\n",
" return load_pickles()\n",
" data = []\n",
" label = []\n",
" for user in range(0, user_count):\n",
" user_path = base_path + str(user) + '/split_letters_csv/'\n",
" for file in os.listdir(user_path):\n",
" file_name = user_path + file\n",
" letter = ''.join(filter(lambda x: x.isalpha(), file))[0]\n",
" data.append(pd.read_csv(file_name, delim))\n",
" label.append(letter)\n",
" return (np.asarray(data, dtype=pd.DataFrame), np.asarray(label, dtype=str), np.asarray(file_name))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "37d66d26",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 2.76 s, sys: 205 ms, total: 2.97 s\n",
"Wall time: 2.97 s\n"
]
}
],
"source": [
"%%time\n",
"x, y = load_data()\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "3178395b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 3.22 s, sys: 2.07 ms, total: 3.22 s\n",
"Wall time: 3.22 s\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"<timed exec>:1: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n"
]
}
],
"source": [
"%%time\n",
"f_data = np.array(list(map(shorten, x)))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "dcbb85b7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"count 13102.000000\n",
"mean 61.169058\n",
"std 30.698514\n",
"min 10.000000\n",
"50% 57.000000\n",
"95% 102.000000\n",
"96% 107.000000\n",
"97% 113.000000\n",
"98% 127.000000\n",
"99% 156.000000\n",
"max 1522.000000\n",
"dtype: float64"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXYAAAD4CAYAAAD4k815AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAf/ElEQVR4nO3deXTU9b3/8ec7O0kIIQmRJWCCgBUQBSNgq9VWWtFaqT+thap1Q2xvvV3scu1tr23t7T1Ve7rdUhXRWrciWrXUUnEp1dsqS0BBCIthkyCQkIRAErLMzOf3xww4xoQMMMl3ltfjnJyZ75Lkdb7JvPLNZ76LOecQEZHEkeJ1ABERiS4Vu4hIglGxi4gkGBW7iEiCUbGLiCSYNK++cVFRkSstLfXq24uIxKVVq1btc84NOto6nhV7aWkpFRUVXn17EZG4ZGY7elpHQzEiIglGxS4ikmBU7CIiCabHYjezh8ysxszWdbPczOw3ZlZlZmvNbFL0Y4qISKQi2WN/GJh+lOUXA6NDH3OAe088loiIHK8ei9059xpQf5RVZgCPuKBlQL6ZDYlWQBEROTbRGGMfBuwMm64OzfsQM5tjZhVmVlFbWxuFby0iIp316XHszrl5wDyA8vJyXS9YROKSP+Bo8/lp6wjQ2vmxw0+bL0Cb7/3nnR8v/EgxZwzP77V80Sj2XcDwsOmS0DwRkV7l8wdo9QVo6/C//9gRoM334ce2ztNHKd73Hz/4uYcLvMN/Yvulxf0zY77YFwG3mtkCYArQ6JzbHYWvKyIJwDlHU5uP/S0dNB56/+Pw9MHWjtAe7uHS/XDBtnVTvP7A8ResGWSlpZKVnkJmF4+5mWkU5qSSmZ5CVlp3jylkpr//OZlpKWSld/2YefgxLQUzi+IW/rAei93M/ghcABSZWTXwQyAdwDl3H7AYuASoAlqAG3orrIh4p7XDz4FDHezvVMz7W9q7nB/+cbQCTkuxbkswKz2F/H7pZOVlHr04w8r2g5/f/WNaivV6wXqlx2J3zs3qYbkDvhq1RCLSZwIBx679h6iqbWJLTRN7GluPFHTj4eI+1E7joQ5aOwLdfh0zyMtKJz87nQH9gh/DC7IZ0C+N/H4ZwXmhZfmh54fnZ6X3/h5ssvHsImAi0nfafQG21zVTVRMs8KraJqpqmtha28yhDv+R9bLSUxiYnXGknE8uzOaM7AHBQg6bn3+kpIPz+melkZKico4VKnaRBBIIOLbXNbP+vQNs2H2Ad0JFvqO+5QPDIcPy+3FKcS5TygoZVZx75KMgJ8PD9BItKnaRONXuC7B570Eq3zvA+vcaj5R5c3twDzwtxSgtymHMSf255PQhR8p75KAcsjP00k9k+umKxImG5nZe3rCXFdvqWf/eAd6pOXjksLucjFROG5LHlWeVMG7oAMYOzWPMSf3JSNN1/pKRil0khu1pbOXFyj0sWb+HZVvr8QccBTkZjBuax8fHjGTc0DzGDc2jtDBHY9xyhIpdJIa0+fy8Xd3I8m31vLxhL2++ux+AUwbl8OXzRzJ93BDGD8vTUSRyVCp2EQ81tflYvaOBldvrWbGtnrd27qfNFzyscPywPL5z0alcNO4kRhX39zipxBMVu0gf8vkDrNzewCsb9rJie3Cs3B9wpKYY44bmcc3Ukzm7tICzSwdSmJvpdVyJUyp2kV7W5vPzr6p9vLBuDy9vqKG+uZ2MtBQmDs/n3y44hbNLC5h08kByM/VylOjQb5JIL3DO8fqWOhas3MnSjTU0tfnon5nGJ08rZvq4wZx/6iAdcii9Rr9ZIlEUCDhe2VjDb5dWsWbnfgpyMvjsGUO4aNxgPnpKkQ4/lD6hYheJAp8/wF/f3s3vlm5h096DDC/ox08vH88Vk0rISk/1Op4kGRW7yAlo9wV4ZnU19766hR11LYwqzuUXV53BZWcMJS1Ve+fiDRW7yHF4t66Fxet284fXt7O7sZXxw/K49+pJXDRusE4UEs+p2EUi4Jxj894mXlgXPAu0cvcBACaXFvCzKybw8dFFOmlIYoaKXeQo2nx+5v69ir+s3c22fc2YwVkjBvKDz5zGReMGM7wg2+uIIh+iYhfpRmuHn1seXcWrm2s5b3QRs88r41NjT6K4f5bX0USOSsUu0oUddc38x5/WsnxbPT/7f6czc/IIryOJREzFLhJm056D3PuPKhateY+01BR+edWZfG7iMK9jiRwTFbsIsLvxED/883perNxLdkYqs88byexzyyjO07CLxB8VuyS9HXXNfPGB5exvaedrF47mho+WMlC3iJM4pmKXpFZV08TV85fR5guwYM45nF4ywOtIIidMxS5Jq/K9A1z74HLMjCfnnMOpg3XNc0kMKnZJOs45XtlQw7eeWkN2RiqPz57CyEG5XscSiRoVuyQNf8Dxwro9zF1aReXuA4wsyuEPN07WSUaScFTskhRe2bCXny7ewNbaZkYW5XDPlRP43MRhpOtCXZKAVOyS8J5eVc13n17DqOJc5n5xEtPHDyZVF+qSBKZil4T22LId/OC5dZw7qoh5XzpLdy2SpKDfcklYD7y2lZ8u3sC004r57Rcn6YYXkjRU7JKQDpf6Z04fwq9mnqmxdEkqKnZJSI8u28Hk0gJ+PfNM3clIkk5Ev/FmNt3MNplZlZnd3sXyEWa21MzeNLO1ZnZJ9KOKRMY5hz/gKBnYT6UuSanH33ozSwXmAhcDY4FZZja202o/ABY65yYCM4HfRTuoSCScc/zshY3s2n+IUSfppCNJTpHszkwGqpxzW51z7cACYEandRyQF3o+AHgvehFFIhMIOH60aD33v7qVa6aO4MsfP8XrSCKeiGSMfRiwM2y6GpjSaZ0fAS+a2b8DOcC0rr6Qmc0B5gCMGKEbF0j0+AOO7z2zloUV1dx8Xhn/eclpugepJK1oDUDOAh52zpUAlwCPmtmHvrZzbp5zrtw5Vz5o0KAofWtJdh3+AN988i0WVlTz9QtHq9Ql6UWyx74LGB42XRKaF+4mYDqAc+4NM8sCioCaaIQU6U6bz8+tT7zJS5V7uf3ij/Dl8zX8IhLJHvtKYLSZlZlZBsE3Rxd1Wudd4EIAMzsNyAJqoxlUpLND7X5ufmQVL1Xu5c4Z41TqIiE97rE753xmdiuwBEgFHnLOrTezO4EK59wi4FvAA2b2TYJvpF7vnHO9GVySW1ObjxsfXsnK7fXcfcUErjp7eM+fJJIkIjpByTm3GFjcad4dYc8rgY9FN5pI19p9Aa6Zv5y3dzXyqy+cyYwzdbNpkXA6e0PizsY9B3hr537uuHSsSl2kCyp2iTsd/uAo3/CCfh4nEYlNKnaJK3VNbfzXc+vISE2htDDH6zgiMUkXAZO4sfdAK1fPX87O+hYeuK5c9ykV6YaKXeJCdUMLV89fzr6DbfzhxslMHVnodSSRmKVil5hXc6CVq+57g6Y2H4/NnsLEEQO9jiQS01TsEvNee2cf7zW2smDOVJW6SAT05qnEPJ8/AMCwfB0FIxIJFbvEtE17DvLzFzdTlJtBUW6m13FE4oKKXWLWul2NzJz3BikGC+ZMpV+GbkYtEgkVu8SkVTsamPXAMrIz0lh4yzmMKu7vdSSRuKE3TyXmrNm5n2sfXE5x/0wev3mqxtZFjpGKXWLOX9/eTYc/wMJbzqE4L8vrOCJxR0MxEnM6/AHSU1NU6iLHScUuMWXpxhqeWP4upw3J63llEemSil1ixgvrdjPn0QpGn5TLA18q9zqOSNzSGLvEhGffrObbT61lQskAHr5hMgP6pXsdSSRuqdjFc4vWvMdtC9cwtayQ+deVk5OpX0uRE6FXkHjumdXVDB+Yze9vOJusdJ2EJHKiNMYunuvwBxiYk6FSF4kSFbt46qF/buNfVXWM1VEwIlGjYhfP/OH17dz5fCUXjx/Mjy8b53UckYShMXbxzFOrdjKhZAD/O2siaanaxxCJFr2axDM+v2NQbqZKXSTK9IqSPuec4+4XNrJxz0HGDdXYuki0aShG+pRzjjufr+T3/9rOrMkj+Ma0MV5HEkk4KnbpM/6A4/vPvs2ClTu58WNl/Nelp2FmXscSSTgqdukzP3guWOq3fmIU3/r0GJW6SC9RsUuf+dPqXVw+cRjfvuhUr6OIJDS9eSp9wucP4JzjJF1jXaTXqdil17X7Anz1idV0+B3jh+koGJHeFlGxm9l0M9tkZlVmdns361xlZpVmtt7MnohuTIlXzjm++sRqlqzfyx2XjuXSCUO9jiSS8HocYzezVGAu8CmgGlhpZoucc5Vh64wGvgd8zDnXYGbFvRVY4ktTm4+XKvdy83ll3HhumddxRJJCJHvsk4Eq59xW51w7sACY0Wmdm4G5zrkGAOdcTXRjSrxyoUeNrYv0nUiKfRiwM2y6OjQv3BhgjJn9y8yWmdn0rr6Qmc0xswozq6itrT2+xBJXVm1vAGBofj+Pk4gkj2i9eZoGjAYuAGYBD5hZfueVnHPznHPlzrnyQYMGRelbSyz7/evbKe6fybTTTvI6ikjSiKTYdwHDw6ZLQvPCVQOLnHMdzrltwGaCRS9JbOmmGl7bXMsNHysjI00HYIn0lUhebSuB0WZWZmYZwExgUad1niO4t46ZFREcmtkavZgSb9p9AX745/WMHJTDjeeWeh1HJKn0WOzOOR9wK7AE2AAsdM6tN7M7zeyy0GpLgDozqwSWAt9xztX1VmiJfbsbD/FufQuzzx1JZppueSfSlyK6pIBzbjGwuNO8O8KeO+C20IcI7b4AAFnpGoIR6Wt61UnU1TW18fUFb5GRmsJpupepSJ/TRcAkqg61+/nCvGXsrG/hgevKVewiHlCxS1RV7j5AVU0TP//8GZw/Roe0inhBQzESVYfH1gf1z/Q4iUjyUrFL1FQ3tHD7M2vJyUhlVHGu13FEkpaGYiQq6prauOq+N2hq8/HY7CkM0yUERDyjYpeoWLWjgfcaW3n4hrOZOGKg13FEkpqGYiQq2kJj60W5GlsX8ZqKXU7Ypj0H+fFfKinMyWB4QbbXcUSSnopdTsi7dS3MnPcGKQZP3jKVAf3SvY4kkvQ0xi4n5PUt+2ho6eAvt57LqOL+XscREbTHLifoyNh6/wyPk4jIYSp2OW6rdtTz8yWbGFGQTWGO3jQViRUqdjku63Y1cu2DKyjqn8mCOVN1Iw2RGKIxdjkur26upaXdzxM3T2HIAJ2MJBJLtJslx6Wtww+gIRiRGKRil2O2dGMN9722ldOHDSA91byOIyKdqNjlmLy+ZR9zHq3g1JP688iNkzFTsYvEGo2xyzF5qXIvqSnG4zdPIS9LJyOJxCLtscsxae0IkJ6aolIXiWEqdonYM6ureXLlu5x1sq7eKBLLVOwSkcVv7+ZbT61h6shCfnf1JK/jiMhRaIxdIrJk/R6KcjN56PqzyUpP9TqOiByF9tglIq0dfnIyUlXqInFAxS49mv9/W1myfi/lpQVeRxGRCKjY5ageXbaD//7rBi45fTD/c/npXscRkQhojF2Oasm6PYwqzuU3MyeSlqr9AJF4oFeqdMs5x6EOPwP6pavUReKIXq3SJeccd72wiVU7Gigv1XHrIvFExS5d+uXL73Dfq1u4esoI/uOij3gdR0SOgcbYpUsvrt/D5NIC/vtz43WhL5E4E9Eeu5lNN7NNZlZlZrcfZb0rzMyZWXn0IkpfCwQcbb4A+dnpKnWRONRjsZtZKjAXuBgYC8wys7FdrNcf+DqwPNohpe/4A47bn1nLtn3NTC7Tcesi8SiSPfbJQJVzbqtzrh1YAMzoYr2fAHcBrVHMJ33sjj+vY2FFNV+7cDQ3nVvmdRwROQ6RFPswYGfYdHVo3hFmNgkY7pz769G+kJnNMbMKM6uora095rDS+16q3MvF4wdz26fGaBhGJE6d8FExZpYC/AL4Vk/rOufmOefKnXPlgwYNOtFvLb0kP1vXWheJZ5EU+y5geNh0SWjeYf2B8cA/zGw7MBVYpDdQRUS8EUmxrwRGm1mZmWUAM4FFhxc65xqdc0XOuVLnXCmwDLjMOVfRK4ml13T4A3T4A17HEJET1GOxO+d8wK3AEmADsNA5t97M7jSzy3o7oPSN1g4/X3lsNQ0tHUwdWeh1HBE5ARGdoOScWwws7jTvjm7WveDEY0lf+9of3+TlDXv5yYxxzDhzWM+fICIxS5cUEABe2VjDNVNHcO05pV5HEZETpGIXDrX7CThHfr8Mr6OISBSo2JPcwdYOrntoBQacrTNNRRKCLgKWxAIBxw2/X8lbO/fz65kTOX+Mzi0QSQTaY09iB1o7qNjRwK2fHMVnzxjqdRwRiRIVexJrbvcDMKCfzjQVSSQq9iS1r6mNmx5eSUZqCmedrDskiSQSjbEnoTafn5nzllHd0MKD15czoSTf60giEkXaY09C1Q2HqKpp4vufGct5o/WGqUiiUbEnoZa24Nh6Xpb+YRNJRCr2JLOzvoV/e2IVuZlpnKEhGJGEpF22JNJ4qIOr7n+DlnY/j8+eQmlRjteRRKQXqNiTyKY9B9nd2Mq9V0/ijOH5XscRkV6ioZgk0tzmAyBPx62LJDQVe5LYsPsA33l6DYU5GZw6uL/XcUSkF6nYk8Cu/YeYOW8ZaSkpPHnLORTlZnodSUR6kcbYk8DanftpPNTB/OvKGVWc63UcEell2mNPAgdDY+v9ddy6SFJQsSe4ldvrufMvlQwv6MfJBTq8USQZqNgT2MY9B/jSgysozstk4S3n0C8j1etIItIHVOwJbOX2Bg51+HnwurMZMqCf13FEpI+o2BNYU2twbD03U2PrIslExZ6gXtmwl1++vJlxQ/MoyNFNqkWSiYo9Aa3YVs8tj67iI4P78/jsKaSmmNeRRKQP6X/0BPTGljp8AcejN03Rbe9EkpD22BPQwdYOAPprbF0kKanYE8zTq6p56F/bOHdUEaYRGJGkpGJPIC+s28O3n1rDR08pYt6XzsLU7CJJSf+rJ5DXt+wjNzON+deVk5Wuk5FEkpX22BPIwVYf6ammUhdJcir2BDHvtS08++YuLji12OsoIuKxiIrdzKab2SYzqzKz27tYfpuZVZrZWjN7xcxOjn5U6Ypzjl+//A7/s3gjn5kwhLuvnOB1JBHxWI/FbmapwFzgYmAsMMvMxnZa7U2g3Dk3AXgauDvaQeXDnHP87IWN/PLlzVwxqYTfzJxIeqr+CRNJdpG0wGSgyjm31TnXDiwAZoSv4Jxb6pxrCU0uA0qiG1O6MndpFfe/upVrp57MPVdO0BmmIgJEVuzDgJ1h09Whed25CfhbVwvMbI6ZVZhZRW1tbeQppUv/rNrH+GF53DljHCkqdREJier/7WZ2DVAO3NPVcufcPOdcuXOufNCgQdH81kknEHAcbPWRnZGm49VF5AMiKfZdwPCw6ZLQvA8ws2nA94HLnHNt0YknXfEHHN/901rWv3eAT35ER8GIyAdFUuwrgdFmVmZmGcBMYFH4CmY2EbifYKnXRD+mhLvzL+t5elU135g2mls+PtLrOCISY3osduecD7gVWAJsABY659ab2Z1mdllotXuAXOApM3vLzBZ18+UkCl7fUsf5YwbxjWljNAwjIh8S0SUFnHOLgcWd5t0R9nxalHNJNzr8AVra/eRk6uxSEemaDnqOI60dfr7y2Cp27T/EJ3SGqYh0QxcBixOH2v3MebSC/3tnHz+ZMY7Plw/v+ZNEJCmp2OPAwdYObnq4good9dxz5QSVuogclYo9Dtz8SAWr323g1zMn8tkzhnodR0RinMbY48CyrfXceG6ZSl1EIqJij3H7W9oB6KdrrItIhFTsMWxfUxsz5y0jIzWFj48p8jqOiMQJjbHHqD2NrVw9fxm79h/iwevLOevkAq8jiUicULHHoJ31LVw9fzn1ze08cuMUJpep1EUkcir2GNPS7uML979BU5uPx2ZP4czh+V5HEpE4ozH2GLOz/hDvNbZyx2fHqdRF5Lio2GNMQ+gomOwMHQUjIsdHxR5DNu45wK1PrKYwJ4NJIwZ6HUdE4pSKPUasrd7PzHnLSE0xnrzlHAYPyPI6kojEKb15GgNW7ajn+odWMiA7nSdmT2VEYbbXkUQkjqnYPfZ61T5u+kMFgwdk8fjsKQzN7+d1JBGJcyp2Dy3dWMMtj62irDCHR2dPpri/hl9E5MSp2D3yt7d387UFb/KRwXk8cuNkBuZkeB1JRBKE3jz1wLNvVvPVJ1YzoSSfx2+eolIXkajSHnsf++OKd/nPZ9/mnJGFzL+unOwM/QhEJLrUKn3owX9u4yfPV/KJUwdx7zVnkaVL8YpIL1Cx95Hf/v0dfv7iZqaPG8xvZk0kI02jYCLSO1Tsvcw5x89f3MTcpVu4fOIw7rlyAmmpKnUR6T0q9l5UVdPEr17ezPNrdzNr8gh++rnxpKSY17FEJMGp2HvBul2NzF1axQvr95CZlsI3po3m6xeOxkylLiK9T8UeRa0dfr791BqeX7ub/plpfPWCUdzwsVIKczO9jiYiSUTFHiXNbT5ufqSCN7bW8bULRzP7vDLystK9jiUiSUjFfoKcc7z2zj7ufmEjG3Yf4BdXncHlE0u8jiUiSUzFfpwCAceS9XuY+48q1u06wJABWdx3zVl8etxgr6OJSJJTsR+DvQdaWbGtnpXb63ltcy3b61ooLczmritO5/KJJTo2XURigoq9G845dtS1sGJbPSu2B8t8R10LELxt3aQRA7nt06fymdOHkKpDGEUkhkRU7GY2Hfg1kArMd879rNPyTOAR4CygDviCc257dKOeOOccTW0+Gpo7qGtuo6GlnfrmDuqb2z70uLPhELUH2wAYmJ1OeWkB1049mbNLCxg7NI90nWQkIjGqx2I3s1RgLvApoBpYaWaLnHOVYavdBDQ450aZ2UzgLuALvRE4XIc/QENzO/Ut7dQ3BR8bmtupaw57bGmnrin42NDcQbs/0OXXSk81CnIyGJidQWFuBueNKmLSyQOZUlbAKYNydWKRiMSNSPbYJwNVzrmtAGa2AJgBhBf7DOBHoedPA781M3POuShmBeDJle9y7z+2UNfczsFWX7fr5WWlUZibycDsdEoGZjOhZAADczIoDCvvgdkZFOZkMjAnndzMNJ1AJCIJIZJiHwbsDJuuBqZ0t45zzmdmjUAhsC98JTObA8wBGDFixHEFLszJZEJJ/pG964LcDAqyMyjIef8jPztdQyUikrT69M1T59w8YB5AeXn5ce3NTxt7EtPGnhTVXCIiiSSS3dpdwPCw6ZLQvC7XMbM0YADBN1FFRKSPRVLsK4HRZlZmZhnATGBRp3UWAdeFnl8J/L03xtdFRKRnPQ7FhMbMbwWWEDzc8SHn3HozuxOocM4tAh4EHjWzKqCeYPmLiIgHIhpjd84tBhZ3mndH2PNW4PPRjSYiIsdDh46IiCQYFbuISIJRsYuIJBgVu4hIgjGvjko0s1pghyffvGtFdDpTNkYpZ/TEQ0ZQzmiKh4xw9JwnO+cGHe2TPSv2WGNmFc65cq9z9EQ5oyceMoJyRlM8ZIQTz6mhGBGRBKNiFxFJMCr2983zOkCElDN64iEjKGc0xUNGOMGcGmMXEUkw2mMXEUkwKnYRkQSTlMVuZsPNbKmZVZrZejP7emh+gZm9ZGbvhB4HxkDWVDN708yeD02XmdlyM6sysydDl1L2OmO+mT1tZhvNbIOZnROj2/KboZ/3OjP7o5llxcL2NLOHzKzGzNaFzety+1nQb0J515rZJA8z3hP6ma81s2fNLD9s2fdCGTeZ2UV9kbG7nGHLvmVmzsyKQtOebMuj5TSzfw9t0/VmdnfY/GPbns65pPsAhgCTQs/7A5uBscDdwO2h+bcDd8VA1tuAJ4DnQ9MLgZmh5/cBX4mBjH8AZoeeZwD5sbYtCd6+cRvQL2w7Xh8L2xP4ODAJWBc2r8vtB1wC/A0wYCqw3MOMnwbSQs/vCss4FlgDZAJlwBYg1aucofnDCV56fAdQ5OW2PMr2/ATwMpAZmi4+3u3Zp7/AsfoB/Bn4FLAJGBKaNwTY5HGuEuAV4JPA86FfwH1hL6ZzgCUeZxwQKkzrND/WtuXh+/IWELxc9fPARbGyPYHSTi/yLrcfcD8wq6v1+jpjp2WXA4+Hnn8P+F7YsiXAOV5ty9C8p4EzgO1hxe7ZtuzmZ74QmNbFese8PZNyKCacmZUCE4HlwEnOud2hRXsAr2+u+ivgu0AgNF0I7HfO+ULT1QQLy0tlQC3w+9CQ0XwzyyHGtqVzbhfwc+BdYDfQCKwi9rbnYd1tv65uLh8LmW8kuPcLMZbRzGYAu5xzazotiqmcwBjgvNDQ4KtmdnZo/jHnTOpiN7Nc4E/AN5xzB8KXueCfRs+OBTWzS4Ea59wqrzJEKI3gv5T3OucmAs0Ehw6O8HpbAoTGqGcQ/EM0FMgBpnuZKVKxsP2Oxsy+D/iAx73O0pmZZQP/CdzR07oxII3gf5RTge8AC83MjucLJW2xm1k6wVJ/3Dn3TGj2XjMbElo+BKjxKh/wMeAyM9sOLCA4HPNrIN+CNwyHrm8s3teqgWrn3PLQ9NMEiz6WtiXANGCbc67WOdcBPENwG8fa9jysu+0Xyc3l+4yZXQ9cClwd+gMEsZXxFIJ/zNeEXkslwGozG0xs5YTga+kZF7SC4H/qRRxHzqQs9tBfwQeBDc65X4QtCr8p93UEx9494Zz7nnOuxDlXSvAesn93zl0NLCV4w3DwOCOAc24PsNPMTg3NuhCoJIa2Zci7wFQzyw79/A/njKntGaa77bcI+FLoiI6pQGPYkE2fMrPpBIcKL3POtYQtWgTMNLNMMysDRgMrvMjonHvbOVfsnCsNvZaqCR44sYcY2pYhzxF8AxUzG0PwQIR9HM/27Ks3CmLpAziX4L+2a4G3Qh+XEBzDfgV4h+C70wVeZw3lvYD3j4oZGfqhVgFPEXoH3eN8ZwIVoe35HDAwFrcl8GNgI7AOeJTgUQaeb0/gjwTH/TsIFs9N3W0/gm+gzyV4ZMTbQLmHGasIjv0efg3dF7b+90MZNwEXe7ktOy3fzvtvnnqyLY+yPTOAx0K/n6uBTx7v9tQlBUREEkxSDsWIiCQyFbuISIJRsYuIJBgVu4hIglGxi4gkGBW7iEiCUbGLiCSY/w8RGwvaVOothQAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"x_len = np.asarray(list(map(len, f_data)))\n",
"l = []\n",
"sq_xlen = pd.Series(x_len)\n",
"ptiles = [x*0.01 for x in range(100)]\n",
"for i in ptiles:\n",
" l.append(sq_xlen.quantile(i))\n",
"plt.plot(l, ptiles)\n",
"sq_xlen.describe(percentiles=[x*0.01 for x in range(95,100)])"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "1878d067",
"metadata": {},
"outputs": [],
"source": [
"thresh_p = 0.99\n",
"thresh = int(sq_xlen.quantile(thresh_p))\n",
"len_mask = np.where(x_len <= thresh)\n",
"\n",
"x_filter = f_data[len_mask]\n",
"y_filter = y[len_mask]\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "3a01c1ad",
"metadata": {},
"outputs": [],
"source": [
"lb = LabelBinarizer()\n",
"a = [x.drop(labels='Millis', axis=1) for x in x_filter]\n",
"x_filter = pad_sequences(x_filter, dtype=float, padding='post')\n",
"yt_filter = lb.fit_transform(y_filter)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "634a024c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 34.7 ms, sys: 5.84 ms, total: 40.6 ms\n",
"Wall time: 39.2 ms\n"
]
}
],
"source": [
"%%time\n",
"x_train, x_test, y_train, y_test = train_test_split(x_filter, yt_filter, test_size=0.2, random_state=177013)\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "0109b9b6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"sequential\"\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"flatten (Flatten) (None, 2340) 0 \n",
"_________________________________________________________________\n",
"batch_normalization (BatchNo (None, 2340) 9360 \n",
"_________________________________________________________________\n",
"dense (Dense) (None, 2200) 5150200 \n",
"_________________________________________________________________\n",
"dense_1 (Dense) (None, 1100) 2421100 \n",
"_________________________________________________________________\n",
"dense_2 (Dense) (None, 550) 605550 \n",
"_________________________________________________________________\n",
"dense_3 (Dense) (None, 225) 123975 \n",
"_________________________________________________________________\n",
"dense_4 (Dense) (None, 26) 5876 \n",
"=================================================================\n",
"Total params: 8,316,061\n",
"Trainable params: 8,311,381\n",
"Non-trainable params: 4,680\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"model = Sequential()\n",
"\n",
"model.add(Flatten(input_shape=x_filter[0].shape))\n",
"\n",
"model.add(BatchNormalization())\n",
"\n",
"model.add(Dense(2200, activation='relu'))\n",
"\n",
"model.add(Dense(1100, activation='relu'))\n",
"\n",
"model.add(Dense(550, activation='relu'))\n",
"\n",
"model.add(Dense(225, activation='relu'))\n",
"\n",
"model.add(Dense(26, activation='softmax'))\n",
"\n",
"model.compile(\n",
" optimizer=tf.keras.optimizers.Adam(0.001),\n",
" loss=\"categorical_crossentropy\", \n",
" metrics=[\"acc\"],\n",
")\n",
"\n",
"model.summary()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "204ed561",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/32\n",
"82/82 [==============================] - 1s 3ms/step - loss: 2.8553 - acc: 0.1745\n",
"Epoch 2/32\n",
"82/82 [==============================] - 0s 3ms/step - loss: 1.7793 - acc: 0.4480\n",
"Epoch 3/32\n",
"82/82 [==============================] - 0s 3ms/step - loss: 1.2391 - acc: 0.6070\n",
"Epoch 4/32\n",
"82/82 [==============================] - 0s 3ms/step - loss: 0.9623 - acc: 0.7021\n",
"Epoch 5/32\n",
"82/82 [==============================] - 0s 3ms/step - loss: 0.8489 - acc: 0.7336\n",
"Epoch 6/32\n",
"82/82 [==============================] - 0s 3ms/step - loss: 0.5827 - acc: 0.8169\n",
"Epoch 7/32\n",
"82/82 [==============================] - 0s 3ms/step - loss: 0.5208 - acc: 0.8313\n",
"Epoch 8/32\n",
"82/82 [==============================] - 0s 3ms/step - loss: 0.5864 - acc: 0.8147\n",
"Epoch 9/32\n",
"82/82 [==============================] - 0s 3ms/step - loss: 0.4101 - acc: 0.8710\n",
"Epoch 10/32\n",
"82/82 [==============================] - 0s 3ms/step - loss: 0.2856 - acc: 0.9087\n",
"Epoch 11/32\n",
"82/82 [==============================] - 0s 3ms/step - loss: 0.2789 - acc: 0.9126\n",
"Epoch 12/32\n",
"82/82 [==============================] - 0s 3ms/step - loss: 0.3118 - acc: 0.9027\n",
"Epoch 13/32\n",
"82/82 [==============================] - 0s 3ms/step - loss: 0.3337 - acc: 0.9054\n",
"Epoch 14/32\n",
"82/82 [==============================] - 0s 3ms/step - loss: 0.3052 - acc: 0.9049\n",
"Epoch 15/32\n",
"82/82 [==============================] - 0s 3ms/step - loss: 0.2052 - acc: 0.9403\n",
"Epoch 16/32\n",
"82/82 [==============================] - 0s 3ms/step - loss: 0.4292 - acc: 0.8907\n",
"Epoch 17/32\n",
"82/82 [==============================] - 0s 3ms/step - loss: 0.1545 - acc: 0.9542\n",
"Epoch 18/32\n",
"82/82 [==============================] - 0s 3ms/step - loss: 0.1401 - acc: 0.9575\n",
"Epoch 19/32\n",
"82/82 [==============================] - 0s 3ms/step - loss: 0.1907 - acc: 0.9483\n",
"Epoch 20/32\n",
"82/82 [==============================] - 0s 3ms/step - loss: 0.2635 - acc: 0.9303\n",
"Epoch 21/32\n",
"82/82 [==============================] - 0s 3ms/step - loss: 0.1116 - acc: 0.9671\n",
"Epoch 22/32\n",
"82/82 [==============================] - 0s 3ms/step - loss: 0.2453 - acc: 0.9317\n",
"Epoch 23/32\n",
"82/82 [==============================] - 0s 3ms/step - loss: 0.1090 - acc: 0.9681\n",
"Epoch 24/32\n",
"82/82 [==============================] - 0s 3ms/step - loss: 0.1578 - acc: 0.9541\n",
"Epoch 25/32\n",
"82/82 [==============================] - 0s 3ms/step - loss: 0.1609 - acc: 0.9570\n",
"Epoch 26/32\n",
"82/82 [==============================] - 0s 3ms/step - loss: 0.0801 - acc: 0.9775\n",
"Epoch 27/32\n",
"82/82 [==============================] - 0s 3ms/step - loss: 0.1597 - acc: 0.9615\n",
"Epoch 28/32\n",
"82/82 [==============================] - 0s 3ms/step - loss: 0.0695 - acc: 0.9807\n",
"Epoch 29/32\n",
"82/82 [==============================] - 0s 3ms/step - loss: 0.0622 - acc: 0.9853\n",
"Epoch 30/32\n",
"82/82 [==============================] - 0s 3ms/step - loss: 0.0655 - acc: 0.9841\n",
"Epoch 31/32\n",
"82/82 [==============================] - 0s 3ms/step - loss: 0.0383 - acc: 0.9910\n",
"Epoch 32/32\n",
"82/82 [==============================] - 0s 3ms/step - loss: 0.0716 - acc: 0.9792\n",
"CPU times: user 14 s, sys: 3.02 s, total: 17 s\n",
"Wall time: 8.95 s\n"
]
},
{
"data": {
"text/plain": [
"<tensorflow.python.keras.callbacks.History at 0x7f6dec4bf130>"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"model.fit(x_train, y_train, \n",
" epochs=32,\n",
" batch_size=128,\n",
" shuffle=True,\n",
" verbose=1\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "10a0d074",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluate on test data\n",
"82/82 [==============================] - 0s 2ms/step - loss: 1.7331 - acc: 0.7341\n",
"test loss, test acc: [1.7330855131149292, 0.7341040372848511]\n",
"Generate predictions for 3 samples\n",
"predictions shape: (3, 26)\n"
]
},
{
"data": {
"text/plain": [
"(array(['N', 'U', 'I'], dtype='<U1'), array(['N', 'U', 'I'], dtype='<U1'))"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Evaluate the model on the test data using `evaluate`\n",
"print(\"Evaluate on test data\")\n",
"results = model.evaluate(x_test, y_test, batch_size=32)\n",
"print(\"test loss, test acc:\", results)\n",
"\n",
"# Generate predictions (probabilities -- the output of the last layer)\n",
"# on new data using `predict`\n",
"print(\"Generate predictions for 3 samples\")\n",
"predictions = model.predict(x_test[:3])\n",
"print(\"predictions shape:\", predictions.shape)\n",
"\n",
"lb.inverse_transform(y_test[:3]), lb.inverse_transform(predictions)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "63f89de5",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "1a5d0352",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "4ebc323f",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "445a8e54",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "32206e00",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "fe00f947",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "d0fbe763",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "d18608c3",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 14,
"id": "6a0e538b",
"metadata": {},
"outputs": [],
"source": [
"exit()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "521b2be6",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}