{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "8c784b5a", "metadata": {}, "outputs": [], "source": [ "import os\n", "import pickle\n", "import pandas as pd\n", "import numpy as np\n", "import tensorflow as tf\n", "import matplotlib.pyplot as plt\n", "from tensorflow.keras.preprocessing.sequence import pad_sequences\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.preprocessing import LabelEncoder, LabelBinarizer\n", "from tensorflow.keras.models import Sequential\n", "from tensorflow.keras.layers import Dense, Flatten, BatchNormalization\n", "\n", "os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'\n", "os.environ['CUDA_VISIBLE_DEVICES'] = '2'\n", "\n", "\n", "\n", "delim = ';'\n", "user_count = 100\n", "base_path = '/opt/iui-datarelease1-sose2021/'\n", "\n", "Xpickle_file = './X.pickle'\n", "\n", "ypickle_file = './y.pickle'" ] }, { "cell_type": "code", "execution_count": 2, "id": "7b486d61", "metadata": {}, "outputs": [], "source": [ "def load_pickles():\n", " _p = open(Xpickle_file, 'rb')\n", " X = pickle.load(_p)\n", " _p.close()\n", " \n", " _p = open(ypickle_file, 'rb')\n", " y = pickle.load(_p)\n", " _p.close()\n", " \n", " return (np.asarray(X, dtype=pd.DataFrame), np.asarray(y, dtype=str))" ] }, { "cell_type": "code", "execution_count": 3, "id": "5ea384ea", "metadata": {}, "outputs": [], "source": [ "def shorten(npList):\n", " temp = npList['Force']\n", " thresh = 100\n", " leeway = 5\n", " \n", " temps_over_T = np.where(temp > thresh)[0]\n", " return npList[max(temps_over_T[0]-leeway,0):temps_over_T[-1]+leeway]" ] }, { "cell_type": "code", "execution_count": 4, "id": "09aad3f2", "metadata": {}, "outputs": [], "source": [ "def load_data():\n", " if os.path.isfile(Xpickle_file) and os.path.isfile(ypickle_file):\n", " return load_pickles()\n", " data = []\n", " label = []\n", " for user in range(0, user_count):\n", " user_path = base_path + str(user) + '/split_letters_csv/'\n", " for file in os.listdir(user_path):\n", " file_name = user_path + file\n", " letter = ''.join(filter(lambda x: x.isalpha(), file))[0]\n", " data.append(pd.read_csv(file_name, delim))\n", " label.append(letter)\n", " return (np.asarray(data, dtype=pd.DataFrame), np.asarray(label, dtype=str), np.asarray(file_name))" ] }, { "cell_type": "code", "execution_count": 5, "id": "37d66d26", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 2.76 s, sys: 205 ms, total: 2.97 s\n", "Wall time: 2.97 s\n" ] } ], "source": [ "%%time\n", "x, y = load_data()\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "3178395b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 3.22 s, sys: 2.07 ms, total: 3.22 s\n", "Wall time: 3.22 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ ":1: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n" ] } ], "source": [ "%%time\n", "f_data = np.array(list(map(shorten, x)))" ] }, { "cell_type": "code", "execution_count": 7, "id": "dcbb85b7", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "count 13102.000000\n", "mean 61.169058\n", "std 30.698514\n", "min 10.000000\n", "50% 57.000000\n", "95% 102.000000\n", "96% 107.000000\n", "97% 113.000000\n", "98% 127.000000\n", "99% 156.000000\n", "max 1522.000000\n", "dtype: float64" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXYAAAD4CAYAAAD4k815AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAf/ElEQVR4nO3deXTU9b3/8ec7O0kIIQmRJWCCgBUQBSNgq9VWWtFaqT+thap1Q2xvvV3scu1tr23t7T1Ve7rdUhXRWrciWrXUUnEp1dsqS0BBCIthkyCQkIRAErLMzOf3xww4xoQMMMl3ltfjnJyZ75Lkdb7JvPLNZ76LOecQEZHEkeJ1ABERiS4Vu4hIglGxi4gkGBW7iEiCUbGLiCSYNK++cVFRkSstLfXq24uIxKVVq1btc84NOto6nhV7aWkpFRUVXn17EZG4ZGY7elpHQzEiIglGxS4ikmBU7CIiCabHYjezh8ysxszWdbPczOw3ZlZlZmvNbFL0Y4qISKQi2WN/GJh+lOUXA6NDH3OAe088loiIHK8ei9059xpQf5RVZgCPuKBlQL6ZDYlWQBEROTbRGGMfBuwMm64OzfsQM5tjZhVmVlFbWxuFby0iIp316XHszrl5wDyA8vJyXS9YROKSP+Bo8/lp6wjQ2vmxw0+bL0Cb7/3nnR8v/EgxZwzP77V80Sj2XcDwsOmS0DwRkV7l8wdo9QVo6/C//9gRoM334ce2ztNHKd73Hz/4uYcLvMN/Yvulxf0zY77YFwG3mtkCYArQ6JzbHYWvKyIJwDlHU5uP/S0dNB56/+Pw9MHWjtAe7uHS/XDBtnVTvP7A8ResGWSlpZKVnkJmF4+5mWkU5qSSmZ5CVlp3jylkpr//OZlpKWSld/2YefgxLQUzi+IW/rAei93M/ghcABSZWTXwQyAdwDl3H7AYuASoAlqAG3orrIh4p7XDz4FDHezvVMz7W9q7nB/+cbQCTkuxbkswKz2F/H7pZOVlHr04w8r2g5/f/WNaivV6wXqlx2J3zs3qYbkDvhq1RCLSZwIBx679h6iqbWJLTRN7GluPFHTj4eI+1E7joQ5aOwLdfh0zyMtKJz87nQH9gh/DC7IZ0C+N/H4ZwXmhZfmh54fnZ6X3/h5ssvHsImAi0nfafQG21zVTVRMs8KraJqpqmtha28yhDv+R9bLSUxiYnXGknE8uzOaM7AHBQg6bn3+kpIPz+melkZKico4VKnaRBBIIOLbXNbP+vQNs2H2Ad0JFvqO+5QPDIcPy+3FKcS5TygoZVZx75KMgJ8PD9BItKnaRONXuC7B570Eq3zvA+vcaj5R5c3twDzwtxSgtymHMSf255PQhR8p75KAcsjP00k9k+umKxImG5nZe3rCXFdvqWf/eAd6pOXjksLucjFROG5LHlWeVMG7oAMYOzWPMSf3JSNN1/pKRil0khu1pbOXFyj0sWb+HZVvr8QccBTkZjBuax8fHjGTc0DzGDc2jtDBHY9xyhIpdJIa0+fy8Xd3I8m31vLxhL2++ux+AUwbl8OXzRzJ93BDGD8vTUSRyVCp2EQ81tflYvaOBldvrWbGtnrd27qfNFzyscPywPL5z0alcNO4kRhX39zipxBMVu0gf8vkDrNzewCsb9rJie3Cs3B9wpKYY44bmcc3Ukzm7tICzSwdSmJvpdVyJUyp2kV7W5vPzr6p9vLBuDy9vqKG+uZ2MtBQmDs/n3y44hbNLC5h08kByM/VylOjQb5JIL3DO8fqWOhas3MnSjTU0tfnon5nGJ08rZvq4wZx/6iAdcii9Rr9ZIlEUCDhe2VjDb5dWsWbnfgpyMvjsGUO4aNxgPnpKkQ4/lD6hYheJAp8/wF/f3s3vlm5h096DDC/ox08vH88Vk0rISk/1Op4kGRW7yAlo9wV4ZnU19766hR11LYwqzuUXV53BZWcMJS1Ve+fiDRW7yHF4t66Fxet284fXt7O7sZXxw/K49+pJXDRusE4UEs+p2EUi4Jxj894mXlgXPAu0cvcBACaXFvCzKybw8dFFOmlIYoaKXeQo2nx+5v69ir+s3c22fc2YwVkjBvKDz5zGReMGM7wg2+uIIh+iYhfpRmuHn1seXcWrm2s5b3QRs88r41NjT6K4f5bX0USOSsUu0oUddc38x5/WsnxbPT/7f6czc/IIryOJREzFLhJm056D3PuPKhateY+01BR+edWZfG7iMK9jiRwTFbsIsLvxED/883perNxLdkYqs88byexzyyjO07CLxB8VuyS9HXXNfPGB5exvaedrF47mho+WMlC3iJM4pmKXpFZV08TV85fR5guwYM45nF4ywOtIIidMxS5Jq/K9A1z74HLMjCfnnMOpg3XNc0kMKnZJOs45XtlQw7eeWkN2RiqPz57CyEG5XscSiRoVuyQNf8Dxwro9zF1aReXuA4wsyuEPN07WSUaScFTskhRe2bCXny7ewNbaZkYW5XDPlRP43MRhpOtCXZKAVOyS8J5eVc13n17DqOJc5n5xEtPHDyZVF+qSBKZil4T22LId/OC5dZw7qoh5XzpLdy2SpKDfcklYD7y2lZ8u3sC004r57Rcn6YYXkjRU7JKQDpf6Z04fwq9mnqmxdEkqKnZJSI8u28Hk0gJ+PfNM3clIkk5Ev/FmNt3MNplZlZnd3sXyEWa21MzeNLO1ZnZJ9KOKRMY5hz/gKBnYT6UuSanH33ozSwXmAhcDY4FZZja202o/ABY65yYCM4HfRTuoSCScc/zshY3s2n+IUSfppCNJTpHszkwGqpxzW51z7cACYEandRyQF3o+AHgvehFFIhMIOH60aD33v7qVa6aO4MsfP8XrSCKeiGSMfRiwM2y6GpjSaZ0fAS+a2b8DOcC0rr6Qmc0B5gCMGKEbF0j0+AOO7z2zloUV1dx8Xhn/eclpugepJK1oDUDOAh52zpUAlwCPmtmHvrZzbp5zrtw5Vz5o0KAofWtJdh3+AN988i0WVlTz9QtHq9Ql6UWyx74LGB42XRKaF+4mYDqAc+4NM8sCioCaaIQU6U6bz8+tT7zJS5V7uf3ij/Dl8zX8IhLJHvtKYLSZlZlZBsE3Rxd1Wudd4EIAMzsNyAJqoxlUpLND7X5ufmQVL1Xu5c4Z41TqIiE97rE753xmdiuwBEgFHnLOrTezO4EK59wi4FvAA2b2TYJvpF7vnHO9GVySW1ObjxsfXsnK7fXcfcUErjp7eM+fJJIkIjpByTm3GFjcad4dYc8rgY9FN5pI19p9Aa6Zv5y3dzXyqy+cyYwzdbNpkXA6e0PizsY9B3hr537uuHSsSl2kCyp2iTsd/uAo3/CCfh4nEYlNKnaJK3VNbfzXc+vISE2htDDH6zgiMUkXAZO4sfdAK1fPX87O+hYeuK5c9ykV6YaKXeJCdUMLV89fzr6DbfzhxslMHVnodSSRmKVil5hXc6CVq+57g6Y2H4/NnsLEEQO9jiQS01TsEvNee2cf7zW2smDOVJW6SAT05qnEPJ8/AMCwfB0FIxIJFbvEtE17DvLzFzdTlJtBUW6m13FE4oKKXWLWul2NzJz3BikGC+ZMpV+GbkYtEgkVu8SkVTsamPXAMrIz0lh4yzmMKu7vdSSRuKE3TyXmrNm5n2sfXE5x/0wev3mqxtZFjpGKXWLOX9/eTYc/wMJbzqE4L8vrOCJxR0MxEnM6/AHSU1NU6iLHScUuMWXpxhqeWP4upw3J63llEemSil1ixgvrdjPn0QpGn5TLA18q9zqOSNzSGLvEhGffrObbT61lQskAHr5hMgP6pXsdSSRuqdjFc4vWvMdtC9cwtayQ+deVk5OpX0uRE6FXkHjumdXVDB+Yze9vOJusdJ2EJHKiNMYunuvwBxiYk6FSF4kSFbt46qF/buNfVXWM1VEwIlGjYhfP/OH17dz5fCUXjx/Mjy8b53UckYShMXbxzFOrdjKhZAD/O2siaanaxxCJFr2axDM+v2NQbqZKXSTK9IqSPuec4+4XNrJxz0HGDdXYuki0aShG+pRzjjufr+T3/9rOrMkj+Ma0MV5HEkk4KnbpM/6A4/vPvs2ClTu58WNl/Nelp2FmXscSSTgqdukzP3guWOq3fmIU3/r0GJW6SC9RsUuf+dPqXVw+cRjfvuhUr6OIJDS9eSp9wucP4JzjJF1jXaTXqdil17X7Anz1idV0+B3jh+koGJHeFlGxm9l0M9tkZlVmdns361xlZpVmtt7MnohuTIlXzjm++sRqlqzfyx2XjuXSCUO9jiSS8HocYzezVGAu8CmgGlhpZoucc5Vh64wGvgd8zDnXYGbFvRVY4ktTm4+XKvdy83ll3HhumddxRJJCJHvsk4Eq59xW51w7sACY0Wmdm4G5zrkGAOdcTXRjSrxyoUeNrYv0nUiKfRiwM2y6OjQv3BhgjJn9y8yWmdn0rr6Qmc0xswozq6itrT2+xBJXVm1vAGBofj+Pk4gkj2i9eZoGjAYuAGYBD5hZfueVnHPznHPlzrnyQYMGRelbSyz7/evbKe6fybTTTvI6ikjSiKTYdwHDw6ZLQvPCVQOLnHMdzrltwGaCRS9JbOmmGl7bXMsNHysjI00HYIn0lUhebSuB0WZWZmYZwExgUad1niO4t46ZFREcmtkavZgSb9p9AX745/WMHJTDjeeWeh1HJKn0WOzOOR9wK7AE2AAsdM6tN7M7zeyy0GpLgDozqwSWAt9xztX1VmiJfbsbD/FufQuzzx1JZppueSfSlyK6pIBzbjGwuNO8O8KeO+C20IcI7b4AAFnpGoIR6Wt61UnU1TW18fUFb5GRmsJpupepSJ/TRcAkqg61+/nCvGXsrG/hgevKVewiHlCxS1RV7j5AVU0TP//8GZw/Roe0inhBQzESVYfH1gf1z/Q4iUjyUrFL1FQ3tHD7M2vJyUhlVHGu13FEkpaGYiQq6prauOq+N2hq8/HY7CkM0yUERDyjYpeoWLWjgfcaW3n4hrOZOGKg13FEkpqGYiQq2kJj60W5GlsX8ZqKXU7Ypj0H+fFfKinMyWB4QbbXcUSSnopdTsi7dS3MnPcGKQZP3jKVAf3SvY4kkvQ0xi4n5PUt+2ho6eAvt57LqOL+XscREbTHLifoyNh6/wyPk4jIYSp2OW6rdtTz8yWbGFGQTWGO3jQViRUqdjku63Y1cu2DKyjqn8mCOVN1Iw2RGKIxdjkur26upaXdzxM3T2HIAJ2MJBJLtJslx6Wtww+gIRiRGKRil2O2dGMN9722ldOHDSA91byOIyKdqNjlmLy+ZR9zHq3g1JP688iNkzFTsYvEGo2xyzF5qXIvqSnG4zdPIS9LJyOJxCLtscsxae0IkJ6aolIXiWEqdonYM6ureXLlu5x1sq7eKBLLVOwSkcVv7+ZbT61h6shCfnf1JK/jiMhRaIxdIrJk/R6KcjN56PqzyUpP9TqOiByF9tglIq0dfnIyUlXqInFAxS49mv9/W1myfi/lpQVeRxGRCKjY5ageXbaD//7rBi45fTD/c/npXscRkQhojF2Oasm6PYwqzuU3MyeSlqr9AJF4oFeqdMs5x6EOPwP6pavUReKIXq3SJeccd72wiVU7Gigv1XHrIvFExS5d+uXL73Dfq1u4esoI/uOij3gdR0SOgcbYpUsvrt/D5NIC/vtz43WhL5E4E9Eeu5lNN7NNZlZlZrcfZb0rzMyZWXn0IkpfCwQcbb4A+dnpKnWRONRjsZtZKjAXuBgYC8wys7FdrNcf+DqwPNohpe/4A47bn1nLtn3NTC7Tcesi8SiSPfbJQJVzbqtzrh1YAMzoYr2fAHcBrVHMJ33sjj+vY2FFNV+7cDQ3nVvmdRwROQ6RFPswYGfYdHVo3hFmNgkY7pz769G+kJnNMbMKM6uora095rDS+16q3MvF4wdz26fGaBhGJE6d8FExZpYC/AL4Vk/rOufmOefKnXPlgwYNOtFvLb0kP1vXWheJZ5EU+y5geNh0SWjeYf2B8cA/zGw7MBVYpDdQRUS8EUmxrwRGm1mZmWUAM4FFhxc65xqdc0XOuVLnXCmwDLjMOVfRK4ml13T4A3T4A17HEJET1GOxO+d8wK3AEmADsNA5t97M7jSzy3o7oPSN1g4/X3lsNQ0tHUwdWeh1HBE5ARGdoOScWwws7jTvjm7WveDEY0lf+9of3+TlDXv5yYxxzDhzWM+fICIxS5cUEABe2VjDNVNHcO05pV5HEZETpGIXDrX7CThHfr8Mr6OISBSo2JPcwdYOrntoBQacrTNNRRKCLgKWxAIBxw2/X8lbO/fz65kTOX+Mzi0QSQTaY09iB1o7qNjRwK2fHMVnzxjqdRwRiRIVexJrbvcDMKCfzjQVSSQq9iS1r6mNmx5eSUZqCmedrDskiSQSjbEnoTafn5nzllHd0MKD15czoSTf60giEkXaY09C1Q2HqKpp4vufGct5o/WGqUiiUbEnoZa24Nh6Xpb+YRNJRCr2JLOzvoV/e2IVuZlpnKEhGJGEpF22JNJ4qIOr7n+DlnY/j8+eQmlRjteRRKQXqNiTyKY9B9nd2Mq9V0/ijOH5XscRkV6ioZgk0tzmAyBPx62LJDQVe5LYsPsA33l6DYU5GZw6uL/XcUSkF6nYk8Cu/YeYOW8ZaSkpPHnLORTlZnodSUR6kcbYk8DanftpPNTB/OvKGVWc63UcEell2mNPAgdDY+v9ddy6SFJQsSe4ldvrufMvlQwv6MfJBTq8USQZqNgT2MY9B/jSgysozstk4S3n0C8j1etIItIHVOwJbOX2Bg51+HnwurMZMqCf13FEpI+o2BNYU2twbD03U2PrIslExZ6gXtmwl1++vJlxQ/MoyNFNqkWSiYo9Aa3YVs8tj67iI4P78/jsKaSmmNeRRKQP6X/0BPTGljp8AcejN03Rbe9EkpD22BPQwdYOAPprbF0kKanYE8zTq6p56F/bOHdUEaYRGJGkpGJPIC+s28O3n1rDR08pYt6XzsLU7CJJSf+rJ5DXt+wjNzON+deVk5Wuk5FEkpX22BPIwVYf6ammUhdJcir2BDHvtS08++YuLji12OsoIuKxiIrdzKab2SYzqzKz27tYfpuZVZrZWjN7xcxOjn5U6Ypzjl+//A7/s3gjn5kwhLuvnOB1JBHxWI/FbmapwFzgYmAsMMvMxnZa7U2g3Dk3AXgauDvaQeXDnHP87IWN/PLlzVwxqYTfzJxIeqr+CRNJdpG0wGSgyjm31TnXDiwAZoSv4Jxb6pxrCU0uA0qiG1O6MndpFfe/upVrp57MPVdO0BmmIgJEVuzDgJ1h09Whed25CfhbVwvMbI6ZVZhZRW1tbeQppUv/rNrH+GF53DljHCkqdREJier/7WZ2DVAO3NPVcufcPOdcuXOufNCgQdH81kknEHAcbPWRnZGm49VF5AMiKfZdwPCw6ZLQvA8ws2nA94HLnHNt0YknXfEHHN/901rWv3eAT35ER8GIyAdFUuwrgdFmVmZmGcBMYFH4CmY2EbifYKnXRD+mhLvzL+t5elU135g2mls+PtLrOCISY3osduecD7gVWAJsABY659ab2Z1mdllotXuAXOApM3vLzBZ18+UkCl7fUsf5YwbxjWljNAwjIh8S0SUFnHOLgcWd5t0R9nxalHNJNzr8AVra/eRk6uxSEemaDnqOI60dfr7y2Cp27T/EJ3SGqYh0QxcBixOH2v3MebSC/3tnHz+ZMY7Plw/v+ZNEJCmp2OPAwdYObnq4good9dxz5QSVuogclYo9Dtz8SAWr323g1zMn8tkzhnodR0RinMbY48CyrfXceG6ZSl1EIqJij3H7W9oB6KdrrItIhFTsMWxfUxsz5y0jIzWFj48p8jqOiMQJjbHHqD2NrVw9fxm79h/iwevLOevkAq8jiUicULHHoJ31LVw9fzn1ze08cuMUJpep1EUkcir2GNPS7uML979BU5uPx2ZP4czh+V5HEpE4ozH2GLOz/hDvNbZyx2fHqdRF5Lio2GNMQ+gomOwMHQUjIsdHxR5DNu45wK1PrKYwJ4NJIwZ6HUdE4pSKPUasrd7PzHnLSE0xnrzlHAYPyPI6kojEKb15GgNW7ajn+odWMiA7nSdmT2VEYbbXkUQkjqnYPfZ61T5u+kMFgwdk8fjsKQzN7+d1JBGJcyp2Dy3dWMMtj62irDCHR2dPpri/hl9E5MSp2D3yt7d387UFb/KRwXk8cuNkBuZkeB1JRBKE3jz1wLNvVvPVJ1YzoSSfx2+eolIXkajSHnsf++OKd/nPZ9/mnJGFzL+unOwM/QhEJLrUKn3owX9u4yfPV/KJUwdx7zVnkaVL8YpIL1Cx95Hf/v0dfv7iZqaPG8xvZk0kI02jYCLSO1Tsvcw5x89f3MTcpVu4fOIw7rlyAmmpKnUR6T0q9l5UVdPEr17ezPNrdzNr8gh++rnxpKSY17FEJMGp2HvBul2NzF1axQvr95CZlsI3po3m6xeOxkylLiK9T8UeRa0dfr791BqeX7ub/plpfPWCUdzwsVIKczO9jiYiSUTFHiXNbT5ufqSCN7bW8bULRzP7vDLystK9jiUiSUjFfoKcc7z2zj7ufmEjG3Yf4BdXncHlE0u8jiUiSUzFfpwCAceS9XuY+48q1u06wJABWdx3zVl8etxgr6OJSJJTsR+DvQdaWbGtnpXb63ltcy3b61ooLczmritO5/KJJTo2XURigoq9G845dtS1sGJbPSu2B8t8R10LELxt3aQRA7nt06fymdOHkKpDGEUkhkRU7GY2Hfg1kArMd879rNPyTOAR4CygDviCc257dKOeOOccTW0+Gpo7qGtuo6GlnfrmDuqb2z70uLPhELUH2wAYmJ1OeWkB1049mbNLCxg7NI90nWQkIjGqx2I3s1RgLvApoBpYaWaLnHOVYavdBDQ450aZ2UzgLuALvRE4XIc/QENzO/Ut7dQ3BR8bmtupaw57bGmnrin42NDcQbs/0OXXSk81CnIyGJidQWFuBueNKmLSyQOZUlbAKYNydWKRiMSNSPbYJwNVzrmtAGa2AJgBhBf7DOBHoedPA781M3POuShmBeDJle9y7z+2UNfczsFWX7fr5WWlUZibycDsdEoGZjOhZAADczIoDCvvgdkZFOZkMjAnndzMNJ1AJCIJIZJiHwbsDJuuBqZ0t45zzmdmjUAhsC98JTObA8wBGDFixHEFLszJZEJJ/pG964LcDAqyMyjIef8jPztdQyUikrT69M1T59w8YB5AeXn5ce3NTxt7EtPGnhTVXCIiiSSS3dpdwPCw6ZLQvC7XMbM0YADBN1FFRKSPRVLsK4HRZlZmZhnATGBRp3UWAdeFnl8J/L03xtdFRKRnPQ7FhMbMbwWWEDzc8SHn3HozuxOocM4tAh4EHjWzKqCeYPmLiIgHIhpjd84tBhZ3mndH2PNW4PPRjSYiIsdDh46IiCQYFbuISIJRsYuIJBgVu4hIgjGvjko0s1pghyffvGtFdDpTNkYpZ/TEQ0ZQzmiKh4xw9JwnO+cGHe2TPSv2WGNmFc65cq9z9EQ5oyceMoJyRlM8ZIQTz6mhGBGRBKNiFxFJMCr2983zOkCElDN64iEjKGc0xUNGOMGcGmMXEUkw2mMXEUkwKnYRkQSTlMVuZsPNbKmZVZrZejP7emh+gZm9ZGbvhB4HxkDWVDN708yeD02XmdlyM6sysydDl1L2OmO+mT1tZhvNbIOZnROj2/KboZ/3OjP7o5llxcL2NLOHzKzGzNaFzety+1nQb0J515rZJA8z3hP6ma81s2fNLD9s2fdCGTeZ2UV9kbG7nGHLvmVmzsyKQtOebMuj5TSzfw9t0/VmdnfY/GPbns65pPsAhgCTQs/7A5uBscDdwO2h+bcDd8VA1tuAJ4DnQ9MLgZmh5/cBX4mBjH8AZoeeZwD5sbYtCd6+cRvQL2w7Xh8L2xP4ODAJWBc2r8vtB1wC/A0wYCqw3MOMnwbSQs/vCss4FlgDZAJlwBYg1aucofnDCV56fAdQ5OW2PMr2/ATwMpAZmi4+3u3Zp7/AsfoB/Bn4FLAJGBKaNwTY5HGuEuAV4JPA86FfwH1hL6ZzgCUeZxwQKkzrND/WtuXh+/IWELxc9fPARbGyPYHSTi/yLrcfcD8wq6v1+jpjp2WXA4+Hnn8P+F7YsiXAOV5ty9C8p4EzgO1hxe7ZtuzmZ74QmNbFese8PZNyKCacmZUCE4HlwEnOud2hRXsAr2+u+ivgu0AgNF0I7HfO+ULT1QQLy0tlQC3w+9CQ0XwzyyHGtqVzbhfwc+BdYDfQCKwi9rbnYd1tv65uLh8LmW8kuPcLMZbRzGYAu5xzazotiqmcwBjgvNDQ4KtmdnZo/jHnTOpiN7Nc4E/AN5xzB8KXueCfRs+OBTWzS4Ea59wqrzJEKI3gv5T3OucmAs0Ehw6O8HpbAoTGqGcQ/EM0FMgBpnuZKVKxsP2Oxsy+D/iAx73O0pmZZQP/CdzR07oxII3gf5RTge8AC83MjucLJW2xm1k6wVJ/3Dn3TGj2XjMbElo+BKjxKh/wMeAyM9sOLCA4HPNrIN+CNwyHrm8s3teqgWrn3PLQ9NMEiz6WtiXANGCbc67WOdcBPENwG8fa9jysu+0Xyc3l+4yZXQ9cClwd+gMEsZXxFIJ/zNeEXkslwGozG0xs5YTga+kZF7SC4H/qRRxHzqQs9tBfwQeBDc65X4QtCr8p93UEx9494Zz7nnOuxDlXSvAesn93zl0NLCV4w3DwOCOAc24PsNPMTg3NuhCoJIa2Zci7wFQzyw79/A/njKntGaa77bcI+FLoiI6pQGPYkE2fMrPpBIcKL3POtYQtWgTMNLNMMysDRgMrvMjonHvbOVfsnCsNvZaqCR44sYcY2pYhzxF8AxUzG0PwQIR9HM/27Ks3CmLpAziX4L+2a4G3Qh+XEBzDfgV4h+C70wVeZw3lvYD3j4oZGfqhVgFPEXoH3eN8ZwIVoe35HDAwFrcl8GNgI7AOeJTgUQaeb0/gjwTH/TsIFs9N3W0/gm+gzyV4ZMTbQLmHGasIjv0efg3dF7b+90MZNwEXe7ktOy3fzvtvnnqyLY+yPTOAx0K/n6uBTx7v9tQlBUREEkxSDsWIiCQyFbuISIJRsYuIJBgVu4hIglGxi4gkGBW7iEiCUbGLiCSY/w8RGwvaVOothQAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "x_len = np.asarray(list(map(len, f_data)))\n", "l = []\n", "sq_xlen = pd.Series(x_len)\n", "ptiles = [x*0.01 for x in range(100)]\n", "for i in ptiles:\n", " l.append(sq_xlen.quantile(i))\n", "plt.plot(l, ptiles)\n", "sq_xlen.describe(percentiles=[x*0.01 for x in range(95,100)])" ] }, { "cell_type": "code", "execution_count": 8, "id": "1878d067", "metadata": {}, "outputs": [], "source": [ "thresh_p = 0.99\n", "thresh = int(sq_xlen.quantile(thresh_p))\n", "len_mask = np.where(x_len <= thresh)\n", "\n", "x_filter = f_data[len_mask]\n", "y_filter = y[len_mask]\n", "\n" ] }, { "cell_type": "code", "execution_count": 9, "id": "3a01c1ad", "metadata": {}, "outputs": [], "source": [ "lb = LabelBinarizer()\n", "a = [x.drop(labels='Millis', axis=1) for x in x_filter]\n", "x_filter = pad_sequences(x_filter, dtype=float, padding='post')\n", "yt_filter = lb.fit_transform(y_filter)" ] }, { "cell_type": "code", "execution_count": 10, "id": "634a024c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 34.7 ms, sys: 5.84 ms, total: 40.6 ms\n", "Wall time: 39.2 ms\n" ] } ], "source": [ "%%time\n", "x_train, x_test, y_train, y_test = train_test_split(x_filter, yt_filter, test_size=0.2, random_state=177013)\n" ] }, { "cell_type": "code", "execution_count": 11, "id": "0109b9b6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"sequential\"\n", "_________________________________________________________________\n", "Layer (type) Output Shape Param # \n", "=================================================================\n", "flatten (Flatten) (None, 2340) 0 \n", "_________________________________________________________________\n", "batch_normalization (BatchNo (None, 2340) 9360 \n", "_________________________________________________________________\n", "dense (Dense) (None, 2200) 5150200 \n", "_________________________________________________________________\n", "dense_1 (Dense) (None, 1100) 2421100 \n", "_________________________________________________________________\n", "dense_2 (Dense) (None, 550) 605550 \n", "_________________________________________________________________\n", "dense_3 (Dense) (None, 225) 123975 \n", "_________________________________________________________________\n", "dense_4 (Dense) (None, 26) 5876 \n", "=================================================================\n", "Total params: 8,316,061\n", "Trainable params: 8,311,381\n", "Non-trainable params: 4,680\n", "_________________________________________________________________\n" ] } ], "source": [ "model = Sequential()\n", "\n", "model.add(Flatten(input_shape=x_filter[0].shape))\n", "\n", "model.add(BatchNormalization())\n", "\n", "model.add(Dense(2200, activation='relu'))\n", "\n", "model.add(Dense(1100, activation='relu'))\n", "\n", "model.add(Dense(550, activation='relu'))\n", "\n", "model.add(Dense(225, activation='relu'))\n", "\n", "model.add(Dense(26, activation='softmax'))\n", "\n", "model.compile(\n", " optimizer=tf.keras.optimizers.Adam(0.001),\n", " loss=\"categorical_crossentropy\", \n", " metrics=[\"acc\"],\n", ")\n", "\n", "model.summary()" ] }, { "cell_type": "code", "execution_count": 12, "id": "204ed561", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/32\n", "82/82 [==============================] - 1s 3ms/step - loss: 2.8553 - acc: 0.1745\n", "Epoch 2/32\n", "82/82 [==============================] - 0s 3ms/step - loss: 1.7793 - acc: 0.4480\n", "Epoch 3/32\n", "82/82 [==============================] - 0s 3ms/step - loss: 1.2391 - acc: 0.6070\n", "Epoch 4/32\n", "82/82 [==============================] - 0s 3ms/step - loss: 0.9623 - acc: 0.7021\n", "Epoch 5/32\n", "82/82 [==============================] - 0s 3ms/step - loss: 0.8489 - acc: 0.7336\n", "Epoch 6/32\n", "82/82 [==============================] - 0s 3ms/step - loss: 0.5827 - acc: 0.8169\n", "Epoch 7/32\n", "82/82 [==============================] - 0s 3ms/step - loss: 0.5208 - acc: 0.8313\n", "Epoch 8/32\n", "82/82 [==============================] - 0s 3ms/step - loss: 0.5864 - acc: 0.8147\n", "Epoch 9/32\n", "82/82 [==============================] - 0s 3ms/step - loss: 0.4101 - acc: 0.8710\n", "Epoch 10/32\n", "82/82 [==============================] - 0s 3ms/step - loss: 0.2856 - acc: 0.9087\n", "Epoch 11/32\n", "82/82 [==============================] - 0s 3ms/step - loss: 0.2789 - acc: 0.9126\n", "Epoch 12/32\n", "82/82 [==============================] - 0s 3ms/step - loss: 0.3118 - acc: 0.9027\n", "Epoch 13/32\n", "82/82 [==============================] - 0s 3ms/step - loss: 0.3337 - acc: 0.9054\n", "Epoch 14/32\n", "82/82 [==============================] - 0s 3ms/step - loss: 0.3052 - acc: 0.9049\n", "Epoch 15/32\n", "82/82 [==============================] - 0s 3ms/step - loss: 0.2052 - acc: 0.9403\n", "Epoch 16/32\n", "82/82 [==============================] - 0s 3ms/step - loss: 0.4292 - acc: 0.8907\n", "Epoch 17/32\n", "82/82 [==============================] - 0s 3ms/step - loss: 0.1545 - acc: 0.9542\n", "Epoch 18/32\n", "82/82 [==============================] - 0s 3ms/step - loss: 0.1401 - acc: 0.9575\n", "Epoch 19/32\n", "82/82 [==============================] - 0s 3ms/step - loss: 0.1907 - acc: 0.9483\n", "Epoch 20/32\n", "82/82 [==============================] - 0s 3ms/step - loss: 0.2635 - acc: 0.9303\n", "Epoch 21/32\n", "82/82 [==============================] - 0s 3ms/step - loss: 0.1116 - acc: 0.9671\n", "Epoch 22/32\n", "82/82 [==============================] - 0s 3ms/step - loss: 0.2453 - acc: 0.9317\n", "Epoch 23/32\n", "82/82 [==============================] - 0s 3ms/step - loss: 0.1090 - acc: 0.9681\n", "Epoch 24/32\n", "82/82 [==============================] - 0s 3ms/step - loss: 0.1578 - acc: 0.9541\n", "Epoch 25/32\n", "82/82 [==============================] - 0s 3ms/step - loss: 0.1609 - acc: 0.9570\n", "Epoch 26/32\n", "82/82 [==============================] - 0s 3ms/step - loss: 0.0801 - acc: 0.9775\n", "Epoch 27/32\n", "82/82 [==============================] - 0s 3ms/step - loss: 0.1597 - acc: 0.9615\n", "Epoch 28/32\n", "82/82 [==============================] - 0s 3ms/step - loss: 0.0695 - acc: 0.9807\n", "Epoch 29/32\n", "82/82 [==============================] - 0s 3ms/step - loss: 0.0622 - acc: 0.9853\n", "Epoch 30/32\n", "82/82 [==============================] - 0s 3ms/step - loss: 0.0655 - acc: 0.9841\n", "Epoch 31/32\n", "82/82 [==============================] - 0s 3ms/step - loss: 0.0383 - acc: 0.9910\n", "Epoch 32/32\n", "82/82 [==============================] - 0s 3ms/step - loss: 0.0716 - acc: 0.9792\n", "CPU times: user 14 s, sys: 3.02 s, total: 17 s\n", "Wall time: 8.95 s\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%time\n", "model.fit(x_train, y_train, \n", " epochs=32,\n", " batch_size=128,\n", " shuffle=True,\n", " verbose=1\n", " )" ] }, { "cell_type": "code", "execution_count": 13, "id": "10a0d074", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Evaluate on test data\n", "82/82 [==============================] - 0s 2ms/step - loss: 1.7331 - acc: 0.7341\n", "test loss, test acc: [1.7330855131149292, 0.7341040372848511]\n", "Generate predictions for 3 samples\n", "predictions shape: (3, 26)\n" ] }, { "data": { "text/plain": [ "(array(['N', 'U', 'I'], dtype='