548 lines
27 KiB
Plaintext
548 lines
27 KiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 1,
|
||
|
"id": "8c784b5a",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"import os\n",
|
||
|
"import pickle\n",
|
||
|
"import pandas as pd\n",
|
||
|
"import numpy as np\n",
|
||
|
"import tensorflow as tf\n",
|
||
|
"import matplotlib.pyplot as plt\n",
|
||
|
"from tensorflow.keras.preprocessing.sequence import pad_sequences\n",
|
||
|
"from sklearn.model_selection import train_test_split\n",
|
||
|
"from sklearn.preprocessing import LabelEncoder, LabelBinarizer\n",
|
||
|
"from tensorflow.keras.models import Sequential\n",
|
||
|
"from tensorflow.keras.layers import Dense, Flatten, BatchNormalization\n",
|
||
|
"\n",
|
||
|
"os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'\n",
|
||
|
"os.environ['CUDA_VISIBLE_DEVICES'] = '2'\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"delim = ';'\n",
|
||
|
"user_count = 100\n",
|
||
|
"base_path = '/opt/iui-datarelease1-sose2021/'\n",
|
||
|
"\n",
|
||
|
"Xpickle_file = './X.pickle'\n",
|
||
|
"\n",
|
||
|
"ypickle_file = './y.pickle'"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 2,
|
||
|
"id": "7b486d61",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def load_pickles():\n",
|
||
|
" _p = open(Xpickle_file, 'rb')\n",
|
||
|
" X = pickle.load(_p)\n",
|
||
|
" _p.close()\n",
|
||
|
" \n",
|
||
|
" _p = open(ypickle_file, 'rb')\n",
|
||
|
" y = pickle.load(_p)\n",
|
||
|
" _p.close()\n",
|
||
|
" \n",
|
||
|
" return (np.asarray(X, dtype=pd.DataFrame), np.asarray(y, dtype=str))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 3,
|
||
|
"id": "5ea384ea",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def shorten(npList):\n",
|
||
|
" temp = npList['Force']\n",
|
||
|
" thresh = 100\n",
|
||
|
" leeway = 5\n",
|
||
|
" \n",
|
||
|
" temps_over_T = np.where(temp > thresh)[0]\n",
|
||
|
" return npList[max(temps_over_T[0]-leeway,0):temps_over_T[-1]+leeway]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 4,
|
||
|
"id": "09aad3f2",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def load_data():\n",
|
||
|
" if os.path.isfile(Xpickle_file) and os.path.isfile(ypickle_file):\n",
|
||
|
" return load_pickles()\n",
|
||
|
" data = []\n",
|
||
|
" label = []\n",
|
||
|
" for user in range(0, user_count):\n",
|
||
|
" user_path = base_path + str(user) + '/split_letters_csv/'\n",
|
||
|
" for file in os.listdir(user_path):\n",
|
||
|
" file_name = user_path + file\n",
|
||
|
" letter = ''.join(filter(lambda x: x.isalpha(), file))[0]\n",
|
||
|
" data.append(pd.read_csv(file_name, delim))\n",
|
||
|
" label.append(letter)\n",
|
||
|
" return (np.asarray(data, dtype=pd.DataFrame), np.asarray(label, dtype=str), np.asarray(file_name))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 5,
|
||
|
"id": "37d66d26",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"CPU times: user 2.76 s, sys: 205 ms, total: 2.97 s\n",
|
||
|
"Wall time: 2.97 s\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"%%time\n",
|
||
|
"x, y = load_data()\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 6,
|
||
|
"id": "3178395b",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"CPU times: user 3.22 s, sys: 2.07 ms, total: 3.22 s\n",
|
||
|
"Wall time: 3.22 s\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"<timed exec>:1: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"%%time\n",
|
||
|
"f_data = np.array(list(map(shorten, x)))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 7,
|
||
|
"id": "dcbb85b7",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"count 13102.000000\n",
|
||
|
"mean 61.169058\n",
|
||
|
"std 30.698514\n",
|
||
|
"min 10.000000\n",
|
||
|
"50% 57.000000\n",
|
||
|
"95% 102.000000\n",
|
||
|
"96% 107.000000\n",
|
||
|
"97% 113.000000\n",
|
||
|
"98% 127.000000\n",
|
||
|
"99% 156.000000\n",
|
||
|
"max 1522.000000\n",
|
||
|
"dtype: float64"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 7,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXYAAAD4CAYAAAD4k815AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAf/ElEQVR4nO3deXTU9b3/8ec7O0kIIQmRJWCCgBUQBSNgq9VWWtFaqT+thap1Q2xvvV3scu1tr23t7T1Ve7rdUhXRWrciWrXUUnEp1dsqS0BBCIthkyCQkIRAErLMzOf3xww4xoQMMMl3ltfjnJyZ75Lkdb7JvPLNZ76LOecQEZHEkeJ1ABERiS4Vu4hIglGxi4gkGBW7iEiCUbGLiCSYNK++cVFRkSstLfXq24uIxKVVq1btc84NOto6nhV7aWkpFRUVXn17EZG4ZGY7elpHQzEiIglGxS4ikmBU7CIiCabHYjezh8ysxszWdbPczOw3ZlZlZmvNbFL0Y4qISKQi2WN/GJh+lOUXA6NDH3OAe088loiIHK8ei9059xpQf5RVZgCPuKBlQL6ZDYlWQBEROTbRGGMfBuwMm64OzfsQM5tjZhVmVlFbWxuFby0iIp316XHszrl5wDyA8vJyXS9YROKSP+Bo8/lp6wjQ2vmxw0+bL0Cb7/3nnR8v/EgxZwzP77V80Sj2XcDwsOmS0DwRkV7l8wdo9QVo6/C//9gRoM334ce2ztNHKd73Hz/4uYcLvMN/Yvulxf0zY77YFwG3mtkCYArQ6JzbHYWvKyIJwDlHU5uP/S0dNB56/+Pw9MHWjtAe7uHS/XDBtnVTvP7A8ResGWSlpZKVnkJmF4+5mWkU5qSSmZ5CVlp3jylkpr//OZlpKWSld/2YefgxLQUzi+IW/rAei93M/ghcABSZWTXwQyAdwDl3H7AYuASoAlqAG3orrIh4p7XDz4FDHezvVMz7W9q7nB/+cbQCTkuxbkswKz2F/H7pZOVlHr04w8r2g5/f/WNaivV6wXqlx2J3zs3qYbkDvhq1RCLSZwIBx679h6iqbWJLTRN7GluPFHTj4eI+1E7joQ5aOwLdfh0zyMtKJz87nQH9gh/DC7IZ0C+N/H4ZwXmhZfmh54fnZ6X3/h5ssvHsImAi0nfafQG21zVTVRMs8KraJqpqmtha28yhDv+R9bLSUxiYnXGknE8uzOaM7AHBQg6bn3+kpIPz+melkZKico4VKnaRBBIIOLbXNbP+vQNs2H2Ad0JFvqO+5QPDIcPy+3FKcS5TygoZVZx75KMgJ8PD9BItKnaRONXuC7B570Eq3zvA+vcaj5R5c3twDzwtxSgtymHMSf255PQhR8p75KAcsjP00k9k+umKxImG5nZe3rCXFdvqWf/eAd6pOXjksLucjFROG5LHlWeVMG7oAMYOzWPMSf3JSNN1/pKRil0khu1pbOXFyj0sWb+HZVvr8QccBTkZjBuax8fHjGTc0DzGDc2jtDBHY9xyhIpdJIa0+fy8Xd3I8m31vLxhL2++ux+AUwbl8OXzRzJ93BDGD8vTUSRyVCp2EQ81tflYvaOBldvrWbGtnrd27qfNFzyscPywPL5z0alcNO4kRhX39zipxBMVu0gf8vkDrNzewCsb9rJie3Cs3B9wpKYY44bmcc3Ukzm7tICzSwdSmJvpdVyJUyp2kV7W5vPzr6p9vLBuDy9vqKG+uZ2MtBQmDs/n3y44hbNLC5h08kByM/VylOjQb5JIL3DO8fqWOhas3MnSjTU0tfnon5nGJ08rZvq4wZx/6iAdcii9Rr9ZIlEUCDhe2VjDb5dWsWbnfgpyMvjsGUO4aNxgPnpKkQ4/lD6hYheJAp8/wF/f3s3vlm5h096DDC/ox08vH88Vk0rISk/1Op4kGRW7yAlo9wV4ZnU19766hR11LYwqzuUXV53BZWcMJS1Ve+fiDRW7yHF4t66Fxet284fXt7O7sZXxw/K49+pJXDRusE4UEs+p2EUi4Jxj894mXlgXPAu0cvcBACaXFvCzKybw8dFFOmlIYoaKXeQo2nx+5v69ir+s3c22fc2YwVkjBvKDz5zGReMGM7wg2+uIIh+iYhfpRmuHn1seXcWrm2s5b3QRs88r41NjT6K4f5bX0USOSsUu0oUddc38x5/WsnxbPT/7f6czc/IIryOJREzFLhJm056D3PuPKhateY+01BR+edWZfG7iMK9jiRwTFbsIsLvxED/883perNxLdkYqs88byexzyyjO07CLxB8VuyS9HXXNfPGB5exvaedrF47mho+WMlC3iJM4pmKXpFZV08TV85fR5guwYM45nF4ywOtIIidMxS5Jq/K9A1z74HLMjCfnnMOpg3XNc0kMKnZJOs45XtlQw7eeWkN2RiqPz57CyEG5XscSiRoVuyQNf8Dxwro9zF1aReXuA4wsyuEPN07WSUaScFTskhRe2bCXny7ewNbaZkYW5XDPlRP43MRhpOtCXZKAVOyS8J5eVc13n17DqOJc5n5xEtPHDyZVF+qSBKZil4T22LId/OC5dZw7qoh5XzpLdy2SpKDfcklYD7y2lZ8u3sC004r57Rcn6YYXkjRU7JKQDpf6Z04fwq9mnqmxdEkqKnZJSI8u28Hk0gJ+PfNM3clIkk5Ev/FmNt3MNplZlZnd3sXyEWa21MzeNLO1ZnZJ9KOKRMY5hz/gKBnYT6UuSanH33ozSwXmAhcDY4FZZja202o/ABY65yYCM4HfRTuoSCScc/zshY3s2n+IUSfppCNJTpHszkwGqpxzW51z7cACYEandRyQF3o+AHgvehFFIhMIOH60aD33v7qVa6aO4MsfP8XrSCKeiGSMfRiwM2y6GpjSaZ0fAS+a2b8DOcC0rr6Qmc0B5gCMGKEbF0j0+AOO7z2zloUV1dx8Xhn/eclpugepJK1oDUDOAh52zpUAlwCPmtmHvrZzbp5zrtw5Vz5o0KAofWtJdh3+AN988i0WVlTz9QtHq9Ql6UWyx74LGB42XRKaF+4mYDqAc+4NM8sCioCaaIQU6U6bz8+tT7zJS5V7uf3ij/Dl8zX8IhLJHvtKYLSZlZlZBsE3Rxd1Wudd4EIAMzsNyAJqoxlUpLND7X5ufmQVL1Xu5c4Z41TqIiE97rE753xmdiuwBEgFHnLOrTezO4EK59wi4FvAA2b2TYJvpF7vnHO9GVySW1ObjxsfXsnK7fXcfcUErjp7eM+fJJIkIjpByTm3GFjcad4dYc8rgY9FN5pI19p9Aa6Zv5y3dzXyqy+cyYwzdbNpkXA6e0PizsY9B3hr537uuHSsSl2kCyp2iTsd/uAo3/CCfh4nEYlNKnaJK3VNbfzXc+vISE2htDDH6zgiMUkXAZO4sfdAK1fPX87O+hYeuK5c9ykV6YaKXeJCdUMLV89fzr6DbfzhxslMHVnodSSRmKVil5hXc6CVq+57g6Y2H4/NnsLEEQO9jiQS01TsEvNee2cf7zW2smDOVJW6SAT05qnEPJ8/AMCwfB0FIxIJFbvEtE17DvLzFzdTlJtBUW6m13FE4oKKXWLWul2NzJz3BikGC+ZMpV+GbkYtEgkVu8SkVTsamPXAMrIz0lh4yzmMKu7vdSSRuKE3TyXmrNm5n2sfXE5x/0wev3mqxtZFjpGKXWLOX9/eTYc/wMJbzqE4L8vrOCJxR0MxEnM6/AHSU1NU6iLHScUuMWXpxhqeWP4upw3J63llEemSil1ixgvrdjPn0QpGn5TLA18q9zqOSNzSGLvEhGffrObbT61lQskAHr5hMgP6pXsdSSRuqdjFc4vWvMdtC9cwtayQ+deVk5OpX0uRE6FXkHjumdXVDB+Yze9vOJusdJ2EJHKiNMYunuvwBxiYk6FSF4kSFbt46qF/buNfVXWM1VEwIlGjYhfP/OH17dz5fCUXjx/Mjy8b53UckYShMXbxzFOrdjKhZAD/O2siaanaxxCJFr2axDM+v2NQbqZKXSTK9IqSPuec4+4XNrJxz0HGDdXYuki0aShG+pRzjjufr+T3/9rOrMkj+Ma0MV5HEkk4KnbpM/6A4/vPvs2ClTu58WNl/Nelp2FmXscSSTgqdukzP3guWOq
|
||
|
"text/plain": [
|
||
|
"<Figure size 432x288 with 1 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {
|
||
|
"needs_background": "light"
|
||
|
},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"x_len = np.asarray(list(map(len, f_data)))\n",
|
||
|
"l = []\n",
|
||
|
"sq_xlen = pd.Series(x_len)\n",
|
||
|
"ptiles = [x*0.01 for x in range(100)]\n",
|
||
|
"for i in ptiles:\n",
|
||
|
" l.append(sq_xlen.quantile(i))\n",
|
||
|
"plt.plot(l, ptiles)\n",
|
||
|
"sq_xlen.describe(percentiles=[x*0.01 for x in range(95,100)])"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 8,
|
||
|
"id": "1878d067",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"thresh_p = 0.99\n",
|
||
|
"thresh = int(sq_xlen.quantile(thresh_p))\n",
|
||
|
"len_mask = np.where(x_len <= thresh)\n",
|
||
|
"\n",
|
||
|
"x_filter = f_data[len_mask]\n",
|
||
|
"y_filter = y[len_mask]\n",
|
||
|
"\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 9,
|
||
|
"id": "3a01c1ad",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"lb = LabelBinarizer()\n",
|
||
|
"a = [x.drop(labels='Millis', axis=1) for x in x_filter]\n",
|
||
|
"x_filter = pad_sequences(x_filter, dtype=float, padding='post')\n",
|
||
|
"yt_filter = lb.fit_transform(y_filter)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 10,
|
||
|
"id": "634a024c",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"CPU times: user 34.7 ms, sys: 5.84 ms, total: 40.6 ms\n",
|
||
|
"Wall time: 39.2 ms\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"%%time\n",
|
||
|
"x_train, x_test, y_train, y_test = train_test_split(x_filter, yt_filter, test_size=0.2, random_state=177013)\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 11,
|
||
|
"id": "0109b9b6",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Model: \"sequential\"\n",
|
||
|
"_________________________________________________________________\n",
|
||
|
"Layer (type) Output Shape Param # \n",
|
||
|
"=================================================================\n",
|
||
|
"flatten (Flatten) (None, 2340) 0 \n",
|
||
|
"_________________________________________________________________\n",
|
||
|
"batch_normalization (BatchNo (None, 2340) 9360 \n",
|
||
|
"_________________________________________________________________\n",
|
||
|
"dense (Dense) (None, 2200) 5150200 \n",
|
||
|
"_________________________________________________________________\n",
|
||
|
"dense_1 (Dense) (None, 1100) 2421100 \n",
|
||
|
"_________________________________________________________________\n",
|
||
|
"dense_2 (Dense) (None, 550) 605550 \n",
|
||
|
"_________________________________________________________________\n",
|
||
|
"dense_3 (Dense) (None, 225) 123975 \n",
|
||
|
"_________________________________________________________________\n",
|
||
|
"dense_4 (Dense) (None, 26) 5876 \n",
|
||
|
"=================================================================\n",
|
||
|
"Total params: 8,316,061\n",
|
||
|
"Trainable params: 8,311,381\n",
|
||
|
"Non-trainable params: 4,680\n",
|
||
|
"_________________________________________________________________\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"model = Sequential()\n",
|
||
|
"\n",
|
||
|
"model.add(Flatten(input_shape=x_filter[0].shape))\n",
|
||
|
"\n",
|
||
|
"model.add(BatchNormalization())\n",
|
||
|
"\n",
|
||
|
"model.add(Dense(2200, activation='relu'))\n",
|
||
|
"\n",
|
||
|
"model.add(Dense(1100, activation='relu'))\n",
|
||
|
"\n",
|
||
|
"model.add(Dense(550, activation='relu'))\n",
|
||
|
"\n",
|
||
|
"model.add(Dense(225, activation='relu'))\n",
|
||
|
"\n",
|
||
|
"model.add(Dense(26, activation='softmax'))\n",
|
||
|
"\n",
|
||
|
"model.compile(\n",
|
||
|
" optimizer=tf.keras.optimizers.Adam(0.001),\n",
|
||
|
" loss=\"categorical_crossentropy\", \n",
|
||
|
" metrics=[\"acc\"],\n",
|
||
|
")\n",
|
||
|
"\n",
|
||
|
"model.summary()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 12,
|
||
|
"id": "204ed561",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Epoch 1/32\n",
|
||
|
"82/82 [==============================] - 1s 3ms/step - loss: 2.8553 - acc: 0.1745\n",
|
||
|
"Epoch 2/32\n",
|
||
|
"82/82 [==============================] - 0s 3ms/step - loss: 1.7793 - acc: 0.4480\n",
|
||
|
"Epoch 3/32\n",
|
||
|
"82/82 [==============================] - 0s 3ms/step - loss: 1.2391 - acc: 0.6070\n",
|
||
|
"Epoch 4/32\n",
|
||
|
"82/82 [==============================] - 0s 3ms/step - loss: 0.9623 - acc: 0.7021\n",
|
||
|
"Epoch 5/32\n",
|
||
|
"82/82 [==============================] - 0s 3ms/step - loss: 0.8489 - acc: 0.7336\n",
|
||
|
"Epoch 6/32\n",
|
||
|
"82/82 [==============================] - 0s 3ms/step - loss: 0.5827 - acc: 0.8169\n",
|
||
|
"Epoch 7/32\n",
|
||
|
"82/82 [==============================] - 0s 3ms/step - loss: 0.5208 - acc: 0.8313\n",
|
||
|
"Epoch 8/32\n",
|
||
|
"82/82 [==============================] - 0s 3ms/step - loss: 0.5864 - acc: 0.8147\n",
|
||
|
"Epoch 9/32\n",
|
||
|
"82/82 [==============================] - 0s 3ms/step - loss: 0.4101 - acc: 0.8710\n",
|
||
|
"Epoch 10/32\n",
|
||
|
"82/82 [==============================] - 0s 3ms/step - loss: 0.2856 - acc: 0.9087\n",
|
||
|
"Epoch 11/32\n",
|
||
|
"82/82 [==============================] - 0s 3ms/step - loss: 0.2789 - acc: 0.9126\n",
|
||
|
"Epoch 12/32\n",
|
||
|
"82/82 [==============================] - 0s 3ms/step - loss: 0.3118 - acc: 0.9027\n",
|
||
|
"Epoch 13/32\n",
|
||
|
"82/82 [==============================] - 0s 3ms/step - loss: 0.3337 - acc: 0.9054\n",
|
||
|
"Epoch 14/32\n",
|
||
|
"82/82 [==============================] - 0s 3ms/step - loss: 0.3052 - acc: 0.9049\n",
|
||
|
"Epoch 15/32\n",
|
||
|
"82/82 [==============================] - 0s 3ms/step - loss: 0.2052 - acc: 0.9403\n",
|
||
|
"Epoch 16/32\n",
|
||
|
"82/82 [==============================] - 0s 3ms/step - loss: 0.4292 - acc: 0.8907\n",
|
||
|
"Epoch 17/32\n",
|
||
|
"82/82 [==============================] - 0s 3ms/step - loss: 0.1545 - acc: 0.9542\n",
|
||
|
"Epoch 18/32\n",
|
||
|
"82/82 [==============================] - 0s 3ms/step - loss: 0.1401 - acc: 0.9575\n",
|
||
|
"Epoch 19/32\n",
|
||
|
"82/82 [==============================] - 0s 3ms/step - loss: 0.1907 - acc: 0.9483\n",
|
||
|
"Epoch 20/32\n",
|
||
|
"82/82 [==============================] - 0s 3ms/step - loss: 0.2635 - acc: 0.9303\n",
|
||
|
"Epoch 21/32\n",
|
||
|
"82/82 [==============================] - 0s 3ms/step - loss: 0.1116 - acc: 0.9671\n",
|
||
|
"Epoch 22/32\n",
|
||
|
"82/82 [==============================] - 0s 3ms/step - loss: 0.2453 - acc: 0.9317\n",
|
||
|
"Epoch 23/32\n",
|
||
|
"82/82 [==============================] - 0s 3ms/step - loss: 0.1090 - acc: 0.9681\n",
|
||
|
"Epoch 24/32\n",
|
||
|
"82/82 [==============================] - 0s 3ms/step - loss: 0.1578 - acc: 0.9541\n",
|
||
|
"Epoch 25/32\n",
|
||
|
"82/82 [==============================] - 0s 3ms/step - loss: 0.1609 - acc: 0.9570\n",
|
||
|
"Epoch 26/32\n",
|
||
|
"82/82 [==============================] - 0s 3ms/step - loss: 0.0801 - acc: 0.9775\n",
|
||
|
"Epoch 27/32\n",
|
||
|
"82/82 [==============================] - 0s 3ms/step - loss: 0.1597 - acc: 0.9615\n",
|
||
|
"Epoch 28/32\n",
|
||
|
"82/82 [==============================] - 0s 3ms/step - loss: 0.0695 - acc: 0.9807\n",
|
||
|
"Epoch 29/32\n",
|
||
|
"82/82 [==============================] - 0s 3ms/step - loss: 0.0622 - acc: 0.9853\n",
|
||
|
"Epoch 30/32\n",
|
||
|
"82/82 [==============================] - 0s 3ms/step - loss: 0.0655 - acc: 0.9841\n",
|
||
|
"Epoch 31/32\n",
|
||
|
"82/82 [==============================] - 0s 3ms/step - loss: 0.0383 - acc: 0.9910\n",
|
||
|
"Epoch 32/32\n",
|
||
|
"82/82 [==============================] - 0s 3ms/step - loss: 0.0716 - acc: 0.9792\n",
|
||
|
"CPU times: user 14 s, sys: 3.02 s, total: 17 s\n",
|
||
|
"Wall time: 8.95 s\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"<tensorflow.python.keras.callbacks.History at 0x7f6dec4bf130>"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 12,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"%%time\n",
|
||
|
"model.fit(x_train, y_train, \n",
|
||
|
" epochs=32,\n",
|
||
|
" batch_size=128,\n",
|
||
|
" shuffle=True,\n",
|
||
|
" verbose=1\n",
|
||
|
" )"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 13,
|
||
|
"id": "10a0d074",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Evaluate on test data\n",
|
||
|
"82/82 [==============================] - 0s 2ms/step - loss: 1.7331 - acc: 0.7341\n",
|
||
|
"test loss, test acc: [1.7330855131149292, 0.7341040372848511]\n",
|
||
|
"Generate predictions for 3 samples\n",
|
||
|
"predictions shape: (3, 26)\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"(array(['N', 'U', 'I'], dtype='<U1'), array(['N', 'U', 'I'], dtype='<U1'))"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 13,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# Evaluate the model on the test data using `evaluate`\n",
|
||
|
"print(\"Evaluate on test data\")\n",
|
||
|
"results = model.evaluate(x_test, y_test, batch_size=32)\n",
|
||
|
"print(\"test loss, test acc:\", results)\n",
|
||
|
"\n",
|
||
|
"# Generate predictions (probabilities -- the output of the last layer)\n",
|
||
|
"# on new data using `predict`\n",
|
||
|
"print(\"Generate predictions for 3 samples\")\n",
|
||
|
"predictions = model.predict(x_test[:3])\n",
|
||
|
"print(\"predictions shape:\", predictions.shape)\n",
|
||
|
"\n",
|
||
|
"lb.inverse_transform(y_test[:3]), lb.inverse_transform(predictions)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "63f89de5",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": []
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "1a5d0352",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": []
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "4ebc323f",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": []
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "445a8e54",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": []
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "32206e00",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": []
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "fe00f947",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": []
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "d0fbe763",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": []
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "d18608c3",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": []
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 14,
|
||
|
"id": "6a0e538b",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"exit()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "521b2be6",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": []
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.8.5"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 5
|
||
|
}
|