iui-group-l-name-zensiert/1-first-project/tdt/NNwAll.ipynb

616 lines
424 KiB
Plaintext
Raw Normal View History

2021-06-10 19:18:36 +02:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "f7afce31",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "4023ed0f",
"metadata": {},
"outputs": [],
"source": [
"delim = ';'\n",
"\n",
"base_path = '/opt/iui-datarelease1-sose2021/'\n",
"\n",
"Xpickle_file = './X.pickle'\n",
"\n",
"ypickle_file = './y.pickle'"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "1c43ce06",
"metadata": {},
"outputs": [],
"source": [
"THRESH = 0.1\n",
"LEEWAY = 1\n",
"EPOCH = 20\n",
"\n",
"DENSE_COUNT = 2\n",
"DENSE_NEURONS = 2400\n",
"\n",
"DENSE2_COUNT = 3\n",
"DENSE2_NEURONS = 600\n",
"\n",
"AVG_FROM = 10"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "384cdc94",
"metadata": {},
"outputs": [],
"source": [
"def plot_pd(data):\n",
" fig, axs = plt.subplots(4, 3, figsize=(3*3, 3*4))\n",
" t = data['Millis']\n",
" axs[0][0].plot(t, data['Acc1 X'])\n",
" axs[0][1].plot(t, data['Acc1 Y'])\n",
" axs[0][2].plot(t, data['Acc1 Z'])\n",
" axs[1][0].plot(t, data['Acc2 X'])\n",
" axs[1][1].plot(t, data['Acc2 Y'])\n",
" axs[1][2].plot(t, data['Acc2 Z'])\n",
" axs[2][0].plot(t, data['Gyro X'])\n",
" axs[2][1].plot(t, data['Gyro Y'])\n",
" axs[2][2].plot(t, data['Gyro Z'])\n",
" axs[3][0].plot(t, data['Mag X'])\n",
" axs[3][1].plot(t, data['Mag Y'])\n",
" axs[3][2].plot(t, data['Mag Z'])\n",
"\n",
" for a in axs:\n",
" for b in a:\n",
" b.plot(t, data['Force'])\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "28e58847",
"metadata": {},
"outputs": [],
"source": [
"def plot_data(data):\n",
" fig, axs = plt.subplots(4, 3, figsize=(3*3, 3*4))\n",
" axs[0][0].plot(data[1])\n",
" axs[0][1].plot(data[2])\n",
" axs[0][2].plot(data[3])\n",
" axs[1][0].plot(data[4])\n",
" axs[1][1].plot(data[5])\n",
" axs[1][2].plot(data[6])\n",
" axs[2][0].plot(data[7])\n",
" axs[2][1].plot(data[8])\n",
" axs[2][2].plot(data[9])\n",
" axs[3][0].plot(data[10])\n",
" axs[3][1].plot(data[11])\n",
" axs[3][2].plot(data[12])\n",
"\n",
" for a in axs:\n",
" for b in a:\n",
" b.plot(data[13])\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "dd1eea42",
"metadata": {},
"outputs": [],
"source": [
"import pickle\n",
"\n",
"def load_pickles():\n",
" _p = open(Xpickle_file, 'rb')\n",
" X = pickle.load(_p)\n",
" _p.close()\n",
" \n",
" _p = open(ypickle_file, 'rb')\n",
" y = pickle.load(_p)\n",
" _p.close()\n",
" \n",
" return (np.asarray(X, dtype=pd.DataFrame), np.asarray(y, dtype=str))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "132c38fa",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"def load_data():\n",
" if os.path.isfile(Xpickle_file) and os.path.isfile(ypickle_file):\n",
" return load_pickles()\n",
" data = []\n",
" label = []\n",
" for user in range(0, user_count):\n",
" user_path = base_path + str(user) + '/split_letters_csv/'\n",
" for file in os.listdir(user_path):\n",
" file_name = user_path + file\n",
" letter = ''.join(filter(lambda x: x.isalpha(), file))[0]\n",
" data.append(pd.read_csv(file_name, delim))\n",
" label.append(letter)\n",
" return (np.asarray(data, dtype=pd.DataFrame), np.asarray(label, dtype=str), np.asarray(file_name))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "494e249c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 2.78 s, sys: 205 ms, total: 2.98 s\n",
"Wall time: 2.98 s\n"
]
},
{
"data": {
"text/plain": [
"(13102,)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"X, y = load_data()\n",
"\n",
"X.shape"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "00400cb9",
"metadata": {},
"outputs": [],
"source": [
"def shorten(npList, thresh):\n",
" temp = npList['Force']\n",
" thresh = temp.max()*thresh\n",
" leeway = LEEWAY\n",
" \n",
" temps_over_T = np.where(temp > thresh)[0]\n",
" return npList[max(temps_over_T[0]-leeway,0):min(len(npList)-1,temps_over_T[-1]+leeway)].to_numpy()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "13fe2978",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 3.47 s, sys: 35.2 ms, total: 3.51 s\n",
"Wall time: 3.47 s\n"
]
}
],
"source": [
"%%time\n",
"\n",
"XX = np.array(list(map(shorten, X, [THRESH for _ in range(len(X))])),dtype=object)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "cbdc141a",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjMAAAKtCAYAAADSPu/xAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOydd3hb1f3/X0db3iOx4zh7b0IICXvvvWcZhUJpobsU+HXRwbd0LyiFAmWUEvaeCRCgkEASErKHM53EsR3voa3z++PcK8u25C1bls/refRYOrrjSL66930/U0gp0Wg0Go1GoxmsWAZ6AhqNRqPRaDS9QYsZjUaj0Wg0gxotZjQajUaj0QxqtJjRaDQajUYzqNFiRqPRaDQazaBGixmNRqPRaDSDmoSKGSHEaCHEB0KIjUKIDUKI7xjjdwsh9gkh1hiPs6LWuUsIUSKE2CKEOD1q/AxjrEQIcWci563RaDQajWbwIBJZZ0YIUQQUSSm/EEJkAquAC4DLgEYp5R/aLD8DeBpYAIwElgBTjLe3AqcCe4EVwJVSyo3x9j1s2DA5bty4Pv08mqHLqlWrDkoph/fnPvUxrOlLBuIYBn0ca/qOjo5hWyJ3LKUsA8qM5w1CiE1AcQernA8sklL6gJ1CiBKUsAEokVLuABBCLDKWjStmxo0bx8qVK/vgU2g0IITY3d/71Mewpi8ZiGMY9HGs6Ts6Oob7LWZGCDEOOBT4zBi6TQixVgjxqBAi1xgrBkqjVttrjMUb12g0Go1GM8TpFzEjhMgAXgC+K6WsBx4AJgJzUZabP/bRfm4WQqwUQqysrKzsi01qNBqNRqNJchIuZoQQdpSQeUpK+SKAlLJcShmSUoaBf9HiStoHjI5afZQxFm+8FVLKh6SU86WU84cP73fXsEaj0Wg0mgEg0dlMAngE2CSl/FPUeFHUYhcC643nrwJXCCGcQojxwGTgc1TA72QhxHghhAO4wlhWo9FoNBrNECehAcDA0cA1wDohxBpj7P8BVwoh5gIS2AV8HUBKuUEI8SwqsDcI3CqlDAEIIW4D3gGswKNSyg0JnnvXCPph61uw+U1IHwbDJkPBDPVwZgz07FKDgAc+exC2vg3Dp8HIuSCsEPJD3gQongeu7IGepWYAWFNay2Of7ORPl83FYhEDPR1NAvh0+0Fe+3I/v7lozkBPRZPEJDqb6X9ArDPMmx2scw9wT4zxNztar9+p2Axr/gNfLoKmSkjLB38zBD3GAgJGHgqzLoLZl0Fm4YBOd9AQDsO+VbD5dajZpcZKP4eG/VA4G9a/CKv+3WYlAfmTYMRsyBkDNhc40iB9OOSMhbFHgdAXulTkk5KDvLxmPz85ZwbDMpwDPR1NAvhwayVPf17Kry+YjVULVk0cEm2ZST28dfDGD2Ddc2CxwZQzYN51MOlkQEDdHijfAGVrlSXh3Z/Ap/fBzUshq6izrQ9tdn4Er38PqkrUd5s3ARAwfApc/C8Yd4wSO3V71LjFBpWblfgp+xL2rVQiKORvvd0xR8GZ90LRIQPxqTQJxBsIAVDvCWgxk6J4/ep/7A2ESHfqS5YmNvrI6A77VsFzX4W6vXDcj2DBzZDRJtA4d5x6TDsbTrxLrfPYufDstXD962BL4RNu00Fw5YC1m4dV0Adv3g5fPK6+uwv+CVPPAHdu+2UtFrWMSXaxISSjkBL8jcpitv0D+OAeePB4uPQxmHlB9+amSWo8xoWuzhMY4JloEoU3EDb+ajGjiY8+MrrKuufh5W9CRgF89S0Ys7Br6xUfBhfcD89dry7Y5/518Ls8vPWw8WVlfXLnKDGy9W04uBUsdsgbr/4GPZBeAEVzYNThMPFkSM9vvS1/MzzzFdj+Hhz1bTjhLuUi6g1CgDNTPfImwKyL4alL1f9v2BQonNG77WuSBo9pmfEGB3gmmkRh/o+9wfAAz0STzGgx0xlBHyy9F/73J+WuuPw/7S/InTHzQuUG+d+fIX8iHP2dxMw1UQS8UL0D9iyDHR/AtiVKqDgyIdAECOUCmns1eGqUmwhU7Er9PljzX/j8IbXciFmQNxGyisFihd2fKuvVeffBvGsSM393Dlz2BDx0PCy6Cm7+ILbVRzPo8ES5mTSpielKNP9qNLHQYiYe4bC6cL91B1Rtg0O/Amf/qeduopN+pgJaF/9MXchnX9Kn000INbvhmavhwHpU4hmQNQrmXqmES/FhyqUTDnT8vYRDULYGtr4Le1dA+XrYthhkWFlhLn448d9HVhFc9iQ8dha8fRdc+M/E7k/TL/gMF4R2M6UuHi1mNF1Ai5m2+Jthyd3KjdJYrrJhrn4BJp/Su+1aLCoWpLECXroFMkcoa0ayUrsHHj9HuZROuEu5a0bOVVlD0W4yIcDSicCzWJXwKT4soVPulDELlSvrf39S4jSZv39Nl2hxM2kxk6r4omJmNJp49FtvpkHDW7crl8johXDRv+Cby3svZEzsLrjiKSUMFl0FlVv6Zrt9TUM5PH4ueOrg2pfhhDtgzqWqhs5gj/c57nbIHqMy0kL6AjjYMQOA6z06ZiZV8QZNy4yOmdHER4uZaNY8Dav/A8f+AC5/EuZc1vtg1La4c+Hq58DqhKcugcYk6yEV8CrXUmMFXPOSqpWTSjjS4KzfqZTu5Q8M9Gw0vcS0zGg3U+ri8XfRzRQOwa5PYP0LKg7PU5v4yQ01/M2w6nFY8TDsXakye6u2q79SDujUtJvJpGIzvPF9GHuMcqskktyxcNUz8O+z4PmvwjUvdz+dORFICa9/V8W1XPYEjBpgt1CimHomTD4NPvqDcjel5Q30jDQ9xKvdTClPh5aZcFjVl9rwkno0lLV+P71AZS9OPAmmnqWsy5quEfSra0HpcmXF9jXAl09Dc1Xs5d25UDgLCqZD/mRVAd/qUOU60vLUuN2dsOkmwRU0CfA3wXPXgd0IRu0PYVE8D875M7x8C7z/Kzj1F4nfZ2d89qA6WE+4C2acP9CzSSyn3A0PHA0f/xFOv4fS0lKuvfZaysvLEUJw8803853vfIfq6mouv/xygFlCiMXAZVLKGqPv2F+Bs4Bm4Hop5RcAQojrgJ8Ye/q1lPJxY/ww4DHAjapm/R0pB/h2ZpDj1dlMKY/Hr0SMp61lpmSJKndRvUNdNCedCrMvVi1PanaprMrKrbB/tUq8WPwzOP03cOQ3+/9DDCa89Squ8POHwd8Q9YZQN4HHfA+yR6mkjuZqlbXqq4cD61Ryx5r/qjpfbUkvgGO+C/NvSIio0WJGShU/UblFuVX6s0rv3Cth7+fwyV9g9AJVaG+g2LtKVSueepYqCJjqFM6EuVep+KiFX8dms/PHP/6RefPm0dDQwGGHHcapp57KY489xsknn8ySJUvWA+8BdwJ3AGeiGqFOBhYCDwALhRB5wM+B+agUsFVCiFellDXGMjcBn6HEzBnAW/38yVMKnZqd+vjaZjN5apWIWfesqht14YPqvOXKalmpcGbrjdSWwtt3wjt3qdYzh1zeP5Pvb6RUIs5bp4qLpuUbsYGya5m4G15W18Pmg6o+18wLYfxx4DS+2+iYyZzR8efQWAGBZlXaxFunLGYrH4F3/p8SO9e/oUpm9CFazKx5Slkjjr8TJp7Y//s/417l333lNpXtkzmi/+fgqYHnr4fMIrjgHyrzaihwwl2qGOL791B00YMUFSkhm5mZyfTp09m3bx+vvPIKS5cu5a677gJ4HFiKEjPnA08YlpXlQogcoxv8CcBiKWU1gGHNOUMIsRTIklIuN8afAC5Ai5leEQkA1kXzUpYWN1MISlfA8zeoXm3H3wnHfr9rF+mc0XDJo/Cfi+GVb6rGtFPPSPDME0z1Dtj0Gmx9B8JBJTgqN0NdadRCgkhZjZwxqgHymCNg3HGqTYwzU70XDsOHv4UP74Xi+Squs3hez+YlROxehDMvgC1vqWr4T1+hjAd9aKEZ2mKmcotS+OOOheMHyBphc6qsqQePg1duhauf7/+MoQ/+D+r3ww3vDq1icjmj4YhvKMvY4V+D0YcDsGvXLlavXs3ChQspLy+PiBzgAGD+SouB6LPGXmOso/G9McY1vcCsCqsDgFOTYCh
"text/plain": [
"<Figure size 648x864 with 12 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAiMAAAKrCAYAAAAjyY0EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOydd3hb53m37xcAwQEC3HuIIrUly5IsW97b2YkzmjSrGU3qps1o+nV8SduvadIkTZs0adLsPZvRTCdxYjtesjwka1lbIkWKew8QBEhivd8fLw4IkhgHgwTHua9LF8gD8OAVeXDOc57n9/weIaXEwMDAwMDAwCBbmLK9AAMDAwMDA4P1jRGMGBgYGBgYGGQVIxgxMDAwMDAwyCpGMGJgYGBgYGCQVYxgxMDAwMDAwCCrWLK9gHiUl5fLpqambC/DYA1w7NixESllxXK/r3EMG2SSbBzHxjFskEliHcMrOhhpamri6NGj2V6GwRpACNGZjfc1jmGDTJKN49g4hg0ySaxj2CjTGBgYGBgYGGQVIxgxMDAwMDAwyCpGMGJgYGBgYGCQVYxgxMDAwMDAwCCrrOtgZHByhrd/6whjbm+2l2JgYLDSOPI1ePLT2V6FQQinx8fbv3WEnnFPtpdisASs62DkWOc4j10c5qm2kWwvxcDAYKXx7Bfh5A+yvQqDEE9fHuGxi8M82z6W7aUYLAHrOhgZDWVEzvdPZnklBgYGKwr3KIy1w9RwtldiEOJ0rxNQGW2Dtce6DkbGQ8HIOSMYMTAwiKQ35Ksx6wSfcfFbCZzpU+fpAafx91iLrOtgRNOKnOszghEDA4MIep6b+9ptZEeyjZSSM6HMSL8RjKxJjGAEGHLNMuyazfJqDAwMVgzzgpGh7K3DAIA+50z4fG2UadYm6zoYGfd4MZsEYOhGDAwMQgQD0HMMavao76eMYCTbaFmRrVV2IzOyRlnXwciY28uehmLA0I0YGBiEGL4IXhdse5n63ghGss6ZXicmAbdvq2DUPYsvEMz2kgwyzLoPRjaW26grzjd0I+scIcQ3hRBDQogzMZ4XQojPCSHahBCnhBD7lnuNBsuEVqLZ9hL1aJRpss6ZXiebK+1sLLMhpSqtG6wt1m0wIqVkzO2l1GZle43DKNMYfBt4UZznXwxsDv27D/jSMqzJIBv0PAf5JVC5A3IdRntvlpFScrp3kl11RVQV5QEw4JzO8qoMMs26DUamfQFm/UFKCqzsqHVweXiKGV8g28syyBJSyoNAPDele4HvSsWzQLEQomZ5Vpc+D54d4BO/u0AwKLO9lJVPz1GovxaEAFuFkRnJMoOTs4xMzbKrzkFNOBgxMiNrjXUbjGjK7FJbDjtqHAQlXBxwZXlVBiuYOqA74vue0LYVzbjby3t/eII//94xvvzEZdpHprK9pJXNjBOGL6hgBKCwMjOZEa8HJrrS3886RBOvXlVXRLUjFIwYHTUZ5VDrCG1D2b3+GcGILZedtQ7AELEaZAYhxH1CiKNCiKPDw9lL8f/+zAD3fOYJfn+mnz/e3wDAWUMbFZ/e44CcC0ZsFTA1mP5+n/5v+PItEDSEl8lyuteJELCj1kFRfg65FpNRpskw7//xCf7zoUtZXYMRjNhyqC/Jx55rMUSsmUSuuXJAL9AQ8X19aNsipJRflVLul1Lur6ioWJbFRTLu9vJXPzrBu75/jEp7Hve/52Y++qpdWM0m4xhPRM9RQEBdSJ9cWJWZMs3YZZiZAM9o+vtaZ5ztc9JSUUiB1YIQgpqiPAYmjTJNpnDP+hmZ8tIx4s7qOtZtMDLuUcFISYEVIQTbax1GZiRTBIPw+Wvh8FezvZJMcj/wllBXzfWAU0rZn+1FLeShswPc85mD/PZUP3999xZ+9Z6b2F7jIMdsYkt1oZEZSUTPEajYBnlF6vvCSlW68ad58XOFDhVXX3r7WYec7nVyVV1R+PsqR56RGckg3aEpyB0j7qxqytZtMDI6pYKRMlsuADtqHFzonzQEfplg8DSMtkLvsWyvRDdCiB8CzwBbhRA9Qoh3CCHeJYR4V+glDwDtQBvwNeAvs7TUqEx4vLz/Rye473vHqLDn8qv33MRf3b2ZHPPcR3xnTRHn+ieRKyxr9csTveE5UVlFStVJU79/bpstlNlK1xLeNaAeJ41gJBmGXDMMTs6GS+lAKDNiaEYyRfeYCuxm/UH6s/h7tWTtnbOM5r5qz1O/gh01DtzeAF1jHprKbVle3Sqn46B6XEV3gVLKNyR4XgLvXqblJMVjF4b4+5+dUuWZuzbz7js2YbUsvs/YWefgx0e7GZicoaYoPwsrXcyVETfv//FJ/vYFW3jPnZuzu5ixdpgen9OLgMqMgDI+K6pPfd9GMJISZ3tVJm9eZqQoj0HnLFJKhBDZWtrqwDcNI5eg5uqYL+ke84S/vjLipq44O+eGdZsZGXP7KCnIwRSyg99hiFgzhxaMGCfeJcc14+PPv3eM0gIrv3z3Tfz1PVuiBiKgAm6YO8GvBJ7vmQDgwkroZNPMziKDEVsoGEknMzI7BbOh37lrxVX2VjSnQ500OyOCkWpHHt5AMKz7M4jDU5+Fr92ljsEYdI970GK69izqRtZtMDLu9lJSYA1/v6myEItJGAK/dAn4oPNp9fVk/1oUsq4oLg648AaC/P2LtrIr4oQdje01DoRYWQH36R51sVkRbfU9z4HVDhVb57YVhso06XTURP6sEaAnxZleJ83lNgpz55L4Ya8Ro1STmPYnIOiLGwR3j02zubKQ/BwzHcNGMLLsaO6rGnk5ZjZVFq6oE/WqpO8EeKfU3aXPPXdHaLAkaBmFrdX2hK+15VrYWGbjbJ9zqZelm1OhO9/2ETez/vRNB49eGZuXdk6KnudUF43JPLfNFlGmSZXwhUAYwUiSnOl1Lgqyq0JeI8b03gT4pqH3qPo6TjDSM+6hsbSApnIbHVn0IVq/wYhnfjACKo29njMjM75A+uLGjifU41WvU4+TRlp6Kbk44MKea9Fd591e61gxHTWBoORsr5MKey6BoOTyUHp3ZcGg5O3ffo7PPdKa/A97PTBwBhqum7/dWqCyJemUaTS9SPkWo0yTBKNTs/Q5Z+bpRQCqQ5kRY3pvAnqOQiBUytKOwQVIKeke81BfUsDG8gKujKYYyGeAZQtGEg0iW27G3V5KFgYjtQ4GJmfWZS1ycHKGqz/8EIfaRtLbUcdBqL4Kqnaq71eRiHU1cmFgkq3Vdt1Cvp21DnrGp3FO+5Z4ZYnpGJnC7Q3w6r3KyPbiYHpBUve4B9eMP7X0fd8JkIH5ehGNworMZEbq9hmZkSQ4Ewqad9Y55m2vKMzFJGDQCEbi0/nU3NcxguBxjw+3N0BDaQEby210jXmyNhF5OTMj3yb+ILJlIxiUjHu8lC0IRraHBH7rcWjeia4JZv1BLvSnUbv3zUDXYdh4GzhCY1uMk++SIaXkwoBLV4lGY2etustcCRnAUyG9yCv21JJjFmmLWLWMz8hUCjcTmni1bv/i52yV6WdGcgqUFmV2Mq6Y0GAOzQZeO2Y1LGYTFfZcIzOSiCuHoHo3WAtjZka6QiXNhpJ8NpYXEgjK1MucabJswYiOQWTLhnPaR1AyT8AKc8HISjhRLzeaVmbIlcYHvOcIBGZh461gr1XbMlGmmZ5QxlMG8+h3zuCa8bMtiWAk3FGzAnQjp3qc5OeY2VbtoKWiMG0Rq/a5HU5lvHzPc1DaDLayxc9lIjNirwZH3dz3Bgk53eNkQ1kBRfk5i56rdhheI3Hxz6pjuulmdezFOOa0wEPLjABZc2JdcZqR5ZjrMebRrODnByOlNis1RXnrUsSqZYNSOpFrdBwEYYbGGyAnD/JLM1Om+d+3wi9XlMfYikC7eG+rcSR45RwV9lwq7bkr4hg/3etkV50Ds0mwrdqedjCiBVhj7lkCyZgXhs3OopRoIJQZSScYGQB7jfoHMBl1isCysdJK5rE407dYvKpRXZRnCFjj0XsM/DOw4SZ13MXIjGjuq0YwEoXlmOsxN5fGuui59Spi1f7PQ+kGI3X7IC90cXTUZiYzMnQBhs6nv581xvkB9TfbUqU/MwJKN5L
"text/plain": [
"<Figure size 648x864 with 12 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"from sklearn.preprocessing import RobustScaler, MaxAbsScaler, MinMaxScaler, StandardScaler\n",
"scaler = StandardScaler()\n",
"data = XX[0]\n",
"plot_data(data.T)\n",
"\n",
"plot_data(scaler.fit_transform(data))"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "8c8dde89",
"metadata": {},
"outputs": [],
"source": [
"XXX = np.array(list(map(scaler.fit_transform, XX)),dtype=object)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "f26acb5e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"count 13102.000000\n",
"mean 52.640666\n",
"std 30.235132\n",
"min 6.000000\n",
"50% 49.000000\n",
"90% 79.000000\n",
"91% 81.000000\n",
"92% 83.000000\n",
"93% 87.000000\n",
"94% 90.000000\n",
"95% 93.000000\n",
"96% 97.000000\n",
"97% 105.000000\n",
"98% 117.000000\n",
"99% 143.990000\n",
"max 1512.000000\n",
"dtype: float64"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAgNUlEQVR4nO3deXxU9b3/8dcn+wIkgQQMhLAJClLZIovW5Ra8gu0VrVZRcSkUtLda7S5tH72tbW9vbW977aNWRUpRW1FU2h91KVUvKipbAEF2YtjXBBIICdlmvr8/ZuAOmJgBJjkzk/fz8chj5izJvD1m3jl85yzmnENERGJfgtcBREQkMlToIiJxQoUuIhInVOgiInFChS4iEieSvHrh3Nxc17t3b69eXkQkJq1cubLcOZfX1DLPCr13794UFxd79fIiIjHJzHY0t0xDLiIicUKFLiISJ1ToIiJxosVCN7PZZnbQzNY1s9zM7HdmVmJma81seORjiohIS8LZQ58DjP+U5ROA/sGv6cDj5x5LRETOVIuF7px7Fzj8KatMBJ5xAUuBbDPLj1RAEREJTyTG0HsAu0KmdwfnfYKZTTezYjMrLisri8BLi4jICW16HLpzbiYwE6CoqEjX7RWRmOSco67RT22Dr9nH2gY/dY2nPp5YPvbCrgzpmR3xXJEo9D1Az5DpguA8EZFW5Zyj3uc/WZp1DWderice60LXP/GzQh5rGwLr1Db6qW/0n1Purh1To7bQFwD3mdnzwCjgiHNuXwR+rojEkQafn8qaBipr6jlcXU/Fiec19RyrbQyU7GlFGijf5veC6xr9nMs9elISE0hNTiAtOZHUpFMf05IT6JSeTFpyAqlJiScfU5MTSPuUx7Tkpub93/emJiVgZpHbsCFaLHQzmwtcBeSa2W7gP4BkAOfcE8BrwLVACVADfLlVkopI1Kht8AVLuZ7KmgYqauqpCJb0iXmHq+uprAnOq66nqq6x2Z+XnGgnCzC1iSLskJl0SqmmhZRw6ikl/MliDi3oU4o5KYGEhNYpVq+0WOjOuVtbWO6Ar0UskYh4oqq2gX1HatlbeZx9R2rZf6SWw9WBPejKmnoqqoPFXVNPbUPzQw4dUpPIyUwmJyOF7IwU+uRmkp2RQk5GCp0zk08+P7FOTkYK6SmJbfhfGr88uziXiLSd2gbfyaI+8bjvyHH2VgYe91XWfmIP2gyy0k+UbjL5WWkMzO90SimfXtDZ6SmkJOkEdK+o0EXiRKPPz6b9VazcUUHJwWOnFHZFTcMn1u+SmUJ+dhq9umQypm8X8rPTyc9Ko3vwsVunNJITVc6xRIUuEqMqa+pZvbOSlTsqWLmjgg93VXK8wQdAp7Skk8U8tDCb7llp5Gelk5+dRvesdM7LSiMtWcMc8UaFLhIjjtQ08PaWg7xfUs7KHRV8XFYNQGKCcVH3TtxySU9G9MphRK8cumene5xWvKBCF4li28qreWvjAd7ceIAV2yvw+R3ZGcmMKMzhi8MLGNErh4sLsshI0VtZVOgiUcXnd6zaWcGbGwIlfmIv/IJuHbnnir6MHdiNoT2zSYyzw+0kMlToIh6rqm1g8dZy3txwgEWbD1JR00ByojGqTxcmj+7FuIHd6Nk5w+uYEgNU6CJtrK7Rx/q9R1m1o4J3tpSxtPQQDb7AUMrnLujK2IHduGJALh3Tkr2OKjFGhS7Sysqq6li1s4JVwaNR1u45cvJaIP3yMplyWR/GDuzG8MJsknSYoJwDFbpIhPn8jjc2HGDh+v2s2lnBjkM1QOC6IZ8pyOKuMb0Y0SuH4YU5dO2U5nFaiScqdJEIOV7v46WVu/jje9vYfqiG3A4pFPXqzORRvRjeK4fBPTqRmqRjv6X1qNBFzlFZVR3PLtnOs0t3UFHTwJCe2fxh/IVcc9F5OhpF2pQKXeQslRw8xh/fK+XlVXto8PkZe2E37rmyL0W9clrt8qgin0aFLnIGnHMs23aYp94t5a1NB0lNSuCmEQVM/Wwf+uV18DqetHMqdJEwNPr8vL5uP08tLmXt7iN0zkzhwXH9uWN0L7p0SPU6ngigQhf5VMfqGpm3IvBB557K4/TJzeTnNwzmxuEFuriVRB0VukgT9h+pZc4H2/nLsh1U1TZySe8cfnzdRYy9sGvc3eVG4ocKXSTEpv1HeerdbSxYswef3zFhcD5fubwPwwpzvI4m0iIVurR7zjneKyln5rulLN5aTnpyIreP6sWUy/pQ2EXXUJHYoUKXdu2dLWX84rWNbNpfRW6HVL5zzQXcPqqQ7IwUr6OJnDEVurRbzy/fyff/+hG9u2TyyI0XM3FYd53JKTFNhS7tjnOOxxaV8Ot/buGKAXk8fvtwMlP1VpDYp99iaVd8fsfDf1/P00t2cMOwHjxy08W6EbLEDRW6tBt1jT6+OW8Nr67dx7TL+zBjwkAdgihxRYUu7UJVbQP3PLuSDz4+xPevvZDpV/TzOpJIxKnQJe4drKrl7tkr2HKgit/cPIQvDi/wOpJIq1ChS1zbXl7NHbOXUV5Vz6y7irjqgq5eRxJpNSp0iVvr9hzhrtnL8TvHc9NG6WxPiXsqdIlL720t555ni8nOSOGZqSN1aVtpF1ToEnf+vmYv35z3If3yOvD0lJF00307pZ1QoUtcmfP+Nn7yygYu6dWZp+4qIis92etIIm0mrDMqzGy8mW02sxIze6iJ5YVmtsjMVpvZWjO7NvJRRZrnnOORf2zix3/fwL8O6sYzU0eqzKXdaXEP3cwSgceAq4HdwAozW+Cc2xCy2g+Bec65x81sEPAa0LsV8op8QqPPz/f/+hHzindz68hCfnb9YN2cWdqlcIZcRgIlzrlSADN7HpgIhBa6AzoFn2cBeyMZUqQ5x+t93D93FW9uPMgDY/vz4Lj+ukGztFvhFHoPYFfI9G5g1Gnr/Bj4p5ndD2QC45r6QWY2HZgOUFhYeKZZRU5RWVPPlDkrWL2rkp9eP5g7RvfyOpKIpyJ1VaJbgTnOuQLgWuBZM/vEz3bOzXTOFTnnivLy8iL00tIeHalp4KYnlrBuz1Eeu224ylyE8Ap9D9AzZLogOC/UVGAegHNuCZAG5EYioEhTlpQeouTgMR6dNJRrP5PvdRyRqBBOoa8A+ptZHzNLASYBC05bZycwFsDMBhIo9LJIBhUJ1ej3A9A7N9PjJCLRo8VCd841AvcBC4GNBI5mWW9mD5vZdcHVvgVMM7M1wFzgbueca63Q0r5tOVDFz17ZSMe0JPKzdNKQyAlhnVjknHuNwKGIofN+FPJ8A3BZZKOJfFLx9sNMmbOC1OREXpg+Rvf+FAmhM0UlZry54QBfe24V3bPTeWbKSHp2zvA6kkhUUaFLTJhXvIsZ8z9icPdOzL77Erp0SPU6kkjUUaFLVHPO8Ye3P+ZXCzdzef9cnpg8Qjd0FmmG3hkS1X61cDN/ePtjrh/anUduGkJKkm7oLNIcFbpEtT8v3cG4gV35zc1DdUNnkRZod0eilt/v8PkdPTtnqMxFwqBCl6hU1+jj/udXU13vY+B5nVr+BhHRkItEH7/fMXVOMe+VlPP9ay/k5kt6tvxNIqI9dIk+Ow/X8F5JOd8YN4DpV/TzOo5IzFChS9Rp8AWu09Kri04cEjkTKnSJKmVVdTz4wockJRjnd+3gdRyRmKIxdIka28uruXP2csqq6ph1VxGDe2R5HUkkpqjQJSp8tPsIX56zHJ/f8dy0UQwrzPE6kkjMUaGL5z7cVcntTy0lOyOFZ6aOpF+ehlpEzoYKXTz3ypq9NPgc8//9Urp10vXNRc6WPhQVz9X7/KQkJajMRc6RCl08tWDNXuYu38mg7jobVORcqdDFM396fxtfn7uaYYU5PHVnkddxRGKextClzTnneGThZh5/+2Ouuagbj04aRlpyotexRGKeCl3a3I8XrOfpJTu4bVQhP504mERdSVEkIlTo0uZeKN7FFy7O5+fXD8ZMZS4SKRpDlzbV4PPj90OPnHSVuUiEqdClzRyv93HPsyup9/kZ3F2n9YtEmoZcpE00+vzcPmspH+6q5Oc3DObfhnT3OpJI3NEeurSJj8uqWbWzkhkTBnL7qF5
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"X_len = np.asarray(list(map(len, XXX)))\n",
"l = []\n",
"sq_xlen = pd.Series(X_len)\n",
"ptiles = [x*0.01 for x in range(100)]\n",
"for i in ptiles:\n",
" l.append(sq_xlen.quantile(i))\n",
"plt.plot(l, ptiles)\n",
"sq_xlen.describe(percentiles=[x*0.01 for x in range(90,100)])"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "0e63825e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"((12970,), (63, 15))"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"threshold_p = 0.99\n",
"len_mask = np.where(np.asarray(list(map(len, XX))) <= int(pd.Series(np.asarray(list(map(len, XXX)))).quantile(threshold_p)))\n",
"\n",
"X_filter = XX[len_mask]\n",
"y_filter = y[len_mask]\n",
"\n",
"X_filter.shape, X_filter[0].shape"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "8424e150",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjMAAAKtCAYAAADSPu/xAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOydd5xcZb3/39+p27Mt2fRKAiQhBAhNypVeVAIiCHoVlCsW+F2Ve1XU68WGYsHCFVEUFFCJqCBRKYYunQAhpBDSSdlsNtnepp3n98dzzsyZ2Znd2d2Z3Z3d5/16zWtmntOebM458z2fbxOlFAaDwWAwGAyFimekJ2AwGAwGg8EwFIwxYzAYDAaDoaAxxozBYDAYDIaCxhgzBoPBYDAYChpjzBgMBoPBYChojDFjMBgMBoOhoMmrMSMiM0TkSRHZICLrReSz9vjXRWSPiKyxX+e7tvmyiGwRkU0ico5r/Fx7bIuIXJ/PeRsMBoPBYCgcJJ91ZkRkCjBFKfWaiJQDrwIXApcCHUqpH6asvxC4FzgOmAo8BiywF78NnAXsBl4BLldKbch07NraWjV79uyc/nsM45dXX331gFJq4nAe05zDhlwyEucwmPPYkDv6Ood9+TywUqoeqLc/t4vIRmBaH5ssB1YopULAdhHZgjZsALYopbYBiMgKe92Mxszs2bNZvXp1Dv4VBgOIyM7hPqY5hw25ZCTOYTDnsSF39HUOD1vMjIjMBo4CXrKHrhWRtSJyp4hU2WPTgF2uzXbbY5nGDQaDwWAwjHOGxZgRkTLgL8DnlFJtwG3APGApWrm5OUfHuVpEVovI6sbGxlzs0mAwGAwGwygn78aMiPjRhszvlVL3AyilGpRSMaWUBfyKhCtpDzDDtfl0eyzTeBJKqduVUsuUUssmThx217DBYDAYDIYRIN/ZTALcAWxUSv3INT7FtdpFwDr780rgMhEJisgcYD7wMjrgd76IzBGRAHCZva7BYDAYDIZxTr6VmZOAjwCnp6Rhf19E3hSRtcBpwOcBlFLrgfvQgb2PANfYCk4UuBZ4FNgI3Gevmx92vwp/vQbeeghi0bwdZlC89Eu493JY/RvY9hSYrueGYaCxPcR/3PUKrV2RkZ6KYYzy0Jv1fP+Rt0Z6GoYCJd/ZTM8CkmbRQ31scyNwY5rxh/raLmdsehj+9DGI9sCa30HlTPjAb2H6MXk/dN/zegTW/B42rgSPHzbZf4q6xXDo+VA9Bw57LxRVjOw8DWOSdXtbeWzjft7a18bxc2tGejqGMcjjG/fz7JZGvnjuYSM9FUMBYioAu9n+L1jxIZh0GFy3AT74ez1+z0UQ7hy5eb1yB9z7Qdj1MixcDl/YDJ9dCxfeppWZZ34Af/00/O5iiJknZ0PuicW0AhiOWSM8E8NYJWZZxCyjNBsGR16VmYIiFoGH/hsmzIAr/g7BMqiYCsVV8Nvz4a1/wJJL83f8aAg8PvB4k8cPboVHvwrzzoDL7wVfUI8XV0HVLFj6IbBisPY++Oun4IVb4eTP5W+ehnFJ1NJGTDhqjBlDfohaiqgxZgyDxCgzDq/dDY1vwbk3aUPGYeaJ2tX0xr25P6Zl6RiYh74AN82EHx0OPzsWHvkydB7U6zzzAxAPLL81Ycik4vHC0sth5rv0PE0cjSHHOD8yIWPMGPJEzFJxBdBgGChGmQEId2mjYcYJcOh5ycs8HlhyGfzrh9C2V6s1uaB5pzZiNj+qjZW5p0FxJfS0wcu3w9o/wuHvgzf/BMd+Aiqm9LtLjrgY/vFf0LAeJi/OzTwNBojL/0aZMeQLo8wYhoIxZgBWfQ3a6+EDvwFJE6985GXwzPe1YXHSZ4d+vEe/Ci/8DHzFcP4PYdnHk91LDevhiW/DuvuhpBZOvCa7/S68SBtIG1caY8aQU6IxY8wY8kvMUiZmxjBoxrcxoxT84zpYfSeceC3MOjH9ejXzYPqx8MaKoRszrbvhxdt0IO/ZN0LljN7r1C3S8TEDpbQGph6tU7ZP+8rQ5mkwuHBiZkImANiQJyIxK36eGQwDZXzHzNSv0YbMsZ+AM7/R97pLPgj7N8C+dX2v1x+v3AEoOOtb6Q2ZoTL33bB7tXZXGQw5ImrcTIY8E7MUlgLLqDOGQTC+jZkND+oMotO+At5+RKpF79frrv3j4I9nWXr7Q87UmUj5YO67QcVgx7P52b9hXGJiZgz5xjGYY+4Eht2vwq5XRl9SQywCL90O77wEzTtG3/zGIePXzaSUNmZmnwIl1f2vX1qjjZA3/wxnfr13CnU27HoJ2vbo7fPFjOPAVwQ7/gWHnZ+/4xjGFRETM2PIM47BHLMUfqLw+Dfg+Vv0wgkz4ZTr4Jgr08c1DhfRsC5euulhnbzhMGkhHPXvMPEwmHR47hJFDFkzfpWZhvXQtE3HrmTLkkuhfS/sfG5wx1x/vzY0UjOmcokvqON7jDIzIHbt2sVpp53GwoULWbRoET/96U8BaGpq4qyzzgJYLCKrRKQKdN8xEblFRLaIyFoROdrZl4hcISKb7dcVrvFj7DYeW+xtR/CuPDBiTsxMNDbCMzGMVRxlJhqz4M9XakNm2VVw4S9gwjT4++d0KYuRItQBd5yl57H1cTjtq/D+X8N539f33Ue/Ar97P/xkCWx7euTmOU4Zv8bMxpU6Jfqw92a/zYLzIFA2OFdTLArrH4AF50CwfODbD4RZJ8G+N6G7Jb/HGUP4fD5uvvlmNmzYwIsvvsitt97Khg0buOmmmzjjjDNAN0N9HLje3uQ8dCPU+cDVwG0AIlIN3AAcj+4Gf4NjANnrfMK13bnD868bOiZmxpBvHIPZs/5+2Pg3OON/4b0/0jW0rnxIl8546Tbtrs8XoXYdFxmL6GKkbv7+Odi3Fi65C752AP7ti7DkEjj+k3D1U/D59XqeNYfAfR/Vxo9h2Bi/xsyGlbrIXNnE7LcJlMBh79HVgAfaNmDns9DZCIsvHth2g2H2SYCCd17M/7HGCFOmTOHoo7W4Ul5ezuGHH86ePXt48MEHueKKuLhyF3Ch/Xk5cLfSvAhU2t3gzwFWKaWalFLNwCrgXHtZhVLqRaWUAu527WvUY9oZGPJNNKaYQAdFj39FZ2We9LnEQo8HjvuEjk/Z8ljuD77mXrjzPPi/ZfCLk+BbtfD9ufDPr0F3s76XvvknOPWLsOjC9K6uCdP1vfeCW6CnBdb9JffzNGRkfBozjW9D48aBuZgcDr9An9wDdeO89Q/wl8D8swd+zIEy/VjwBrQBZRgwO3bs4PXXX+f444+noaGBKVPiBQv3AXX252nALtdmu+2xvsZ3pxkvCCJGmTHkmVjM4hv+3yLdzdogSI1LPPwCKJ0Er/w6NwdUSqvld1+oW8F0HYC6hfDeH8O/XQ/zToPn/w9+cAj87gP62Cf9Z//7nX6sjqF57a7czNOQFeMzAHjjg/r98AG4mBzmna6Nko1/0yd7ttS/AVOWgr944MccKP5imHYM7BhkbM84pqOjg4svvpif/OQnVFQkdyBXSikRyWvagohcjXZbMXPmzHweakDETG8mQz6Jhriq57dc6H2ethO+SMXkI3qv4wvoAOBnfgBN26F6ztCO+cYKbcSUToJzvgPHf6q3AXXKm9rgaW+Axe+HQGn/+xWBIy/XxVhb3tHtcAx5p6CUGRE5V0Q22QGU1/e/RQbWPwjTjxtcxHmgRGc1vfX37H23lqX9sOku0Hwx6yRtQIXah++YBU4kEuHiiy/mwx/+MO9///sBqKuro76+HgDbVbTfXn0P4C4UNN0e62t8eprxJJRStyullimllk2cOAAXaJ6J92YybiZDPnjsG1wWeYAV0XfTuqyPwqTLPqZjHV/9zdCOF4tqo6juCPivt3SV9XQZqpOP0LE7F94Kh5yR/f4Pe49+3/TI0OZpyJqCMWZExAvcig68XAhcLiILB7yjPa9Bw5tD64C9cDl0NMDuV7Jbv3k7RDqH15iZfbKpNzMAlFJcddVVHH744Vx33XXx8QsuuIC77orLxVcAtqzHSuCjdlbTCUCrUqoeeBQ4W0Sq7MDfs4FH7WVtInKCncX0Ude+Rj0xk5ptyBfRELzxBx7
"text/plain": [
"<Figure size 648x864 with 12 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"from tensorflow.keras.preprocessing.sequence import pad_sequences\n",
"\n",
"X_filter = pad_sequences(X_filter, dtype=float, padding='post')\n",
"\n",
"plot_data(X_filter[0].T)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "a5180e6e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(12970, 143, 14)"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_filter = np.array([np.delete(x, 0, 1) for x in X_filter])\n",
"X_filter.shape"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "9d836c9d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(10376, 143, 14)\n",
"(2594, 143, 14)\n",
"(10376, 26)\n",
"(2594, 26)\n"
]
}
],
"source": [
"from sklearn.model_selection import train_test_split\n",
"from sklearn.preprocessing import LabelEncoder, LabelBinarizer\n",
"import tensorflow as tf\n",
"\n",
"lb = LabelBinarizer()\n",
"\n",
"yt_filter = lb.fit_transform(y_filter)\n",
"\n",
"X_train, X_test, y_train, y_test = train_test_split(X_filter, yt_filter, test_size=0.2, random_state=177013)\n",
"\n",
"print(X_train.shape)\n",
"print(X_test.shape)\n",
"print(y_train.shape)\n",
"print(y_test.shape)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "cd7fad65",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using [0] NVIDIA A40 \n"
]
}
],
"source": [
"import os\n",
"\n",
"# By users on card, could do by memory, but too lazy\n",
"def get_least_used_gpu():\n",
" a = [x.split('|') for x in os.popen('gpustat').read().split('\\n')[1:-1]]\n",
" lu = min([x[-1].split() for x in a], key=len)\n",
" lui = [i for i in range(len(a)) if a[i][3] == ''.join(lu)][0]\n",
" print(f'Using {a[lui][0]}')\n",
" return lui\n",
" \n",
"lu_gpu = get_least_used_gpu()"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "84402e45",
"metadata": {},
"outputs": [],
"source": [
"# FIRST CELL: set these variables to limit GPU usage.\n",
"os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' # this is required\n",
"os.environ['CUDA_VISIBLE_DEVICES'] = f'{lu_gpu}' # set to '0' for GPU0, '1' for GPU1 or '2' for GPU2. Check \"gpustat\" in a terminal."
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "04b1528e",
"metadata": {},
"outputs": [],
"source": [
"accs = []"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "dd490b18",
"metadata": {},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"from tensorflow.keras.models import Sequential\n",
"from tensorflow.keras.layers import Dense, Flatten, BatchNormalization, Dropout\n",
"from tqdm import tqdm\n",
"\n",
"\n",
"def build_model():\n",
" model = Sequential()\n",
" \n",
" model.add(Flatten(input_shape=X_filter[0].shape))\n",
"\n",
" for i in range(DENSE_COUNT):\n",
" model.add(Dense(DENSE_NEURONS, activation='relu'))\n",
" \n",
" for i in range(DENSE2_COUNT):\n",
" model.add(Dense(DENSE2_NEURONS, activation='relu'))\n",
" \n",
" model.add(Dense(26, activation='softmax'))\n",
"\n",
" model.compile(\n",
" optimizer=tf.keras.optimizers.Adam(0.001),\n",
" loss=\"categorical_crossentropy\", \n",
" metrics=[\"acc\"],\n",
" )\n",
"\n",
" return model\n",
"# model.summary()\n"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "e9975558",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 10/10 [00:55<00:00, 5.51s/it]\n"
]
}
],
"source": [
"for i in tqdm(range(AVG_FROM)):\n",
" model = build_model()\n",
" \n",
" model.fit(X_train, y_train, \n",
" epochs=EPOCH,\n",
" batch_size=128,\n",
" shuffle=True,\n",
" validation_data=(X_test, y_test),\n",
" verbose=0,\n",
" )\n",
" # Evaluate the model on the test data using `evaluate`\n",
"# print(\"Evaluate on test data\")\n",
" results = model.evaluate(X_test, y_test, batch_size=128, verbose=0)\n",
"# print(\"test loss, test acc:\", results)\n",
" accs.append((model,results[1]))"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "8c2a1e54",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.172552040964365"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.mean(np.delete(accs,0,1).astype('float64'))"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "6dc4d27d",
"metadata": {},
"outputs": [],
"source": [
"exit()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "57553c1f",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}