1322 lines
289 KiB
Plaintext
1322 lines
289 KiB
Plaintext
|
{
|
|||
|
"cells": [
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 1,
|
|||
|
"id": "617fbc6e",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"glob_path = '/opt/iui-datarelease2-sose2021/*/split_letters_csv/*'\n",
|
|||
|
"\n",
|
|||
|
"pickle_file = 'data.pickle'"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 2,
|
|||
|
"id": "4ad43419",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"from glob import glob\n",
|
|||
|
"import pandas as pd\n",
|
|||
|
"from tqdm import tqdm\n",
|
|||
|
"\n",
|
|||
|
"def dl_from_blob(filename) -> list:\n",
|
|||
|
" all_data = []\n",
|
|||
|
" \n",
|
|||
|
" for path in tqdm(glob(filename)):\n",
|
|||
|
" path = path\n",
|
|||
|
" df = pd.read_csv(path, ';')\n",
|
|||
|
" u = path.split('/')[3]\n",
|
|||
|
" l = ''.join(filter(lambda x: x.isalpha(), path.split('/')[5]))[0] \n",
|
|||
|
" d = {\n",
|
|||
|
" 'file': path,\n",
|
|||
|
" 'data': df,\n",
|
|||
|
" 'user': u,\n",
|
|||
|
" 'label': l\n",
|
|||
|
" }\n",
|
|||
|
" all_data.append(d)\n",
|
|||
|
" return all_data"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 3,
|
|||
|
"id": "a30c00f3",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"def save_pickle(f, structure):\n",
|
|||
|
" _p = open(f, 'wb')\n",
|
|||
|
" pickle.dump(structure, _p)\n",
|
|||
|
" _p.close()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 4,
|
|||
|
"id": "7b942135",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"import pickle\n",
|
|||
|
"\n",
|
|||
|
"def load_pickles(f) -> list:\n",
|
|||
|
" _p = open(pickle_file, 'rb')\n",
|
|||
|
" _d = pickle.load(_p)\n",
|
|||
|
" _p.close()\n",
|
|||
|
" \n",
|
|||
|
" return _d"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 5,
|
|||
|
"id": "365d7f88",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Loading data...\n",
|
|||
|
"data.pickle found...\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import os\n",
|
|||
|
"def load_data() -> list:\n",
|
|||
|
" if os.path.isfile(pickle_file):\n",
|
|||
|
" print(f'{pickle_file} found...')\n",
|
|||
|
" return load_pickles(pickle_file)\n",
|
|||
|
" print(f'Didn\\'t find {pickle_file}...')\n",
|
|||
|
" all_data = dl_from_blob(glob_path)\n",
|
|||
|
" print(f'Creating {pickle_file}...')\n",
|
|||
|
" save_pickle(pickle_file, all_data)\n",
|
|||
|
" return all_data\n",
|
|||
|
"\n",
|
|||
|
"print(\"Loading data...\")\n",
|
|||
|
"data = load_data()\n",
|
|||
|
"# plot_pd(data[0]['data'], False)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 6,
|
|||
|
"id": "986ef149",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"import matplotlib.pyplot as plt\n",
|
|||
|
"\n",
|
|||
|
"def plot_pd(data, force=True):\n",
|
|||
|
" fig, axs = plt.subplots(5, 3, figsize=(3*3, 3*5))\n",
|
|||
|
" axs[0][0].plot(data['Acc1 X'])\n",
|
|||
|
" axs[0][1].plot(data['Acc1 Y'])\n",
|
|||
|
" axs[0][2].plot(data['Acc1 Z'])\n",
|
|||
|
" axs[1][0].plot(data['Acc2 X'])\n",
|
|||
|
" axs[1][1].plot(data['Acc2 Y'])\n",
|
|||
|
" axs[1][2].plot(data['Acc2 Z'])\n",
|
|||
|
" axs[2][0].plot(data['Gyro X'])\n",
|
|||
|
" axs[2][1].plot(data['Gyro Y'])\n",
|
|||
|
" axs[2][2].plot(data['Gyro Z'])\n",
|
|||
|
" axs[3][0].plot(data['Mag X'])\n",
|
|||
|
" axs[3][1].plot(data['Mag Y'])\n",
|
|||
|
" axs[3][2].plot(data['Mag Z'])\n",
|
|||
|
" axs[4][0].plot(data['Time'])\n",
|
|||
|
"\n",
|
|||
|
" if force:\n",
|
|||
|
" for a in axs:\n",
|
|||
|
" for b in a:\n",
|
|||
|
" b.plot(data['Force'])\n",
|
|||
|
" else:\n",
|
|||
|
" axs[4][1].plot(data['Force'])\n",
|
|||
|
"\n",
|
|||
|
"def plot_np(data, force=True):\n",
|
|||
|
" fig, axs = plt.subplots(5, 3, figsize=(3*3, 3*5))\n",
|
|||
|
" axs[0][0].plot(data[0])\n",
|
|||
|
" axs[0][1].plot(data[1])\n",
|
|||
|
" axs[0][2].plot(data[2])\n",
|
|||
|
" axs[1][0].plot(data[3])\n",
|
|||
|
" axs[1][1].plot(data[4])\n",
|
|||
|
" axs[1][2].plot(data[5])\n",
|
|||
|
" axs[2][0].plot(data[6])\n",
|
|||
|
" axs[2][1].plot(data[7])\n",
|
|||
|
" axs[2][2].plot(data[8])\n",
|
|||
|
" axs[3][0].plot(data[9])\n",
|
|||
|
" axs[3][1].plot(data[10])\n",
|
|||
|
" axs[3][2].plot(data[11])\n",
|
|||
|
" axs[4][0].plot(data[13])\n",
|
|||
|
"\n",
|
|||
|
" if force:\n",
|
|||
|
" for a in axs:\n",
|
|||
|
" for b in a:\n",
|
|||
|
" b.plot(data[12])\n",
|
|||
|
" else:\n",
|
|||
|
" axs[4][1].plot(data[12])\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 7,
|
|||
|
"id": "56812df6",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"def mill_drop(entry):\n",
|
|||
|
" #drop millis on single\n",
|
|||
|
" data_wo_mill = entry['data'].drop(labels='Millis', axis=1, inplace=False)\n",
|
|||
|
" drop_entry = entry\n",
|
|||
|
" drop_entry['data'] = data_wo_mill.reset_index(drop=True)\n",
|
|||
|
" \n",
|
|||
|
" return drop_entry"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 8,
|
|||
|
"id": "9a3298c0",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"import numpy as np\n",
|
|||
|
"\n",
|
|||
|
"def cut_force(drop_entry):\n",
|
|||
|
" # force trans\n",
|
|||
|
" shorten_entry = drop_entry\n",
|
|||
|
" shorten_data = shorten_entry['data']\n",
|
|||
|
" sf_entry = shorten_data['Force']\n",
|
|||
|
" leeway = 10\n",
|
|||
|
" \n",
|
|||
|
" try:\n",
|
|||
|
" thresh = 70\n",
|
|||
|
" temps_over_T = np.where(sf_entry > thresh)[0]\n",
|
|||
|
" shorten_data = shorten_data[max(temps_over_T.min()-leeway,0):min(len(sf_entry)-1,temps_over_T.max()+leeway)]\n",
|
|||
|
" except:\n",
|
|||
|
" thresold = 0.05\n",
|
|||
|
" thresh = sf_entry.max()*thresold\n",
|
|||
|
" temps_over_T = np.where(sf_entry > thresh)[0]\n",
|
|||
|
" shorten_data = shorten_data[max(temps_over_T.min()-leeway,0):min(len(sf_entry)-1,temps_over_T.max()+leeway)]\n",
|
|||
|
" \n",
|
|||
|
" shorten_entry['data'] = shorten_data.reset_index(drop=True)\n",
|
|||
|
" return shorten_entry"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 9,
|
|||
|
"id": "6ad1c481",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"def norm_force(shorten_entry, flist):\n",
|
|||
|
" fnorm_entry = shorten_entry\n",
|
|||
|
" u = fnorm_entry['user']\n",
|
|||
|
" d = fnorm_entry['data']\n",
|
|||
|
" \n",
|
|||
|
" \n",
|
|||
|
" d['Force'] = ((d['Force'] - flist[u].mean())/flist[u].std())\n",
|
|||
|
" \n",
|
|||
|
" fnorm_entry['data'] = fnorm_entry['data'].reset_index(drop=True)\n",
|
|||
|
" return fnorm_entry"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 10,
|
|||
|
"id": "8a53c07c",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"def time_trans(fnorm_entry):\n",
|
|||
|
" #timetrans\n",
|
|||
|
" time_entry = fnorm_entry\n",
|
|||
|
" \n",
|
|||
|
" time_entry['data']['Time'] = fnorm_entry['data']['Time']-fnorm_entry['data']['Time'][0]\n",
|
|||
|
" \n",
|
|||
|
" time_entry['data'] = time_entry['data'].reset_index(drop=True)\n",
|
|||
|
"\n",
|
|||
|
" return time_entry"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 11,
|
|||
|
"id": "571664a9",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"def norm(time_entry):\n",
|
|||
|
" # normalize\n",
|
|||
|
" norm_entry = time_entry\n",
|
|||
|
" \n",
|
|||
|
" norm_entry['data']['Acc1 X'] = norm_entry['data']['Acc1 X'] / 32768\n",
|
|||
|
" norm_entry['data']['Acc1 Y'] = norm_entry['data']['Acc1 Y'] / 32768\n",
|
|||
|
" norm_entry['data']['Acc1 Z'] = norm_entry['data']['Acc1 Z'] / 32768\n",
|
|||
|
" norm_entry['data']['Acc2 X'] = norm_entry['data']['Acc2 X'] / 8192\n",
|
|||
|
" norm_entry['data']['Acc2 Y'] = norm_entry['data']['Acc2 Y'] / 8192\n",
|
|||
|
" norm_entry['data']['Acc2 Z'] = norm_entry['data']['Acc2 Z'] / 8192\n",
|
|||
|
" norm_entry['data']['Gyro X'] = norm_entry['data']['Gyro X'] / 32768\n",
|
|||
|
" norm_entry['data']['Gyro Y'] = norm_entry['data']['Gyro Y'] / 32768\n",
|
|||
|
" norm_entry['data']['Gyro Z'] = norm_entry['data']['Gyro Z'] / 32768\n",
|
|||
|
" norm_entry['data']['Mag X'] = norm_entry['data']['Mag X'] / 8192\n",
|
|||
|
" norm_entry['data']['Mag Y'] = norm_entry['data']['Mag Y'] / 8192\n",
|
|||
|
" norm_entry['data']['Mag Z'] = norm_entry['data']['Mag Z'] / 8192\n",
|
|||
|
"# norm_entry['data']['Mag Z'] = norm_entry['data']['Mag Z'] / 4096\n",
|
|||
|
" \n",
|
|||
|
" norm_entry['data'] = norm_entry['data'].reset_index(drop=True)\n",
|
|||
|
" \n",
|
|||
|
" return norm_entry"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 12,
|
|||
|
"id": "b70412b3",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"def preproc(d):\n",
|
|||
|
" flist = {} \n",
|
|||
|
" d_res = []\n",
|
|||
|
" for e in data:\n",
|
|||
|
" if e['user'] not in flist:\n",
|
|||
|
" flist[e['user']] = e['data']['Force']\n",
|
|||
|
" else:\n",
|
|||
|
" flist[e['user']] = flist[e['user']].append(e['data']['Force'])\n",
|
|||
|
" \n",
|
|||
|
" for e in tqdm(data):\n",
|
|||
|
" d_res.append(preproc_entry(e, flist))\n",
|
|||
|
" return d_res\n",
|
|||
|
" \n",
|
|||
|
"def preproc_entry(entry, flist):\n",
|
|||
|
" drop_entry = mill_drop(entry)\n",
|
|||
|
"# plot_pd(drop_entry['data'])\n",
|
|||
|
"# \n",
|
|||
|
" shorten_entry = cut_force(drop_entry)\n",
|
|||
|
"# plot_pd(shorten_entry['data'])\n",
|
|||
|
"# \n",
|
|||
|
" fnorm_entry = norm_force(shorten_entry, flist)\n",
|
|||
|
"# plot_pd(fnorm_entry['data'])\n",
|
|||
|
"# \n",
|
|||
|
" time_entry = time_trans(shorten_entry)\n",
|
|||
|
"# plot_pd(time_entry['data'])\n",
|
|||
|
"# \n",
|
|||
|
" norm_entry = norm(time_entry)\n",
|
|||
|
"# plot_pd(norm_entry['data'], False)\n",
|
|||
|
" return norm_entry\n",
|
|||
|
"\n",
|
|||
|
"print(\"Preprocessing...\")\n",
|
|||
|
"pdata = preproc(data)\n",
|
|||
|
"# plot_pd(pdata[0]['data'], False)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 13,
|
|||
|
"id": "32b81b5b",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"def throw(pdata):\n",
|
|||
|
" llist = pd.Series([len(x['data']) for x in pdata])\n",
|
|||
|
" threshold = int(llist.quantile(threshold_p))\n",
|
|||
|
" longdex = np.where(llist <= threshold)[0]\n",
|
|||
|
" return np.array(pdata)[longdex]\n",
|
|||
|
"\n",
|
|||
|
"llist = pd.Series([len(x['data']) for x in pdata])\n",
|
|||
|
"threshold_p = 0.75\n",
|
|||
|
"threshold = int(llist.quantile(threshold_p))\n",
|
|||
|
"\n",
|
|||
|
"print(\"Truncating...\")\n",
|
|||
|
"tpdata = throw(pdata)\n",
|
|||
|
"# plot_pd(tpdata[0]['data'], False)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 14,
|
|||
|
"id": "dda26384",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"from tensorflow.keras.preprocessing.sequence import pad_sequences\n",
|
|||
|
"# ltpdata = []\n",
|
|||
|
"def elong(tpdata):\n",
|
|||
|
" for x in tqdm(tpdata):\n",
|
|||
|
" y = x['data'].to_numpy().T\n",
|
|||
|
" x['data'] = pad_sequences(y, dtype=float, padding='post', maxlen=threshold)\n",
|
|||
|
" return tpdata\n",
|
|||
|
"\n",
|
|||
|
"print(\"Padding...\")\n",
|
|||
|
"ltpdata = elong(tpdata)\n",
|
|||
|
"# plot_np(ltpdata[0]['data'], False)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 15,
|
|||
|
"id": "6d717d0d",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"import tensorflow as tf\n",
|
|||
|
"from tensorflow.keras.models import Sequential\n",
|
|||
|
"from tensorflow.keras.layers import Dense, Flatten, BatchNormalization, Dropout, Conv2D, MaxPooling2D\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"def build_model():\n",
|
|||
|
" model = Sequential()\n",
|
|||
|
" ncount = train_shape[0]*train_shape[1]\n",
|
|||
|
" \n",
|
|||
|
"# model.add(Conv2D(64, (5, 5), input_shape=train_shape, activation='relu', padding='same'))\n",
|
|||
|
"# model.add(MaxPooling2D(pool_size=(2, 2)))\n",
|
|||
|
" \n",
|
|||
|
"# model.add(Conv2D(64, (5, 5), activation='relu', padding='same'))\n",
|
|||
|
"# model.add(MaxPooling2D(pool_size=(2, 2)))\n",
|
|||
|
" \n",
|
|||
|
" model.add(Flatten(input_shape=train_shape))\n",
|
|||
|
" \n",
|
|||
|
" model.add(BatchNormalization())\n",
|
|||
|
" \n",
|
|||
|
" model.add(Dropout(0.1))\n",
|
|||
|
" \n",
|
|||
|
" for i in range(2,5):\n",
|
|||
|
" model.add(Dense(int(ncount/i), activation='relu'))\n",
|
|||
|
" model.add(Dropout(0.1))\n",
|
|||
|
" \n",
|
|||
|
" model.add(Dense(classes, activation='softmax'))\n",
|
|||
|
"\n",
|
|||
|
" model.compile(\n",
|
|||
|
" optimizer=tf.keras.optimizers.Adam(0.001),\n",
|
|||
|
" loss=\"categorical_crossentropy\", \n",
|
|||
|
" metrics=[\"acc\"],\n",
|
|||
|
" )\n",
|
|||
|
"\n",
|
|||
|
" return model"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 16,
|
|||
|
"id": "815c345d",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"checkpoint_file = './goat.weights'\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"def train(X_train, y_train, X_test, y_test):\n",
|
|||
|
" model = build_model()\n",
|
|||
|
" \n",
|
|||
|
" model.summary()\n",
|
|||
|
" \n",
|
|||
|
" model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(\n",
|
|||
|
" filepath = checkpoint_file,\n",
|
|||
|
" save_weights_only=True,\n",
|
|||
|
" monitor='val_acc',\n",
|
|||
|
" mode='max',\n",
|
|||
|
" save_best_only=True\n",
|
|||
|
" )\n",
|
|||
|
" \n",
|
|||
|
" model.fit(X_train, y_train, \n",
|
|||
|
" epochs=30,\n",
|
|||
|
" batch_size=256,\n",
|
|||
|
" shuffle=True,\n",
|
|||
|
" validation_data=(X_test, y_test),\n",
|
|||
|
" verbose=1,\n",
|
|||
|
" callbacks=[model_checkpoint_callback]\n",
|
|||
|
" )\n",
|
|||
|
" \n",
|
|||
|
" print(\"Evaluate on test data\")\n",
|
|||
|
" results = model.evaluate(X_test, y_test, batch_size=128, verbose=0)\n",
|
|||
|
" print(\"test loss, test acc:\", results)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 17,
|
|||
|
"id": "8b1d12b1",
|
|||
|
"metadata": {
|
|||
|
"tags": []
|
|||
|
},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' # this is required\n",
|
|||
|
"os.environ['CUDA_VISIBLE_DEVICES'] = '0' # set to '0' for GPU0, '1' for GPU1 or '2' for GPU2. Check \"gpustat\" in a terminal."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 23,
|
|||
|
"id": "8ba0b7cb",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"from sklearn.model_selection import train_test_split\n",
|
|||
|
"from sklearn.preprocessing import LabelEncoder, LabelBinarizer\n",
|
|||
|
"\n",
|
|||
|
"X = np.array([x['data'] for x in ltpdata])\n",
|
|||
|
"y = np.array([x['label'] for x in ltpdata])\n",
|
|||
|
"\n",
|
|||
|
"lb = LabelBinarizer()\n",
|
|||
|
"y_tran = lb.fit_transform(y)\n",
|
|||
|
"\n",
|
|||
|
"X_train, X_test, y_train, y_test = train_test_split(X, y_tran, test_size=0.2, random_state=177013)\n",
|
|||
|
"\n",
|
|||
|
"X_train=X_train.reshape(X_train.shape[0],X_train.shape[1],X_train.shape[2],1)\n",
|
|||
|
"X_test=X_test.reshape(X_test.shape[0],X_test.shape[1],X_test.shape[2],1)\n",
|
|||
|
"\n",
|
|||
|
"train_shape = X_train[0].shape\n",
|
|||
|
"classes = y_train[0].shape[0]"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 24,
|
|||
|
"id": "e0d634da",
|
|||
|
"metadata": {
|
|||
|
"tags": []
|
|||
|
},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Training...\n",
|
|||
|
"Model: \"sequential\"\n",
|
|||
|
"_________________________________________________________________\n",
|
|||
|
"Layer (type) Output Shape Param # \n",
|
|||
|
"=================================================================\n",
|
|||
|
"conv2d (Conv2D) (None, 14, 75, 64) 1664 \n",
|
|||
|
"_________________________________________________________________\n",
|
|||
|
"max_pooling2d (MaxPooling2D) (None, 7, 37, 64) 0 \n",
|
|||
|
"_________________________________________________________________\n",
|
|||
|
"conv2d_1 (Conv2D) (None, 7, 37, 64) 102464 \n",
|
|||
|
"_________________________________________________________________\n",
|
|||
|
"max_pooling2d_1 (MaxPooling2 (None, 3, 18, 64) 0 \n",
|
|||
|
"_________________________________________________________________\n",
|
|||
|
"flatten (Flatten) (None, 3456) 0 \n",
|
|||
|
"_________________________________________________________________\n",
|
|||
|
"batch_normalization (BatchNo (None, 3456) 13824 \n",
|
|||
|
"_________________________________________________________________\n",
|
|||
|
"dropout (Dropout) (None, 3456) 0 \n",
|
|||
|
"_________________________________________________________________\n",
|
|||
|
"dense (Dense) (None, 525) 1814925 \n",
|
|||
|
"_________________________________________________________________\n",
|
|||
|
"dropout_1 (Dropout) (None, 525) 0 \n",
|
|||
|
"_________________________________________________________________\n",
|
|||
|
"dense_1 (Dense) (None, 350) 184100 \n",
|
|||
|
"_________________________________________________________________\n",
|
|||
|
"dropout_2 (Dropout) (None, 350) 0 \n",
|
|||
|
"_________________________________________________________________\n",
|
|||
|
"dense_2 (Dense) (None, 262) 91962 \n",
|
|||
|
"_________________________________________________________________\n",
|
|||
|
"dropout_3 (Dropout) (None, 262) 0 \n",
|
|||
|
"_________________________________________________________________\n",
|
|||
|
"dense_3 (Dense) (None, 52) 13676 \n",
|
|||
|
"=================================================================\n",
|
|||
|
"Total params: 2,222,615\n",
|
|||
|
"Trainable params: 2,215,703\n",
|
|||
|
"Non-trainable params: 6,912\n",
|
|||
|
"_________________________________________________________________\n",
|
|||
|
"Epoch 1/30\n",
|
|||
|
"62/62 [==============================] - 4s 14ms/step - loss: 3.4102 - acc: 0.1025 - val_loss: 3.8896 - val_acc: 0.0438\n",
|
|||
|
"Epoch 2/30\n",
|
|||
|
"62/62 [==============================] - 1s 10ms/step - loss: 2.7475 - acc: 0.2131 - val_loss: 3.7641 - val_acc: 0.0461\n",
|
|||
|
"Epoch 3/30\n",
|
|||
|
"62/62 [==============================] - 1s 9ms/step - loss: 2.2471 - acc: 0.3301 - val_loss: 3.5046 - val_acc: 0.1347\n",
|
|||
|
"Epoch 4/30\n",
|
|||
|
"62/62 [==============================] - 1s 9ms/step - loss: 1.8689 - acc: 0.4238 - val_loss: 3.2023 - val_acc: 0.2352\n",
|
|||
|
"Epoch 5/30\n",
|
|||
|
"62/62 [==============================] - 1s 9ms/step - loss: 1.6068 - acc: 0.4923 - val_loss: 3.8644 - val_acc: 0.0558\n",
|
|||
|
"Epoch 6/30\n",
|
|||
|
"62/62 [==============================] - 1s 9ms/step - loss: 1.3984 - acc: 0.5521 - val_loss: 2.1733 - val_acc: 0.4010\n",
|
|||
|
"Epoch 7/30\n",
|
|||
|
"62/62 [==============================] - 1s 10ms/step - loss: 1.2403 - acc: 0.5896 - val_loss: 1.9064 - val_acc: 0.4376\n",
|
|||
|
"Epoch 8/30\n",
|
|||
|
"62/62 [==============================] - 1s 9ms/step - loss: 1.1112 - acc: 0.6230 - val_loss: 1.8146 - val_acc: 0.4743\n",
|
|||
|
"Epoch 9/30\n",
|
|||
|
"62/62 [==============================] - 1s 9ms/step - loss: 1.0028 - acc: 0.6547 - val_loss: 2.0000 - val_acc: 0.4236\n",
|
|||
|
"Epoch 10/30\n",
|
|||
|
"62/62 [==============================] - 1s 9ms/step - loss: 0.9017 - acc: 0.6884 - val_loss: 2.1602 - val_acc: 0.5038\n",
|
|||
|
"Epoch 11/30\n",
|
|||
|
"62/62 [==============================] - 1s 9ms/step - loss: 0.8505 - acc: 0.7011 - val_loss: 2.1521 - val_acc: 0.5624\n",
|
|||
|
"Epoch 12/30\n",
|
|||
|
"62/62 [==============================] - 1s 10ms/step - loss: 0.7718 - acc: 0.7303 - val_loss: 2.2699 - val_acc: 0.5736\n",
|
|||
|
"Epoch 13/30\n",
|
|||
|
"62/62 [==============================] - 1s 10ms/step - loss: 0.7100 - acc: 0.7485 - val_loss: 1.8627 - val_acc: 0.5550\n",
|
|||
|
"Epoch 14/30\n",
|
|||
|
"62/62 [==============================] - 1s 9ms/step - loss: 0.6737 - acc: 0.7576 - val_loss: 1.9876 - val_acc: 0.5636\n",
|
|||
|
"Epoch 15/30\n",
|
|||
|
"62/62 [==============================] - 1s 9ms/step - loss: 0.6239 - acc: 0.7723 - val_loss: 2.1203 - val_acc: 0.5540\n",
|
|||
|
"Epoch 16/30\n",
|
|||
|
"62/62 [==============================] - 1s 10ms/step - loss: 0.5808 - acc: 0.7914 - val_loss: 8.2953 - val_acc: 0.3977\n",
|
|||
|
"Epoch 17/30\n",
|
|||
|
"62/62 [==============================] - 1s 10ms/step - loss: 0.5578 - acc: 0.7981 - val_loss: 2.3404 - val_acc: 0.5339\n",
|
|||
|
"Epoch 18/30\n",
|
|||
|
"62/62 [==============================] - 1s 10ms/step - loss: 0.5090 - acc: 0.8137 - val_loss: 1.7944 - val_acc: 0.5937\n",
|
|||
|
"Epoch 19/30\n",
|
|||
|
"62/62 [==============================] - 1s 10ms/step - loss: 0.4675 - acc: 0.8289 - val_loss: 2.0554 - val_acc: 0.5866\n",
|
|||
|
"Epoch 20/30\n",
|
|||
|
"62/62 [==============================] - 1s 10ms/step - loss: 0.4484 - acc: 0.8343 - val_loss: 1.8284 - val_acc: 0.5832\n",
|
|||
|
"Epoch 21/30\n",
|
|||
|
"62/62 [==============================] - 1s 9ms/step - loss: 0.4210 - acc: 0.8490 - val_loss: 2.1521 - val_acc: 0.6219\n",
|
|||
|
"Epoch 22/30\n",
|
|||
|
"62/62 [==============================] - 1s 9ms/step - loss: 0.4036 - acc: 0.8523 - val_loss: 1.9749 - val_acc: 0.6477\n",
|
|||
|
"Epoch 23/30\n",
|
|||
|
"62/62 [==============================] - 1s 10ms/step - loss: 0.3790 - acc: 0.8581 - val_loss: 2.8492 - val_acc: 0.6143\n",
|
|||
|
"Epoch 24/30\n",
|
|||
|
"62/62 [==============================] - 1s 10ms/step - loss: 0.3494 - acc: 0.8700 - val_loss: 10.8548 - val_acc: 0.5950\n",
|
|||
|
"Epoch 25/30\n",
|
|||
|
"62/62 [==============================] - 1s 9ms/step - loss: 0.3566 - acc: 0.8714 - val_loss: 1.9813 - val_acc: 0.6278\n",
|
|||
|
"Epoch 26/30\n",
|
|||
|
"62/62 [==============================] - 1s 9ms/step - loss: 0.3400 - acc: 0.8782 - val_loss: 4.9607 - val_acc: 0.3447\n",
|
|||
|
"Epoch 27/30\n",
|
|||
|
"62/62 [==============================] - 1s 9ms/step - loss: 0.3165 - acc: 0.8824 - val_loss: 2.1550 - val_acc: 0.6049\n",
|
|||
|
"Epoch 28/30\n",
|
|||
|
"62/62 [==============================] - 1s 9ms/step - loss: 0.3147 - acc: 0.8845 - val_loss: 3.1088 - val_acc: 0.4463\n",
|
|||
|
"Epoch 29/30\n",
|
|||
|
"62/62 [==============================] - 1s 9ms/step - loss: 0.2944 - acc: 0.8921 - val_loss: 3.7178 - val_acc: 0.5479\n",
|
|||
|
"Epoch 30/30\n",
|
|||
|
"62/62 [==============================] - 1s 9ms/step - loss: 0.2819 - acc: 0.8980 - val_loss: 3.5398 - val_acc: 0.5876\n",
|
|||
|
"Evaluate on test data\n",
|
|||
|
"test loss, test acc: [3.5338056087493896, 0.5878309607505798]\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"print(\"Training...\")\n",
|
|||
|
"train(X_train, y_train, X_test, y_test)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 26,
|
|||
|
"id": "9ad292f7",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"((14, 75, 1), 52)"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 26,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"train_shape, classes"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 27,
|
|||
|
"id": "6681106e",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"(14, 75, 1)"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 27,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"X_train[0].shape"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 28,
|
|||
|
"id": "ea5ff2c7",
|
|||
|
"metadata": {
|
|||
|
"tags": []
|
|||
|
},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"d = data[0]['data'].reshape(1,data[0]['data'].shape[0], data[0]['data'].shape[1], X_train[0].shape[2])\n",
|
|||
|
"pd.DataFrame(d.reshape(d.shape[1], d.shape[2]).T)\n",
|
|||
|
"dd = d.reshape(d.shape[1], d.shape[2])\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 29,
|
|||
|
"id": "9c346a61",
|
|||
|
"metadata": {
|
|||
|
"tags": []
|
|||
|
},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjUAAANOCAYAAADph/0uAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOy9d3hkZ3n3/7mna0a9rFbbu+11t+V1w5jgNZhm02xsCJgAr0MSUl7SSMhL8sLLLwYSIAUIpppeHIqJDcY2Nhjjsuvu7evtu2qrLo2mP78/TpkzoxlpJE3RjJ7Pde21U45mHo3OnHOf7/2971uUUmg0Go1Go9FUO65KL0Cj0Wg0Go2mGOigRqPRaDQaTU2ggxqNRqPRaDQ1gQ5qNBqNRqPR1AQ6qNFoNBqNRlMTeCq9gPnQ3t6u1q1bV+llaGqAp5566rRSqqPc76v3YU0xqcR+rPdhTTEp1j5clUHNunXr2LlzZ6WXoakBRORoJd5X78OaYlKJ/Vjvw5piUqx9WKefNBqNRqPR1AQ6qNFoNBqNRlMT6KBGo5kFEblORPaJyEER+VCO59eIyEMi8oyIPC8ir63EOjWafBSwD/tF5Pvm80+IyLoKLFOjWTA6qNFoZkBE3MDngNcAW4FbRGRr1mb/APxAKXUhcDPw+fKuUqPJT4H78HuBYaXUJuAzwCfKu0qNpjhUpVFYMz9u//leROBvrzuz0kupJrYBB5VShwBE5HvADcBuxzYKaDRvNwGnyrrCGuU9X9/BG87v4k0Xrqr0UqqdQvbhG4B/Mm/fBfyniIjSwwGrmmRK8Zp/+w3HhsKVXgoADQEvOz68vaTvoYOaJcSThwdxu6TSy6g2VgLHHfdPAJdmbfNPwC9F5E+BEJDzWysitwG3AaxZs6boC601fnvgNKtb6nRQs3AK2YftbZRSCREZBdqA086N9D5cXUxEE+zvm+DyDW2ct6qp0svB7yl9ckgHNUuIeFKRTOkLrxJwC/B1pdS/isjlwDdF5BylVMq5kVLqDuAOgO7ubv2HmIFUShFLpojEU7NvrCkbeh+uLmIJ4/vz2nOX887L11V2MWVCe2qWEPFkilhSH4fmyElgteP+KvMxJ+8FfgCglHoMCADtZVldjRJLGgfjSCJZ4ZXUBIXsw/Y2IuLBSKMOlmV1mpJhfY+87qVzql86v6mGWDJFPKmvfOfIDmCziKwXER+GEfjurG2OAdcAiMhZGEHNQFlXWWNEzSvMSFwHNUWgkH34buBW8/ZbgV9pP031Eze/R74ypH0WCzr9tISIJ1MI2lMzF0x/wQeA+wA38FWl1C4R+SiwUyl1N/CXwJdE5H9jmIbfrU8IC8OSza3gRjN/CtyHv4KRNj0IDGEEPpoqx1JqdFCjqUniCYWIPtfOFaXUvcC9WY99xHF7N3BluddVy0TNtJNWaopDAftwBLix3OvSlBbr4sC3hNJPOqhZQsSTKUQLNZoqIGann7RSo9HMl6hOP2lqmZj202iqBO2p0WgWTkwHNZpaRpuENdWCdTCOaU+NRjNvrGO+Tj9papK4LufWVAlaqdFoFs5SVGqWzm+6xEmmVMY/jWYxY3tqtFKj0cybpVj9VNLftIDJsB8Ukd3mZOMHRWRtKdezlHGmnnQaSrPY0dVPGs3CWYrVTyX7TQucDPsM0K2UOg9jiNonS7WepY4zkNG9PzSLnZgj/aRb/mg080Onn4qLPRlWKRUDrMmwNkqph5RS1vjQxzHad2tKgNNPo5UazWLHCrxTChI6XarRzIvoEjQKl/I3zTUZduUM278X+Hm+J0XkNhHZKSI7BwZ0B/q5otNPmmrCWfWkU1AazfxYimMSFsVvKiK/D3QDn8q3jVLqDqVUt1Kqu6Ojo3yLqxGcJwldJqtZ7EQdgyx1Az6NZn4sRaNwKUu6C5kMi4hsBz4MXK2UipZwPUsardRoqomoVmo0mgWjjcLFZdbJsCJyIfBF4HqlVH8J17LkcXpqYgntUdAsbpxBjTa2azTzI5ZI4RLw6KBm4SilEoA1GXYP8ANrMqyIXG9u9imgHvihiDwrInfneTnNAnGqM3pcgmaxoz01Gs3CiSVTeJdQQAMl7ihcwGTY7aV8f02amE4/aaqITKVGBzUazXyIJVJLyk8Di8QorCk9ccdJIq7lfM0iJ1Op0furRjMfYskUfh3UaGoRp6cmqpUazSLHqc5opUajmR+xRGpJmYRBBzVLhozqJ63UaBY5WqkpDiLSKiL3i8gB8/+WPNv9QkRGROR/yr1GTenQ6SdNzZLpqdHVT5rFTTSRQsS4rY3CC+JDwINKqc3Ag+b9XHwKeGfZVqUpCzqo0dQsmdVP+iShWdzEEika/EYdg1ZqFsQNwJ3m7TuBN+baSCn1IDBepjVpysRSrH5aWr/tEiYz/aSVGs3iJppI0ljnBbRSs0A6lVI95u1eoLOSi1mqPH5okKODk2V/33hSKzWaGsUZyGijsGaxE0umaDKDGt18b2a2b98OcLaIvJj1L3uAsAIWdEWjZ/DNj7/8wXN84eGXyv6+0SVoFC5pnxrN4iGmjcKaKiIaT9EY0EpNITzwwAOIyC6lVHf2cyLSJyJdSqkeEekCFtS5XSl1B3AHQHd3t5Z8CyQSTzIeSZT9fWOJFA2BpXWaX1oh3BLGWU2im+9pFjuxZIqA14XP4yKiS7oXwt3ArebtW4GfVnAtS5ZESjEZq0xQo/vUaGqSDKOwVmrmhIhcJyL7ROSgiOSsHhGRm0Rkt4jsEpHvlHuNtUY0bngB/B4XUW0UXgi3A9eKyAFgu3kfEekWkS9bG4nII8APgWtE5ISIvLoiq61REskU4Wj5g/OlaBReWrrUEkZP6Z4fIuIGPgdcC5wAdojI3Uqp3Y5tNgN/B1yplBoWkWWVWW3tYHRCdRPwunXzvQWglBoErsnx+E7gfY77V5VzXUuNREoRjpdfqdFGYU3NEjN70/jcLvu2piC2AQeVUoeUUjHgexhlsk7+F/A5pdQwgJ44v3Ci8SQ+j4uA16VLujVVTyKlKqPULEGj8NL6bZcw8aSxc/s8Lp1+mhsrgeOO+yfMx5xsAbaIyKMi8riIXJfrhXTlSOFYM2sCHvecjcJHByf57AP7SaZ08L6UUEpx8x2P8YsXe2bfuIwopUhW0FOjlRpNTRJPpPC6Ba9bdPqp+HiAzcArgFuAL4lIc/ZGSqk7lFLdSqnujo6Oki7oyOlJxiPxkr5HKbE9NV7XnIOae17o4bMPHOC3B0+XaHWaxUg8qXj80BDPnRit9FIySJjBdThWIaVGBzWaWiEcS/CLF3sBQ6nxegylRgc1c+IksNpxf5X5mJMTwN1KqbhS6jCwHyPIqRhv/a/f8bmHyt8Xo1hELU+Nxz3n9NPgRAyA7+84VoqlaRYpVpXcYmsBkEimgxqjVVD5iCZ1+klTQ9zzfA/v/9ZTnBqZIpZUeN0uvG6dfpojO4DNIrJeRHzAzRhlsk5+gqHSICLtGOmoQ2VcYwbRRJLTEzGODZW/g2kxUErZV5jzMQoPTkQBuH93H6fN25raJ2IqIYutWWMiZawnmVJlXZtSShuFNbXFRNTI4Y5HEmlPjduV0YhPMzNKqQTwAeA+YA/wA6XULhH5qIhcb252HzAoIruBh4C/NqtOKsJo2Eg79Y9V5wnd2j/98zQKD07GaK/3EU8qfvx0tqimqVWs/WSxKjUAU2VMQSVSCqVYckqNLumuYayrgnDMCGq8btHpp3mglLoXuDfrsY84bivgg+a/ijNsBjV945GSvL5SivFowu74W2ys/dbvceH3uOfcfG9wIsb5q5oZmYrz30+f4H+9fEMplqlZZFj7yeJTatJBzWQsQUvIV5b3tRR5rdRoagbrimUqljSDmnT6KZFM8dGf7aZndKrCq9QUm+Gw4SnpG4uWJIf/6/0DdH/sAU6OlGbfiTmDGu/cm+8NTcZoq/dx8doWDp+eLLuPQVMZLBUkutiUmlR6/y2nWVgHNZqaI63UJIklDE+NodQojg6F+eqjh3l4ny4trjVGzKAmlkgxNlX8MtJ9vePEkimeOjpc9NeG9H47H0+NUorBySitIT/LGwNEEyl
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 648x1080 with 15 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {
|
|||
|
"needs_background": "light"
|
|||
|
},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"def pplot(dd):\n",
|
|||
|
" x = dd.shape[0]\n",
|
|||
|
" fix = int(x/3)+1\n",
|
|||
|
" fiy = 3\n",
|
|||
|
" fig, axs = plt.subplots(fix, fiy, figsize=(3*fiy, 3*fix))\n",
|
|||
|
" \n",
|
|||
|
" for i in range(x):\n",
|
|||
|
" axs[int(i/3)][i%3].plot(dd[i])\n",
|
|||
|
" \n",
|
|||
|
"pplot(dd)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 30,
|
|||
|
"id": "c6ba9fb3",
|
|||
|
"metadata": {
|
|||
|
"tags": []
|
|||
|
},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAiYAAANOCAYAAAAlHsfFAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOy9eZwcd3nn/3m6p3vue0aaQ3NItnxIsiXZsmxuQwDLQGwWCDFsuELihB9s2ITNBpL8gIXNb8lBdjcLC3GCYyDhPowCxsaAD4wvSZZkW7Iky7rm1Mx09RzdPd3Vx/P7o6p6WjN9VHdVd1f1PO/Xa17qo7rqO6Nvffv5PsfnIWaGIAiCIAiCE/BUewCCIAiCIAgGYpgIgiAIguAYxDARBEEQBMExiGEiCIIgCIJjEMNEEARBEATHUFetC/f09PDo6Gi1Li/UGIcOHZpj5t5KXlPmsGAn1ZjDgMxjwT7smsNVM0xGR0dx8ODBal1eqDGI6HylrylzWLCTasxhQOaxYB92zWEJ5QiCIAiC4BjEMBEEQRAEwTGIYSIIgiAIgmMQw0RYFxDRPiI6SUSniejjeY57OxExEe2p5PgEQRCcDjOjEm1sxDDJ4GPfOYp/evRMtYch2AwReQF8EcCtALYBeBcRbctyXCuAjwJ4qrIjdDa/e88BfP/QeLWHIQhClYmoSWz9i5/iq4+fK+t1xDDJ4OGTM/j1S3PVHoZgP3sBnGbmM8ysAvgWgNuzHPdZAH8NIFrJwTmZRDKFX56YwZNnAtUeiiAIVUYJq0ikGI1+b1mvI4aJTirFCEZUBMNqtYci2M8ggLGM5+P6a2mI6DoAQ8z8k3wnIqI7ieggER2cnZ21f6QOY345DgAIRuS+EIT1jqJ/P3Y3+8t6HTFMdBaW40gxEBDDZN1BRB4Afw/gY4WOZea7mHkPM+/p7a24FlZBnjwTwD2/Pmvb+YyFSO4LQRCM9aBTDJPKYCy8iizAtcgEgKGM55v01wxaAewA8DARnQNwE4D9bkyA/ebTF/D5B0/Zdr5ASO4LQRA0xGNSYYw/eERNIhpPVnk0gs0cALCViDYTkR/AHQD2G28y8wIz9zDzKDOPAngSwG3M7Do5TCWsYimagJpI2XY+AFBCYpgIwnrHWA+6xDCpDJk7Qtkd1hbMnADwEQAPAHgBwHeY+RgRfYaIbqvu6OzF8HDYlROi6OdZitln7AiC4E4CYRV+rwct9eXtZlO1XjlOY7VhMtDRWMXRCHbDzPcBuG/Va5/McezNlRhTOUjnhIRUbGxrsH6+DE9JMGLPOQVBcCfBsIrOZh+IqKzXEY+JjhKOpR9Lop/gRpgZAX0eBzLmsxUuuS8knCMI65pAWEVXc33ZryOGiY4SjqcfS8mw4EaWYgnEk5oqo13hSCWScV9IybAgrGuUcKzsia+AGCZplHAMrQ1aZEs8JoIbyQy72OXdkPtCEASDYCRe9lJhQAyTNIGwis09zfB66BL3tSC4hUAZErgDIRVbN7Ro5wzJfSEI65lASDwmFSUYUdHd7Ednk/+SsI4guIVAKDNPyh4jIhhRsbmnBUSXhnWEylOoESURvZ+IZonoiP7ze9UYp1CbxJMpLEYTZS8VBqQqJ40SUnHFxlZ0NfvEYyK4EsNL0tpQZ0soh5mhhFX0ttajvVHui2qS0YjyDdBaKhwgov3MfHzVod9m5o9UfIBCzROskOorYNJjUust47VqBs1j0tXsFx0ToWLYqQ1ihHIu39Biyxw2kmnlvnAEZhtRCkJZMDSNHBHKWQ8t45fjScQSKXQ116O7uV4WYKEinJhexLZP3o/TMyFbzqeEVTT5vRjoaLRlDgczVB67xTCpNgUbUeq8nYieJaLvEdFQlvcBrL9mlIJ1jOT6SoRyzHhMar5lvOH2lp2hUElOTi8hkWK8eHHJlvMFQjF0t/jR0+zHnA2JqoEMw0TuC1fw7wBGmflaAA8C+GquA53ejFJwHpnrQbkxY5jY1jLeqWR2TOxs9mN+OY5kiqs8KqHWmdMNYjuMCGBF/KiruR6L0QTiSWthoswdkhgmVadQI0owc4CZjcn0zwCur9DYhHWAoWPkFMMkL8W0jHeq+1CJXOqyZhYxKaH8zC7FLvnXKoqRJ9WiLRxWhQIz74uuZj+CkThSYrBXi7yNKAGAiPoznt4GrS+UINhCIKSCCOho9JX9WmYME9taxjvVfaisCuUAov4qlB/DUzJrmxiamjauAeuCaOkW5y1+dDXXI5liLEalZLgamGxE+UdEdIyIjgL4IwDvr85ohVpECatob/Shzlt+lREz5cJpSx2aQXIHgHcbbzLzAoAe4zkRPQzgv7ipZXxmKKcrY1HfWs1BCTWPYZjYEcpJV5a1rBgmVkMvSlhFfZ0HjT4vupq1XVIgrKKjqfyuXGEthRpRMvMnAHyi0uMS1gdKRK1IGAcw4TFZDy3jA2EVPi+hraEu/YeXeLpQbuwM5YRiCaiJFLqb/ejWQzlWDR4jNERE6cZdcl8IwvpECakVKRUGTAqs1XrL+GBYRWeTtgDbtdsUhELY6TFR0hnz9bYZEUpYTeeryH0hCOsbJaxipLupItcSSXoY1Qzawmu4qWUBFspJKsXpqpzZpRiYrSWVGvkk3c1+dDT64CHrczigG+zAitqj3BeCsD5RImraG1tuxDCB1kHVMEz8dR60NtTJAiyUFaMkfaC9AbFECqFYwtL50gncLX54PISuZr8Nya8rDbvEYyII6xdmRjDsoByT9UAwEr/kDy4ql0K5McI3V/W36c+teje08xnzuKvZf0lTv1IIhuPpsFCDz4smv1fuC0FYhywuJ5BIcdqDWm7EMIGmmJlpmHSKYSKUGSPh9er+1kuel8pKKEczJKwKosUSSYRiiXQ1DgC987bcF4Kw3kj3yZFQTmXI1sq52wY3uCDkw/CYXJ32mFisoAmpaPR50ej3AtAMFCtzODOZ1qC7Re4LQViPKGmPbH2BI+1h3RsmwSwdE7ua/SKwJpSVFY9J2yXPS0VZFf+16jFZMUzkvhCE9U5mP7lKsO4Nk2w7wy69w7DVSglByMVsKAa/14PR7mZ4yLrHZC6soifDzdrd4sd8JF5yv5xchomEcgRh/ZEpQloJxDAJGX/wlVh6V7MPatJ6pYQg5GJuSTMkvHoFjXUxtNiacCRQes+nrIZJkz+dZCsIwvpByRJZKCdimEQuTRoEICqXNQgR7SOik0R0mog+nuX9PySi54joCBE9RkTbyjme2VAMva3aPOtpqbceygmpa7x+QOlzWAlnCXG2+BGNp7CsJi2MVBAEt6GEVDT5vWjweStyPTFMsuwMRbOhtiAiL4AvArgVwDYA78pieHyDma9h5l0A/gZax+yyMbcUQ0+LZjz0ttZbauSX2SfHIN1aocTzKmEVHgLaMzqJrjQHFK+JIKwnlAyxxUqw7g0TI6mnoymjLFIMk1pjL4DTzHyGmVUA3wJwe+YBzLyY8bQZQFkTjDI9Jr0t9Ziz4DEJq0nE9D45Bka+yVyJc9hQffV4KP1ap6giC8K6pJKqr4DJXjm1jNHK2ZfRytmutvGCYxgEMJbxfBzAjasPIqIPA/gTAH4Ar8t2IiK6E8CdADA8PFzSYFIphhJW0x6TntZ6zIY0WXoiKvDptRhekdWJqtp7pRk82VQejYVJ7gtBWF+srvorN+veY6JE1nZMNP4DpDRyfcHMX2TmywD8GYC/zHHMXcy8h5n39Pb2lnSdYERFMsVpr0ZPix9qIoWlEpOtjdBK5o6mo8kPstAvJxBW12TgG3krcl9Uj0K5UhnHvZ2ImIj2VHJ8TuXvHjiJ//q9o9UehmsJhFR0SSinciihtQtwk98Lf51HXNa1wwSAoYznm/TXcvEtAG8t12BmdS9Gb2uD/q/2hV9qAmy2knevh9DZVLogmhLOYrBLKKeqmMyVAhG1AvgogKcqO0J7OD2zhPfd/TQiqn1Vkb86PYdHT83Zdr71RjAiHpOKks1FRUSi/lpbHACwlYg2E5EfwB0A9mceQERbM56+GcCL5RrM3JI2r1Y8JvX666UZJrnEj6z0fMoWymlrrEOdh+S+qB4Fc6V
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 648x1080 with 15 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {
|
|||
|
"needs_background": "light"
|
|||
|
},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"conv = Conv2D(1, (3,3), input_shape=train_shape, activation='relu', padding='same')\n",
|
|||
|
"c = conv(d)\n",
|
|||
|
"cc = pd.DataFrame(c.numpy().reshape(d.shape[1], d.shape[2]))\n",
|
|||
|
"\n",
|
|||
|
"pplot(cc)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 31,
|
|||
|
"id": "2631a70f",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAiYAAAIICAYAAABaaCUAAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAB5/klEQVR4nO3deXzcdbX4/9fJZG22LkmXJE3TfWVpCQULCLJcC3JBr4KAqCjai4KioF7U++Mq9+v9qvd7Eb0iioCooIigWBHEDWQRCl2hSZpQmq5ZmqZNZpJmm5nz+2NmSihJM21n5vP5zJzn45FHZvlkPqfpO8mZ9+e8z1tUFWOMMcYYN8hyOgBjjDHGmBhLTIwxxhjjGpaYGGOMMcY1LDExxhhjjGtYYmKMMcYY17DExBhjjDGuke3UicvKyrSmpsap05s0s27dun2qWp7Kc9oYNonkxBgGG8cmcRI1hh1LTGpqali7dq1TpzdpRkR2pPqcNoZNIjkxhsHGsUmcRI1hu5RjjDHGGNewxMQYY4wxrmGJiTHGGGNcI2MSk/oWP//3yQYO9A46HYoxJgNtafNzwy/W07yv1+lQjDkmobDyrz9fy9Nb9ib1PBmRmOzvHeQTP32FH/19Gyu/+yzPvd7hdEjGmAzzcvN+Hn+1ldzsjPi1a9JQ874enqprZ1/PQFLPk/Y/IaGwcuNDG9jXO8j/XHYSxfk5fPjel7nt9/X0D4WcDs8YkyE27OxicnEeFaX5TodizDGpa/EDsKSyNKnnSfvE5H//9jrPvb6Pr1+ymPefUsXvbziTj75jBve90Mx773yBLW1+p0M0xmSA9TsPsLR6PCLidCjGHJPNe7rJzc5izuSipJ4nrROTvzd18N2/vs6/LKvkilOnA1CQ6+Prly7hJx87lX09g1zyvy9wz3PbCIfV4WiNMemqs2eAHZ0HWVY94ZhfQ0Smi8jTIlIvInUicuMIx4iIfE9EtorIqyKy7LgCN2aYuhY/86cUk+NLbuqQtolJS1cfn3toA/OnFPON957wtncp75o/mT9+7izeOa+M//OHBj5y38u0dfc7FK0xJp1t3NUFwNLjSEyAIHCzqi4CTgeuF5FFhx1zITA3+rEKuOt4TmhMjKpS1+JnSWVJ0s+VlonJYDDMpx9cz1BI+cGHllGQ6xvxuLKiPH78kVr+630nsG7HAVZ+91mefK01xdEaY9Ldhp1d+LKEE47j2ryqtqrq+ujtANAAVB522KXAzzTiJWC8iEw75pMaE7X7QB/dfUMsqkhufQmkaWLyX080sHFXF9/+wInMKj/ytTAR4arTqvnDZ8+keuI4PvXger7w6030DARTFK1JBRG5T0T2isjmUZ63KXCTNOt3HmDhtOJR3yQdLRGpAZYCaw57qhLYNez+bt6evCAiq0RkrYis7eiwVYpmbIcKXytsxuSo/X5TC/f/YzvXnjmTi06I/43CrPIiHv3UCm541xx+s343F333Odbt2J/ESE2K3Q+sPMLzNgVukiIUVjbt6mLp9OO6jHOIiBQBjwKfU9Vjqt5X1btVtVZVa8vLU75voPGgupZusgQWTLXE5Khs3dvDLY++yikzJnDLhQuO+utzfFl84d3z+dW/voOwKpf98EVu/3MTQ6FwEqI1qaSqzwJHyjRtCtwkxet7A/QOhlg2Y/xxv5aI5BBJSh5U1d+McMgeYPqw+1XRx4w5LnUtfmaXFyVs1u9I0iYxOTgY5FMPrCM/x8edVy07rqrhU2sm8sSNZ/HepZV876+vc9kPX2S7dWtMdzYFbpJiw84ugOOeMZFIBf+9QIOq3j7KYauBj0QvTZ4OdKuqFc6Z41bX0p30/iUxaZGYqCpf+c1rbO3o4XtXLmVqAhoYleTncPvlJ/P9q5ayraOHi773HA+9vBNVW1acyWwK3BytDTsPMGFcDjMmjTvelzoD+DBwrohsjH5cJCLXich10WOeALYBW4EfA58+3pMa0xEYoN0/wOIU1JcAZKfkLEn2wJqdPLaxhZsvmMcZc8oS+toXn1jBKTMmcPPDm7jlN6/xty17+eb7T2RiYW5Cz2McZ1PgJinW7+xiafWE426spqrPA0d8EY28c7r+uE5kzGHqWroBWJyCFTmQBjMmm3Z18Z+/r+ec+eVc/645STnHtNICHrj2NL560UKeaexg5R3P8myTTeOnGZsCNwnX3TfE1r09LKse73Qoxhyz2IqcRSmaMfF0YnKgd5BPP7ie8uI8vnP5yWRlJa/Vc1aW8Ml3zuKx68+gtCCHj9z3Ml9bXWf77XiEiPwSeBGYLyK7ReRamwI3ybYpMY3VjHFUXUs30ycWUFqQk5LzxXUpR0RWAt8FfMA9qvrNUY57P/AIcKqqrk1YlCMIh5WbHt5IR2CAX1/3Diak6NLKoooSfv+ZM/nmk1u4/x/b+ccb+7jjg0tTlkmaY6OqV47xvE2Bm4TbsLMLETixKjVT4MYkQ12LnyUpuowDccyYiIgPuJNIn4dFwJUjtEFGRIqBG3l7w5+k+MEzW3m6sYP/758XcdL08ak45SH5OT6+dslifvrx5Rw4OMR773yBl5ut54kx5q027DrAvMnFFOen5p2mMYnm7x9iR+fBlBW+QnyXcpYDW1V1m6oOAg8R6flwuP8EvgUkfcOZF7bu4/Y/N3HpyRVcfVp1sk83qrPnlfPHG88iKwueqmtzLA5jjPuEw8qGnV0stfoS42H10fqSxSlaKgzxJSZj9neItu+erqp/ONILJaIHRFt3P5/95QZmlxfxf//l7ZvzpdqkojwWTith855uR+MwxrhLc2cv3X1Dx7WjsDFOi/1tc9uMyRGJSBZwO3DzWMcebw+IoVCYG36xnr6hEHddfQrjct2x2nlJRSn1LX7CYetxYoyJONRYzWZMjIfVt/gpL85jcvHx9weLVzyJyVj9HYqBJcAzIrKdyHbcq0WkNlFBxnzzyS2s3XGAb73/ROZMPvLmfKm0uKKEwECQnfsPOh2KMcYlNuw8QHFeNrPH2EjUGDeLFL6mdnFHPInJK8BcEZkpIrnAFUR6PgCgqt2qWqaqNapaA7wEXJLoVTlPvtbKvc83c82KGv75pIpEvvRxi7Xpja31NsaYDTu7OLl6fFLbGBiTTP1DIbZ29KSssVrMmImJqgaBG4CngAbgYVWtE5HbROSSZAcIsK2jhy8+8ionTx/PVy5amIpTHpW5U4rI8QmbW6zOxBgDvQNBtrT5WZriFYPGJNKWtgChsLKkMrUzJnEVaajqE0QaUA1/7NZRjj3n+MN6U99giE8/uJ4cn3Dnh5aRm+2+nnB52T7mTSm2AlhjDACv7u4mrLB0hhW+Gu96s/DVZTMmTlJV/v2xzTS2B7jjiqVUji9wOqRRLakopa7Fb5v8GWPYsOsAACdXjXc2EGOOQ12Ln5L8bKompPZvr6sTk4de2cWj63fz2XPncvY8d+/kuriyhP29g7R2J72NizHG5Tbs7GJWWWHKOlIbkwz1Ld0srihNeVsO1yYmm/d08x+r6zhrbhmfPW+u0+GMKTbVZQWwxmQ2VWXDzgOcbMuEjYcNhcI0tAVS2r8kxpWJSffBIT714DomFeby3SuW4vNAVfvCacVkCVZnYkyG232gj309g7Zxn/G0Nzp6GAyGD606TSV3dCgbJhxWbv71Rtq6+/nVv76DiR6ZCh2XG+lXUGcrc4zJaOt3RupLltmMifGwzXuirehtxgR+9Ow2/tKwl69etNBzrZyXVJYe+s80xmSmDTu7KMjxMX9KsdOhGHPM6lq6yc/JYpYDDQJdlZgMhcL8buMe3nPiND66osbpcI7a4ooS2vz97OsZcDoUY4xDNuzq4sSqUrJ9rvr1asxRqWvxs3BaiSOlFK76ycnxZfGbT6/g2+8/0fHN+Y6FFcC6l4isFJFGEdkqIreM8Hy1iDwtIhtE5FURuciJOI239Q+FqG/ptvoS42nhsFLf4nfkMg64LDGBSK1GYZ7rSl/isij6n2gFsO4iIj7gTuBCYBFwpYgsOuywfyfS1XgpkW0XfpDaKE06qGvpZiiktnGf8bSd+w/SMxBkSYobq8W4LjHxstKCHGZMGmc
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 648x648 with 9 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {
|
|||
|
"needs_background": "light"
|
|||
|
},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"pol = MaxPooling2D(pool_size=(2, 2))\n",
|
|||
|
"p = pol(c)\n",
|
|||
|
"pp = pd.DataFrame(p.numpy().reshape(p.shape[1], p.shape[2]))\n",
|
|||
|
"pplot(pp)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 32,
|
|||
|
"id": "706a384c",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>0</th>\n",
|
|||
|
" <th>1</th>\n",
|
|||
|
" <th>2</th>\n",
|
|||
|
" <th>3</th>\n",
|
|||
|
" <th>4</th>\n",
|
|||
|
" <th>5</th>\n",
|
|||
|
" <th>6</th>\n",
|
|||
|
" <th>7</th>\n",
|
|||
|
" <th>8</th>\n",
|
|||
|
" <th>9</th>\n",
|
|||
|
" <th>10</th>\n",
|
|||
|
" <th>11</th>\n",
|
|||
|
" <th>12</th>\n",
|
|||
|
" <th>13</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>0</th>\n",
|
|||
|
" <td>0.001282</td>\n",
|
|||
|
" <td>0.412811</td>\n",
|
|||
|
" <td>-0.256378</td>\n",
|
|||
|
" <td>0.000854</td>\n",
|
|||
|
" <td>-0.409180</td>\n",
|
|||
|
" <td>0.262451</td>\n",
|
|||
|
" <td>0.006683</td>\n",
|
|||
|
" <td>0.005463</td>\n",
|
|||
|
" <td>0.008514</td>\n",
|
|||
|
" <td>0.017822</td>\n",
|
|||
|
" <td>0.096680</td>\n",
|
|||
|
" <td>-0.016479</td>\n",
|
|||
|
" <td>-0.561079</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1</th>\n",
|
|||
|
" <td>-0.016144</td>\n",
|
|||
|
" <td>0.409729</td>\n",
|
|||
|
" <td>-0.254547</td>\n",
|
|||
|
" <td>0.009766</td>\n",
|
|||
|
" <td>-0.405640</td>\n",
|
|||
|
" <td>0.272217</td>\n",
|
|||
|
" <td>0.005798</td>\n",
|
|||
|
" <td>0.004425</td>\n",
|
|||
|
" <td>0.008240</td>\n",
|
|||
|
" <td>0.017822</td>\n",
|
|||
|
" <td>0.096680</td>\n",
|
|||
|
" <td>-0.016602</td>\n",
|
|||
|
" <td>-0.561079</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>2</th>\n",
|
|||
|
" <td>-0.006683</td>\n",
|
|||
|
" <td>0.408264</td>\n",
|
|||
|
" <td>-0.252441</td>\n",
|
|||
|
" <td>0.011963</td>\n",
|
|||
|
" <td>-0.401001</td>\n",
|
|||
|
" <td>0.274780</td>\n",
|
|||
|
" <td>0.004425</td>\n",
|
|||
|
" <td>0.002502</td>\n",
|
|||
|
" <td>0.006561</td>\n",
|
|||
|
" <td>0.018555</td>\n",
|
|||
|
" <td>0.096558</td>\n",
|
|||
|
" <td>-0.016357</td>\n",
|
|||
|
" <td>-0.561079</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>3</th>\n",
|
|||
|
" <td>-0.011963</td>\n",
|
|||
|
" <td>0.407440</td>\n",
|
|||
|
" <td>-0.253662</td>\n",
|
|||
|
" <td>-0.002930</td>\n",
|
|||
|
" <td>-0.400635</td>\n",
|
|||
|
" <td>0.282837</td>\n",
|
|||
|
" <td>0.002716</td>\n",
|
|||
|
" <td>0.001312</td>\n",
|
|||
|
" <td>0.005707</td>\n",
|
|||
|
" <td>0.018433</td>\n",
|
|||
|
" <td>0.096680</td>\n",
|
|||
|
" <td>-0.015991</td>\n",
|
|||
|
" <td>-0.561079</td>\n",
|
|||
|
" <td>3.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>4</th>\n",
|
|||
|
" <td>-0.011292</td>\n",
|
|||
|
" <td>0.401672</td>\n",
|
|||
|
" <td>-0.246674</td>\n",
|
|||
|
" <td>-0.006226</td>\n",
|
|||
|
" <td>-0.399658</td>\n",
|
|||
|
" <td>0.289795</td>\n",
|
|||
|
" <td>0.000549</td>\n",
|
|||
|
" <td>0.001709</td>\n",
|
|||
|
" <td>0.005615</td>\n",
|
|||
|
" <td>0.018311</td>\n",
|
|||
|
" <td>0.096680</td>\n",
|
|||
|
" <td>-0.016357</td>\n",
|
|||
|
" <td>-0.561079</td>\n",
|
|||
|
" <td>4.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>70</th>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>71</th>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>72</th>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>73</th>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>74</th>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>75 rows × 14 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" 0 1 2 3 4 5 6 \\\n",
|
|||
|
"0 0.001282 0.412811 -0.256378 0.000854 -0.409180 0.262451 0.006683 \n",
|
|||
|
"1 -0.016144 0.409729 -0.254547 0.009766 -0.405640 0.272217 0.005798 \n",
|
|||
|
"2 -0.006683 0.408264 -0.252441 0.011963 -0.401001 0.274780 0.004425 \n",
|
|||
|
"3 -0.011963 0.407440 -0.253662 -0.002930 -0.400635 0.282837 0.002716 \n",
|
|||
|
"4 -0.011292 0.401672 -0.246674 -0.006226 -0.399658 0.289795 0.000549 \n",
|
|||
|
".. ... ... ... ... ... ... ... \n",
|
|||
|
"70 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 \n",
|
|||
|
"71 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 \n",
|
|||
|
"72 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 \n",
|
|||
|
"73 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 \n",
|
|||
|
"74 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 \n",
|
|||
|
"\n",
|
|||
|
" 7 8 9 10 11 12 13 \n",
|
|||
|
"0 0.005463 0.008514 0.017822 0.096680 -0.016479 -0.561079 0.0 \n",
|
|||
|
"1 0.004425 0.008240 0.017822 0.096680 -0.016602 -0.561079 1.0 \n",
|
|||
|
"2 0.002502 0.006561 0.018555 0.096558 -0.016357 -0.561079 2.0 \n",
|
|||
|
"3 0.001312 0.005707 0.018433 0.096680 -0.015991 -0.561079 3.0 \n",
|
|||
|
"4 0.001709 0.005615 0.018311 0.096680 -0.016357 -0.561079 4.0 \n",
|
|||
|
".. ... ... ... ... ... ... ... \n",
|
|||
|
"70 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.0 \n",
|
|||
|
"71 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.0 \n",
|
|||
|
"72 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.0 \n",
|
|||
|
"73 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.0 \n",
|
|||
|
"74 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.0 \n",
|
|||
|
"\n",
|
|||
|
"[75 rows x 14 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 32,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"pd.DataFrame(d.reshape(X_train[0].shape[0], X_train[0].shape[1]).T)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 33,
|
|||
|
"id": "8f2dc307",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>0</th>\n",
|
|||
|
" <th>1</th>\n",
|
|||
|
" <th>2</th>\n",
|
|||
|
" <th>3</th>\n",
|
|||
|
" <th>4</th>\n",
|
|||
|
" <th>5</th>\n",
|
|||
|
" <th>6</th>\n",
|
|||
|
" <th>7</th>\n",
|
|||
|
" <th>8</th>\n",
|
|||
|
" <th>9</th>\n",
|
|||
|
" <th>10</th>\n",
|
|||
|
" <th>11</th>\n",
|
|||
|
" <th>12</th>\n",
|
|||
|
" <th>13</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>0</th>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.164418</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.460755</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.267381</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.322476</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.116835</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1</th>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.358112</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.394801</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.405149</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.045922</td>\n",
|
|||
|
" <td>0.236685</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.052844</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>2</th>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.356213</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.403690</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.408704</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.045495</td>\n",
|
|||
|
" <td>0.236544</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.562665</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>3</th>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.355599</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.399693</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.410803</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.045657</td>\n",
|
|||
|
" <td>0.236834</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.072485</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>4</th>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.360133</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.382545</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.400446</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.045980</td>\n",
|
|||
|
" <td>0.237183</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.582306</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>70</th>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>71</th>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>72</th>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>73</th>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>74</th>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>75 rows × 14 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" 0 1 2 3 4 5 6 7 8 9 10 \\\n",
|
|||
|
"0 0.0 0.164418 0.0 0.460755 0.0 0.267381 0.0 0.0 0.0 0.0 0.000000 \n",
|
|||
|
"1 0.0 0.358112 0.0 0.394801 0.0 0.405149 0.0 0.0 0.0 0.0 0.045922 \n",
|
|||
|
"2 0.0 0.356213 0.0 0.403690 0.0 0.408704 0.0 0.0 0.0 0.0 0.045495 \n",
|
|||
|
"3 0.0 0.355599 0.0 0.399693 0.0 0.410803 0.0 0.0 0.0 0.0 0.045657 \n",
|
|||
|
"4 0.0 0.360133 0.0 0.382545 0.0 0.400446 0.0 0.0 0.0 0.0 0.045980 \n",
|
|||
|
".. ... ... ... ... ... ... ... ... ... ... ... \n",
|
|||
|
"70 0.0 0.000000 0.0 0.000000 0.0 0.000000 0.0 0.0 0.0 0.0 0.000000 \n",
|
|||
|
"71 0.0 0.000000 0.0 0.000000 0.0 0.000000 0.0 0.0 0.0 0.0 0.000000 \n",
|
|||
|
"72 0.0 0.000000 0.0 0.000000 0.0 0.000000 0.0 0.0 0.0 0.0 0.000000 \n",
|
|||
|
"73 0.0 0.000000 0.0 0.000000 0.0 0.000000 0.0 0.0 0.0 0.0 0.000000 \n",
|
|||
|
"74 0.0 0.000000 0.0 0.000000 0.0 0.000000 0.0 0.0 0.0 0.0 0.000000 \n",
|
|||
|
"\n",
|
|||
|
" 11 12 13 \n",
|
|||
|
"0 0.322476 0.0 0.116835 \n",
|
|||
|
"1 0.236685 0.0 0.052844 \n",
|
|||
|
"2 0.236544 0.0 0.562665 \n",
|
|||
|
"3 0.236834 0.0 1.072485 \n",
|
|||
|
"4 0.237183 0.0 1.582306 \n",
|
|||
|
".. ... ... ... \n",
|
|||
|
"70 0.000000 0.0 0.000000 \n",
|
|||
|
"71 0.000000 0.0 0.000000 \n",
|
|||
|
"72 0.000000 0.0 0.000000 \n",
|
|||
|
"73 0.000000 0.0 0.000000 \n",
|
|||
|
"74 0.000000 0.0 0.000000 \n",
|
|||
|
"\n",
|
|||
|
"[75 rows x 14 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 33,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"pd.DataFrame(c.numpy().reshape(c.shape[1], c.shape[2]).T)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"id": "73042b1b",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": []
|
|||
|
}
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"kernelspec": {
|
|||
|
"display_name": "Python 3",
|
|||
|
"language": "python",
|
|||
|
"name": "python3"
|
|||
|
},
|
|||
|
"language_info": {
|
|||
|
"codemirror_mode": {
|
|||
|
"name": "ipython",
|
|||
|
"version": 3
|
|||
|
},
|
|||
|
"file_extension": ".py",
|
|||
|
"mimetype": "text/x-python",
|
|||
|
"name": "python",
|
|||
|
"nbconvert_exporter": "python",
|
|||
|
"pygments_lexer": "ipython3",
|
|||
|
"version": "3.8.5"
|
|||
|
}
|
|||
|
},
|
|||
|
"nbformat": 4,
|
|||
|
"nbformat_minor": 5
|
|||
|
}
|