iui-group-l-name-zensiert/1-first-project/tdt/NNwAll.ipynb

708 lines
391 KiB
Plaintext
Raw Normal View History

2021-06-10 19:18:36 +02:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "e6f9c155",
2021-06-10 19:18:36 +02:00
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "f688c561",
2021-06-10 19:18:36 +02:00
"metadata": {},
"outputs": [],
"source": [
"delim = ';'\n",
"\n",
"base_path = '/opt/iui-datarelease1-sose2021/'\n",
"\n",
"Xpickle_file = '../X2.pickle'\n",
2021-06-10 19:18:36 +02:00
"\n",
"ypickle_file = '../y2.pickle'"
2021-06-10 19:18:36 +02:00
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "24a30600",
2021-06-10 19:18:36 +02:00
"metadata": {},
"outputs": [],
"source": [
"THRESH = 0.1\n",
"LEEWAY = 1\n",
"EPOCH = 20\n",
"\n",
"DENSE_COUNT = 2\n",
"DENSE_NEURONS = 1000\n",
2021-06-10 19:18:36 +02:00
"\n",
"DENSE2_COUNT = 2\n",
2021-06-10 19:18:36 +02:00
"DENSE2_NEURONS = 600\n",
"\n",
"AVG_FROM = 10"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "767b5a06",
2021-06-10 19:18:36 +02:00
"metadata": {},
"outputs": [],
"source": [
"def plot_pd(data):\n",
" fig, axs = plt.subplots(4, 3, figsize=(3*3, 3*4))\n",
" t = data['Millis']\n",
" axs[0][0].plot(t, data['Acc1 X'])\n",
" axs[0][1].plot(t, data['Acc1 Y'])\n",
" axs[0][2].plot(t, data['Acc1 Z'])\n",
" axs[1][0].plot(t, data['Acc2 X'])\n",
" axs[1][1].plot(t, data['Acc2 Y'])\n",
" axs[1][2].plot(t, data['Acc2 Z'])\n",
" axs[2][0].plot(t, data['Gyro X'])\n",
" axs[2][1].plot(t, data['Gyro Y'])\n",
" axs[2][2].plot(t, data['Gyro Z'])\n",
" axs[3][0].plot(t, data['Mag X'])\n",
" axs[3][1].plot(t, data['Mag Y'])\n",
" axs[3][2].plot(t, data['Mag Z'])\n",
"\n",
" for a in axs:\n",
" for b in a:\n",
" b.plot(t, data['Force'])\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "f50de79f",
2021-06-10 19:18:36 +02:00
"metadata": {},
"outputs": [],
"source": [
"def plot_data(data):\n",
" fig, axs = plt.subplots(4, 3, figsize=(3*3, 3*4))\n",
" axs[0][0].plot(data[1])\n",
" axs[0][1].plot(data[2])\n",
" axs[0][2].plot(data[3])\n",
" axs[1][0].plot(data[4])\n",
" axs[1][1].plot(data[5])\n",
" axs[1][2].plot(data[6])\n",
" axs[2][0].plot(data[7])\n",
" axs[2][1].plot(data[8])\n",
" axs[2][2].plot(data[9])\n",
" axs[3][0].plot(data[10])\n",
" axs[3][1].plot(data[11])\n",
" axs[3][2].plot(data[12])\n",
"\n",
" for a in axs:\n",
" for b in a:\n",
" b.plot(data[13])\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "db8d52e9",
2021-06-10 19:18:36 +02:00
"metadata": {},
"outputs": [],
"source": [
"import pickle\n",
"\n",
"def load_pickles():\n",
" _p = open(Xpickle_file, 'rb')\n",
" X = pickle.load(_p)\n",
" _p.close()\n",
" \n",
" _p = open(ypickle_file, 'rb')\n",
" y = pickle.load(_p)\n",
" _p.close()\n",
" \n",
" return (np.asarray(X, dtype=pd.DataFrame), np.asarray(y, dtype=str))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "dff0b4cc",
2021-06-10 19:18:36 +02:00
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"def load_data():\n",
" if os.path.isfile(Xpickle_file) and os.path.isfile(ypickle_file):\n",
" return load_pickles()\n",
" data = []\n",
" label = []\n",
" for user in range(0, user_count):\n",
" user_path = base_path + str(user) + '/split_letters_csv/'\n",
" for file in os.listdir(user_path):\n",
" file_name = user_path + file\n",
" letter = ''.join(filter(lambda x: x.isalpha(), file))[0]\n",
" data.append(pd.read_csv(file_name, delim))\n",
" label.append(letter)\n",
" return (np.asarray(data, dtype=pd.DataFrame), np.asarray(label, dtype=str), np.asarray(file_name))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "acd83b6d",
2021-06-10 19:18:36 +02:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 5.38 s, sys: 398 ms, total: 5.78 s\n",
"Wall time: 5.78 s\n"
2021-06-10 19:18:36 +02:00
]
},
{
"data": {
"text/plain": [
"(26179,)"
2021-06-10 19:18:36 +02:00
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"X, y = load_data()\n",
"\n",
"X.shape"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "6c69c01d",
2021-06-10 19:18:36 +02:00
"metadata": {},
"outputs": [],
"source": [
"def shorten(npList, thresh):\n",
" temp = npList['Force']\n",
" thresh = temp.max()*thresh\n",
" leeway = LEEWAY\n",
" \n",
" temps_over_T = np.where(temp > thresh)[0]\n",
" return npList[max(temps_over_T[0]-leeway,0):min(len(npList)-1,temps_over_T[-1]+leeway)].to_numpy()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "9c4b1131",
2021-06-10 19:18:36 +02:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 6.94 s, sys: 72.8 ms, total: 7.01 s\n",
"Wall time: 6.98 s\n"
2021-06-10 19:18:36 +02:00
]
}
],
"source": [
"%%time\n",
"\n",
"XX = np.array(list(map(shorten, X, [THRESH for _ in range(len(X))])),dtype=object)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "e60bd183",
2021-06-10 19:18:36 +02:00
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjMAAAKuCAYAAABUqp1fAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOydd5xcVfn/32fq9p5sNrtppPdAliT0HgIooUuTAFEUwS+CfhUsPxFFQb+IhSIoTRRClwCBmIQaCCGB9L7pZXtv08/vj3Pv7GxvszM7u+f9es1rZ87cuffM7C3PfcrnEVJKNBqNRqPRaGIVS7QnoNFoNBqNRtMbtDGj0Wg0Go0mptHGjEaj0Wg0mphGGzMajUaj0WhiGm3MaDQajUajiWm0MaPRaDQajSam6XNjRghxQAixRQixUQix3hjLEEKsEELsMf6mG+NCCPEXIUSBEGKzEOKEkPUsMpbfI4RY1Nfz1mg0moGCPg9rBjqR8sycJaWcJaXMN17fDaySUo4HVhmvAS4AxhuPW4DHQR10wC+BucAc4JfmgafR9BYhxAghxAdCiO1CiG1CiDuM8XuFEEeNC8BGIcSFIZ+5xzjZ7xJCnB8yvsAYKxBC3B0yPkYIsdYYf0kI4Yjst9Ro9HlYM3ARfS2aJ4Q4AORLKctCxnYBZ0opC4UQOcCHUsqJQognjOcvhi5nPqSU3zHGmy3XFllZWXL06NF986U0Awqv14vX6yUhIQG/38+OHTsYO3YslZWVWCwWhg0bxpdfflkmpRwCIISYAryIOqEPB1YCE4zV7QbOA44A64BrpJTbhRAvA69LKZcIIf4GbJJSPt7RvPQ+rAkXW7ZswePxlEsps8wxfR7WxBqh5+GW2CKwfQn8VwghgSeklE8C2VLKQuP9IiDbeJ4LHA757BFjrL3xdhk9ejTr168Pw/Q1g42FCxdy++238+mnn5KUlMSPfvQjhBAHQxcBlkgp3cB+IUQByrABKJBS7gMQQiwBFgohdgBnA9cayzwH3Itxx9seeh/WhIsxY8Zw4MCBeCHEl+jzsCZGaXEebkYkwkynSilPQLkubxNCnB76plSuobC4h4QQtwgh1gsh1peWloZjlZpBxoEDB9iwYQNz584F4JFHHmHGjBkAo0Nc6t092WcCVVJKX4txjSYirF69GmAH+jysGaD0uTEjpTxq/C0B3kDdwRYbbk2MvyXG4keBESEfzzPG2htvua0npZT5Usr8IUPa9ERpNO1SV1fH5Zdfzp/+9CdSUlK49dZb2bt3Lxs3bgTwAg/19Rz0hUDTF+TmKttZn4c1A5U+NWaEEIlCiGTzOTAf2AosBcxM+EXAm8bzpcANRjb9PKDacIMuB+YLIdKNu+P5xphGExa8Xi+XX3451113HZdddhkA2dnZWK1WLBYLQClNoaTunuzLgTQhhK3FeCv0hUATburr66mtrQX0eVgzcOnrnJls4A0hhLmtF6SU7wkh1gEvCyEWAweBq4zllwEXAgVAA3ATgJSyQgjxa1RCJcB9UsqKPp67ZpAgpWTx4sVMnjyZu+66KzheWFhITk6O+TIN+Mx4vhR4QQjxR1QC8HjgC0AA44UQY1DGytXAtVJKKYT4ALgCWELzC4dG06cUFxdz6aWXAkxB7af6PKwZcPR5NVO0yM/Pl32VeCal5MevbubK/BHMGZPRJ9vQRI7Vq1dz2mmnMX36dNMLw29/+1tefPFFNm7ciBCCLVu2VAOTzYRJIcTPgJsBH/ADKeW7xviFwJ8AK/C0lPJ+Y/w4lCGTAWwArjcSiNulL/fhwcwTH+0lKc7GdXNHRXsqEUUI8WVIWXZE0Ptw7FFY3cjP39jKn66eRXKcPdrTaUZH+3AkqpkGHI1eP698eYThafHamBkAnHrqqbRl1F94YVBWBiFEQUjlB4aRcn/Lz0gpl6HubFuO76MpTKWJIm9sOEp6gmPQGTMaTVdYu6+CVTtL2F1cx+xRsSMjpNsZ9IA6typK8QUCUZ6JRqPpLi6vn3qPr/MFNZpBSHm9B1DHSSyhPTM9oN6t/sk+/8AM0Wk0A5lGrx+LyuPTaDQtKK9T0e9GjzZmBjx1LnVX59XGjEYTczR6/AzQVEGNpteU1ynPTKP2zAx8dJhJo4ldGr1+/AFtzWg0bVFeb3hmtDEz8Kl3m54ZbcxoNLGE1x/A65d4/X4CAYnFosNNGk0oZXWxmTOjE4B7gJk8qMNMGk1sEXqC1knAGk1rgp6ZGMuZ0cZMDwiGmbRnRqOJKUJd52Yiv0ajaSJWc2a0MdMDggnAOu6u0cQULk/TDUid2xvFmWg0/Y8Gj48GwyOjPTODgHrtmdFoYpIGb1NoqU57ZjSaZpheGdCemUFBndaZ0WhiktC7TfOmRKPRKEzBPNCemUFBsJpJh5k0mpgi9G6z1qWNGY0mFFMwD2LPM6NLs3tAnVnN5NNhJo0mlmhWzaQ9MwMLbyMUbYXCjVC2B6RxfrY6wJEIcakwch4MPwEs+j6+LcwwU2q8velYkRJqjkHZbhAW9VsGH0lgi1OPKP+m2pjpAWYCsBbN02hii8aQBGBdmh3j1BZDwUo4sFoZMKW7QBoXYEcyWI2Oz34PeOqaPpeQBVMvgTPvgcSsSM+6/xLw4y3dzUWWzznFWcLI4hL4R50yYlxVnX/eFg9JQyApG1LzIGMsZI6FEXMh4zjo4xYi2pjpAU2ieTrMpNHEEjrMFON4G2HLq/DlM3D0SzWWOASGHw+TLoKcWTB8FqTkNr94BgLQUA77PoTd78GXz6r1nPP/YPaNYLFG/Kv0G4q3w8Z/w+aXuK6+lOscEHBZKLUMAftEmHopZE+FIROVZ8bToIxDbwN46tVfr0uN1ZVAXREc2wjblzYZl8nDYeICOOUOSB/dJ19DGzM9QLcz0Ghik8YQb4wOM8UQNYWw7u+w/hlorIChU+DsX8CE8yF7Wud3/RaL8hrMuFI9Sv4Xlv0I3rlLGUULH+1zz0G/QEqoL4OKfVCwAna+AyXbwWKHiQt4sXIyK6uG4ciZzIEqP+8uOq3n2/J7oXwvHPwU9n8MG/4FX/0TZl4NZ/0cUnLC973QxkyPMN3TraqZGiqgsbLptTNFxWltjgjOTqPRtIfpmbFbhTZmYoH6MvjvL2DLKxDwKe/L3O/C6FN7Z3wMnQSL3oL3fwOf/J8Kg5z+o/DNuz8gJexapgzAuiJorIL6UvC51PvCAiNPhgUPwvQrIDGLZU+tpS7FxwhHPC5vde+2b7Wr33noJDhxscq7Wf0n5RXb+Q58/c8wZWEvv2QT2pjpAWbOTKK3Ara9AQc+VdZnyfa2P5A2EqZfCTOuhiETIjjTGKWxCkp2wJF1cOQLFRt3Jjc94lLVX3sCOBIgcaiK0aaNgsTMaM9e048xc2YyE53UamOmf1O6G164Unll8m+Ged9VRke4EALO/jlUHYT3fw0ZY2Da5eFbfzTZ/wn89+cqlyhtJAydCtnT1fkxdaQ6X46Y0ypnqKzOQ25aPPF2a/hLs1OGw4W/hznfhte/DS/fAMd/Ey56CGzOXq9eGzPdpeow13tf40LHp0yuPwyvAPZEGDlXHQhpI9VyAb+KIbqq4OAaWP0wfPKQis8ueADs8VH8Ev2IgB8OfQ6H1sDhtVC0BWoLm95PH61+U1cVVB0Cdy24a1Scti0ShygX9Ig5MPFCFUsfDO5jTZdo9PpxWC2kxtu1Z6Y/s+8jePmbqhLppmWQl9832xECLn4Eqg7Da9+GglVw2g9V4mqssv8T+NdlkJwDCx+DGd8Aa9cu9eV1bmbmpRJnt9LQVwnyWeNh8Qr48Hfqmlh9GL7xL3WD2gu0MdNVSnfBqvuQO9/hh1bJusAEHrXdwG03LoKcmU2Z8+1RWwRrHoHP/gpHvoSrnovtA6a3+DyweYlyO1bsVWNZE2HMGcotOWQy5J4ASUPb/rzfZySeNUBdMVQfgcoDyjtWtBU++SN8/AdIyVPu4xMWdVg6ePjwYW644QaKi4sRQnDLLbdwxx13UFFRwTe+8Q2AaUKIFcBVUspKIYQA/gxcCDQAN0opvwIQQiwCfm6s+jdSyueM8dnAs0A8sAy4Q0q
2021-06-10 19:18:36 +02:00
"text/plain": [
"<Figure size 648x864 with 12 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAhoAAAKrCAYAAAC3LyT2AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOyddXhb59n/P0cyM8jMjh3bscPYNFhMymu3dV07+HVrRx3v7ejd3jF1zOvWrdsKW2nt2qaQUpImaZic2I6ZmdmC8/vjkWTZFlqSZSnnc125FB9J5zxxjo7uc9/f+3tLsiyjoKCgoKCgoOANVL5egIKCgoKCgkLgogQaCgoKCgoKCl5DCTQUFBQUFBQUvIYSaCgoKCgoKCh4DSXQUFBQUFBQUPAaQb44qEajkXNzc31xaIUA5MSJEz2yLCct5DGVc1jBkyjnsIK/Y+8c9kmgkZuby/Hjx31xaIUARJKkxoU+pnIOK3gS5RxW8HfsncNK6URBQUFBQUHBayiBhoKCgoKCgoLXUAINBQUFBQUFBa+hBBoKCgoKCgoKXkMJNPyYmq5hPvLwMUYmdb5eioKCV3m9opMv/Pu0r5ehoDBvzrcNcu8/jtM1POHrpSw4SqDhx7x6oZPXK7s4cLHb10tRUPAaeoPMd1+4wDOnWukbnfL1chQChJquESZ1+gU5VufQBB95+DivXujkmZOtC3LMxYQSaPgxVR3DAByo6fHxShQUvMfL5R009I4BUN8z4uPVKAQCfaNT7P7Vfh4+2OD1Y41P6bnnH8cZmtCSkxjBi2fbvX7MxYYSaPgxpkDj7Wol0FAITGRZ5o/7aomPCAagtnvUxytSCARONvaj1cscre/z6nEMBpkvPnmac62D/Op9q/nAphzOtQ7S0HNpncdKoOGnaPUGartHSIgMoalvjCbjHZ+CQiBxsKaXc62DfPGaIoLVEnVKoKHgAU409QNwqnkAWZa9dpxfvHaRPec6+OruYq5elsLu5WkAvHju0spqKIGGn1LfM4pWL3PXxmwADtQoOg2FwOOP+2pJjg7lPesyyUmMpK5bKZ0ouM+JRhFo9I1O0dTnnZu0Z0+18ps3anjvukzu2ZoPQEZcOGuy4y658onbgYYkSVmSJL0pSdIFSZLOS5L0WU8sTME+prLJrrI00mPDlPKJQsBxrmWQt2t6+MiWPEKD1ORrIqm7xFLOCp5HqzdwpnmAy/ITATjVNODxY5xo7OP+p86yMS+B792yHEmSzM9dvyKdC+1Dl1TQ7ImMhg74oizLy4BNwKckSVrmgf0q2KGqYxi1SmJJciSXF2g4VNuL3uC9FKCCwkLzx321RIcF8X5j1i4/KYrG3lF0eoOPV6bgz1xoG2JSZ+B9G7KICFFzylhG8RTNfWPc+48TpMWF8ce71hISNPNr9rrlqQDsuYTKJ24HGrIst8uyfNL492GgAshwd78K9qnsGCZfE0lokJothRoGx7Wcax309bIUFDxCfc8oe8rb+cCmHKLDhBA0PykSrV6mpX/cx6tT8GdOGgOLDXkJrMyM41TzgMf2PTyh5aN/P86U3sBDH1pPfGTInNekxYazLieeFy6h8olHNRqSJOUCq4EjVp67V5Kk45IkHe/uVvQE7nKxc5ilqdEAXF6gAeDtauX3qhAYPLi/jmC1iv93eZ55W74mEhBBiILCfDnR2E96bBhpseGszo7jQtsQE1r3/TT0BpnPPH6Kmu4R/nDnWgqSo2y+9voVaVR2DFPTdWmUTzwWaEiSFAU8DXxOluWh2c/LsvygLMvrZFlel5RkdWS9gpOMTupo6hujOEUEGpqoUJalxXBA0WkoBABdQxM8faKF96zNJCk61Lw9P0lcuGvt1baPPAjNR729RAU/5mRjP2ty4gFYkx2PziBT7oFs8PdfrODNqm6+dVMpWwo1dl+7uywNSbp0yiceCTQkSQpGBBmPyrL8jCf2qWCbi51CCFpkzGgAbC3UcLKpn1HFjlzBz/nrwQZ0BgP3bsufsT0hMoS4iGDbgtCpUXj5y3DgZwuwSgV/pG1gnLbBCdZki0BjVXYc4L4g9PGjTfz1YD0f3pzLBzblOHx9amwY63MSLpnuE090nUjAQ0CFLMs/d39JCo6wFmhsKdQsiAFNoCJJ0l8lSeqSJKnc12u5lBma0PLoO43sXp5GTmLknOfzNXZaXNvPgmyAxkOgVwJuhbmY9BlrjRkNTVQo2QkRnGqevyBUpzfw45cruSw/kf+9vsTp912/Io2qzmGqjdfzQMYTGY3LgQ8AV0iSdNr45zoP7FfBBpUdw4QHq8mKjzBvW5+bQEiQSimfzJ+HgV2+XsSlzqPvNDE8qeMT25dYfT4/Kcq2aVfbSfE4OQQdZ7y0QgV/5mTjAGHBKpalx5i3rc6O42TjwLz3eaZlkIExLXduyiZI7fxX6u6yVCTp0jDv8kTXyduyLEuyLK+QZXmV8c8eTyxOwToXO4dZmhKFSjXdmx0WrGZDbgJvK8Zd80KW5f2Akg7yIRNaPQ+9Xc/WQg1lGbFWX5OfFEnX8CTDE9q5T7aehDDj+xre9uJKFfyVE039rMiMI9giIFidFUfH0ATtg/PrZtp3sRuVBFsK7OsyZpMcE8aGXCvlkze+D0/fM6+1LFYUZ1A/pKpjeEbZxMSWQg0XO0foHLr0xhAvBErnlHd55mQrPSOTNrMZAPkaIQi12nnSdhJyt4JmKdQf8NYyFfyUCa2e862D5rKJidVGvcZ8dRr7qrpYnR1PXMTcVlZH3LAijequEXM5HIDyp6D8aZiY01PhtyiBhp/RMzJJz8gURakxc57bYm5zVcon3kDpnHKNrzx9lit++hZPHGtG68BkS2+QeXB/LSsyY7lsSaLN1+Un2WhxHe+HvjrIWCOCjabDoLeS9QhgFJ2Rfc62DKIzyGYhqImStBhCglTzMu7qHZnkbOsg25fO73pwbVkqKolpT42xPnEey3qhNQoQlEDDz7hotB4vSpmb0ViWFkNiZAhvK2PjFXxMfc8o/z7eTN/YFPc/fZadP32Lx482MaWzHnCYRsF/YvuSGXbNs8lJjEAlWZni2nZaPKavgbytMDUyvW2+tJ2GmtfAi0O3PMzDKDojm5jmm6wxdpqYCAlSsTwjdl4ZjbdrepBl5h1oJEeHsTEvkRfPtonhbq0np5+s3zevfS5GlEDDz6jsmNtxYkKlkthcoDGe/H5zcVQIQB7cX0uwWsWrn9/G3z68nsSoUL76zDl2/vQtHnmnkUndtEGSLMv8YV8NeZpIrilNtbvf0CA1mfERcztPTELQ9FUiowHQsH/+/wBZhqfuhkdug7/thuZj89/XAqHojOxzsqmfPE0kiVGhc55bnRXHudZBm4GwLd6q6iYhMoTlNjRFznD9ijRqu0ep6hyG1hOAJALmOiXQUPARVR3DJEaGzDAysmRrgYbu4Ulx0io4jSRJjwOHgSJJklokSfqIr9fkrwjDrVbeszaT5OgwdhYn8+wnN/Pw/1tPckwo//tsOTseeIt/HG5gQqvnYE0v5a1DfGxbPmqV7WyGifykyLmdJ60nISEfwuMhUgPJy9zTaXRXQV8tFN8AvbXw0FXwxAfF3/2YS1VnJMuyMOqaVTYxsTo7nkmdgcoO53URBoPM/ovdbCvUzBDmu8ouY/nkxbPt0Hockoqg+HroOg8jgfF/pAQafkZV5zBLrZRNTJgc6RSdhmvIsnyHLMtpsiwHy7KcKcvyQ75ek79izXBLkiR2FCXzzCc288+PbCAjLpxvPnee7Q+8ybeeP09ydCjvWuPciKR8TRT1PaMYLIcItp0Sd4EmcrdC8xHQTc3vH1H5gni87gH4zCnY8TWofg1+twFe/JLffgFcqjqjxt4xeken5ghBTayeh3FXedsgvaNT7ChKdmttmqhQNi/R8OKZNuTWE5CxDvJ3iCfdycotIpRAw48wGGQudlrvODGRHhdOflKk4qeh4BMcGW5JksTWwiSe/PhlPPbRjeQkRlLTNcK92/IJDVI7dYz8pEjGtXo6TN1Vw50w1CqEoCZyt4B2bLqk4iqVL0LGWohJh9Ao2PFl+OxpWPMhOP5X+PUq2Pc
2021-06-10 19:18:36 +02:00
"text/plain": [
"<Figure size 648x864 with 12 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"from sklearn.preprocessing import RobustScaler, MaxAbsScaler, MinMaxScaler, StandardScaler\n",
"scaler = StandardScaler()\n",
"data = XX[0]\n",
"plot_data(data.T)\n",
"\n",
"plot_data(scaler.fit_transform(data))"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "6e98530f",
2021-06-10 19:18:36 +02:00
"metadata": {},
"outputs": [],
"source": [
"XXX = np.array(list(map(scaler.fit_transform, XX)),dtype=object)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "9049dc4b",
2021-06-10 19:18:36 +02:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"count 26179.000000\n",
"mean 48.405172\n",
"std 30.709627\n",
2021-06-10 19:18:36 +02:00
"min 6.000000\n",
"50% 44.000000\n",
"90% 72.000000\n",
"91% 74.000000\n",
"92% 76.000000\n",
"93% 79.000000\n",
"94% 82.000000\n",
"95% 87.000000\n",
"96% 91.000000\n",
"97% 98.000000\n",
"98% 109.000000\n",
"99% 137.000000\n",
"max 1718.000000\n",
2021-06-10 19:18:36 +02:00
"dtype: float64"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXUAAAD4CAYAAAATpHZ6AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAfC0lEQVR4nO3de3xU9Z3/8deHkIQkBBIgQSTc5KaoKBgVsa1o7RZt1XbrjVpvdUu9td2uj+1D267tutt9bLXbbi9KxdaiaGutrZZtaalabX+Vi4SKyMVAuCYRSAyE3MhlMp/fHzNAiISMYZIzM3k/H488MnPOIXlzyLw5+Z5zvmPujoiIpIYBQQcQEZH4UamLiKQQlbqISApRqYuIpBCVuohIChkY1DceMWKEjx8/PqhvLyKSlNasWfOuuxd0tT6wUh8/fjwlJSVBfXsRkaRkZjuPt17DLyIiKUSlLiKSQlTqIiIppNtSN7PHzazKzNZ3sd7M7AdmVmZm68xsZvxjiohILGI5Ul8EzD3O+suAydGP+cCCE48lIiI90W2pu/tfgX3H2eQq4EmPWAnkmdmoeAUUEZHYxWNMfTRQ3uF5RXTZe5jZfDMrMbOS6urqOHxrERHpqE+vU3f3hcBCgOLiYs35KyJJpz3stITaaWkL0xIK0xoKR54f+hxdfmRZ9KPtyPMPn1rIWWPyeiVfPEq9EhjT4XlRdJmISFy5e6RI28PR8owWZYfHraFwlwXb2lXxtkW/5vHWRUs5FD7x49HC3MyELvUlwN1m9gxwPnDA3XfH4euKSJJqbmtnX2Mr+xpb2d8U/dzYyv6mNprbOh3Vdi7o6FHtkeI+upRPVHqakTkwjYyBA8g8/JFGZnrk8aD0NIZkpR9z3Xv+XHrakW0GDjiy7vDyI48PrctIG4CZxWEvH1u3pW5mvwDmACPMrAL4BpAO4O4/BpYClwNlQBNwa2+FFZG+1xoKU9vUyr7D5dzGvqZISR9V2k3RdY2tHGxr7/LrDUo/UoKZ0ccZaQMOF2deVjqZuZnvKcPOf+Y969I7bddFIacN6L1CTQTdlrq7z+tmvQN3xS2RiPQ6d6emsZXK/Qep2H+QytomqupajpR1U1vkyLqxlfqWUJdfJ3fQQIblZJCfnUFh7iCmjhzCsJx0huVkMiwnnfzsjMj6nAyGZWcwJCs95Us1aIFN6CUivSccdqrqW6isbaLicHFHP+9vorL2IM1tRw9lZKWnMSznSAlPGJ59uIzzDy0/XNKRwk5P003piUalLpLE6pvbWFdxgDcratnxbuPh8t5d20xr+9GlPSwng9F5WUwZmcvFUwspys9idH42o/OyGJ2fxdCs9ID+FhJPKnWRJBFqD1O6t5615bWs3VXL2vJayqob8OjFGIW5mYzOz2J6UR6XnREp6qL8LIqipZ2doZd7f6B/ZZEE1NzWTllVA5v31rNpdx1ry2t5q/LA4SGT/Ox0ZozN54qzTubsMXmcVZTH0GwdaYtKXSRQbe1hdtY0UrqngdK99WzeU8/mvfXsqGnk0OXQGWkDmHbyEK4/dywzxuZx9pg8xg7L7tXL4iR5qdRF+oi7U1bVwF82V/NW5QFK99Szrbrx8Nj3AIPxw3OYMjKXj591MlNH5jL1pMGMG56jE5ISM5W6SC9qbAmxfGsNr5ZW8WppNZW1BwE4eeggpp6Uy0VTC5g6MpcpI3OZVDiYQelpASeWZKdSF4kjd2drdePhEn99+z5a28PkZKQxe9II7rx4InOmFjI6LyvoqJKiVOoiJ6ipNcSKrTW8Ei3yiv2Ro/HJhYO5efY4Lp5aSPH4YWQM1BCK9D6VukgPNLaEeP6NSpZt2MOqbZGj8eyMNGZPHMHtF01kztQCivKzg44p/ZBKXeR92FnTyJMrdvLs6nLqW0KcUpDDTReMY87UQs6dkE/mQI2JS7BU6iLdcHdeK6th0fLtvPx2FWlmXH7mKG6ePZ6ZY/N0aaEkFJW6SBeaWkP85u+VPLF8B1uqGhiek8EXLp7EDbPGMXLIoKDjiRyTSl2kk/J9TTy5Yge/XF1OXXOIM0YP4X+uOYuPTR+lSw4l4anURYgMsazYWsPPlu/gpU17GWDGZWecxK0Xjmfm2HwNsUjSUKlLv/eHt3bzvy9toXRvPcNyMrhrziRumDWWUUN1LbkkH5W69Fvuzvde2sIPXt7CqSfl8tDV07nirJM1xCJJTaUu/VL5viYeXFbK/735DlefU8R/ffJM3RwkKUGlLv2Gu7Ny2z4WLd/Oixv3Ymb860encueciRozl5ShUpd+YdmGPXzvxc28vaee/Ox07pgzkc/MGqdxc0k5KnVJae7Ow6+U8Z0/bWbKyME8+KnpXHm2xs0ldanUJWW1hsJ89fm3eG5NBZ+cMZr//tSZuo1fUp5KXVLSgaY2bn9qDSu21fDPl07mSx+erHFz6RdU6pJydtU0ceui19m1r4nvXnsW/zizKOhIIn1GpS4pZc3O/cx/soR2d5667XzOP2V40JFE+pRKXVLG79ft5svPrmXU0EH87JZzOaVgcNCRRPqcSl2Snruz4C9befCPpRSPy2fhTcUMy8kIOpZIIFTqktTa2sN8/fn1/LKknCvOOpmHrp6uyxWlX1OpS9I6cLCNO59ew2tlNXzhkkl8+dIpDBigK1ykf1OpS1Iq39fEZxetZvu7jTx09XSuKR4TdCSRhKBSl6SztryWf3piNS2hME9+9jxmTxoRdCSRhBHTtHRmNtfMSs2szMzuPcb6sWb2ipm9YWbrzOzy+EcVgT+u38P1C1cwKD2N5++crUIX6aTbUjezNOBh4DJgGjDPzKZ12uzrwLPuPgO4Hngk3kGlf3N3HvvrNu54eg2njRrCC3ddyKTC3KBjiSScWIZfzgPK3H0bgJk9A1wFbOywjQNDoo+HAu/EM6TIt36/iZ/8bTsfO3MU/3PtWbrCRaQLsZT6aKC8w/MK4PxO23wT+JOZfQHIAS491hcys/nAfICxY8e+36zSjz3+WqTQfzhvhq5wETmOeL3VyzxgkbsXAZcDi83sPV/b3Re6e7G7FxcUFMTpW0uqa2sPE3aYWJCjQhfpRiylXgl0vF6sKLqso9uAZwHcfQUwCNAZLDlhdc1tfHbRagCmnjSkm61FJJZSXw1MNrMJZpZB5ETokk7b7AI+DGBmpxEp9ep4BpX+p7L2INcsWMGKrTV8+1Nn8rHpo4KOJJLwuh1Td/eQmd0NLAPSgMfdfYOZPQCUuPsS4B7gMTP7MpGTpre4u/dmcElt6ypque2JEppb21l063l8YLJ+8ROJRUw3H7n7UmBpp2X3d3i8EbgwvtGkv/rThj186Zm1DMvJ4Ok7z2fKSF26KBIr3VEqCcPdefy1Hfzn7zcyvSiPn9xUTEFuZtCxRJKKSl0SQqg9zAO/28iTK3Yy9/ST+N51Z5OVoWvRRd4vlboErqElxN0//zuvllYz/0OncO/cU3XpokgPqdQlUA0tIa5esJwtVQ1865NncMP544KOJJLUVOoSqNU79vH2nnq9QbRInMTrjlKRHmkNhQH0fqIicaJSl8CsrzzA/b9dz+DMgRTlZwUdRyQlqNQlEC9t3Mu1j64gzYzn7riAEYN16aJIPGhMXfrcote288DvNnL6yUP56c3FFA4ZFHQkkZShUpc+9Z1lpfzolTI+Mm0k37/+bLIz9CMoEk96RUmfWrxyJ5ecWsiPP3MOaboWXSTuNKYufaY97LSHnbHDslXoIr1EpS59orElxPwnS2hoCXH6yZoXXaS3aPhFel1Ta4hrH13Bpt11/McnzuCa4jHd/yER6RGVuvS6teW1bHinjgc/NZ1rz1Whi/QmDb9Irzt01+i44dkBJxFJfSp16VUb3jnAvb9+i5yMNMaPyAk6jkjKU6lLr1mzcx/X/ngFZvDcHbMZqZuMRHqdxtSl1/x27TuEHV6460IVukgf0ZG69JqWtjBZGWkqdJE+pFKXXvHE8h38ak05Z4weGnQUkX5Fwy8Sd999cTM/eHkLl542kh/MOzvoOCL9ikpd4u7nq3bxoSkFPHqj5ncR6WsafpG4ag87oXCYovwsFbpIAFTqEje
2021-06-10 19:18:36 +02:00
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"X_len = np.asarray(list(map(len, XX)))\n",
2021-06-10 19:18:36 +02:00
"l = []\n",
"sq_xlen = pd.Series(X_len)\n",
"ptiles = [x*0.01 for x in range(100)]\n",
"for i in ptiles:\n",
" l.append(sq_xlen.quantile(i))\n",
"plt.plot(l, ptiles)\n",
"sq_xlen.describe(percentiles=[x*0.01 for x in range(90,100)])"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "8db3d74b",
2021-06-10 19:18:36 +02:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"((25918,), (53, 15))"
2021-06-10 19:18:36 +02:00
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"threshold_p = 0.99\n",
"len_mask = np.where(np.asarray(list(map(len, XX))) <= int(pd.Series(np.asarray(list(map(len, XXX)))).quantile(threshold_p)))\n",
"\n",
"X_filter = XX[len_mask]\n",
"y_filter = y[len_mask]\n",
"\n",
"X_filter.shape, X_filter[0].shape"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "5c7afe06",
2021-06-10 19:18:36 +02:00
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjMAAAKuCAYAAABUqp1fAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOydd5hkVZn/P2+Frs5xenpmuiczTASGYYiS44DKiKCCKKCsGGBVWH+Kui6oi6K7JlbUxQUJK0mFBZE05DjAAMMkGCbHntQ5Vjy/P869Vbeqq7uruqu6p7rO53n6qepz7617qvuG937fJEopDAaDwWAwGHIV12hPwGAwGAwGg2E4GGPGYDAYDAZDTmOMGYPBYDAYDDmNMWYMBoPBYDDkNMaYMRgMBoPBkNMYY8ZgMBgMBkNOk3VjRkS2ishqEVkpIiussWoRWSYiG6zXKmtcROQWEdkoIqtEZJHjcy631t8gIpdne94Gg8EwVjDXYcNYZ6SUmdOUUguVUout368HnlVKzQKetX4HOBeYZf1cBfwe9EkH3AAcCxwD3GCfeAbDcBGRySLyvIisE5G1IvINa/xGEdll3QBWish5jm2+a13s14vIOY7xJdbYRhG53jE+XUTesMYfEJGCkf2WBoO5DhvGLpLtonkishVYrJQ64BhbD5yqlGoUkYnAC0qp2SLy39b7+5zr2T9KqS9b43HrJWPcuHFq2rRp2flShjFFMBgkGAxSXFxMOBzm/fffZ+bMmbS0tOByuZgwYQJvv/32AaVULYCIzAPuQ1/QJwHPAIdaH/chcBawE3gLuEQptU5EHgQeUkrdLyJ/AN5TSv1+oHmZY9iQKVavXk0gEGhSSo2zx8x12JBrOK/DiXhGYP8KeFpEFPDfSqnbgDqlVKO1fA9QZ72vB3Y4tt1pjfU33i/Tpk1jxYoVGZi+Id9YunQp11xzDa+++iqlpaV861vfQkS2OVcB7ldK+YEtIrIRbdgAbFRKbQYQkfuBpSLyPnA68FlrnbuAG7GeePvDHMOGTDF9+nS2bt1aJCJvY67Dhhwl4Tocx0i4mU5USi1CS5dXi8jJzoVKS0MZkYdE5CoRWSEiK/bv35+JjzTkGVu3buXdd9/l2GOPBeC3v/0thx9+OMA0h6Se7sW+BmhVSoUSxg2GEeGVV14BeB9zHTaMUbJuzCildlmv+4CH0U+wey1ZE+t1n7X6LmCyY/MGa6y/8cR93aaUWqyUWlxbm1SJMhj6pbOzkwsvvJBf//rXlJeX89WvfpVNmzaxcuVKgCDwi2zPwdwIDNmgvl7bzuY6bBirZNWYEZESESmz3wNnA2uARwE7Ev5y4BHr/aPAZVY0/XFAmyWDPgWcLSJV1tPx2daYwZARgsEgF154IZdeeimf/OQnAairq8PtduNyuQD2E3MlpXuxbwIqRcSTMN4HcyMwZJquri46OjoAcx02jF2yHTNTBzwsIva+7lVKPSkibwEPisiVwDbg09b6jwPnARuBbuALAEqpZhH5MTqgEuBHSqnmLM/dkCcopbjyyiuZO3cu1113XXS8sbGRiRMn2r9WAq9Z7x8F7hWRX6IDgGcBbwICzBKR6Whj5WLgs0opJSLPAxcB9xN/4zAYssrevXu54IILAOahj1NzHTaMObKezTRaLF68WKUTeNbRG+TKO1fwH586nKk1JVmcmeFg45VXXuGkk07isMMOs1UYfvKTn3DfffexcuVKRITVq1e3AXPtgEkR+T7wRSAEfFMp9YQ1fh7wa8AN3KGUuskan4E2ZKqBd4HPWQHE/ZLuMZxt2rqDXPfgSn520eGMK/WN9nQMaSIibzvSskeEg+0YTgV/KMw/3bWC7yyZw4L6itGejsHBQMfwSGQz5QTL1u3lza3N/GrZh/z64iNHezqGEeTEE08kmVF/3nnRsjKIyEZH5geWkXJT4jZKqcfRT7aJ45uJualykg/2tPPsB/tYs6uNU2ePH+3pGAxZ4f3GDl7ecIC2ntU8es2Joz0dQ4qYdgYW4Yi+mblcMsozMRgOTuxzJBQem2quwQAQCkcAcJt7QU5hjBmLiPVk7hZzABsMyQjZxkwkMsozMRiyR9Ay1r0uc3vMJcx/y8Iyxo01bjD0g63MBI0yYxjD2Me5x23uBbmEMWYswhEjLRoMAxE2yowhD7CPb4/b3B5zCfPfsrAv1MaYMRiSEzLKjCEPsGPCPOZekFMYY8bCvj67TMyMwZCUmJvJKDOGsUvIqPQ5iTFmLCJGmTEYBsS+yJtsJsNYxh/Sx7nXxMzkFMaYsQgrY8wYDANhlBlDPuAP2sqMuT3mEua/ZRGtM2PcTAZDUmKp2UaZMYxd/KEwAF7zYJtTGGPGIuZmGuWJGAwHKbGieUaZMYxdeoMmZiYXMbdui7ApmmcwDIjJZjLkA7YyY1KzcwvTm8kiYtoZGAwDErYUGVNnxpB1wkFY9QDsXgm+MjjkDJg2Mn2SbGXGRQS2Lwe3F8rroaQWXO4RmYMhfYwxY2ErM6a2gMGQnJDpzWQYCT54HJ76HrRsAV8FBLvglV/CWT+Cj3wj67t39TbzNff/8YUPXoNVOx0LvFC/COZ9Ao66AgqKsz4XQ+oYY8bCDgMwyozBkBy7f5lxMxmyxvM/hRdvhvHz4ZIH4NBzINQLD10Fz9wI00+BSQuzt/8db/LVVZ+iyNvBZs/h1H70B+Arh/Zd0LIVtrwET30X3vojXPpXqJmZvbkY0sIYMxam0aTBMDCm0aQhq7x1uzZkFl4KH/s1eAr0uLcIzv8v7fJZ9m9w+aOZ37dS8MFj8PBX6HJX8Inu7zNv7vH86oiFfdfd8hI8eBk88Hn4p2eMQnOQYCKcLOxMDRGgvRGat0AoMLqTMhgOIsJho8wYssSGZ+Dxb8Gsc+Djt8QMGZuiSu3a2foydO7P3H6DvfDQl+E/D4UHPgc1h/Cbyb9hvZrSfwmC6SfDJ/8I+9bC67dmbi6GYWGMGYtwRFFOJye//0P45Ry4ZSHcPAVW/WVkJ9LTCstugDs/Bv93Nbz0H7D+SQj2jOw8DIYEQqZoniEbRCLw5Hdg3KFw0R3g7sdhMG8pqIhWUDLFk9+BVffDjFPgE7+HLz7JPqqAWPPhpMw6C+Z8DF79DXQdyNx8DEPGuJks6jrfZ5nv29Q2dsDx18D4ubDyXnjoS1A+CaZ9JLsTCHTDy/8Jy/8AwW6YeARsXAade/Xyoio48To9N1OZ0jAKmDozhqyw4Wlo2ggX3g6+0v7Xq5sP1TO1O+qIS8BbOLz9Nm+Bt++E466GJT+JDtvtDMKDFYc849/gg2P0Z5z8reHNxTBszF0R4MAGvrDx6wTx8JdFd8E5N8GRn9MBXmUT4PmbtE81W4RDcN/F8PIvYPa58OWX4Msvwrc+hO/ths89BA1Hw7IfwN+/nr155DE7duzgtNNOY968ecyfP5/f/OY3ADQ3N3PWWWcBLBCRZSJSBSCaW0Rko4isEpFF9meJyOUissH6udwxfpSIrLa2uUUktwK0osqMqQBsyBSRCLz4Myhv0MrLQIjA2f8Oe1fD098f/r7XPqRfj/tK3LDdzmBQY6Z2Nkw7Cd65S38Pw6hijJlwEB76EmFx8xn/D9hfOie2rKBYqyHbXoWdK7Kzf6W0r3jLi3D+b+Gi22Hi4Y45lOgaC599UKclvnsPbFiWnbnkMR6Ph1/84hesW7eO5cuXc+utt7Ju3TpuvvlmzjjjDIA1wLPA9dYm5wKzrJ+rgN8DiEg1cANwLHAMcINtAFnrfMmx3ZKR+XaZIRxtNGku3IYMseavsPsdOP1fdT2XwZhznlZS3vof2Pzi0PerFKx5CBqOgcopcYt6raJ5KbXtWPwFaN0Om54b+lwMGcEYMy/9B+x+l79N/Ba7qKXP8XvYRYDApmczu9/eNu1SuvOj8Paf4MRrYdHn+19fBE77PtTMgqe+b54EMszEiRNZtEiLK2VlZcydO5ddu3bxyCOPcPnlUXHlLuAT1vulwN1KsxyoFJGJwDnAMqVUs1KqBVgGLLGWlSulliulFHC347NyAlNnxpBRgj3wzA+1S/3wz6S+3en/CtUz4NF/hkDX0Pb9+q2
2021-06-10 19:18:36 +02:00
"text/plain": [
"<Figure size 648x864 with 12 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"from tensorflow.keras.preprocessing.sequence import pad_sequences\n",
"\n",
"X_filter = pad_sequences(X_filter, dtype=float, padding='post')\n",
"\n",
"plot_data(X_filter[0].T)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "d9f6b07e",
2021-06-10 19:18:36 +02:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(25918, 137, 14)"
2021-06-10 19:18:36 +02:00
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_filter = np.array([np.delete(x, 0, 1) for x in X_filter])\n",
"X_filter.shape"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "1218eb18",
2021-06-10 19:18:36 +02:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(20734, 137, 14)\n",
"(5184, 137, 14)\n",
"(20734, 52)\n",
"(5184, 52)\n"
2021-06-10 19:18:36 +02:00
]
}
],
"source": [
"from sklearn.model_selection import train_test_split\n",
"from sklearn.preprocessing import LabelEncoder, LabelBinarizer\n",
"import tensorflow as tf\n",
"\n",
"lb = LabelBinarizer()\n",
"\n",
"yt_filter = lb.fit_transform(y_filter)\n",
"\n",
"X_train, X_test, y_train, y_test = train_test_split(X_filter, yt_filter, test_size=0.2, random_state=177013)\n",
"\n",
"print(X_train.shape)\n",
"print(X_test.shape)\n",
"print(y_train.shape)\n",
"print(y_test.shape)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "d0aae270",
2021-06-10 19:18:36 +02:00
"metadata": {},
"outputs": [],
2021-06-10 19:18:36 +02:00
"source": [
"# import os\n",
2021-06-10 19:18:36 +02:00
"\n",
"# # By users on card, could do by memory, but too lazy\n",
"# def get_least_used_gpu():\n",
"# a = [x.split('|') for x in os.popen('gpustat').read().split('\\n')[1:-1]]\n",
"# lu = min([x[-1].split() for x in a], key=len)\n",
"# lui = [i for i in range(len(a)) if a[i][3] == ''.join(lu)][0]\n",
"# print(f'Using {a[lui][0]}')\n",
"# return lui\n",
2021-06-10 19:18:36 +02:00
" \n",
"# lu_gpu = get_least_used_gpu()"
2021-06-10 19:18:36 +02:00
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "e202c8cf",
2021-06-10 19:18:36 +02:00
"metadata": {},
"outputs": [],
"source": [
"# FIRST CELL: set these variables to limit GPU usage.\n",
"os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' # this is required\n",
"os.environ['CUDA_VISIBLE_DEVICES'] = '2' # set to '0' for GPU0, '1' for GPU1 or '2' for GPU2. Check \"gpustat\" in a terminal."
2021-06-10 19:18:36 +02:00
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "7b09f8fe",
2021-06-10 19:18:36 +02:00
"metadata": {},
"outputs": [],
"source": [
"accs = []"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "09a037ed",
2021-06-10 19:18:36 +02:00
"metadata": {},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"from tensorflow.keras.models import Sequential\n",
"from tensorflow.keras.layers import Dense, Flatten, BatchNormalization, Dropout\n",
"from tqdm import tqdm\n",
"\n",
"\n",
"def build_model():\n",
" model = Sequential()\n",
" \n",
" model.add(BatchNormalization(input_shape=X_filter[0].shape))\n",
" model.add(Dropout(0.2))\n",
" \n",
" \n",
" model.add(Flatten( ))\n",
2021-06-10 19:18:36 +02:00
"\n",
" for i in range(DENSE_COUNT):\n",
" model.add(Dense(DENSE_NEURONS, activation='relu'))\n",
" \n",
" for i in range(DENSE2_COUNT):\n",
" model.add(Dense(DENSE2_NEURONS, activation='relu'))\n",
" model.add(Dropout(0.2))\n",
2021-06-10 19:18:36 +02:00
" \n",
" \n",
" model.add(Dense(52, activation='softmax'))\n",
2021-06-10 19:18:36 +02:00
"\n",
" model.compile(\n",
" optimizer=tf.keras.optimizers.Adam(0.001),\n",
" loss=\"categorical_crossentropy\", \n",
" metrics=[\"acc\"],\n",
" )\n",
"\n",
" return model\n",
"# model.summary()\n"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "061a0bdb",
2021-06-10 19:18:36 +02:00
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/1 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/20\n",
"81/81 [==============================] - 1s 6ms/step - loss: 3.1932 - acc: 0.1521 - val_loss: 2.4668 - val_acc: 0.3038\n",
"Epoch 2/20\n",
"81/81 [==============================] - 0s 4ms/step - loss: 2.1157 - acc: 0.3769 - val_loss: 1.7419 - val_acc: 0.4670\n",
"Epoch 3/20\n",
"81/81 [==============================] - 0s 4ms/step - loss: 1.6080 - acc: 0.4976 - val_loss: 1.4454 - val_acc: 0.5484\n",
"Epoch 4/20\n",
"81/81 [==============================] - 0s 4ms/step - loss: 1.3078 - acc: 0.5800 - val_loss: 1.3221 - val_acc: 0.5799\n",
"Epoch 5/20\n",
"81/81 [==============================] - 0s 4ms/step - loss: 1.1233 - acc: 0.6282 - val_loss: 1.2188 - val_acc: 0.6173\n",
"Epoch 6/20\n",
"81/81 [==============================] - 0s 4ms/step - loss: 0.9627 - acc: 0.6661 - val_loss: 1.1781 - val_acc: 0.6296\n",
"Epoch 7/20\n",
"81/81 [==============================] - 0s 4ms/step - loss: 0.8519 - acc: 0.7032 - val_loss: 1.1871 - val_acc: 0.6354\n",
"Epoch 8/20\n",
"81/81 [==============================] - 0s 4ms/step - loss: 0.7557 - acc: 0.7320 - val_loss: 1.1691 - val_acc: 0.6433\n",
"Epoch 9/20\n",
"81/81 [==============================] - 0s 4ms/step - loss: 0.6753 - acc: 0.7536 - val_loss: 1.1882 - val_acc: 0.6412\n",
"Epoch 10/20\n",
"81/81 [==============================] - 0s 4ms/step - loss: 0.6153 - acc: 0.7766 - val_loss: 1.2020 - val_acc: 0.6508\n",
"Epoch 11/20\n",
"81/81 [==============================] - 0s 4ms/step - loss: 0.5711 - acc: 0.7933 - val_loss: 1.2155 - val_acc: 0.6397\n",
"Epoch 12/20\n",
"81/81 [==============================] - 0s 4ms/step - loss: 0.5196 - acc: 0.8081 - val_loss: 1.2665 - val_acc: 0.6512\n",
"Epoch 13/20\n",
"81/81 [==============================] - 0s 4ms/step - loss: 0.4745 - acc: 0.8239 - val_loss: 1.2465 - val_acc: 0.6644\n",
"Epoch 14/20\n",
"81/81 [==============================] - 0s 4ms/step - loss: 0.4222 - acc: 0.8450 - val_loss: 1.3260 - val_acc: 0.6568\n",
"Epoch 15/20\n",
"81/81 [==============================] - 0s 4ms/step - loss: 0.4110 - acc: 0.8471 - val_loss: 1.3274 - val_acc: 0.6593\n",
"Epoch 16/20\n",
"81/81 [==============================] - 0s 4ms/step - loss: 0.3950 - acc: 0.8558 - val_loss: 1.3014 - val_acc: 0.6707\n",
"Epoch 17/20\n",
"81/81 [==============================] - 0s 4ms/step - loss: 0.3736 - acc: 0.8633 - val_loss: 1.3430 - val_acc: 0.6561\n",
"Epoch 18/20\n",
"81/81 [==============================] - 0s 4ms/step - loss: 0.3454 - acc: 0.8739 - val_loss: 1.4568 - val_acc: 0.6433\n",
"Epoch 19/20\n",
"81/81 [==============================] - 0s 4ms/step - loss: 0.3520 - acc: 0.8699 - val_loss: 1.4099 - val_acc: 0.6516\n",
"Epoch 20/20\n",
"81/81 [==============================] - 0s 4ms/step - loss: 0.3378 - acc: 0.8832 - val_loss: 1.3834 - val_acc: 0.6609\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1/1 [00:08<00:00, 8.67s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"test loss, test acc: [1.383405089378357, 0.6608796119689941]\n",
"Model: \"sequential\"\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"batch_normalization (BatchNo (None, 137, 14) 56 \n",
"_________________________________________________________________\n",
"dropout (Dropout) (None, 137, 14) 0 \n",
"_________________________________________________________________\n",
"flatten (Flatten) (None, 1918) 0 \n",
"_________________________________________________________________\n",
"dense (Dense) (None, 1000) 1919000 \n",
"_________________________________________________________________\n",
"dense_1 (Dense) (None, 1000) 1001000 \n",
"_________________________________________________________________\n",
"dense_2 (Dense) (None, 600) 600600 \n",
"_________________________________________________________________\n",
"dense_3 (Dense) (None, 600) 360600 \n",
"_________________________________________________________________\n",
"dropout_1 (Dropout) (None, 600) 0 \n",
"_________________________________________________________________\n",
"dense_4 (Dense) (None, 52) 31252 \n",
"=================================================================\n",
"Total params: 3,912,508\n",
"Trainable params: 3,912,480\n",
"Non-trainable params: 28\n",
"_________________________________________________________________\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
2021-06-10 19:18:36 +02:00
]
}
],
"source": [
"for i in tqdm(range(1)):\n",
2021-06-10 19:18:36 +02:00
" model = build_model()\n",
" \n",
" model.fit(X_train, y_train, \n",
" epochs=EPOCH,\n",
" batch_size=256,\n",
2021-06-10 19:18:36 +02:00
" shuffle=True,\n",
" validation_data=(X_test, y_test),\n",
" verbose=1,\n",
2021-06-10 19:18:36 +02:00
" )\n",
" # Evaluate the model on the test data using `evaluate`\n",
"# print(\"Evaluate on test data\")\n",
" results = model.evaluate(X_test, y_test, batch_size=128, verbose=0)\n",
" print(\"test loss, test acc:\", results)\n",
" accs.append((model,results[1]))\n",
" model.summary()"
2021-06-10 19:18:36 +02:00
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "69673754",
2021-06-10 19:18:36 +02:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.6608796119689941"
2021-06-10 19:18:36 +02:00
]
},
"execution_count": 23,
2021-06-10 19:18:36 +02:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.mean(np.delete(accs,0,1).astype('float64'))"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "389c41e3",
2021-06-10 19:18:36 +02:00
"metadata": {},
"outputs": [],
"source": [
"exit()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "478e3241",
2021-06-10 19:18:36 +02:00
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}