iui-group-l-name-zensiert/1-first-project/Abgabe.ipynb

766 lines
64 KiB
Plaintext
Raw Normal View History

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
2021-08-06 20:20:52 +02:00
"id": "dab4afe8",
"metadata": {},
"outputs": [],
"source": [
2021-08-06 20:20:52 +02:00
"import os\n",
"\n",
"glob_path = '/opt/iui-datarelease2-sose2021/*/split_letters_csv/*'\n",
"\n",
2021-08-06 20:20:52 +02:00
"pickle_file = 'data.pickle'\n",
"\n",
"checkpoint_path = \"training_1/cp.ckpt\"\n",
"checkpoint_dir = os.path.dirname(checkpoint_path)\n",
"\n",
"# divisor for neuron count step downs (hard to describe), e.g. dense_step = 3: layer1=900, layer2 = 300, layer3 = 100, layer4 = 33...\n",
"dense_steps = 2\n",
"# amount of dense/dropout layers\n",
"layer_count = 3\n",
"# how much to drop\n",
"drop_count = 0.1"
]
},
{
"cell_type": "code",
"execution_count": 2,
2021-08-06 20:20:52 +02:00
"id": "fb114dda",
"metadata": {},
"outputs": [],
"source": [
"from glob import glob\n",
"import pandas as pd\n",
"from tqdm import tqdm\n",
"\n",
"def dl_from_blob(filename) -> list:\n",
" all_data = []\n",
" \n",
" for path in tqdm(glob(filename)):\n",
" path = path\n",
" df = pd.read_csv(path, ';')\n",
" u = path.split('/')[3]\n",
" l = ''.join(filter(lambda x: x.isalpha(), path.split('/')[5]))[0] \n",
" d = {\n",
" 'file': path,\n",
" 'data': df,\n",
" 'user': u,\n",
" 'label': l\n",
" }\n",
" all_data.append(d)\n",
" return all_data"
]
},
{
"cell_type": "code",
"execution_count": 3,
2021-08-06 20:20:52 +02:00
"id": "6b23cd45",
"metadata": {},
"outputs": [],
"source": [
"def save_pickle(f, structure):\n",
" _p = open(f, 'wb')\n",
" pickle.dump(structure, _p)\n",
" _p.close()"
]
},
{
"cell_type": "code",
"execution_count": 4,
2021-08-06 20:20:52 +02:00
"id": "e3cdaf0c",
"metadata": {},
"outputs": [],
"source": [
"import pickle\n",
"\n",
"def load_pickles(f) -> list:\n",
" _p = open(pickle_file, 'rb')\n",
" _d = pickle.load(_p)\n",
" _p.close()\n",
" \n",
" return _d"
]
},
{
"cell_type": "code",
"execution_count": 5,
2021-08-06 20:20:52 +02:00
"id": "c1b8c553",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading data...\n",
"data.pickle found...\n"
]
}
],
"source": [
"import os\n",
"def load_data() -> list:\n",
" if os.path.isfile(pickle_file):\n",
" print(f'{pickle_file} found...')\n",
" return load_pickles(pickle_file)\n",
" print(f'Didn\\'t find {pickle_file}...')\n",
" all_data = dl_from_blob(glob_path)\n",
" print(f'Creating {pickle_file}...')\n",
" save_pickle(pickle_file, all_data)\n",
" return all_data\n",
"\n",
"print(\"Loading data...\")\n",
"data = load_data()\n",
"# plot_pd(data[0]['data'], False)"
]
},
{
"cell_type": "code",
"execution_count": 6,
2021-08-06 20:20:52 +02:00
"id": "7d8249b6",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"def plot_pd(data, force=True):\n",
" fig, axs = plt.subplots(5, 3, figsize=(3*3, 3*5))\n",
" axs[0][0].plot(data['Acc1 X'])\n",
" axs[0][1].plot(data['Acc1 Y'])\n",
" axs[0][2].plot(data['Acc1 Z'])\n",
" axs[1][0].plot(data['Acc2 X'])\n",
" axs[1][1].plot(data['Acc2 Y'])\n",
" axs[1][2].plot(data['Acc2 Z'])\n",
" axs[2][0].plot(data['Gyro X'])\n",
" axs[2][1].plot(data['Gyro Y'])\n",
" axs[2][2].plot(data['Gyro Z'])\n",
" axs[3][0].plot(data['Mag X'])\n",
" axs[3][1].plot(data['Mag Y'])\n",
" axs[3][2].plot(data['Mag Z'])\n",
" axs[4][0].plot(data['Time'])\n",
"\n",
" if force:\n",
" for a in axs:\n",
" for b in a:\n",
" b.plot(data['Force'])\n",
" else:\n",
" axs[4][1].plot(data['Force'])\n",
"\n",
"def plot_np(data, force=True):\n",
" fig, axs = plt.subplots(5, 3, figsize=(3*3, 3*5))\n",
" axs[0][0].plot(data[0])\n",
" axs[0][1].plot(data[1])\n",
" axs[0][2].plot(data[2])\n",
" axs[1][0].plot(data[3])\n",
" axs[1][1].plot(data[4])\n",
" axs[1][2].plot(data[5])\n",
" axs[2][0].plot(data[6])\n",
" axs[2][1].plot(data[7])\n",
" axs[2][2].plot(data[8])\n",
" axs[3][0].plot(data[9])\n",
" axs[3][1].plot(data[10])\n",
" axs[3][2].plot(data[11])\n",
" axs[4][0].plot(data[13])\n",
"\n",
" if force:\n",
" for a in axs:\n",
" for b in a:\n",
" b.plot(data[12])\n",
" else:\n",
" axs[4][1].plot(data[12])\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
2021-08-06 20:20:52 +02:00
"id": "6e7b9412",
"metadata": {},
"outputs": [],
"source": [
"def mill_drop(entry):\n",
" #drop millis on single\n",
" data_wo_mill = entry['data'].drop(labels='Millis', axis=1, inplace=False)\n",
" drop_entry = entry\n",
" drop_entry['data'] = data_wo_mill.reset_index(drop=True)\n",
" \n",
" return drop_entry"
]
},
{
"cell_type": "code",
"execution_count": 8,
2021-08-06 20:20:52 +02:00
"id": "52de5447",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"def cut_force(drop_entry):\n",
" # force trans\n",
" shorten_entry = drop_entry\n",
" shorten_data = shorten_entry['data']\n",
" sf_entry = shorten_data['Force']\n",
" leeway = 10\n",
" \n",
" try:\n",
" thresh = 70\n",
" temps_over_T = np.where(sf_entry > thresh)[0]\n",
" shorten_data = shorten_data[max(temps_over_T.min()-leeway,0):min(len(sf_entry)-1,temps_over_T.max()+leeway)]\n",
" except:\n",
" thresold = 0.05\n",
" thresh = sf_entry.max()*thresold\n",
" temps_over_T = np.where(sf_entry > thresh)[0]\n",
" shorten_data = shorten_data[max(temps_over_T.min()-leeway,0):min(len(sf_entry)-1,temps_over_T.max()+leeway)]\n",
" \n",
" shorten_entry['data'] = shorten_data.reset_index(drop=True)\n",
" return shorten_entry"
]
},
{
"cell_type": "code",
"execution_count": 9,
2021-08-06 20:20:52 +02:00
"id": "e962f5cf",
"metadata": {},
"outputs": [],
"source": [
"def norm_force(shorten_entry, flist):\n",
" fnorm_entry = shorten_entry\n",
" u = fnorm_entry['user']\n",
" d = fnorm_entry['data']\n",
" \n",
" \n",
" d['Force'] = ((d['Force'] - flist[u].mean())/flist[u].std())\n",
" \n",
" fnorm_entry['data'] = fnorm_entry['data'].reset_index(drop=True)\n",
" return fnorm_entry"
]
},
{
"cell_type": "code",
"execution_count": 10,
2021-08-06 20:20:52 +02:00
"id": "56056f60",
"metadata": {},
"outputs": [],
"source": [
"def time_trans(fnorm_entry):\n",
" #timetrans\n",
" time_entry = fnorm_entry\n",
" \n",
" time_entry['data']['Time'] = fnorm_entry['data']['Time']-fnorm_entry['data']['Time'][0]\n",
" \n",
" time_entry['data'] = time_entry['data'].reset_index(drop=True)\n",
"\n",
" return time_entry"
]
},
{
"cell_type": "code",
"execution_count": 11,
2021-08-06 20:20:52 +02:00
"id": "2a353441",
"metadata": {},
"outputs": [],
"source": [
"def norm(time_entry):\n",
" # normalize\n",
" norm_entry = time_entry\n",
" \n",
" norm_entry['data']['Acc1 X'] = norm_entry['data']['Acc1 X'] / 32768\n",
" norm_entry['data']['Acc1 Y'] = norm_entry['data']['Acc1 Y'] / 32768\n",
" norm_entry['data']['Acc1 Z'] = norm_entry['data']['Acc1 Z'] / 32768\n",
" norm_entry['data']['Acc2 X'] = norm_entry['data']['Acc2 X'] / 8192\n",
" norm_entry['data']['Acc2 Y'] = norm_entry['data']['Acc2 Y'] / 8192\n",
" norm_entry['data']['Acc2 Z'] = norm_entry['data']['Acc2 Z'] / 8192\n",
" norm_entry['data']['Gyro X'] = norm_entry['data']['Gyro X'] / 32768\n",
" norm_entry['data']['Gyro Y'] = norm_entry['data']['Gyro Y'] / 32768\n",
" norm_entry['data']['Gyro Z'] = norm_entry['data']['Gyro Z'] / 32768\n",
" norm_entry['data']['Mag X'] = norm_entry['data']['Mag X'] / 8192\n",
" norm_entry['data']['Mag Y'] = norm_entry['data']['Mag Y'] / 8192\n",
" norm_entry['data']['Mag Z'] = norm_entry['data']['Mag Z'] / 8192\n",
" \n",
" norm_entry['data'] = norm_entry['data'].reset_index(drop=True)\n",
" \n",
" return norm_entry"
]
},
{
"cell_type": "code",
"execution_count": 12,
2021-08-06 20:20:52 +02:00
"id": "ee11dd88",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Preprocessing...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2021-08-06 20:20:52 +02:00
"100%|██████████| 26179/26179 [01:30<00:00, 290.54it/s]\n"
]
}
],
"source": [
"def preproc(d):\n",
" flist = {} \n",
" d_res = []\n",
" for e in data:\n",
" if e['user'] not in flist:\n",
" flist[e['user']] = e['data']['Force']\n",
" else:\n",
" flist[e['user']] = flist[e['user']].append(e['data']['Force'])\n",
" \n",
" for e in tqdm(data):\n",
" d_res.append(preproc_entry(e, flist))\n",
" return d_res\n",
" \n",
"def preproc_entry(entry, flist):\n",
" drop_entry = mill_drop(entry)\n",
"# plot_pd(drop_entry['data'])\n",
"# \n",
" shorten_entry = cut_force(drop_entry)\n",
"# plot_pd(shorten_entry['data'])\n",
"# \n",
" fnorm_entry = norm_force(shorten_entry, flist)\n",
"# plot_pd(fnorm_entry['data'])\n",
"# \n",
" time_entry = time_trans(shorten_entry)\n",
"# plot_pd(time_entry['data'])\n",
"# \n",
" norm_entry = norm(time_entry)\n",
"# plot_pd(norm_entry['data'], False)\n",
" return norm_entry\n",
"\n",
"print(\"Preprocessing...\")\n",
"pdata = preproc(data)\n",
"# plot_pd(pdata[0]['data'], False)"
]
},
{
"cell_type": "code",
"execution_count": 13,
2021-08-06 20:20:52 +02:00
"id": "d32af185",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Truncating...\n"
]
}
],
"source": [
"def throw(pdata):\n",
" llist = pd.Series([len(x['data']) for x in pdata])\n",
" threshold = int(llist.quantile(threshold_p))\n",
" longdex = np.where(llist <= threshold)[0]\n",
" return np.array(pdata)[longdex]\n",
"\n",
"llist = pd.Series([len(x['data']) for x in pdata])\n",
"threshold_p = 0.75\n",
"threshold = int(llist.quantile(threshold_p))\n",
"\n",
"print(\"Truncating...\")\n",
"tpdata = throw(pdata)\n",
"# plot_pd(tpdata[0]['data'], False)"
]
},
{
"cell_type": "code",
"execution_count": 14,
2021-08-06 20:20:52 +02:00
"id": "317f32d3",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2021-08-06 20:20:52 +02:00
" 9%|▉ | 1822/19640 [00:00<00:00, 18211.98it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Padding...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2021-08-06 20:20:52 +02:00
"100%|██████████| 19640/19640 [00:01<00:00, 18646.81it/s]\n"
]
}
],
"source": [
"from tensorflow.keras.preprocessing.sequence import pad_sequences\n",
"# ltpdata = []\n",
"def elong(tpdata):\n",
" for x in tqdm(tpdata):\n",
" y = x['data'].to_numpy().T\n",
" x['data'] = pad_sequences(y, dtype=float, padding='post', maxlen=threshold)\n",
" return tpdata\n",
"\n",
"print(\"Padding...\")\n",
"ltpdata = elong(tpdata)\n",
"# plot_np(ltpdata[0]['data'], False)"
]
},
{
"cell_type": "code",
"execution_count": 15,
2021-08-06 20:20:52 +02:00
"id": "39e3d5e1",
"metadata": {},
"outputs": [],
"source": [
"import tensorflow as tf\n",
2021-08-06 20:20:52 +02:00
"from tensorflow.keras.regularizers import l2\n",
"from tensorflow.keras.models import Sequential\n",
2021-08-06 20:20:52 +02:00
"from tensorflow.keras.layers import Dense, Flatten, BatchNormalization, Dropout\n",
"from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau\n",
"from tensorflow.keras.optimizers import Adam\n",
"\n",
2021-08-06 20:20:52 +02:00
"def build_model(shape, classes):\n",
" model = Sequential()\n",
" \n",
2021-08-06 20:20:52 +02:00
" ncount = shape[0]*shape[1]\n",
" \n",
2021-08-06 20:20:52 +02:00
" model.add(Flatten(input_shape=shape, name='flatten'))\n",
" \n",
2021-08-06 20:20:52 +02:00
" model.add(Dropout(drop_count, name=f'dropout_{drop_count*100}'))\n",
" model.add(BatchNormalization(name='batchNorm'))\n",
" \n",
2021-08-06 20:20:52 +02:00
" for i in range(1,layer_count+1):\n",
" neurons = int(ncount/pow(dense_steps,i))\n",
" if neurons <= classes:\n",
" break\n",
" model.add(Dropout(drop_count*i, name=f'HiddenDropout_{drop_count*i*100:.0f}'))\n",
" model.add(Dense(neurons, activation='relu', \n",
" kernel_regularizer=l2(0.001), name=f'Hidden_{i}')\n",
" )\n",
" \n",
" model.add(Dense(classes, activation='softmax', name='Output'))\n",
" \n",
" model.compile(\n",
2021-08-06 20:20:52 +02:00
" optimizer=Adam(),\n",
" loss=\"categorical_crossentropy\", \n",
" metrics=[\"acc\"],\n",
" )\n",
2021-08-06 20:20:52 +02:00
" \n",
" return model"
]
},
{
"cell_type": "code",
2021-08-06 20:20:52 +02:00
"execution_count": 16,
"id": "13ba96af",
"metadata": {},
"outputs": [],
"source": [
"checkpoint_file = './goat.weights'\n",
"\n",
"def train(X_train, y_train, X_test, y_test):\n",
2021-08-06 20:20:52 +02:00
" model = build_model(X_train[0].shape, 52)\n",
" \n",
" model.summary()\n",
" \n",
2021-08-06 20:20:52 +02:00
" # Create a callback that saves the model's weights\n",
" model_checkpoint = ModelCheckpoint(filepath=checkpoint_path, monitor='loss', \n",
"\t\t\tsave_best_only=True)\n",
" \n",
" history = model.fit(X_train, y_train, \n",
2021-08-06 20:20:52 +02:00
" epochs=30,\n",
" batch_size=256,\n",
" shuffle=True,\n",
" validation_data=(X_test, y_test),\n",
" verbose=2,\n",
" callbacks=[model_checkpoint]\n",
" )\n",
" \n",
2021-08-06 20:20:52 +02:00
" \n",
" model.load_weights(checkpoint_path)\n",
" print(\"Evaluate on test data\")\n",
" return model, history"
]
},
{
"cell_type": "code",
2021-08-06 20:20:52 +02:00
"execution_count": 17,
"id": "519579cf",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' # this is required\n",
"os.environ['CUDA_VISIBLE_DEVICES'] = '0' # set to '0' for GPU0, '1' for GPU1 or '2' for GPU2. Check \"gpustat\" in a terminal."
]
},
{
"cell_type": "code",
2021-08-06 20:20:52 +02:00
"execution_count": 18,
"id": "3bf061f4",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.model_selection import train_test_split\n",
"from sklearn.preprocessing import LabelEncoder, LabelBinarizer\n",
"\n",
"X = np.array([x['data'] for x in ltpdata])\n",
"y = np.array([x['label'] for x in ltpdata])\n",
"\n",
"lb = LabelBinarizer()\n",
"y_tran = lb.fit_transform(y)\n",
"\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y_tran, test_size=0.2, random_state=177013)\n",
"\n",
2021-08-06 20:20:52 +02:00
"X_train=X_train.reshape(X_train.shape[0],X_train.shape[1],X_train.shape[2])\n",
"X_test=X_test.reshape(X_test.shape[0],X_test.shape[1],X_test.shape[2])\n",
"\n",
"train_shape = X_train[0].shape\n",
"classes = y_train[0].shape[0]"
]
},
{
"cell_type": "code",
2021-08-06 20:20:52 +02:00
"execution_count": 19,
"id": "a8dffb69",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2021-08-06 20:20:52 +02:00
"Model: \"sequential\"\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
2021-08-06 20:20:52 +02:00
"flatten (Flatten) (None, 1050) 0 \n",
"_________________________________________________________________\n",
2021-08-06 20:20:52 +02:00
"dropout_10.0 (Dropout) (None, 1050) 0 \n",
"_________________________________________________________________\n",
2021-08-06 20:20:52 +02:00
"batchNorm (BatchNormalizatio (None, 1050) 4200 \n",
"_________________________________________________________________\n",
2021-08-06 20:20:52 +02:00
"HiddenDropout_10 (Dropout) (None, 1050) 0 \n",
"_________________________________________________________________\n",
2021-08-06 20:20:52 +02:00
"Hidden_1 (Dense) (None, 525) 551775 \n",
"_________________________________________________________________\n",
2021-08-06 20:20:52 +02:00
"HiddenDropout_20 (Dropout) (None, 525) 0 \n",
"_________________________________________________________________\n",
2021-08-06 20:20:52 +02:00
"Hidden_2 (Dense) (None, 262) 137812 \n",
"_________________________________________________________________\n",
2021-08-06 20:20:52 +02:00
"HiddenDropout_30 (Dropout) (None, 262) 0 \n",
"_________________________________________________________________\n",
2021-08-06 20:20:52 +02:00
"Hidden_3 (Dense) (None, 131) 34453 \n",
"_________________________________________________________________\n",
2021-08-06 20:20:52 +02:00
"Output (Dense) (None, 52) 6864 \n",
"=================================================================\n",
2021-08-06 20:20:52 +02:00
"Total params: 735,104\n",
"Trainable params: 733,004\n",
"Non-trainable params: 2,100\n",
"_________________________________________________________________\n",
"Epoch 1/30\n",
2021-08-06 20:20:52 +02:00
"62/62 - 2s - loss: 4.8415 - acc: 0.0671 - val_loss: 4.7052 - val_acc: 0.0723\n",
"INFO:tensorflow:Assets written to: training_1/cp.ckpt/assets\n",
"Epoch 2/30\n",
2021-08-06 20:20:52 +02:00
"62/62 - 0s - loss: 4.0441 - acc: 0.1667 - val_loss: 4.2215 - val_acc: 0.1181\n",
"INFO:tensorflow:Assets written to: training_1/cp.ckpt/assets\n",
"Epoch 3/30\n",
2021-08-06 20:20:52 +02:00
"62/62 - 0s - loss: 3.4989 - acc: 0.2526 - val_loss: 3.8433 - val_acc: 0.1752\n",
"INFO:tensorflow:Assets written to: training_1/cp.ckpt/assets\n",
"Epoch 4/30\n",
2021-08-06 20:20:52 +02:00
"62/62 - 0s - loss: 3.1289 - acc: 0.3177 - val_loss: 3.5399 - val_acc: 0.2210\n",
"INFO:tensorflow:Assets written to: training_1/cp.ckpt/assets\n",
"Epoch 5/30\n",
2021-08-06 20:20:52 +02:00
"62/62 - 0s - loss: 2.8521 - acc: 0.3696 - val_loss: 3.3127 - val_acc: 0.2620\n",
"INFO:tensorflow:Assets written to: training_1/cp.ckpt/assets\n",
"Epoch 6/30\n",
2021-08-06 20:20:52 +02:00
"62/62 - 0s - loss: 2.6560 - acc: 0.4068 - val_loss: 3.0547 - val_acc: 0.3236\n",
"INFO:tensorflow:Assets written to: training_1/cp.ckpt/assets\n",
"Epoch 7/30\n",
2021-08-06 20:20:52 +02:00
"62/62 - 0s - loss: 2.5031 - acc: 0.4395 - val_loss: 2.7836 - val_acc: 0.3933\n",
"INFO:tensorflow:Assets written to: training_1/cp.ckpt/assets\n",
"Epoch 8/30\n",
2021-08-06 20:20:52 +02:00
"62/62 - 0s - loss: 2.3725 - acc: 0.4638 - val_loss: 2.5959 - val_acc: 0.4104\n",
"INFO:tensorflow:Assets written to: training_1/cp.ckpt/assets\n",
"Epoch 9/30\n",
2021-08-06 20:20:52 +02:00
"62/62 - 0s - loss: 2.2528 - acc: 0.4946 - val_loss: 2.3562 - val_acc: 0.4705\n",
"INFO:tensorflow:Assets written to: training_1/cp.ckpt/assets\n",
"Epoch 10/30\n",
2021-08-06 20:20:52 +02:00
"62/62 - 0s - loss: 2.1670 - acc: 0.5070 - val_loss: 2.1431 - val_acc: 0.5323\n",
"INFO:tensorflow:Assets written to: training_1/cp.ckpt/assets\n",
"Epoch 11/30\n",
2021-08-06 20:20:52 +02:00
"62/62 - 0s - loss: 2.0921 - acc: 0.5232 - val_loss: 2.0376 - val_acc: 0.5624\n",
"INFO:tensorflow:Assets written to: training_1/cp.ckpt/assets\n",
"Epoch 12/30\n",
2021-08-06 20:20:52 +02:00
"62/62 - 0s - loss: 2.0189 - acc: 0.5411 - val_loss: 1.9824 - val_acc: 0.5644\n",
"INFO:tensorflow:Assets written to: training_1/cp.ckpt/assets\n",
"Epoch 13/30\n",
2021-08-06 20:20:52 +02:00
"62/62 - 0s - loss: 1.9840 - acc: 0.5495 - val_loss: 1.8983 - val_acc: 0.5802\n",
"INFO:tensorflow:Assets written to: training_1/cp.ckpt/assets\n",
"Epoch 14/30\n",
2021-08-06 20:20:52 +02:00
"62/62 - 0s - loss: 1.9388 - acc: 0.5557 - val_loss: 1.8335 - val_acc: 0.6003\n",
"INFO:tensorflow:Assets written to: training_1/cp.ckpt/assets\n",
"Epoch 15/30\n",
2021-08-06 20:20:52 +02:00
"62/62 - 0s - loss: 1.8664 - acc: 0.5737 - val_loss: 1.8275 - val_acc: 0.5909\n",
"INFO:tensorflow:Assets written to: training_1/cp.ckpt/assets\n",
"Epoch 16/30\n",
2021-08-06 20:20:52 +02:00
"62/62 - 0s - loss: 1.8345 - acc: 0.5861 - val_loss: 1.7849 - val_acc: 0.5942\n",
"INFO:tensorflow:Assets written to: training_1/cp.ckpt/assets\n",
"Epoch 17/30\n",
2021-08-06 20:20:52 +02:00
"62/62 - 0s - loss: 1.8048 - acc: 0.5876 - val_loss: 1.7853 - val_acc: 0.5950\n",
"INFO:tensorflow:Assets written to: training_1/cp.ckpt/assets\n",
"Epoch 18/30\n",
2021-08-06 20:20:52 +02:00
"62/62 - 0s - loss: 1.7855 - acc: 0.5971 - val_loss: 1.7390 - val_acc: 0.6227\n",
"INFO:tensorflow:Assets written to: training_1/cp.ckpt/assets\n",
"Epoch 19/30\n",
2021-08-06 20:20:52 +02:00
"62/62 - 0s - loss: 1.7337 - acc: 0.6060 - val_loss: 1.7086 - val_acc: 0.6225\n",
"INFO:tensorflow:Assets written to: training_1/cp.ckpt/assets\n",
"Epoch 20/30\n",
2021-08-06 20:20:52 +02:00
"62/62 - 0s - loss: 1.7177 - acc: 0.6057 - val_loss: 1.6970 - val_acc: 0.6303\n",
"INFO:tensorflow:Assets written to: training_1/cp.ckpt/assets\n",
"Epoch 21/30\n",
2021-08-06 20:20:52 +02:00
"62/62 - 0s - loss: 1.6984 - acc: 0.6166 - val_loss: 1.7008 - val_acc: 0.6151\n",
"INFO:tensorflow:Assets written to: training_1/cp.ckpt/assets\n",
"Epoch 22/30\n",
2021-08-06 20:20:52 +02:00
"62/62 - 0s - loss: 1.6698 - acc: 0.6223 - val_loss: 1.6824 - val_acc: 0.6171\n",
"INFO:tensorflow:Assets written to: training_1/cp.ckpt/assets\n",
"Epoch 23/30\n",
2021-08-06 20:20:52 +02:00
"62/62 - 0s - loss: 1.6533 - acc: 0.6265 - val_loss: 1.6574 - val_acc: 0.6329\n",
"INFO:tensorflow:Assets written to: training_1/cp.ckpt/assets\n",
"Epoch 24/30\n",
2021-08-06 20:20:52 +02:00
"62/62 - 0s - loss: 1.6341 - acc: 0.6373 - val_loss: 1.6485 - val_acc: 0.6283\n",
"INFO:tensorflow:Assets written to: training_1/cp.ckpt/assets\n",
"Epoch 25/30\n",
2021-08-06 20:20:52 +02:00
"62/62 - 0s - loss: 1.6232 - acc: 0.6366 - val_loss: 1.6606 - val_acc: 0.6334\n",
"INFO:tensorflow:Assets written to: training_1/cp.ckpt/assets\n",
"Epoch 26/30\n",
2021-08-06 20:20:52 +02:00
"62/62 - 0s - loss: 1.6098 - acc: 0.6337 - val_loss: 1.6851 - val_acc: 0.6225\n",
"INFO:tensorflow:Assets written to: training_1/cp.ckpt/assets\n",
"Epoch 27/30\n",
2021-08-06 20:20:52 +02:00
"62/62 - 0s - loss: 1.5843 - acc: 0.6459 - val_loss: 1.6324 - val_acc: 0.6377\n",
"INFO:tensorflow:Assets written to: training_1/cp.ckpt/assets\n",
"Epoch 28/30\n",
2021-08-06 20:20:52 +02:00
"62/62 - 0s - loss: 1.5674 - acc: 0.6538 - val_loss: 1.6377 - val_acc: 0.6286\n",
"INFO:tensorflow:Assets written to: training_1/cp.ckpt/assets\n",
"Epoch 29/30\n",
2021-08-06 20:20:52 +02:00
"62/62 - 0s - loss: 1.5684 - acc: 0.6464 - val_loss: 1.6092 - val_acc: 0.6459\n",
"Epoch 30/30\n",
2021-08-06 20:20:52 +02:00
"62/62 - 0s - loss: 1.5520 - acc: 0.6559 - val_loss: 1.6159 - val_acc: 0.6433\n",
"INFO:tensorflow:Assets written to: training_1/cp.ckpt/assets\n",
"Evaluate on test data\n",
2021-08-06 20:20:52 +02:00
"CPU times: user 40.5 s, sys: 4.35 s, total: 44.9 s\n",
"Wall time: 37.2 s\n"
]
}
],
"source": [
2021-08-06 20:20:52 +02:00
"%%time\n",
"if 'model' not in locals():\n",
" tf.keras.backend.clear_session()\n",
" model, history = train(np.array(X_train), np.array(y_train), np.array(X_test), np.array(y_test))\n",
"else:\n",
" print(\"Loaded weights...\")\n",
" model.load_weights(checkpoint_path)"
]
},
{
"cell_type": "code",
2021-08-06 20:20:52 +02:00
"execution_count": 20,
"id": "9b57245b",
"metadata": {},
"outputs": [
{
"data": {
2021-08-06 20:20:52 +02:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXUAAAEICAYAAACgQWTXAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAA1BklEQVR4nO3deXwV1dnA8d+TfQ8kLIGEXXbCEgIoiqJiRaQgKihgFbXuS7XF1loXqrWvC1XfvlWrrXVXpKiIAqKoiDv7DpEkBEkIBLLv63n/mCFcYpab5Ga5N8/388knc2fOzH3mTng498w5Z8QYg1JKKc/g1dYBKKWUch1N6kop5UE0qSullAfRpK6UUh5Ek7pSSnkQTepKKeVBNKl7EBFZLSLXuLpsWxKRFBGZ0gLHNSJymr38TxF5wJmyTXif+SLySVPjVKqxRPupty0RKXB4GQSUApX265uMMW+2flTth4ikAL82xqx18XENMNAYk+iqsiLSFzgA+BpjKlwSqFKN5NPWAXR0xpiQE8v1JTAR8dFEodoL/Xtsv7T5pZ0SkckikioifxCRI8DLItJZRD4SkWMikm0vxzjss05Efm0vLxCRr0VksV32gIhc1MSy/URkvYjki8haEXlWRN6oI25nYnxERL6xj/eJiHRx2P4rETkoIpki8qd6Pp8JInJERLwd1s0SkR328ngR+U5EckQkXUT+ISJ+dRzrFRH5i8Pre+x9DovIdTXKXiwiW0UkT0QOicgih83r7d85IlIgImec+Gwd9p8oIhtFJNf+PdHZz6aRn3OEiLxsn0O2iCx32DZTRLbZ55AkIlPt9ac0dYnIohPXWUT62s1Q14vIT8Dn9vr/2tch1/4bGe6wf6CI/M2+nrn231igiKwUkTtqnM8OEZlV27mqxtGk3r5FARFAH+BGrOv1sv26N1AM/KOe/ScACUAX4AngJRGRJpR9C9gARAKLgF/V857OxDgPuBboBvgBCwFEZBjwvH38nvb7xVALY8wPQCFwXo3jvmUvVwJ32+dzBnA+cGs9cWPHMNWO5wJgIFCzPb8QuBroBFwM3CIil9jbzrZ/dzLGhBhjvqtx7AhgJfB3+9yeAlaKSGSNc/jZZ1OLhj7n17Ga84bbx3rajmE88Bpwj30OZwMpdbxHbc4BhgIX2q9XY31O3YAtgGNz4WJgLDAR6+/490AV8Cpw1YlCIjIKiMb6bFRzGWP0p538YP3jmmIvTwbKgIB6yo8Gsh1er8NqvgFYACQ6bAsCDBDVmLJYCaMCCHLY/gbwhpPnVFuM9zu8vhX42F5+EFjisC3Y/gym1HHsvwD/sZdDsRJunzrK3gW87/DaAKfZy68Af7GX/wM85lBukGPZWo77DPC0vdzXLuvjsH0B8LW9/CtgQ439vwMWNPTZNOZzBnpgJc/OtZR74US89f392a8XnbjODufWv54YOtllwrH+0ykGRtVSLgDIxrpPAVbyf64l/k11xB+tqbdvx4wxJSdeiEiQiLxgf53Nw/q638mxCaKGIycWjDFF9mJII8v2BLIc1gEcqitgJ2M84rBc5BBTT8djG2MKgcy63gurVn6piPgDlwJbjDEH7TgG2U0SR+w4/opVa2/IKTEAB2uc3wQR+cJu9sgFbnbyuCeOfbDGuoNYtdQT6vpsTtHA59wL65pl17JrLyDJyXhrU/3ZiIi3iDxmN+HkcbLG38X+Cajtvey/6XeAq0TEC5iL9c1CuYAm9fatZtek3wGDgQnGmDBOft2vq0nFFdKBCBEJcljXq57yzYkx3fHY9ntG1lXYGLMHKylexKlNL2A14+zDqg2GAfc1JQasbyqO3gJWAL2MMeHAPx2O21BXssNYzSWOegNpTsRVU32f8yGsa9aplv0OAQPqOGYh1re0E6JqKeN4jvOAmVhNVOFYtfkTMRwHSup5r1eB+VjNYkWmRlOVajpN6u4lFOsrbY7dPvtQS7+hXfPdBCwSET8ROQP4ZQvFuAyYLiJn2Tc1H6bhv9G3gN9gJbX/1ogjDygQkSHALU7GsBRYICLD7P9UasYfilULLrHbp+c5bDuG1ezRv45jrwIGicg8EfERkSuAYcBHTsZWM45aP2djTDpWW/dz9g1VXxE5kfRfAq4VkfNFxEtEou3PB2AbcKVdPh643IkYSrG+TQVhfRs6EUMVVlPWUyLS067Vn2F/q8JO4lXA39BauktpUncvzwCBWLWg74GPW+l952PdbMzEasd+B+sfc22eoYkxGmN2A7dhJep0rHbX1AZ2exvr5t3nxpjjDusXYiXcfOBfdszOxLDaPofPgUT7t6NbgYdFJB/rHsBSh32LgEeBb8TqdXN6jWNnAtOxatmZWDcOp9eI21nPUP/n/CugHOvbSgbWPQWMMRuwbsQ+DeQCX3Ly28MDWDXrbODPnPrNpzavYX1TSgP22HE4WgjsBDYCWcDjnJpzXgNise7RKBfRwUeq0UTkHWCfMabFvykozyUiVwM3GmPOautYPInW1FWDRGSciAywv65PxWpHXd7GYSk3Zjdt3Qq82NaxeJoGk7qI/EdEMkRkVx3bRUT+LiKJ9gCCONeHqdpYFFZ3uwKsPta3GGO2tmlEym2JyIVY9x+O0nATj2qkBptf7BssBcBrxpgRtWyfBtwBTMMawPK/xpgJLRCrUkqpBjRYUzfGrMe6yVGXmVgJ3xhjvsfqK9vDVQEqpZRynism9Irm1MEaqfa69JoFReRGrOHuBAcHjx0yZEjNIkoppeqxefPm48aYrnVtb9VZGo0xL2LfGImPjzebNm1qzbdXSim3JyI1RyWfwhW9X9I4dQReDE0bIaeUUqqZXJHUVwBX271gTgdy7RFtSimlWlmDzS8i8jbWjIFdRCQVaziyL4Ax5p9YQ5+nYY2+K8IaraaUUqoNNJjUjTFzG9husIZ2K6XcRHl5OampqZSUlDRcWLWJgIAAYmJi8PX1bdR++jg7pTqg1NRUQkND6du3L3U/N0W1FWMMmZmZpKam0q9fv0btq9MEKNUBlZSUEBkZqQm9nRIRIiMjm/RNSpO6Uh2UJvT2ranXR5O6Ukp5EE3qSqlWl5mZyejRoxk9ejRRUVFER0dXvy4rK6t3302bNnHnnXc2+B4TJ050VbhuRW+UKqVaXWRkJNu2bQNg0aJFhISEsHDhwurtFRUV+PjUnp7i4+OJj49v8D2+/fZbl8TqbrSmrpRqFxYsWMDNN9/MhAkT+P3vf8+GDRs444wzGDNmDBMnTiQhIQGAdevWMX36dMD6D+G6665j8uTJ9O/fn7///e/VxwsJCakuP3nyZC6//HKGDBnC/PnzOTE77apVqxgyZAhjx47lzjvvrD6uo5SUFCZNmkRcXBxxcXGn/Gfx+OOPExsby6hRo7j33nsBSExMZMqUKYwaNYq4uDiSkprznO/G05q6Uh3cnz/czZ7DeS495rCeYTz0y+GN3i81NZVvv/0Wb29v8vLy+Oqrr/Dx8WHt2rXcd999vPvuuz/bZ9++fXzxxRfk5+czePBgbrnllp/17d66dSu7d++mZ8+enHnmmXzzzTfEx8dz0003sX79evr168fcubUPyenWrRuffvopAQEB7N+/n7lz57Jp0yZWr17NBx98wA8//EBQUBBZWdZktvPnz+fee+9l1qxZlJSUUFVV1ejPoTk0qSul2o3Zs2fj7e0NQG5uLtdccw379+9HRCgvL691n4svvhh/f3/8/f3p1q0bR48eJSYm5pQy48ePr143evRoUlJSCAkJoX///tX9wOfOncuLL/78QUzl5eXcfvvtbNu2DW9vb3788UcA1q5dy7XXXktQUBAAERER5Ofnk5aWxqxZswBrAFFr06SuVAfXlBp1SwkODq5efuCBBzj33HN5//33SUlJYfLkybXu4+/vX73s7e1NRUVFk8rU5emnn6Z79+5s376dqqqqNknUjaFt6kqpdik3N5fo6GgAXnnlFZcff/DgwSQnJ5OSkgLAO++8U2ccPXr0wMvLi9dff53KykoALrjgAl5++WWKiooAyMrKIjQ0lJiYGJYvXw5AaWlp9fbWokldKdUu/f73v+ePf/wjY8aMaVT
"text/plain": [
2021-08-06 20:20:52 +02:00
"<Figure size 432x288 with 1 Axes>"
]
},
2021-08-06 20:20:52 +02:00
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
2021-08-06 20:20:52 +02:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXUAAAEICAYAAACgQWTXAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAA8f0lEQVR4nO3dd3wUdfrA8c+zyaaHkAYEQu9KiwRQUQTBs4O9F372cnq2s56Kelju1FPPiucdeucJFkTEdkpRUESKgPQmJZSQBNJ78vz+2CWGkLpsym6et6997ezMd77zzI48mf3Od74jqooxxhj/4GjuAIwxxniPJXVjjPEjltSNMcaPWFI3xhg/YkndGGP8iCV1Y4zxI5bUDQAi8oWIXO3tss1JRLaJyLhGqFdFpJd7+nURebg+ZT3YzuUi8j9P46yl3tEikuLtek3LENjcARjPiUhupY9hQBFQ5v58o6q+W9+6VPX0xijr71T1Jm/UIyLdgF8Bp6qWuut+F6j3MTQGLKn7NFWNODgtItuA61T1m6rlRCTwYKIwxvg3a37xQwd/XovIfSKyF/iXiESLyGwRSRORA+7pxErrzBeR69zTE0VkoYg86y77q4ic7mHZ7iLynYjkiMg3IvKKiPynhrjrE+MTIvK9u77/iUhcpeVXish2EckQkYdq+X5GiMheEQmoNO9cEVnlnh4uIotEJFNE9ojIyyISVENdU0Xkz5U+/9G9zm4RuaZK2TNF5GcRyRaRnSIyqdLi79zvmSKSKyLHHfxuK61/vIgsEZEs9/vx9f1uaiMi/d3rZ4rIGhEZX2nZGSKy1l3nLhG5xz0/zn18MkVkv4gsEBHLJy2AHQT/1QGIAboCN+A61v9yf+4CFAAv17L+CGADEAf8BXhLRMSDsv8FfgJigUnAlbVssz4xXgb8H9AOCAIOJpmjgNfc9Xd0by+RaqjqYiAPOLlKvf91T5cBd7r35zhgLHBLLXHjjuE0dzynAL2Bqu35ecBVQFvgTOBmETnHvWyU+72tqkao6qIqdccAnwEvuffteeAzEYmtsg+HfTd1xOwEPgX+517vNuBdEenrLvIWrqa8SGAAMNc9/24gBYgH2gMPAjbmSAtgSd1/lQOPqmqRqhaoaoaqfqSq+aqaA0wGTqpl/e2q+qaqlgFvAwm4/vHWu6yIdAGGAY+oarGqLgRm1bTBesb4L1XdqKoFwPvAEPf8C4DZqvqdqhYBD7u/g5q8B1wKICKRwBnueajqMlX9UVVLVXUb8EY1cVTnInd8q1U1D9cfscr7N19Vf1HVclVd5d5efeoF1x+BTar6b3dc7wHrgbMrlanpu6nNsUAE8LT7GM0FZuP+boAS4CgRaaOqB1R1eaX5CUBXVS1R1QVqA0m1CJbU/VeaqhYe/CAiYSLyhrt5IhvXz/22lZsgqth7cEJV892TEQ0s2xHYX2kewM6aAq5njHsrTedXiqlj5brdSTWjpm3hOis/T0SCgfOA5aq63R1HH3fTwl53HE/iOmuvyyExANur7N8IEZnnbl7KAm6qZ70H695eZd52oFOlzzV9N3XGrKqV/wBWrvd8XH/wtovItyJynHv+X4HNwP9EZKuI3F+/3TCNzZK6/6p61nQ30BcYoapt+O3nfk1NKt6wB4gRkbBK8zrXUv5IYtxTuW73NmNrKqyqa3Elr9M5tOkFXM0464He7jge9CQGXE1Ilf0X1y+VzqoaBbxeqd66znJ342qWqqwLsKsecdVVb+cq7eEV9arqElWdgKtpZiauXwCoao6q3q2qPYDxwF0iMvYIYzFeYEm99YjE1Uad6W6ffbSxN+g+810KTBKRIPdZ3tm1rHIkMX4InCUiJ7gvaj5O3f9//xf4A64/Hh9UiSMbyBWRfsDN9YzhfWCiiBzl/qNSNf5IXL9cCkVkOK4/Jgel4Wou6lFD3Z8DfUTkMhEJFJGLgaNwNZUcicW4zurvFRGniIzGdYymuY/Z5SISpaoluL6TcgAROUtEermvnWThug5RW3OXaSKW1FuPF4BQIB34EfiyibZ7Oa6LjRnAn4HpuPrTV+cFPIxRVdcAt+JK1HuAA7gu5NXmYJv2XFVNrzT/HlwJNwd40x1zfWL4wr0Pc3E1TcytUuQW4HERyQEewX3W6143H9c1hO/dPUqOrVJ3BnAWrl8zGcC9wFlV4m4wVS3GlcRPx/W9vwpcparr3UWuBLa5m6FuwnU8wXUh+BsgF1gEvKqq844kFuMdYtc2TFMSkenAelVt9F8KxrRGdqZuGpWIDBORniLicHf5m4CrbdYY0wjsjlLT2DoAM3BdtEwBblbVn5s3JGP8V72aX8R1C3oOroshpaqaXGW5AC/i6vqUD0ys1J/VGGNME2nImfqYWi7KnI7rwklvXHcXvuZ+N8YY04S81fwyAXjHfUfZjyLSVkQSVHVPTSvExcVpt27dvLR5Y4xpHZYtW5auqvE1La9vUldcd44p8IaqTqmyvBOH3kmX4p53SFIXkRtwjUNCly5dWLp0aT03b4wxBkBEqt5ZfIj69n45QVWPwdXMcquIjKprheqo6hRVTVbV5Pj4Gv/QGGOM8VC9krqqHrxleB/wMTC8SpFdHHp7dCJHfvuyMcaYBqozqYtIuHsUO0QkHPgdsLpKsVnAVeJyLJBVW3u6McaYxlGfNvX2wMfu4bEDgf+q6pcichOAqr6Oa1yKM3DdGp2Pa0xnY0wLU1JSQkpKCoWFhXUXNs0qJCSExMREnE5ng9arM6mr6lZgcDXzX680rbjG3TDGtGApKSlERkbSrVs3an7miWluqkpGRgYpKSl07969QevaMAHGtCKFhYXExsZaQm/hRITY2FiPflFZUjemlbGE7hs8PU4+l9Q3HdjEc0ufI78kv+7CxhjTyvhcUt+du5upa6ayNmNtc4dijGmgjIwMhgwZwpAhQ+jQoQOdOnWq+FxcXFzrukuXLuX222+vcxvHH3+8V2KdP38+Z511llfqako+N0rjwPiBAKxKX0Vyh+Q6ShtjWpLY2FhWrFgBwKRJk4iIiOCee+6pWF5aWkpgYPVpKTk5meTkuv/N//DDD16J1Vf53Jl6TEgMnSM7syptVXOHYozxgokTJ3LTTTcxYsQI7r33Xn766SeOO+44kpKSOP7449mwYQNw6JnzpEmTuOaaaxg9ejQ9evTgpZdeqqgvIiKiovzo0aO54IIL6NevH5dffjkHR6X9/PPP6devH0OHDuX222+v84x8//79nHPOOQwaNIhjjz2WVatc+efbb7+t+KWRlJRETk4Oe/bsYdSoUQwZMoQBAwawYMECr39ntfG5M3WAQfGDWLxnMapqF32M8dAzPz3D+v3r6y7YAP1i+nHf8PsavF5KSgo//PADAQEBZGdns2DBAgIDA/nmm2948MEH+eijjw5bZ/369cybN4+cnBz69u3LzTfffFif7p9//pk1a9bQsWNHRo4cyffff09ycjI33ngj3333Hd27d+fSSy+tM75HH32UpKQkZs6cydy5c7nqqqtYsWIFzz77LK+88gojR44kNzeXkJAQpkyZwqmnnspDDz1EWVkZ+flNe/3P587UAQbFDSK9IJ29eXubOxRjjBdceOGFBAQEAJCVlcWFF17IgAEDuPPOO1mzZk2165x55pkEBwcTFxdHu3btSE1NPazM8OHDSUxMxOFwMGTIELZt28b69evp0aNHRf/v+iT1hQsXcuWVVwJw8sknk5GRQXZ2NiNHjuSuu+7ipZdeIjMzk8DAQIYNG8a//vUvJk2axC+//EJkZKSnX4tHfPJMfXC8616olekrSYhIaOZojPFNnpxRN5bw8PCK6YcffpgxY8bw8ccfs23bNkaPHl3tOsHBwRXTAQEBlJaWelTmSNx///2ceeaZfP7554wcOZKvvvqKUaNG8d133/HZZ58xceJE7rrrLq666iqvbrc2Pnmm3ie6D8EBwdaubowfysrKolOnTgBMnTrV6/X37duXrVu3sm3bNgCmT59e5zonnngi7777LuBqq4+Li6NNmzZs2bKFgQMHct999zFs2DDWr1/P9u3bad++Pddffz3XXXcdy5c37UPgfC6
"text/plain": [
2021-08-06 20:20:52 +02:00
"<Figure size 432x288 with 1 Axes>"
]
},
2021-08-06 20:20:52 +02:00
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
2021-08-06 20:20:52 +02:00
"def plot_keras_history(history, name='', acc='acc'):\n",
" \"\"\"Plots keras history.\"\"\"\n",
" import matplotlib.pyplot as plt\n",
"\n",
" training_acc = history.history[acc]\n",
" validation_acc = history.history['val_' + acc]\n",
" loss = history.history['loss']\n",
" val_loss = history.history['val_loss']\n",
"\n",
" epochs = range(len(training_acc))\n",
"\n",
" plt.ylim(0, 1)\n",
" plt.plot(epochs, training_acc, 'tab:blue', label='Training acc')\n",
" plt.plot(epochs, validation_acc, 'tab:orange', label='Validation acc')\n",
" plt.title('Training and validation accuracy ' + name)\n",
" plt.legend()\n",
"\n",
" plt.figure()\n",
"\n",
" plt.plot(epochs, loss, 'tab:green', label='Training loss')\n",
" plt.plot(epochs, val_loss, 'tab:red', label='Validation loss')\n",
" plt.title('Training and validation loss ' + name)\n",
" plt.legend()\n",
" plt.show()\n",
" plt.close()\n",
"plot_keras_history(history)"
]
},
{
"cell_type": "code",
"execution_count": null,
2021-08-06 20:20:52 +02:00
"id": "4960be86",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2021-07-14 10:15:52 +02:00
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}