2021-06-23 10:12:11 +02:00
|
|
|
{
|
|
|
|
"cells": [
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 1,
|
2021-07-19 01:21:53 +02:00
|
|
|
"id": "f9b2f6c2",
|
2021-06-23 10:12:11 +02:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"glob_path = '/opt/iui-datarelease2-sose2021/*/split_letters_csv/*'\n",
|
|
|
|
"\n",
|
|
|
|
"pickle_file = 'data.pickle'"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 2,
|
2021-07-19 01:21:53 +02:00
|
|
|
"id": "eb49db4b",
|
2021-06-23 10:12:11 +02:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"from glob import glob\n",
|
|
|
|
"import pandas as pd\n",
|
|
|
|
"from tqdm import tqdm\n",
|
|
|
|
"\n",
|
|
|
|
"def dl_from_blob(filename) -> list:\n",
|
|
|
|
" all_data = []\n",
|
|
|
|
" \n",
|
|
|
|
" for path in tqdm(glob(filename)):\n",
|
|
|
|
" path = path\n",
|
|
|
|
" df = pd.read_csv(path, ';')\n",
|
|
|
|
" u = path.split('/')[3]\n",
|
|
|
|
" l = ''.join(filter(lambda x: x.isalpha(), path.split('/')[5]))[0] \n",
|
|
|
|
" d = {\n",
|
|
|
|
" 'file': path,\n",
|
|
|
|
" 'data': df,\n",
|
|
|
|
" 'user': u,\n",
|
|
|
|
" 'label': l\n",
|
|
|
|
" }\n",
|
|
|
|
" all_data.append(d)\n",
|
|
|
|
" return all_data"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 3,
|
2021-07-19 01:21:53 +02:00
|
|
|
"id": "daefd4a8",
|
2021-06-23 10:12:11 +02:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"def save_pickle(f, structure):\n",
|
|
|
|
" _p = open(f, 'wb')\n",
|
|
|
|
" pickle.dump(structure, _p)\n",
|
|
|
|
" _p.close()"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 4,
|
2021-07-19 01:21:53 +02:00
|
|
|
"id": "3b04c1ee",
|
2021-06-23 10:12:11 +02:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"import pickle\n",
|
|
|
|
"\n",
|
|
|
|
"def load_pickles(f) -> list:\n",
|
|
|
|
" _p = open(pickle_file, 'rb')\n",
|
|
|
|
" _d = pickle.load(_p)\n",
|
|
|
|
" _p.close()\n",
|
|
|
|
" \n",
|
|
|
|
" return _d"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 5,
|
2021-07-19 01:21:53 +02:00
|
|
|
"id": "5cf901e4",
|
2021-06-23 10:12:11 +02:00
|
|
|
"metadata": {},
|
2021-07-05 15:01:40 +02:00
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"name": "stdout",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
|
|
|
"Loading data...\n",
|
|
|
|
"data.pickle found...\n"
|
|
|
|
]
|
|
|
|
}
|
|
|
|
],
|
2021-06-23 10:12:11 +02:00
|
|
|
"source": [
|
|
|
|
"import os\n",
|
|
|
|
"def load_data() -> list:\n",
|
|
|
|
" if os.path.isfile(pickle_file):\n",
|
|
|
|
" print(f'{pickle_file} found...')\n",
|
|
|
|
" return load_pickles(pickle_file)\n",
|
|
|
|
" print(f'Didn\\'t find {pickle_file}...')\n",
|
|
|
|
" all_data = dl_from_blob(glob_path)\n",
|
|
|
|
" print(f'Creating {pickle_file}...')\n",
|
|
|
|
" save_pickle(pickle_file, all_data)\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
" return all_data\n",
|
|
|
|
"\n",
|
|
|
|
"print(\"Loading data...\")\n",
|
|
|
|
"data = load_data()\n",
|
|
|
|
"# plot_pd(data[0]['data'], False)"
|
2021-06-23 10:12:11 +02:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 6,
|
2021-07-19 01:21:53 +02:00
|
|
|
"id": "a68cb0bb",
|
2021-06-23 10:12:11 +02:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"import matplotlib.pyplot as plt\n",
|
|
|
|
"\n",
|
|
|
|
"def plot_pd(data, force=True):\n",
|
|
|
|
" fig, axs = plt.subplots(5, 3, figsize=(3*3, 3*5))\n",
|
|
|
|
" axs[0][0].plot(data['Acc1 X'])\n",
|
|
|
|
" axs[0][1].plot(data['Acc1 Y'])\n",
|
|
|
|
" axs[0][2].plot(data['Acc1 Z'])\n",
|
|
|
|
" axs[1][0].plot(data['Acc2 X'])\n",
|
|
|
|
" axs[1][1].plot(data['Acc2 Y'])\n",
|
|
|
|
" axs[1][2].plot(data['Acc2 Z'])\n",
|
|
|
|
" axs[2][0].plot(data['Gyro X'])\n",
|
|
|
|
" axs[2][1].plot(data['Gyro Y'])\n",
|
|
|
|
" axs[2][2].plot(data['Gyro Z'])\n",
|
|
|
|
" axs[3][0].plot(data['Mag X'])\n",
|
|
|
|
" axs[3][1].plot(data['Mag Y'])\n",
|
|
|
|
" axs[3][2].plot(data['Mag Z'])\n",
|
|
|
|
" axs[4][0].plot(data['Time'])\n",
|
|
|
|
"\n",
|
|
|
|
" if force:\n",
|
|
|
|
" for a in axs:\n",
|
|
|
|
" for b in a:\n",
|
|
|
|
" b.plot(data['Force'])\n",
|
|
|
|
" else:\n",
|
|
|
|
" axs[4][1].plot(data['Force'])\n",
|
|
|
|
"\n",
|
|
|
|
"def plot_np(data, force=True):\n",
|
|
|
|
" fig, axs = plt.subplots(5, 3, figsize=(3*3, 3*5))\n",
|
|
|
|
" axs[0][0].plot(data[0])\n",
|
|
|
|
" axs[0][1].plot(data[1])\n",
|
|
|
|
" axs[0][2].plot(data[2])\n",
|
|
|
|
" axs[1][0].plot(data[3])\n",
|
|
|
|
" axs[1][1].plot(data[4])\n",
|
|
|
|
" axs[1][2].plot(data[5])\n",
|
|
|
|
" axs[2][0].plot(data[6])\n",
|
|
|
|
" axs[2][1].plot(data[7])\n",
|
|
|
|
" axs[2][2].plot(data[8])\n",
|
|
|
|
" axs[3][0].plot(data[9])\n",
|
|
|
|
" axs[3][1].plot(data[10])\n",
|
|
|
|
" axs[3][2].plot(data[11])\n",
|
|
|
|
" axs[4][0].plot(data[13])\n",
|
|
|
|
"\n",
|
|
|
|
" if force:\n",
|
|
|
|
" for a in axs:\n",
|
|
|
|
" for b in a:\n",
|
|
|
|
" b.plot(data[12])\n",
|
|
|
|
" else:\n",
|
|
|
|
" axs[4][1].plot(data[12])\n"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 7,
|
2021-07-19 01:21:53 +02:00
|
|
|
"id": "ae002a37",
|
2021-06-23 10:12:11 +02:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"def mill_drop(entry):\n",
|
|
|
|
" #drop millis on single\n",
|
|
|
|
" data_wo_mill = entry['data'].drop(labels='Millis', axis=1, inplace=False)\n",
|
|
|
|
" drop_entry = entry\n",
|
|
|
|
" drop_entry['data'] = data_wo_mill.reset_index(drop=True)\n",
|
|
|
|
" \n",
|
|
|
|
" return drop_entry"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 8,
|
2021-07-19 01:21:53 +02:00
|
|
|
"id": "0d0b3544",
|
2021-06-23 10:12:11 +02:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"import numpy as np\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
"\n",
|
2021-06-23 10:12:11 +02:00
|
|
|
"def cut_force(drop_entry):\n",
|
|
|
|
" # force trans\n",
|
|
|
|
" shorten_entry = drop_entry\n",
|
|
|
|
" shorten_data = shorten_entry['data']\n",
|
|
|
|
" sf_entry = shorten_data['Force']\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
" leeway = 10\n",
|
2021-06-23 10:12:11 +02:00
|
|
|
" \n",
|
2021-07-05 15:01:40 +02:00
|
|
|
" try:\n",
|
|
|
|
" thresh = 70\n",
|
|
|
|
" temps_over_T = np.where(sf_entry > thresh)[0]\n",
|
|
|
|
" shorten_data = shorten_data[max(temps_over_T.min()-leeway,0):min(len(sf_entry)-1,temps_over_T.max()+leeway)]\n",
|
|
|
|
" except:\n",
|
|
|
|
" thresold = 0.05\n",
|
|
|
|
" thresh = sf_entry.max()*thresold\n",
|
|
|
|
" temps_over_T = np.where(sf_entry > thresh)[0]\n",
|
|
|
|
" shorten_data = shorten_data[max(temps_over_T.min()-leeway,0):min(len(sf_entry)-1,temps_over_T.max()+leeway)]\n",
|
2021-06-23 10:12:11 +02:00
|
|
|
" \n",
|
|
|
|
" shorten_entry['data'] = shorten_data.reset_index(drop=True)\n",
|
|
|
|
" return shorten_entry"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 9,
|
2021-07-19 01:21:53 +02:00
|
|
|
"id": "d371d6e9",
|
2021-06-23 10:12:11 +02:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"def norm_force(shorten_entry, flist):\n",
|
|
|
|
" fnorm_entry = shorten_entry\n",
|
|
|
|
" u = fnorm_entry['user']\n",
|
|
|
|
" d = fnorm_entry['data']\n",
|
|
|
|
" \n",
|
|
|
|
" \n",
|
|
|
|
" d['Force'] = ((d['Force'] - flist[u].mean())/flist[u].std())\n",
|
|
|
|
" \n",
|
|
|
|
" fnorm_entry['data'] = fnorm_entry['data'].reset_index(drop=True)\n",
|
|
|
|
" return fnorm_entry"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 10,
|
2021-07-19 01:21:53 +02:00
|
|
|
"id": "1c14b2a1",
|
2021-06-23 10:12:11 +02:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"def time_trans(fnorm_entry):\n",
|
|
|
|
" #timetrans\n",
|
|
|
|
" time_entry = fnorm_entry\n",
|
|
|
|
" \n",
|
|
|
|
" time_entry['data']['Time'] = fnorm_entry['data']['Time']-fnorm_entry['data']['Time'][0]\n",
|
|
|
|
" \n",
|
|
|
|
" time_entry['data'] = time_entry['data'].reset_index(drop=True)\n",
|
|
|
|
"\n",
|
|
|
|
" return time_entry"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 11,
|
2021-07-19 01:21:53 +02:00
|
|
|
"id": "189de319",
|
2021-06-23 10:12:11 +02:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"def norm(time_entry):\n",
|
|
|
|
" # normalize\n",
|
|
|
|
" norm_entry = time_entry\n",
|
|
|
|
" \n",
|
|
|
|
" norm_entry['data']['Acc1 X'] = norm_entry['data']['Acc1 X'] / 32768\n",
|
|
|
|
" norm_entry['data']['Acc1 Y'] = norm_entry['data']['Acc1 Y'] / 32768\n",
|
|
|
|
" norm_entry['data']['Acc1 Z'] = norm_entry['data']['Acc1 Z'] / 32768\n",
|
|
|
|
" norm_entry['data']['Acc2 X'] = norm_entry['data']['Acc2 X'] / 8192\n",
|
|
|
|
" norm_entry['data']['Acc2 Y'] = norm_entry['data']['Acc2 Y'] / 8192\n",
|
|
|
|
" norm_entry['data']['Acc2 Z'] = norm_entry['data']['Acc2 Z'] / 8192\n",
|
|
|
|
" norm_entry['data']['Gyro X'] = norm_entry['data']['Gyro X'] / 32768\n",
|
|
|
|
" norm_entry['data']['Gyro Y'] = norm_entry['data']['Gyro Y'] / 32768\n",
|
|
|
|
" norm_entry['data']['Gyro Z'] = norm_entry['data']['Gyro Z'] / 32768\n",
|
|
|
|
" norm_entry['data']['Mag X'] = norm_entry['data']['Mag X'] / 8192\n",
|
|
|
|
" norm_entry['data']['Mag Y'] = norm_entry['data']['Mag Y'] / 8192\n",
|
|
|
|
" norm_entry['data']['Mag Z'] = norm_entry['data']['Mag Z'] / 8192\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
"# norm_entry['data']['Mag Z'] = norm_entry['data']['Mag Z'] / 4096\n",
|
2021-06-23 10:12:11 +02:00
|
|
|
" \n",
|
|
|
|
" norm_entry['data'] = norm_entry['data'].reset_index(drop=True)\n",
|
|
|
|
" \n",
|
|
|
|
" return norm_entry"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 12,
|
2021-07-19 01:21:53 +02:00
|
|
|
"id": "a796b9b2",
|
2021-06-23 10:12:11 +02:00
|
|
|
"metadata": {},
|
2021-07-19 01:21:53 +02:00
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"name": "stdout",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
|
|
|
"Preprocessing...\n"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"name": "stderr",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
|
|
|
"100%|██████████| 26179/26179 [01:30<00:00, 290.22it/s]\n"
|
|
|
|
]
|
|
|
|
}
|
|
|
|
],
|
2021-06-23 10:12:11 +02:00
|
|
|
"source": [
|
|
|
|
"def preproc(d):\n",
|
|
|
|
" flist = {} \n",
|
|
|
|
" d_res = []\n",
|
|
|
|
" for e in data:\n",
|
|
|
|
" if e['user'] not in flist:\n",
|
|
|
|
" flist[e['user']] = e['data']['Force']\n",
|
|
|
|
" else:\n",
|
|
|
|
" flist[e['user']] = flist[e['user']].append(e['data']['Force'])\n",
|
|
|
|
" \n",
|
|
|
|
" for e in tqdm(data):\n",
|
|
|
|
" d_res.append(preproc_entry(e, flist))\n",
|
|
|
|
" return d_res\n",
|
|
|
|
" \n",
|
|
|
|
"def preproc_entry(entry, flist):\n",
|
|
|
|
" drop_entry = mill_drop(entry)\n",
|
|
|
|
"# plot_pd(drop_entry['data'])\n",
|
|
|
|
"# \n",
|
|
|
|
" shorten_entry = cut_force(drop_entry)\n",
|
|
|
|
"# plot_pd(shorten_entry['data'])\n",
|
|
|
|
"# \n",
|
|
|
|
" fnorm_entry = norm_force(shorten_entry, flist)\n",
|
|
|
|
"# plot_pd(fnorm_entry['data'])\n",
|
|
|
|
"# \n",
|
2021-07-05 15:01:40 +02:00
|
|
|
" time_entry = time_trans(shorten_entry)\n",
|
2021-06-23 10:12:11 +02:00
|
|
|
"# plot_pd(time_entry['data'])\n",
|
|
|
|
"# \n",
|
|
|
|
" norm_entry = norm(time_entry)\n",
|
|
|
|
"# plot_pd(norm_entry['data'], False)\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
" return norm_entry\n",
|
|
|
|
"\n",
|
|
|
|
"print(\"Preprocessing...\")\n",
|
|
|
|
"pdata = preproc(data)\n",
|
|
|
|
"# plot_pd(pdata[0]['data'], False)"
|
2021-06-23 10:12:11 +02:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 13,
|
2021-07-19 01:21:53 +02:00
|
|
|
"id": "d3e56332",
|
2021-06-23 10:12:11 +02:00
|
|
|
"metadata": {},
|
2021-07-19 01:21:53 +02:00
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"name": "stdout",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
|
|
|
"Truncating...\n"
|
|
|
|
]
|
|
|
|
}
|
|
|
|
],
|
2021-06-23 10:12:11 +02:00
|
|
|
"source": [
|
|
|
|
"def throw(pdata):\n",
|
|
|
|
" llist = pd.Series([len(x['data']) for x in pdata])\n",
|
|
|
|
" threshold = int(llist.quantile(threshold_p))\n",
|
|
|
|
" longdex = np.where(llist <= threshold)[0]\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
" return np.array(pdata)[longdex]\n",
|
|
|
|
"\n",
|
|
|
|
"llist = pd.Series([len(x['data']) for x in pdata])\n",
|
|
|
|
"threshold_p = 0.75\n",
|
|
|
|
"threshold = int(llist.quantile(threshold_p))\n",
|
|
|
|
"\n",
|
|
|
|
"print(\"Truncating...\")\n",
|
|
|
|
"tpdata = throw(pdata)\n",
|
|
|
|
"# plot_pd(tpdata[0]['data'], False)"
|
2021-06-23 10:12:11 +02:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 14,
|
2021-07-19 01:21:53 +02:00
|
|
|
"id": "dabc3af0",
|
2021-06-23 10:12:11 +02:00
|
|
|
"metadata": {},
|
2021-07-19 01:21:53 +02:00
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"name": "stderr",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
|
|
|
" 19%|█▉ | 3723/19640 [00:00<00:00, 18633.70it/s]"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"name": "stdout",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
|
|
|
"Padding...\n"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"name": "stderr",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
|
|
|
"100%|██████████| 19640/19640 [00:01<00:00, 18655.32it/s]\n"
|
|
|
|
]
|
|
|
|
}
|
|
|
|
],
|
2021-06-23 10:12:11 +02:00
|
|
|
"source": [
|
|
|
|
"from tensorflow.keras.preprocessing.sequence import pad_sequences\n",
|
|
|
|
"# ltpdata = []\n",
|
|
|
|
"def elong(tpdata):\n",
|
|
|
|
" for x in tqdm(tpdata):\n",
|
|
|
|
" y = x['data'].to_numpy().T\n",
|
|
|
|
" x['data'] = pad_sequences(y, dtype=float, padding='post', maxlen=threshold)\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
" return tpdata\n",
|
|
|
|
"\n",
|
|
|
|
"print(\"Padding...\")\n",
|
|
|
|
"ltpdata = elong(tpdata)\n",
|
|
|
|
"# plot_np(ltpdata[0]['data'], False)"
|
2021-06-23 10:12:11 +02:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 15,
|
2021-07-19 01:21:53 +02:00
|
|
|
"id": "17fece5a",
|
2021-06-23 10:12:11 +02:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"import tensorflow as tf\n",
|
|
|
|
"from tensorflow.keras.models import Sequential\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
"from tensorflow.keras.layers import Dense, Flatten, BatchNormalization, Dropout, Conv2D, MaxPooling2D\n",
|
2021-06-23 10:12:11 +02:00
|
|
|
"\n",
|
|
|
|
"\n",
|
|
|
|
"def build_model():\n",
|
|
|
|
" model = Sequential()\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
" ncount = train_shape[0]*train_shape[1]\n",
|
|
|
|
" \n",
|
2021-06-23 10:12:11 +02:00
|
|
|
" model.add(Flatten(input_shape=train_shape))\n",
|
|
|
|
" \n",
|
2021-07-05 15:01:40 +02:00
|
|
|
" model.add(BatchNormalization())\n",
|
|
|
|
" \n",
|
2021-06-23 10:12:11 +02:00
|
|
|
" model.add(Dropout(0.1))\n",
|
|
|
|
" \n",
|
2021-07-19 01:21:53 +02:00
|
|
|
" for i in range(1,5):\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
" model.add(Dense(int(ncount/i), activation='relu'))\n",
|
|
|
|
" model.add(Dropout(0.1))\n",
|
2021-06-23 10:12:11 +02:00
|
|
|
" \n",
|
|
|
|
" model.add(Dense(classes, activation='softmax'))\n",
|
|
|
|
"\n",
|
|
|
|
" model.compile(\n",
|
|
|
|
" optimizer=tf.keras.optimizers.Adam(0.001),\n",
|
|
|
|
" loss=\"categorical_crossentropy\", \n",
|
|
|
|
" metrics=[\"acc\"],\n",
|
|
|
|
" )\n",
|
|
|
|
"\n",
|
|
|
|
" return model"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-07-19 01:21:53 +02:00
|
|
|
"execution_count": 24,
|
|
|
|
"id": "1ef39498",
|
2021-06-23 10:12:11 +02:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
2021-07-05 15:01:40 +02:00
|
|
|
"checkpoint_file = './goat.weights'\n",
|
|
|
|
"\n",
|
2021-06-23 10:12:11 +02:00
|
|
|
"def train(X_train, y_train, X_test, y_test):\n",
|
|
|
|
" model = build_model()\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
" \n",
|
|
|
|
" model.summary()\n",
|
|
|
|
" \n",
|
|
|
|
" model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(\n",
|
|
|
|
" filepath = checkpoint_file,\n",
|
|
|
|
" save_weights_only=True,\n",
|
|
|
|
" monitor='val_acc',\n",
|
|
|
|
" mode='max',\n",
|
|
|
|
" save_best_only=True\n",
|
|
|
|
" )\n",
|
|
|
|
" \n",
|
2021-07-19 01:21:53 +02:00
|
|
|
" history = model.fit(X_train, y_train, \n",
|
2021-07-05 15:01:40 +02:00
|
|
|
" epochs=30,\n",
|
2021-06-23 10:12:11 +02:00
|
|
|
" batch_size=256,\n",
|
|
|
|
" shuffle=True,\n",
|
|
|
|
" validation_data=(X_test, y_test),\n",
|
|
|
|
" verbose=1,\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
" callbacks=[model_checkpoint_callback]\n",
|
2021-06-23 10:12:11 +02:00
|
|
|
" )\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
" \n",
|
2021-06-23 10:12:11 +02:00
|
|
|
" print(\"Evaluate on test data\")\n",
|
|
|
|
" results = model.evaluate(X_test, y_test, batch_size=128, verbose=0)\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
" print(\"test loss, test acc:\", results)\n",
|
|
|
|
" return model, history"
|
2021-06-23 10:12:11 +02:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-07-19 01:21:53 +02:00
|
|
|
"execution_count": 25,
|
|
|
|
"id": "160ec98a",
|
2021-07-05 15:01:40 +02:00
|
|
|
"metadata": {
|
|
|
|
"tags": []
|
|
|
|
},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' # this is required\n",
|
|
|
|
"os.environ['CUDA_VISIBLE_DEVICES'] = '0' # set to '0' for GPU0, '1' for GPU1 or '2' for GPU2. Check \"gpustat\" in a terminal."
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-07-19 01:21:53 +02:00
|
|
|
"execution_count": 26,
|
|
|
|
"id": "a4799ab9",
|
2021-06-23 10:12:11 +02:00
|
|
|
"metadata": {},
|
2021-07-05 15:01:40 +02:00
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"from sklearn.model_selection import train_test_split\n",
|
|
|
|
"from sklearn.preprocessing import LabelEncoder, LabelBinarizer\n",
|
|
|
|
"\n",
|
|
|
|
"X = np.array([x['data'] for x in ltpdata])\n",
|
|
|
|
"y = np.array([x['label'] for x in ltpdata])\n",
|
|
|
|
"\n",
|
|
|
|
"lb = LabelBinarizer()\n",
|
|
|
|
"y_tran = lb.fit_transform(y)\n",
|
|
|
|
"\n",
|
|
|
|
"X_train, X_test, y_train, y_test = train_test_split(X, y_tran, test_size=0.2, random_state=177013)\n",
|
|
|
|
"\n",
|
|
|
|
"X_train=X_train.reshape(X_train.shape[0],X_train.shape[1],X_train.shape[2],1)\n",
|
|
|
|
"X_test=X_test.reshape(X_test.shape[0],X_test.shape[1],X_test.shape[2],1)\n",
|
|
|
|
"\n",
|
|
|
|
"train_shape = X_train[0].shape\n",
|
|
|
|
"classes = y_train[0].shape[0]"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-07-19 01:21:53 +02:00
|
|
|
"execution_count": 27,
|
|
|
|
"id": "e73dcbbb",
|
2021-07-05 15:01:40 +02:00
|
|
|
"metadata": {
|
|
|
|
"tags": []
|
|
|
|
},
|
2021-06-23 10:12:11 +02:00
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"name": "stdout",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
|
|
|
"Training...\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
"Model: \"sequential_1\"\n",
|
2021-06-23 10:12:11 +02:00
|
|
|
"_________________________________________________________________\n",
|
|
|
|
"Layer (type) Output Shape Param # \n",
|
|
|
|
"=================================================================\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
"flatten_1 (Flatten) (None, 1050) 0 \n",
|
2021-07-05 15:01:40 +02:00
|
|
|
"_________________________________________________________________\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
"batch_normalization_1 (Batch (None, 1050) 4200 \n",
|
2021-07-05 15:01:40 +02:00
|
|
|
"_________________________________________________________________\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
"dropout_5 (Dropout) (None, 1050) 0 \n",
|
2021-07-05 15:01:40 +02:00
|
|
|
"_________________________________________________________________\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
"dense_5 (Dense) (None, 1050) 1103550 \n",
|
2021-07-05 15:01:40 +02:00
|
|
|
"_________________________________________________________________\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
"dropout_6 (Dropout) (None, 1050) 0 \n",
|
2021-07-05 15:01:40 +02:00
|
|
|
"_________________________________________________________________\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
"dense_6 (Dense) (None, 525) 551775 \n",
|
2021-06-23 10:12:11 +02:00
|
|
|
"_________________________________________________________________\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
"dropout_7 (Dropout) (None, 525) 0 \n",
|
2021-06-23 10:12:11 +02:00
|
|
|
"_________________________________________________________________\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
"dense_7 (Dense) (None, 350) 184100 \n",
|
2021-06-23 10:12:11 +02:00
|
|
|
"_________________________________________________________________\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
"dropout_8 (Dropout) (None, 350) 0 \n",
|
2021-06-23 10:12:11 +02:00
|
|
|
"_________________________________________________________________\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
"dense_8 (Dense) (None, 262) 91962 \n",
|
2021-06-23 10:12:11 +02:00
|
|
|
"_________________________________________________________________\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
"dropout_9 (Dropout) (None, 262) 0 \n",
|
2021-06-23 10:12:11 +02:00
|
|
|
"_________________________________________________________________\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
"dense_9 (Dense) (None, 52) 13676 \n",
|
2021-06-23 10:12:11 +02:00
|
|
|
"=================================================================\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
"Total params: 1,949,263\n",
|
|
|
|
"Trainable params: 1,947,163\n",
|
|
|
|
"Non-trainable params: 2,100\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
"_________________________________________________________________\n",
|
|
|
|
"Epoch 1/30\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
"62/62 [==============================] - 1s 6ms/step - loss: 3.3481 - acc: 0.1160 - val_loss: 3.5396 - val_acc: 0.0687\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
"Epoch 2/30\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
"62/62 [==============================] - 0s 4ms/step - loss: 2.4941 - acc: 0.2746 - val_loss: 3.1564 - val_acc: 0.1263\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
"Epoch 3/30\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
"62/62 [==============================] - 0s 5ms/step - loss: 1.9611 - acc: 0.3980 - val_loss: 3.0374 - val_acc: 0.1533\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
"Epoch 4/30\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
"62/62 [==============================] - 0s 4ms/step - loss: 1.6416 - acc: 0.4826 - val_loss: 2.7437 - val_acc: 0.2085\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
"Epoch 5/30\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
"62/62 [==============================] - 0s 4ms/step - loss: 1.4033 - acc: 0.5439 - val_loss: 2.4287 - val_acc: 0.2632\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
"Epoch 6/30\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
"62/62 [==============================] - 0s 4ms/step - loss: 1.2683 - acc: 0.5797 - val_loss: 2.1105 - val_acc: 0.3564\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
"Epoch 7/30\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
"62/62 [==============================] - 0s 5ms/step - loss: 1.1270 - acc: 0.6207 - val_loss: 1.8558 - val_acc: 0.4155\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
"Epoch 8/30\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
"62/62 [==============================] - 0s 5ms/step - loss: 1.0280 - acc: 0.6520 - val_loss: 1.6051 - val_acc: 0.4776\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
"Epoch 9/30\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
"62/62 [==============================] - 0s 5ms/step - loss: 0.9315 - acc: 0.6812 - val_loss: 1.3901 - val_acc: 0.5489\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
"Epoch 10/30\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
"62/62 [==============================] - 0s 4ms/step - loss: 0.8726 - acc: 0.6988 - val_loss: 1.2578 - val_acc: 0.5939\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
"Epoch 11/30\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
"62/62 [==============================] - 0s 4ms/step - loss: 0.7879 - acc: 0.7230 - val_loss: 1.1692 - val_acc: 0.6191\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
"Epoch 12/30\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
"62/62 [==============================] - 0s 4ms/step - loss: 0.7392 - acc: 0.7379 - val_loss: 1.1623 - val_acc: 0.6283\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
"Epoch 13/30\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
"62/62 [==============================] - 0s 4ms/step - loss: 0.6912 - acc: 0.7543 - val_loss: 1.1486 - val_acc: 0.6359\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
"Epoch 14/30\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
"62/62 [==============================] - 0s 4ms/step - loss: 0.6471 - acc: 0.7709 - val_loss: 1.1279 - val_acc: 0.6586\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
"Epoch 15/30\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
"62/62 [==============================] - 0s 5ms/step - loss: 0.5918 - acc: 0.7853 - val_loss: 1.1477 - val_acc: 0.6469\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
"Epoch 16/30\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
"62/62 [==============================] - 0s 4ms/step - loss: 0.5488 - acc: 0.8007 - val_loss: 1.2157 - val_acc: 0.6477\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
"Epoch 17/30\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
"62/62 [==============================] - 0s 4ms/step - loss: 0.5421 - acc: 0.8056 - val_loss: 1.1407 - val_acc: 0.6647\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
"Epoch 18/30\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
"62/62 [==============================] - 0s 4ms/step - loss: 0.5035 - acc: 0.8180 - val_loss: 1.1731 - val_acc: 0.6617\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
"Epoch 19/30\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
"62/62 [==============================] - 0s 4ms/step - loss: 0.4780 - acc: 0.8278 - val_loss: 1.2031 - val_acc: 0.6550\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
"Epoch 20/30\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
"62/62 [==============================] - 0s 4ms/step - loss: 0.4620 - acc: 0.8346 - val_loss: 1.1839 - val_acc: 0.6642\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
"Epoch 21/30\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
"62/62 [==============================] - 0s 4ms/step - loss: 0.4153 - acc: 0.8489 - val_loss: 1.2167 - val_acc: 0.6606\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
"Epoch 22/30\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
"62/62 [==============================] - 0s 4ms/step - loss: 0.4120 - acc: 0.8494 - val_loss: 1.1883 - val_acc: 0.6678\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
"Epoch 23/30\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
"62/62 [==============================] - 0s 4ms/step - loss: 0.3817 - acc: 0.8624 - val_loss: 1.2221 - val_acc: 0.6673\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
"Epoch 24/30\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
"62/62 [==============================] - 0s 4ms/step - loss: 0.3635 - acc: 0.8696 - val_loss: 1.2405 - val_acc: 0.6843\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
"Epoch 25/30\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
"62/62 [==============================] - 0s 4ms/step - loss: 0.3626 - acc: 0.8721 - val_loss: 1.2756 - val_acc: 0.6634\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
"Epoch 26/30\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
"62/62 [==============================] - 0s 4ms/step - loss: 0.3432 - acc: 0.8789 - val_loss: 1.2590 - val_acc: 0.6708\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
"Epoch 27/30\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
"62/62 [==============================] - 0s 5ms/step - loss: 0.3165 - acc: 0.8909 - val_loss: 1.3211 - val_acc: 0.6662\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
"Epoch 28/30\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
"62/62 [==============================] - 0s 4ms/step - loss: 0.2937 - acc: 0.8960 - val_loss: 1.3015 - val_acc: 0.6746\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
"Epoch 29/30\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
"62/62 [==============================] - 0s 4ms/step - loss: 0.3091 - acc: 0.8910 - val_loss: 1.3578 - val_acc: 0.6637\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
"Epoch 30/30\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
"62/62 [==============================] - 0s 4ms/step - loss: 0.3003 - acc: 0.8931 - val_loss: 1.3836 - val_acc: 0.6673\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
"Evaluate on test data\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
"test loss, test acc: [1.3836346864700317, 0.6675152778625488]\n"
|
2021-06-23 10:12:11 +02:00
|
|
|
]
|
2021-07-05 15:01:40 +02:00
|
|
|
}
|
|
|
|
],
|
|
|
|
"source": [
|
|
|
|
"print(\"Training...\")\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
"model, history = train(X_train, y_train, X_test, y_test)"
|
2021-07-05 15:01:40 +02:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-07-19 01:21:53 +02:00
|
|
|
"execution_count": 104,
|
|
|
|
"id": "ce37826d",
|
2021-07-05 15:01:40 +02:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [
|
2021-06-23 10:12:11 +02:00
|
|
|
{
|
2021-07-19 01:21:53 +02:00
|
|
|
"name": "stdout",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
|
|
|
"(52,)\n"
|
|
|
|
]
|
|
|
|
},
|
2021-07-05 15:01:40 +02:00
|
|
|
{
|
|
|
|
"data": {
|
|
|
|
"text/plain": [
|
2021-07-19 01:21:53 +02:00
|
|
|
"array(['Q'], dtype='<U1')"
|
2021-07-05 15:01:40 +02:00
|
|
|
]
|
2021-06-23 10:12:11 +02:00
|
|
|
},
|
2021-07-19 01:21:53 +02:00
|
|
|
"execution_count": 104,
|
2021-07-05 15:01:40 +02:00
|
|
|
"metadata": {},
|
|
|
|
"output_type": "execute_result"
|
|
|
|
}
|
|
|
|
],
|
|
|
|
"source": [
|
2021-07-19 01:21:53 +02:00
|
|
|
"def predict(model, entry):\n",
|
|
|
|
" print(model.predict(entry)[0].shape)\n",
|
|
|
|
" prediction = np.argmax(model.predict(entry), axis=-1)\n",
|
|
|
|
" p = [0 for i in range(52)]\n",
|
|
|
|
" p[prediction[0]] = 1\n",
|
|
|
|
" return np.array(p)\n",
|
2021-07-05 15:01:40 +02:00
|
|
|
"\n",
|
2021-07-19 01:21:53 +02:00
|
|
|
"p = predict(model, np.array([x]))\n",
|
|
|
|
"lb.inverse_transform(p)"
|
2021-06-23 10:12:11 +02:00
|
|
|
]
|
2021-07-05 15:01:40 +02:00
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-07-19 01:21:53 +02:00
|
|
|
"execution_count": 103,
|
|
|
|
"id": "ea020844",
|
2021-07-05 15:01:40 +02:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"data": {
|
|
|
|
"text/plain": [
|
2021-07-19 01:21:53 +02:00
|
|
|
"'Q'"
|
2021-07-05 15:01:40 +02:00
|
|
|
]
|
|
|
|
},
|
2021-07-19 01:21:53 +02:00
|
|
|
"execution_count": 103,
|
2021-07-05 15:01:40 +02:00
|
|
|
"metadata": {},
|
|
|
|
"output_type": "execute_result"
|
|
|
|
}
|
|
|
|
],
|
|
|
|
"source": [
|
2021-07-19 01:21:53 +02:00
|
|
|
"lb.inverse_transform(y_test)[0]"
|
2021-07-05 15:01:40 +02:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": null,
|
2021-07-19 01:21:53 +02:00
|
|
|
"id": "be9f7690",
|
2021-07-05 15:01:40 +02:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": []
|
2021-06-23 10:12:11 +02:00
|
|
|
}
|
|
|
|
],
|
|
|
|
"metadata": {
|
|
|
|
"kernelspec": {
|
|
|
|
"display_name": "Python 3",
|
|
|
|
"language": "python",
|
|
|
|
"name": "python3"
|
|
|
|
},
|
|
|
|
"language_info": {
|
|
|
|
"codemirror_mode": {
|
|
|
|
"name": "ipython",
|
|
|
|
"version": 3
|
|
|
|
},
|
|
|
|
"file_extension": ".py",
|
|
|
|
"mimetype": "text/x-python",
|
|
|
|
"name": "python",
|
|
|
|
"nbconvert_exporter": "python",
|
|
|
|
"pygments_lexer": "ipython3",
|
2021-07-14 10:15:52 +02:00
|
|
|
"version": "3.8.10"
|
2021-06-23 10:12:11 +02:00
|
|
|
}
|
|
|
|
},
|
|
|
|
"nbformat": 4,
|
|
|
|
"nbformat_minor": 5
|
|
|
|
}
|