iui-group-l-name-zensiert/1-first-project/tdt/NeuralNetwork.ipynb

633 lines
160 KiB
Plaintext
Raw Normal View History

2021-06-07 19:58:49 +02:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "f9261918",
2021-06-07 19:58:49 +02:00
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np"
2021-06-07 19:58:49 +02:00
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "a22e78f7",
2021-06-07 19:58:49 +02:00
"metadata": {},
"outputs": [],
"source": [
"delim = ';'\n",
"\n",
"base_path = '/opt/iui-datarelease1-sose2021/'\n",
"\n",
"Xpickle_file = './X.pickle'\n",
"\n",
"ypickle_file = './y.pickle'"
]
},
{
"cell_type": "code",
"execution_count": 42,
"id": "592e5107",
"metadata": {},
"outputs": [],
"source": [
"THRESH = 70\n",
"LEEWAY = 2\n",
"EPOCH = 30\n",
"\n",
"DENSE_COUNT = 3\n",
"DENSE_NEURONS = 1800\n",
"\n",
"DENSE2_COUNT = 2\n",
"DENSE2_NEURONS = 0\n",
"\n",
"AVG_FROM = 20"
]
},
{
"cell_type": "code",
"execution_count": 43,
"id": "63671cad",
2021-06-07 19:58:49 +02:00
"metadata": {},
"outputs": [],
"source": [
"def shorten(npList):\n",
" temp = npList['Force']\n",
" thresh = THRESH\n",
" leeway = LEEWAY\n",
2021-06-07 19:58:49 +02:00
" \n",
" temps_over_T = np.where(temp > thresh)[0]\n",
" print(temps_over_T)\n",
" return npList[max(temps_over_T[0]-leeway,0):min(len(npList)-1,temps_over_T[-1]+leeway)]"
2021-06-07 19:58:49 +02:00
]
},
{
"cell_type": "code",
"execution_count": 44,
"id": "166cc6b8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161\n",
" 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 196 197 198\n",
" 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214]\n"
]
},
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7f33513f2eb0>]"
]
},
"execution_count": 44,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD4CAYAAAAAczaOAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAA0/0lEQVR4nO2de4xkV33nP796V1c/Z7o9Mx6PPTYYHJOA8c7aZMNGELKAYTcGiUUQCbyIlaMN7CbSPkQSCUiirLLRJiuhJSCyWBjEQtgAwRs5gGNISNgAHhtj/AB7/BjPjGc8z35Vddfz7B/nnuqame7qqlv3UX3v7yO1qvrW69ypnm9963t+53fEGIOiKIqSDjJxD0BRFEWJDhV9RVGUFKGiryiKkiJU9BVFUVKEir6iKEqKyMU9gH7Mz8+bgwcPxj0MRVGUHcWDDz541hizsNltYy36Bw8e5PDhw3EPQ1EUZUchIke3uk3jHUVRlBShoq8oipIiVPQVRVFShIq+oihKilDRVxRFSREq+oqiKClCRV9RFCVFqOj3sFRr8qUHjqHtppVtqZ6DH34e9G9F2WGo6Pfwye88zX/58iN875nzcQ9FGXf+/r/D134djuviQWVnoaLv0e4YvvrQCQDu/fHJmEejjDXtJjzyJXv98b+MdSiKMiwq+h7ff/Ycp5bX2V0p8NePnqLdufhr+9/+9DQf/tqjMY1OGSue+VuonYWJeXjsLy+PeH5yL3z9t+MYmaJsi4q+x+MvLANw5y9ex9nVOscv1Lq3rTfb/PZXfsxn//Eox87XtnoKJS2cesRe/sJ/gOXjsNLzzbBRhb/6Tfjex2FZvzEq44eKvsfxC2tMFXPcsG8agNMr9e5tX37oOC8srQPw3SNnYxmfMkZcOAqVBZh/mf195dTGbQ99FlZftNef/U70Y1OUbVDR9zh+ocb+uTJ7p0sAnPJEHuDJUytMlXJcMVXku0+fi2uIyriw+DzMXg1Te+3vvU7/9OP2A6E8B8/+XTzjU5Q+jHVr5Sg5fmGNq+YmuqL/4vKG6J9dbbAwVeSV+2f4uyfPsNZoUy5k4xqqEjeLz8O+V8LUPvt7r+ivnoHJvbDrWjjyN9CqQ64YzzgVZRPU6QPGGI6dr3HVXJnpco5iLnNRvHNmtc58pciv3noNF2pNPvP/notvsNvwzcdO8fY//S71VjvuoSSTTgeWjlmnX1kAyVwc71RPQ2UeDr3PxjyH74pvrNvx6Ffg02+CdivukSgRoqIPLNaaVBttDuyaQETYM126xOnXmZ8qcMu1u3jdyxf4X3//DJ3O+C3KaXcMd37uQX74/CKPnliOezjJZPUUtBswew1ksjC5ZxOnfwVc93o4+M/hux8bzwVc7Sb8xfvg2Pfg9GNxj0aJEBV9bLQDcNVcGYA908WLRX+lzvyk/Yr+1p/bx7lqg6fPrEY/0G345mMbjvOHz1+IcSQJZvF5ezl7jb2c2rvh9I3xnP4CiMAr3g4rL8CFZ+MZaz8e+8uN68d+ENswlOhR0QeeOWsF3In+FdMlTi/beKfearO83uqK/j89uAuAB54bP1F9+NgihWyGfTMlHlLRD4ezT9nL2avt5dSVG6JfX4HWunX6AFf/vL18/vvRjnEQjj8A+QpUrrDXldSgog986yenmZvI8/I9UwDsmSpx/MIa9Vab89UGQFf0r9k9wfxkgcPPjV+rhqfPrHJwfoJDB3fx0NHFuIeTTJ76hp2o3f1S+/vUXuv+W3WonrHHKp7oL9wApRl4/h/jGWs/zj4J89fDgVvg2Bh+KCmhsa3oi8gBEfm2iDwuIo+JyG94x3eJyH0i8pR3OecdFxH5mIgcEZFHROTmnue6w7v/UyJyR3inNTiNVodv/eQ0v/wze8hl7T/HvpkSjXaHV//efRw9ZxdjzU8WABARDl2zayyd9DNnqlw3P8nP7Z/m1PI6i7VG3ENKFs01OHI/3PBWyHj/dab3QX0Z/vjldoIXYHLBXmYycODW8XTS545Y0b/yJrjwnF1UpqSCQZx+C/iPxpgbgdcAHxCRG4EPAfcbY64H7vd+B7gNuN77uRP4BNgPCeAjwK3ALcBH3AdFnBw+ep6V9RZvfMXe7rG337yf1750nlqjzbd+chqA+amNsrurd0/w4nL9sueKk2a7w/Pna1y3UOGquQkAXlhc3+ZRylA89w/QrFnRd7z6PTbGWbtgSzRhw+kDzB2E5RciHea2NGr2A2r39RtzE+M2RiU0thV9Y8xJY8xD3vUV4AlgP3A7cLd3t7uBt3nXbwc+ayzfA2ZFZB/wJuA+Y8x5Y8wF4D7gzUGejB/cJK6LdsBGOf/zV18NwH2P29WV85UN0Z8p51lrtllvjk9Z5PPna7Q6hpcsTLJvxq41eGFxLeZRJQw3iXvFjRvHpvbCv/6Mvf6Te+1lZWHj9vIcrC9BZ3z+Vjj/tL2cfylM77fX3bcUJfEMlemLyEHg1cD3gT3GGFerdgrY413fD/T+BR33jm11/NLXuFNEDovI4TNnzgwzPF+4CGS2kr/o+OxEgesWKjx71n7tnZ8q9Nxm77u81gx9fIPyzBk7zusWKuyftRPSJ5dU9ANlzZvHmdh18fGpvTBztSemYuv0HeU5wFjhHxfcZPT8y2DmKnt96UR841EiZWDRF5FJ4MvAbxpjLioCN3bXkUCKkY0xnzLGHDLGHFpYWNj+ASNyodYklxGmipcvTj7gxSRvfeU+Jgobt8+WC93HjgtHz1nRP7i7wvxkkXxWOKHxTrCsLdqKl81W2M54/uVV74Jsj4Eoewnm2hjNAbkS0l3XwfSVgMDS8ViHpETHQG0YRCSPFfzPG2O+4h1+UUT2GWNOevHNae/4CeBAz8Ov8o6dAF53yfG/9T/0YFisNZidyCMil932b37hIOvNNv/1bT930XHn9MdpovTU0jqlfKZ7LnumS+r0g6Z2/nKX7/hn/96K/W1/dPHxrugvhjq0oVg+CaVZKFTs75N7bLdQJRUMUr0jwKeBJ4wxf9Jz0z2Aq8C5A/haz/H3elU8rwGWvBjoG8AbRWTOm8B9o3csVhZrTWYnCpve9vqXX8Gf/9rPMzNxcfQzU/ZEf4zinVPL6+ybKXc/vK6cLWumHzRr56E8u/ltN7wV7vi/UJq++HhX9MeoxHf5Bc/he8xcpU4/RQzi9H8BeA/wYxF52Dv228AfAl8SkfcDR4F3erfdC7wFOALUgPcBGGPOi8jvA65+7feMMbH/T7hQazB3iahvx1zFfkgsjVG8c2ppnT3TG7HDlTOlsVxAtqOpnYfyFk5/K9z9xyneWT5xiejvhxcfj288SqRsK/rGmH8ALs8+LG/Y5P4G+MAWz3UXMFYdqBZrTQ7smhjqMbNdpz8+8c7JpXVuuXZDkK6cLXNq+STtjiGb2ertU4Zi7cJGdj8o45jpr5y0XUIdMwfgyW/aNhKbxJxKskj9ilw/Tn+ikCWflbGZyO10DC8ur7PXK9UEuGKqSLtjuDBG8w47njUfTr804z12TES/1YDV07Z9hGNyD7TWbBsJJfGkWvSNMVyoNZnbItPfChFhplxgcUxE/1y1QatjuvX5ALu9thFnV8drEdmOpdOxwl0ecj1hNgfFmfER/dVTgLk43nHrCqrhl0gr8ZNq0V9rtmm0OltO5PZjdiLP0pjEO26XL7cBDGz0Cjq7Mh5j3PHUl8B0tq7e6Ud5dnxE3+3bO90TU7m2EaunL7+/kjhSLfounhk23gGb64+L03elmb3xzoK3mOxcVZ1+INS8moNh4x2w3w7GRvS9RVjT+zaOubYR6vRTQbpF3+ugOetH9CfGR/TPrtrzWOjpD+Sc/pkVFf1AcHX2vpz+GIm+E/bJPRvHXCvoqjr9NJBq0V/y6uz9xDu7KoWLtlSMk9W6PY/JnlXFM+U8+ax0PxCUEXF19sNm+mDbMqy8GOx4/FL3FtMXN3pNMTEPiN31S0k8qRZ9V9ky7EQuwA17pzm7Wr9oh624WK3bZl6VnlYRIsLuSlEncoNilHhnzytg6fmN54iT+ipk8he3ksjm7DcYdfqpIOWi7z/Tf9WBWQB+dGwxwBH5o1pvMVHIkrm
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"idata = shorten(X[5])['Force']\n",
"plt.plot(range(len(idata)), idata)\n",
"plt.plot(range(len(X[5]['Force'])),X[5]['Force'])"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "9d829440",
2021-06-07 19:58:49 +02:00
"metadata": {},
"outputs": [],
"source": [
"import pickle\n",
"\n",
2021-06-07 19:58:49 +02:00
"def load_pickles():\n",
" _p = open(Xpickle_file, 'rb')\n",
" X = pickle.load(_p)\n",
" _p.close()\n",
" \n",
" _p = open(ypickle_file, 'rb')\n",
" y = pickle.load(_p)\n",
" _p.close()\n",
" \n",
" return (np.asarray(X, dtype=pd.DataFrame), np.asarray(y, dtype=str))"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "0cd6bffc",
2021-06-07 19:58:49 +02:00
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
2021-06-07 19:58:49 +02:00
"def load_data():\n",
" if os.path.isfile(Xpickle_file) and os.path.isfile(ypickle_file):\n",
" return load_pickles()\n",
" data = []\n",
" label = []\n",
" for user in range(0, user_count):\n",
" user_path = base_path + str(user) + '/split_letters_csv/'\n",
" for file in os.listdir(user_path):\n",
" file_name = user_path + file\n",
" letter = ''.join(filter(lambda x: x.isalpha(), file))[0]\n",
" data.append(pd.read_csv(file_name, delim))\n",
" label.append(letter)\n",
" return (np.asarray(data, dtype=pd.DataFrame), np.asarray(label, dtype=str), np.asarray(file_name))"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "0455518d",
2021-06-07 19:58:49 +02:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 3.58 s, sys: 74.8 ms, total: 3.65 s\n",
"Wall time: 3.65 s\n"
2021-06-07 19:58:49 +02:00
]
},
{
"data": {
"text/plain": [
"(13102,)"
2021-06-07 19:58:49 +02:00
]
},
"execution_count": 25,
2021-06-07 19:58:49 +02:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"X, y = load_data()\n",
"\n",
"X.shape"
2021-06-07 19:58:49 +02:00
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "c96adaf7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7f33537e0be0>]"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD4CAYAAAAAczaOAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAzjklEQVR4nO2de4xkV33nP796Vz9mema6ZzyeGT9jHoaA8Y5sINkIQgLY2ayJRBCsBBZi16sEdhNldyUnkQKbKFI2u8lKRCwREQ4mIhCWQHCyXsDxkrCwMXgMxi+wPX7OjOfpmenuenQ9z/5xzq3umelH1a17z+2+9/eRWlV1q7rq3Kmeb33re37nd8QYg6IoipINckkPQFEURfGHir6iKEqGUNFXFEXJECr6iqIoGUJFX1EUJUMUkh7AeszOzpqrrroq6WEoiqJsKR566KEzxpi51e7b1KJ/1VVXcejQoaSHoSiKsqUQkRfWuk/jHUVRlAyhoq8oipIhVPQVRVEyhIq+oihKhlDRVxRFyRAq+oqiKBlCRV9RFCVDqOgDtVaXz3/vRXp9bTOtDEn9ZfjB50BbkytbDBV94M+//Ry/+eVH+cenTiU9FGWr8H//G3z1V+GoLh5UthaZF31jDF/+wTEA/tcjJxIejbIl6HXgkS/a60/8TaJDUZRRybzoP3zkPM+dqbNrssQ3njhBu9u/4P5/euZl7vzrR9AdxpQBz/4DNM7AxCw8/jeXRjw/vhe+9ltJjExRNiTzov/E8QUA7viZa1hc6vLM6drgvk6vz2995VG+8OARnjpZW+splKxx4hF7+VP/HhaOwuLx5fvadfi7X4cHPgELx1f9dUVJksyL/tFzTYp54YYDMwCcWmwN7vvbH77Ec2fqAHzn8JkkhqdsRs69AJNzMPsKe3txRSz4/c9C7aS9/ty3/I9NUTYg86J/5GyDy2eqXD5TBeDk/NLgvidPLlLK57hy1wT/7xkVfcVx/kWYuQKmL7O3Vzr9U0/YD4TqDnjuH5MZn6KsQ+ZF/+i5Jvt3VNm9rQzAyYVl0T+z2GZ2qsSbr53lgWfPsrjUSWqYymZiIPp77e2Vol87DVOXwVX/HA7/PXRbqz+HoiSEiv65JvtnJigX8uyYKHJycYXo11rsmirzvpsOUGt1+bNvPZvgSNfnH586zS/+ybept7pJDyXd9Pswf8SK/uQcSO7CeKd+CiZn4eAHbcxz6K7kxroRj30ZPv0O6OnfTJbItOgvdXqcqbU4sNNGO3u2VTi5sOzMXq63mJ0q8br9M9z6k5dx13eep9Prr/V0idHvG26/63s8emyeH7x4PunhpJvaCei1YeZKyOVhas8qTn83XPNW6/a/8/HNuYCr14EvfRCOPACnHk96NIpHMi36R881ANi/YwKA3dsqnLok3rGxzy/85OXUWl2eeGnB/0A34FtPnx5c//6L5xIcSQY4/6K9nLnSXk5ftuz0jXFOfw5E4DW/BIsvwbnnkhnrejz+N8vXj3wvsWEo/sm06D972lbm7N/hnP50eeD0jTHW6U9b0T941Q4AHnz+bAIjXZ+Hj5xHBK7YOaGiHzdnnraXM1fYy+nLl0W/tQjdJev0Aa54k7188bt+xzgMRx+E4iRM7rbXlcyQadH/5pOnmCzlee2+7YCNd04sLNFod5lvduj0zMDp79lW4cDOKoee33yi+szpOvtmqrzpml384MXz9LWHUHw8/XU7UbvrJ+zt6cus+++2oO6+cU060Z97FVS2w4v/lMxY1+PMUzB7HRy4CY5swg8lJTY2FH0ROSAi3xSRJ0TkcRH5NXd8p4jcJyJPu8sd7riIyMdF5LCIPCIiN654rtvd458WkdvjO62N6fUN9z1xkre8ajeVYh6Ay7ZXALj+d77O4VN2MdbsVGnwOwev3LkpnfSzp2tcMzfFa/dvZ77Z4cSKiEqJkE4TDt8Pr/oFyLn/Otv2QmsB/uiVdoIXYGrOXuZycODmzemkXz5sRf/yG+Dc83ZRmZIJhnH6XeA/GGOuB94IfFhErgfuBO43xlwH3O9uA9wCXOd+7gA+CfZDAvgocDNwE/DR4IMiCR49Ns+ZWpu3X79ncOxfvG4vP/dqe/sbT9gFNnPO6QMc2DnB6VprU3XjNMbw3Jk618xODmKq4/PNhEeVUp7/NnQaVvQD3vB+G+M0z9kSTVh2+gA7roKFl7wOc0PaDfsBteu65bmJzTZGJTY2FH1jzHFjzPfd9UXgR8A+4Dbgbvewu4F3ueu3AZ81lgeAGRHZC7wDuM8Yc9YYcw64D3hnlCczCsEk7isvmx4cm5ko8SfvewP5nHCfE/1dK0R/plrEGDZVvb6No3pcu3uKy7db0T92Xp1+LASTuLuvXz42fRn88mfs9R/fay8n55bvr+6ApXno97wMcSjOPmMvZ38Ctu2z14NvKUrqGSnTF5GrgDcA3wX2GGOCWrUTQGCZ9wEr/4KOumNrHb/4Ne4QkUMicuj06dMX3x0Z5xpWuHdMlC44Xi3lefXe6UH7hbnpFaI/UQTgfGPziH4wGX3t7CSXz9h46vh5dfqx0HST+BM7Lzw+fRlsv8KJqdg6/YDqDsBY4d8sBJPRs6+A7fvt9fljyY1H8crQoi8iU8BfA79ujLmgbtHYFpSRZB7GmE8ZYw4aYw7Ozc1t/AshOV9vA8tCvpKrZ6cAeOsr59g5ufyhEDz2XKMd27hG5YWX7TeWK2cnma4UmS4XOD6vTj8WmudtxUuhfOl9251/ef17Ib/ib6rqEszmJpoLCkpId14D2y4HBOaPJjokxR+FYR4kIkWs4H/OGPNld/ikiOw1xhx38U2wA8kx4MCKX9/vjh0D3nLR8X8IP/TxONfoMFHKUy7kL7nvX910BcfPN/mvv/z6C47PuG8F55ubx+mfmG+SE9jtvpHsnalwTJ1+PDTOXuryA97876zY3/KHFx4fiP75WIc2EgvHoTIDpUl7e2qP7RaqZIJhqncE+DTwI2PMH6+46x4gqMC5HfjqiuMfcFU8bwTmXQz0deDtIrLDTeC+3R1LhPPN9iXRTsCbrt3Fl37lzYNyzYCZqnVw85so3jmxsMTcdJli3r6Vl89UdSI3LppnoTqz+n2v+gW4/W+hsu3C4wPR30TrOxZecg7fsX2/Ov0MMYzT/yng/cCjIvKwO/ZbwB8AXxSRDwEvAO9x990L3AocBhrABwGMMWdF5PeAoH7td40xif1PON/orBrtrMfA6W+ieOf4/BKXbasMbu/dXuWRo5soP04TjbNQXcPpr0Xw+M0U7ywcu0j098HJJ5Ibj+KVDUXfGPNtQNa4+22rPN4AH17jue4CNkUHqnONtZ3+Wmyr2H+uzRXvLHHN3OTg9r6ZCmfrbZY6vcH6AyUimueWs/th2YyZ/uJx2Pu65dvbD8BT37BtJGSt/+pKWsjsitwwTr+QzzFdKWyq6p0T80vsdaWaALunres/vagtfSOnGcLpV7a7390kot9tQ+2UbR8RMLUHuk3bRkJJPZkV/TBOH2yJ52aJd2qtLout7mAlMcAut4L4TE1FP1L6fSvc1RHXE+YLUN6+eUS/dgIwF8Y7wbqCenwl0srmIZOi3+sb5psddozo9MGWbW6WeOeEK83cu0L0g8nnM7XN8cGUGlrzYPprV++sR3Vm84h+sG/vthUxVdA2onbq0scrqSOTor/Q7GDM8sTsKGyvFjdNvBOI/p4VE7lBV1B1+hHTcDUHo8Y7YL8dbBrRd4uwtu1dPha0jVCnnwkyKfrB4qpRM337OyXmN4nTD4R95arhXW4x2RnN9KMlqLMP5fQ3kegHwj613HNq0Aq6rk4/C2RS9IN4Jkymv3OiyOnFFmYT7Ia06LZGnC4vF2FVinmmKwV1+lET1NmPmumDbcuweDLa8YSl5RbTl5d7TjExC4jd9UtJPdkU/TGc/qv3bqPW6g7aHyRJsB/uZPnCytu5qbJm+lEzTryz5zUw/+LycyRJqwa54oWtJPIF+w1GnX4myKTon6uHd/qvPzADwA+Pno9wROGot7qIwETpwnr82amyOv2oWavZ2jDs+2f28qXvRzeesLRrUJ669Pjkbp3IzQi
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"idata = shorten(X[5])['Force']\n",
"plt.plot(range(len(idata)), idata)\n",
"plt.plot(range(len(X[5]['Force'])),X[5]['Force'])"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "2512addb",
2021-06-07 19:58:49 +02:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 3.14 s, sys: 9.72 ms, total: 3.15 s\n",
"Wall time: 3.15 s\n"
2021-06-07 19:58:49 +02:00
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"<timed exec>:1: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n"
]
}
],
"source": [
"%%time\n",
"XX = np.array(list(map(shorten, X)))\n"
]
},
{
"cell_type": "markdown",
"id": "e9c16d84",
2021-06-07 19:58:49 +02:00
"metadata": {},
"source": [
"**How to fix this error**:\n",
"```python\n",
"<timed exec>:1: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n",
"```\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "28262137",
2021-06-07 19:58:49 +02:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"count 13102.000000\n",
"mean 52.510991\n",
"std 35.307125\n",
"min 4.000000\n",
"50% 48.000000\n",
"95% 95.000000\n",
"96% 101.000000\n",
"97% 108.000000\n",
"98% 124.000000\n",
"99% 157.000000\n",
"max 1512.000000\n",
2021-06-07 19:58:49 +02:00
"dtype: float64"
]
},
"execution_count": 9,
2021-06-07 19:58:49 +02:00
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAeMUlEQVR4nO3de5yUdd3/8ddnT5wPAgsuJxcRRBTxsCrkAVRMJINSMywr0yRNy7Lbwiwru+vO6tfdXXnf5d2ddjANDykpyp1KeaeiLMgZweUgpwWWRUFYlt2Z+fz+mGthXBZ2YGfnmsP7+XjwcOa6rt39eMG8GL5z7ay5OyIikv0Kwh5ARERSQ0EXEckRCrqISI5Q0EVEcoSCLiKSI4rC+sJ9+vTx8vLysL68iEhWmj9//nZ3L21pX2hBLy8vp7KyMqwvLyKSlczs7UPt05KLiEiOUNBFRHKEgi4ikiNaDbqZ/dbMtpnZ0kPsNzP7uZlVmdliMzsj9WOKiEhrknmG/iAw8TD7LwOGBb+mAf/V9rFERORItRp0d38J2HGYQ6YAv/e4uUBPMytL1YAiIpKcVKyhDwA2JNzfGGw7iJlNM7NKM6usqalJwZcWEZEmab0O3d3vB+4HqKio0Pv2ikjWiURj7Is0/YrS0HS7sdn9SPTA9miMfY3R/R938Yi+jB7UM+WzpSLom4BBCfcHBttERFIqGvN4KBvjYWxIDGfT9uiBuB44JnaYj4vHtuFwHxeEuSEaIxpr+3PRvt06ZGzQZwK3mtkjwDnATnevTsHnFZEs4u7UN8bYubeRXfWN7Nr/38j++3saogdi2tjCs9yEfS0FNpKCmJYUFdChqIAORYXBfwvi24oL6VBYQJcORfTqcmD//uOLmx3f9PHFBZQUHrjd2seVFBZgZik44wdrNehm9jAwHuhjZhuBbwPFAO7+K2AWMAmoAuqAz7bLpCLSrtydfZEgyC3EeFd9pNXtjdHDB7e40OhYVBhE8ODYdSoupGen4v1hjB/z/gAfHMrChEjHtx/0ccHXKyksoKCgfWKaCVoNurtf08p+B25J2UQikjKRaIxN7+5l7fY9rNu+h+pd9QfF+L2ESDdEY4f9fB2LC+jesZjunYrp3rGIY7qUcFzvLnTvVJSwvXj//R6dDhzbrWMxJUX6Xsb2FNqbc4lIajRFe11tHeu272Ht9j28XbuHdbV1bNhR975lipLCArp3KqZHpyK6dyqmZ6diBvfqTPeORQfFuCnEids7FBWG+H8qrVHQRbLEu3UNrNq6m7e2vUfVtt2s276Ht2vr2PBO3fuWOjqXFFLeuwsnlXXjslOOpbxPF4b06cJxvTtT2rVDu63fSvgUdJEMs7OukVXb3uOtrbtZtfU93tr2Hqu27qbmvX37j+lUXEh5ny6ceGw3Lj3lWIb07kJ5ny6U9+5MaTdFO18p6CIhqnlvH3PX1LJg/Tv7A74tIdydSwoZ1rcr44aXMqxvV4b368awfl3p36NTTr+4J0dHQRdJo5r39vHa2lrmrqll7podVG3bDcSfcQ/r15Xzh5UyvF9XhvXryrC+3RjQU+GW5CnoIu2oMRrjldW1PL98K3PX1PJWEPAuJYWcNaQXHztzIGOO783J/btTVKgrQKRtFHSRFGuK+KzF1cxevoV36xrpXFLIWeW9uOKMgYwd2ptTFHBpBwq6SAo0RmO8urqWWUuqeW5ZPOJdOxQx4aS+TBpVxgXDS+lYrEv+pH0p6CJHKRKN8eqaIOJLt/BOXSNdSgqZMLIfH1LEJQQKusgRiERjzF2zg2eWVDN72RZ27GnYH/FJo8oYp4hLiBR0kVZEojFeX7uDp4Nn4jv2NNC5pJCLT4o/Ex9/oiIumUFBF2lBNOa8tvbAcsr23fGIXzSiL5efWsa44X3pVKKIS2ZR0EUC0Zjz+todPLNkM88t3cr23fvoVFzIRSf15fJRZYw/URGXzKagS95bu30PD7y8lllLtrwv4h8aVcaFirhkEQVd8tba7Xv4xYtv8eQbmyguLGDCSfEXNi8cUUrnEj00JPvoT63knXXb9/CLF6t4cuEmiguNG84bwrQLhlLarUPYo4m0iYIueSMx5EUFxnUfKOfz446nb7eOYY8mkhIKuuS8t2vjIf/LG/GQf2ZsOTeNV8gl9yjokrPW19bxyzlv8fiCTRQWGJ8eexw3jxtK3+4KueQmBV1yzp59Eb739HIem7+RgiDkN40bSj+FXHKcgi45Zc++CJ99YB7z17/Dp8Ycx83jFXLJHwq65Iy6hgiffTAe8599/DQ+PLp/2COJpJXekFlyQl1DhOsemEfluh38u2IueUpBl6xX1xDh+gfjMf/Z1NOZrJhLnlLQJas1xfz1tfFn5oq55DMFXbLW3oYoNzxYuT/mU04bEPZIIqFS0CUr7W2Icv2D83htba1iLhJQ0CXr7G2IcsPv4jH/6dWKuUgTXbYoWWVvQ5TP/X4er66p5adXj+YjpyvmIk30DF2yRlPMX1ldy//72Gg+evrAsEcSySgKumSF+sYoN/6+kldW1/KTq0ZzxRmKuUhzSQXdzCaa2UozqzKz6S3sH2xmc8zsDTNbbGaTUj+q5Kv6xiif+10lL6/ezo+vGs2VZyrmIi1pNehmVgjcB1wGjASuMbORzQ77JjDD3U8HpgL/mepBJT81PTNvivlVirnIISXzDP1soMrd17h7A/AIMKXZMQ50D273ADanbkTJV00x/2fVdn505amKuUgrkgn6AGBDwv2NwbZE3wGuNbONwCzgiy19IjObZmaVZlZZU1NzFONKvqhvjDLtD/P5Z9V27r3yVD5WMSjskUQyXqpeFL0GeNDdBwKTgD+Y2UGf293vd/cKd68oLS1N0ZeWXNMU85dW1XDvFadytWIukpRkgr4JSHxEDQy2JboBmAHg7q8CHYE+qRhQ8kt9Y5TPN8X8ylFcfZZiLpKsZII+DxhmZkPMrIT4i54zmx2zHrgYwMxOIh50ranIEWmK+T9W1fDDK0bx8bMGhz2SSFZpNejuHgFuBWYDK4hfzbLMzO4xs8nBYV8FbjSzRcDDwHXu7u01tOSefZEoN/8xHvN/u2IUU89WzEWOVFLf+u/us4i/2Jm47e6E28uBc1M7muQLd+eWhxYwZ2U85tco5iJHRd8pKqHbsaeB51ds46ZxQxVzkTZQ0CV00Vh8dW5AT/0wZ5G2UNAlVPsiUaY/sQSAE/p2C3kakeymoEto4i+ELuDFN7fxg4+OYuzQ3mGPJJLVFHQJxb5IlC8EMf/+R0/hE+do7VykrRR0SbuGSIxbHlrAC29u418/cgqfPOe4sEcSyQkKuqRVQyTGFx6az/MrtvG9j5zCtWMUc5FUUdAlrW575I14zKeczKcUc5GUUtAlbXbvi/Ds0i3ccN4QPjW2POxxRHKOgi5pE4nGACjroevNRdqDgi5p0RCJ8S+PLgZgeD9dby7SHhR0aXcNkRi3/GkBz6/Yyj1TTuaC4XovfJH2oKBLu2qK+d+Wx2P+aa2di7QbBV3a1VdmLFTMRdJEQZd2U98Y5ZnF1Vz3gXLFXCQNFHRpN5HgXRT7dddVLSLpoKBLu2iMxvjqjIUAjCjTVS0i6aCgS8o1RmN88U9vMHvZVr47+WQuPLFv2COJ5AUFXVKqKebPLdvCtz88ks98oDzskUTyhoIuKXXHo4t4btkW7r58JJ89d0jY44jkFQVdUiYac55cuJlPnjOY689TzEXSTUGXlInE4u/VoqtaRMKhoEtKNEZjfPmRhQCcVNY93GFE8pSCLilxx6OLeHbpFr51+UguGdkv7HFE8pKCLm3mHl87/8Q5g7lBa+cioVHQpc0ao/HvCO3brUPIk4jkNwVd2qQxGuO2R94AYKTWzkVCpaBLm3z9scX7184/ePKxYY8jktcUdGmTvyzcxDVnD9LauUgGUNDlqDVGY7hDaVetnYtkAgVdjkridecj+/cIdxgRAZIMuplNNLOVZlZlZtMPcczVZrbczJaZ2Z9SO6ZkmumPL+GZJdV880MnMfEUrZ2LZIKi1g4ws0LgPuASYCMwz8xmuvvyhGOGAXcC57r7O2am90vNcU8t3MT
2021-06-07 19:58:49 +02:00
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"\n",
2021-06-07 19:58:49 +02:00
"\n",
"X_len = np.asarray(list(map(len, XX)))\n",
"l = []\n",
"sq_xlen = pd.Series(X_len)\n",
"ptiles = [x*0.01 for x in range(100)]\n",
"for i in ptiles:\n",
" l.append(sq_xlen.quantile(i))\n",
"plt.plot(l, ptiles)\n",
"sq_xlen.describe(percentiles=[x*0.01 for x in range(95,100)])"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "64fce587",
2021-06-07 19:58:49 +02:00
"metadata": {},
"outputs": [],
"source": [
"def plot_data(data):\n",
" fig, axs = plt.subplots(4, 3, figsize=(3*3, 3*4))\n",
" t = data['Millis']\n",
" axs[0][0].plot(t, data['Acc1 X'])\n",
" axs[0][1].plot(t, data['Acc1 Y'])\n",
" axs[0][2].plot(t, data['Acc1 Z'])\n",
" axs[1][0].plot(t, data['Acc2 X'])\n",
" axs[1][1].plot(t, data['Acc2 Y'])\n",
" axs[1][2].plot(t, data['Acc2 Z'])\n",
" axs[2][0].plot(t, data['Gyro X'])\n",
" axs[2][1].plot(t, data['Gyro Y'])\n",
" axs[2][2].plot(t, data['Gyro Z'])\n",
" axs[3][0].plot(t, data['Mag X'])\n",
" axs[3][1].plot(t, data['Mag Y'])\n",
" axs[3][2].plot(t, data['Mag Z'])\n",
"\n",
" for a in axs:\n",
" for b in a:\n",
" b.plot(t, data['Force'])\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "bd86589d",
2021-06-07 19:58:49 +02:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"((12973,), (62, 15))"
2021-06-07 19:58:49 +02:00
]
},
"execution_count": 11,
2021-06-07 19:58:49 +02:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"threshold_p = 0.99\n",
"threshold = int(sq_xlen.quantile(threshold_p))\n",
"len_mask = np.where(X_len <= threshold)\n",
"\n",
"X_filter = XX[len_mask]\n",
"y_filter = y[len_mask]\n",
"\n",
"X_filter.shape, X_filter[0].shape"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "ce528a76",
2021-06-07 19:58:49 +02:00
"metadata": {},
"outputs": [],
"source": [
"from tensorflow.keras.preprocessing.sequence import pad_sequences\n",
"a = [x.drop(labels='Millis', axis=1) for x in X_filter]"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "59bf9140",
2021-06-07 19:58:49 +02:00
"metadata": {},
"outputs": [],
"source": [
"X_filter = pad_sequences(X_filter, dtype=float, padding='post')"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "66338e0b",
2021-06-07 19:58:49 +02:00
"metadata": {},
"outputs": [],
"source": [
"def plot_data(data):\n",
" fig, axs = plt.subplots(5, 3, figsize=(3*3, 3*5))\n",
" axs[0][0].plot(data[0])\n",
" axs[0][1].plot(data[1])\n",
" axs[0][2].plot(data[2])\n",
" axs[1][0].plot(data[3])\n",
" axs[1][1].plot(data[4])\n",
" axs[1][2].plot(data[5])\n",
" axs[2][0].plot(data[6])\n",
" axs[2][1].plot(data[7])\n",
" axs[2][2].plot(data[8])\n",
" axs[3][0].plot(data[9])\n",
" axs[3][1].plot(data[10])\n",
" axs[3][2].plot(data[11])\n",
" axs[4][0].plot(data[12])\n",
" axs[4][1].plot(data[13])\n",
"\n",
"# for a in axs:\n",
"# for b in a:\n",
"# b.plot(t, data['Force'])\n"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "8df7f1f4",
2021-06-07 19:58:49 +02:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(10378, 157, 15)\n",
"(2595, 157, 15)\n",
"(10378, 26)\n",
2021-06-07 19:58:49 +02:00
"(2595, 26)\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAj4AAANZCAYAAAAPtDT6AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOy9eZhcZZn3/7mrqqu3LJ2dkIUEiEASViPggAxogIAzBsZlYBZAGZlReF9ffUcN48/B0YmDzDAqo8O8KIzgKBEVJxlFMGwuyBa2QAIhTRJIQva1l3R3Lffvj/OcqlPVVd3V3bV23Z/r6qurnrPU08lT53zPvYqqYhiGYRiGUQ+EKj0BwzAMwzCMcmHCxzAMwzCMusGEj2EYhmEYdYMJH8MwDMMw6gYTPoZhGIZh1A0mfAzDMAzDqBtGtfARkbtEZLeIvFLAvueJyPMiEheRDwXGTxORJ0VknYisFZE/Le2sDcMwDMMoFaNa+ADfA5YUuO9bwDXAD7PGu4GrVHWBO9c3RKStSPMzDMMwDKOMRCo9gVKiqr8RkTnBMRE5Dvg2MAVP1HxcVV9T1S1uezLrHK8HXr8tIrvdsQdLOnnDMAzDMIrOqBY+ebgD+BtV3SgiZwH/Dry3kANF5EwgCrxRwvkZhmEYhlEi6kr4iMgY4A+AH4uIP9xY4LHTge8DV6tqcrD9DcMwDMOoPupK+ODFNB1U1dOGcpCIjAN+AXxBVZ8qxcQMwzAMwyg9oz24OQNVPQxsFpEPA4jHqQMdIyJR4GfAPar6kzJM0zAMwzCMEiGjuTu7iNwLnA9MBnYBNwGPArcD04EGYIWqfllE3oUncCYAPcBOVV0gIn8B/CewLnDqa1T1xXL9HYZhGIZhFIdRLXwMwzAMwzCC1JWryzAMwzCM+mbUBjdPnjxZ58yZU+lpGKOE5557bq+qTinnZ9oaNoqJrWGj1inWGh61wmfOnDmsWbOm0tMwRgki8ma5P9PWsFFMbA0btU6x1rC5ugzDMAzDqBtM+BiGYRiGUTeY8DGMMiIiS0Rkg4i0i8iySs/HMAyj3jDhYxhlQkTCeA1yLwHmA1eKyPzKzsowDKO+GLXBzbnY29nLuV97tNLTKDrvmjOR7197VqWnYQzOmUC7qm4CEJEVwFJgfUVnVQX8at1OPnPfS8ST1gZvJIxtauDZLyyu9DSqivfe+jh/dPJ0PnPRCZWeilEl1JXwaW4Ic/W751R6GkXld+17eXn7oUpPwyiMGcDWwPttQIZiFZHrgOsAZs+eXb6ZVZgNOzvo7I3z8ffMJZRuIGwMkcaIGfGz2bSni9sebTfhY6Qom/ARkbuAPwJ2q+rCHNvPB1YCm93Q/ar6ZbdtCfBNIAx8V1VvHs4cWhsj3HjpScM5tGr50qp1bH1+W6WnYRQJVb0DuANg0aJFdVNWPeEqyN94yUmEQiZ8DMMoHeV8PPgesGSQfX6rqqe5H1/0WFzEAIhA3dwda5/twKzA+5lurO5JJr1VbKLHMIxSUzbho6q/AfYP49BUXISq9gF+XIQBCIK1W6sZngXmichcEYkCVwCrKjynqiChSthEj2EYZaDaHMLvFpGXROSXIrLAjeWKi5hR/qlVJyGBpCmfmkBV48ANwEPAq8B9qrqusrMaOb9+fQ8rXxyZ4Sqp3lo2DMMoNdUU3Pw8cIyqdorIpcB/A/OGcoJ6DAwNhcziU0uo6gPAA5WeRzH5/pNb2HbgCEtPG/7zSDKpFtRsFB21i6ORg6qx+KjqYVXtdK8fABpEZDJDiItQ1TtUdZGqLpoypay9+CqGmMXHqDCJpI54DSaS5uoyik8iaddGoz9VI3xE5CgR75FPRM7Em9s+LC5iQCzGx6g0CYX4CG8wCVXCZvExisxI16UxOilnOvu9wPnAZBHZBtwENACo6n8AHwI+ISJx4AhwhXp2yriI+HERYeCu0RAXUSwsxseoNKqaysoaLsmkWkaXUXTM4mPkomzCR1WvHGT7t4Bv5dk26uIiikVIxNLZjYqSSGqqDs+wz2FZXUYJMIuPkYuqcXUZw8MsPkalSSSVkXaaSCSx4Gaj6IzUEmmMTkz41DrixfhY9oJRKZKqI3YpqKqls4+QrVu3csEFFzB//nwWLFjAN7/5TQD279/PhRdeCLBQRFaLyAQA8bhNRNpFZK2InOGfS0SuFpGN7ufqwPg7ReRld8xtflxmtWIWHyMXJnxqHP9mYbrHqBRFcXVZVteIiUQi3Hrrraxfv56nnnqKb3/726xfv56bb76Z973vfQCvAI8Ay9whl+CVDJmHVwbkdgARmYgXg3kWXgHZm3yx5Pb5eOC4warxVxSL8TFyYcKnxvHdA/b1NipFUkfuUkio1fEZKdOnT+eMMzyjzdixYznppJPYvn07K1eu5OqrU0abu4HL3OulwD3q8RTQJiLTgYuB1aq6X1UPAKuBJW7bOFV9yiWe3BM4V1USH6kP1hiVmPCpcfxbhcX5GJUiqSO3+CTN4lNUtmzZwgsvvMBZZ53Frl27mD59ur9pJzDNvc5XFX+g8W05xjMQketEZI2IrNmzZ09R/p7hYhYfIxcmfGocPwXYhI9RKRJJJZEYqcUHEz5ForOzkw9+8IN84xvfYNy4cRnbnKWmpBeLaiokG4zxsThIw8eET40jFuNjVJhixPh4LSuKNKE6JhaL8cEPfpA///M/50/+5E8AmDZtGjt27ADAuat2u93zVcUfaHxmjvGqJWjxsUBnw8eET42TivGx77RRIYqR1WXBzSNHVbn22ms56aST+MxnPpMa/8AHPsDdd9/tv70aWOlerwKuctldZwOHVHUHXrHYi0Rkggtqvgh4yG07LCJnu2yuqwLnqkriAUtkLGHxPoaHCZ8ax2J8SsdnP/tZTjzxRE455RSA40Skzd8mIje6lN4NInJxYHyJG2sXkWWB8bki8rQb/5FrvzIqSOrI11/SgptHzBNPPMH3v/99Hn30UU477TROO+00HnjgAZYtW8bq1asBFgKLgZvdIQ8Am4B24DvAJwFUdT/wFbx2Qc8CX3ZjuH2+6455A/hlef664REU5H1xEz6GRzV1ZzeGgX+zMOFTfC688EL+6Z/+iUgkgoj0ADcCnxeR+Xg94xYARwMPi8g73GHfBi7EC/x8VkRWqep64GvA11V1hYj8B3AtLn241kkmR27xMeEzcs4999y8cSyPPPIIIvKKqi72x1y8z/W59lfVu4C7coyvwRNQNUEwq8uEj+FjFp8aJxXjU9lpjEouuugiIpHUs0EX6fiGpcAKVe1V1c14T79nup92Vd2kqn3ACmCpcwu8F/iJOz6YUlzzJFRJjrCIprm6jKHSvruTT/zXc/TEEnn3ybD4mKvLcJjwqXH8wqlq3+lSM5m0WX+oacCTgIOqGs8a70c1pQIXin9zGYnRJ6FYk1JjSHxp1Tp++cpOnt68P+8+5uoyclHO7ux3AX8E7FbVfqZSEflz4PN4YSsdwCdU9SW3bYsbSwBxVV1UrnlXO/69wlxdw2Px4sXs3Lmz3/jy5ctZunRp6jWeUe0HpZ6Pqt4B3AGwaNGimvhP9YsXjsRqk0wqYdM9xhBoiYYB6OqN590nKHxiIyy5YIweyhnj8z287uv35Nm+GfhDVT0gIpfgXfzPCmy/QFX3lnaKtYfF+IyMhx9+eMDt3/ve9/j5z38OsFnTvpx86b7kGd+HVxU34qw+VZ8GPBT8e8tI4nzM1WUMlTGN3u3r8JFY3n3iZvExclA2V5eq/gbIa5NU1d+78ugAT5FZL8LIQ8hifErGgw8+yC233MKqVasAglfNVcAVItIoInPxehY9g5cBM89lcEXxAqBXOcH0GPAhd3wwpbjm8Wv4jKSWj7WsMAC+/Vg777318YL2bWn0LD57Onrz7pMZ45M/FsioL6o1xudaMtMkFfiViDwnItflO6gW4yNGjFl8SsYNN9xAR0eH39l6vsvGQlXXAfcB64EHgetVNeGsOTfg1UF5FbjP7QueG/czItKOF/NzZ3n/mtIRdHUNFzXhU3f8w/+s47M/filj7J8f2sCmPV0FHe+vt90DCJ9Mi4/3evvBIzz/1oF8hxh
2021-06-07 19:58:49 +02:00
"text/plain": [
"<Figure size 648x1080 with 15 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"from sklearn.model_selection import train_test_split\n",
"from sklearn.preprocessing import LabelEncoder, LabelBinarizer\n",
"import tensorflow as tf\n",
"\n",
"lb = LabelBinarizer()\n",
"\n",
"yt_filter = lb.fit_transform(y_filter)\n",
"\n",
2021-06-07 19:58:49 +02:00
"X_train, X_test, y_train, y_test = train_test_split(X_filter, yt_filter, test_size=0.2, random_state=177013)\n",
"\n",
"print(X_train.shape)\n",
"print(X_test.shape)\n",
"print(y_train.shape)\n",
"print(y_test.shape)\n",
"\n",
"plot_data(X_filter[0].T)"
]
},
{
"cell_type": "markdown",
"id": "289b59cc",
"metadata": {},
"source": [
"fig, axs = plt.subplots(13,2,figsize=(20, 60), sharey=True)\n",
"data_count = int(len(X_train)/10)\n",
"for i,j in zip(X_train[:data_count], lb.inverse_transform(y_train)[:data_count]):\n",
" num = ord(j) - 64\n",
" f = i.T[13]\n",
" r = int((num-1)/2)%13\n",
" c = (num-1)%2\n",
" axs[r][c].title.set_text(f'{j}')\n",
" axs[r][c].plot(f)\n",
"plt.savefig('./all_forces.png')"
]
},
2021-06-07 19:58:49 +02:00
{
"cell_type": "code",
"execution_count": 16,
"id": "13d90f08",
2021-06-07 19:58:49 +02:00
"metadata": {},
"outputs": [],
"source": [
"# FIRST CELL: set these variables to limit GPU usage.\n",
"os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' # this is required\n",
"os.environ['CUDA_VISIBLE_DEVICES'] = '2' # set to '0' for GPU0, '1' for GPU1 or '2' for GPU2. Check \"gpustat\" in a terminal."
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "c52f868c",
"metadata": {},
"outputs": [],
"source": [
"accs = []"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "93d225c1",
"metadata": {},
"outputs": [],
2021-06-07 19:58:49 +02:00
"source": [
"import tensorflow as tf\n",
"from tensorflow.keras.models import Sequential\n",
"from tensorflow.keras.layers import Dense, Flatten, BatchNormalization, Dropout\n",
"from tqdm import tqdm\n",
2021-06-07 19:58:49 +02:00
"\n",
"\n",
"def build_model():\n",
" model = Sequential()\n",
2021-06-07 19:58:49 +02:00
"\n",
" model.add(BatchNormalization(input_shape=X_filter[0].shape))\n",
" \n",
" model.add(Flatten())\n",
2021-06-07 19:58:49 +02:00
"\n",
" for i in range(DENSE_COUNT):\n",
" model.add(Dense(DENSE_NEURONS, activation='relu'))\n",
" \n",
" for i in range(DENSE2_COUNT):\n",
" model.add(Dense(DENSE2_NEURONS, activation='relu'))\n",
" \n",
" model.add(Dense(26, activation='softmax'))\n",
2021-06-07 19:58:49 +02:00
"\n",
" model.compile(\n",
" optimizer=tf.keras.optimizers.Adam(0.001),\n",
" loss=\"categorical_crossentropy\", \n",
" metrics=[\"acc\"],\n",
" )\n",
2021-06-07 19:58:49 +02:00
"\n",
" return model\n",
"# model.summary()\n"
2021-06-07 19:58:49 +02:00
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "29d17e4c",
2021-06-07 19:58:49 +02:00
"metadata": {},
"outputs": [
{
"name": "stderr",
2021-06-07 19:58:49 +02:00
"output_type": "stream",
"text": [
" 0%| | 0/20 [00:00<?, ?it/s]"
2021-06-07 19:58:49 +02:00
]
}
],
"source": [
"for i in tqdm(range(AVG_FROM)):\n",
" model = build_model()\n",
" \n",
" model.fit(X_train, y_train, \n",
" epochs=EPOCH,\n",
" batch_size=128,\n",
" shuffle=True,\n",
" validation_data=(X_test, y_test),\n",
" verbose=0,\n",
" )\n",
" # Evaluate the model on the test data using `evaluate`\n",
"# print(\"Evaluate on test data\")\n",
" results = model.evaluate(X_test, y_test, batch_size=128)\n",
"# print(\"test loss, test acc:\", results)\n",
" accs.append((model,results[1]))\n",
" \n",
"model.save('./model')"
2021-06-07 19:58:49 +02:00
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5cf54d38",
2021-06-07 19:58:49 +02:00
"metadata": {},
"outputs": [],
2021-06-07 19:58:49 +02:00
"source": [
"np.mean(np.delete(accs,0,1).astype('float64'))"
2021-06-07 19:58:49 +02:00
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bc2bcc85",
2021-06-07 19:58:49 +02:00
"metadata": {},
"outputs": [],
"source": [
"exit()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "56db820d",
"metadata": {},
"outputs": [],
"source": []
2021-06-07 19:58:49 +02:00
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}