import gensim.downloader as api
word2vec = api.load('word2vec-google-news-300')
[==================================================] 100.0% 1662.8/1662.8MB downloaded
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Bidirectional, GRU, Embedding
from tensorflow.keras.optimizers import Adam
BATCH_SIZE=256
EPOCHS=30
GRU_SIZE = 64
DENSE = 32
MAX_WORDS = 100000
EMBEDDING_DIM = 300
MAX_SEQUENCE_LENGTH = 250
# create an empty sequential model
model = Sequential()
# Αdd an embedding layer
# Randomly Initialized custom embedding layer trained end2end with the model
# model.add(Embedding(input_dim=MAX_WORDS+2, output_dim=EMBEDDING_DIM,
# input_length=MAX_SEQUENCE_LENGTH,mask_zero=True, trainable=True))
# Word2vec initialized embedding layer, not trainable
model.add(Embedding(input_dim=MAX_WORDS+2, output_dim=EMBEDDING_DIM, weights=[embedding_matrix],
input_length=MAX_SEQUENCE_LENGTH,mask_zero=True, trainable=False))
# Could we try to train the Word2Vec embedding further??? What should we do to do this successfully?
# Hint: learning rate
# Αdd a bidirectional gru layer with 0.33 variational (recurrent) dropout
model.add(Bidirectional(GRU(GRU_SIZE, input_shape=(MAX_SEQUENCE_LENGTH, EMBEDDING_DIM), return_sequences=False, recurrent_dropout = 0.33)))
# return_sequences=False: Whether to return the last output in the output sequence, or the full sequence.
# Αdd a hidden MLP layer
model.add(Dropout(0.33))
model.add(Dense(DENSE, activation='relu' ))
# Αdd the output MLP layer
model.add(Dropout(0.33))
model.add(Dense(len(twenty_train.target_names), activation='softmax'))
# Multi-class classification -> Use softmax over all possible classes
# model.build((None, EMBEDDING_DIM, VECTOR_DIMENSION))
print(model.summary())
model.compile(loss='categorical_crossentropy',
optimizer=Adam(learning_rate=0.002), # 0.001),
metrics=["accuracy"])
# Save model weights after each epoch with ModelCheckpoint
# IF I WANTED TO USE GDRIVE
# '/content/gdrive/My Drive/checkpoints'
if not os.path.exists('/content/checkpoints'):
os.makedirs('/content/checkpoints')
# '/content/gdrive/My Drive/checkpoints/BiGRUMLP.hdf5'
checkpoint = ModelCheckpoint('/content/checkpoints/BiGRUMLP.hdf5',
monitor='val_accuracy',
mode='max', verbose=2,
save_best_only=True,
save_weights_only=True)
history = model.fit(train_data,
y_train_1_hot,
validation_data=(val_data, y_val_1_hot),
batch_size=BATCH_SIZE,
epochs=EPOCHS,
shuffle=True,
callbacks=[Metrics(valid_data=(val_data, y_val_1_hot)),
checkpoint])
WARNING:tensorflow:Layer gru will not use cuDNN kernels since it doesn't meet the criteria. It will use a generic GPU kernel as fallback when running on GPU. WARNING:tensorflow:Layer gru will not use cuDNN kernels since it doesn't meet the criteria. It will use a generic GPU kernel as fallback when running on GPU. WARNING:tensorflow:Layer gru will not use cuDNN kernels since it doesn't meet the criteria. It will use a generic GPU kernel as fallback when running on GPU.
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= embedding_2 (Embedding) (None, 250, 300) 30000600 bidirectional_2 (Bidirecti (None, 128) 140544 onal) dropout_4 (Dropout) (None, 128) 0 dense_4 (Dense) (None, 32) 4128 dropout_5 (Dropout) (None, 32) 0 dense_5 (Dense) (None, 20) 660 ================================================================= Total params: 30145932 (115.00 MB) Trainable params: 145332 (567.70 KB) Non-trainable params: 30000600 (114.44 MB) _________________________________________________________________ None Epoch 1/30 107/107 [==============================] - 18s 162ms/step — val_f1: 0.084604 — val_precision: 0.170068 — val_recall: 0.121060 Epoch 1: val_accuracy improved from -inf to 0.12106, saving model to /content/checkpoints/BiGRUMLP.hdf5
/usr/local/lib/python3.10/dist-packages/sklearn/metrics/_classification.py:1344: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior. _warn_prf(average, modifier, msg_start, len(result))
31/31 [==============================] - 86s 3s/step - loss: 2.9734 - accuracy: 0.0744 - val_loss: 2.9132 - val_accuracy: 0.1211 - val_f1: 0.0846 - val_recall: 0.1211 - val_precision: 0.1701 Epoch 2/30 107/107 [==============================] - 17s 156ms/step — val_f1: 0.201904 — val_precision: 0.296786 — val_recall: 0.228571 Epoch 2: val_accuracy improved from 0.12106 to 0.22857, saving model to /content/checkpoints/BiGRUMLP.hdf5 31/31 [==============================] - 84s 3s/step - loss: 2.7441 - accuracy: 0.1327 - val_loss: 2.4504 - val_accuracy: 0.2286 - val_f1: 0.2019 - val_recall: 0.2286 - val_precision: 0.2968 Epoch 3/30 107/107 [==============================] - 16s 154ms/step — val_f1: 0.269124 — val_precision: 0.353126 — val_recall: 0.313697 Epoch 3: val_accuracy improved from 0.22857 to 0.31370, saving model to /content/checkpoints/BiGRUMLP.hdf5 31/31 [==============================] - 83s 3s/step - loss: 2.3429 - accuracy: 0.2199 - val_loss: 2.0197 - val_accuracy: 0.3137 - val_f1: 0.2691 - val_recall: 0.3137 - val_precision: 0.3531 Epoch 4/30 107/107 [==============================] - 16s 149ms/step — val_f1: 0.392721 — val_precision: 0.461777 — val_recall: 0.420619 Epoch 4: val_accuracy improved from 0.31370 to 0.42062, saving model to /content/checkpoints/BiGRUMLP.hdf5 31/31 [==============================] - 81s 3s/step - loss: 2.0135 - accuracy: 0.3185 - val_loss: 1.7367 - val_accuracy: 0.4206 - val_f1: 0.3927 - val_recall: 0.4206 - val_precision: 0.4618 Epoch 5/30 107/107 [==============================] - 16s 153ms/step — val_f1: 0.460487 — val_precision: 0.494728 — val_recall: 0.491311 Epoch 5: val_accuracy improved from 0.42062 to 0.49131, saving model to /content/checkpoints/BiGRUMLP.hdf5 31/31 [==============================] - 77s 3s/step - loss: 1.7313 - accuracy: 0.3855 - val_loss: 1.4422 - val_accuracy: 0.4913 - val_f1: 0.4605 - val_recall: 0.4913 - val_precision: 0.4947 Epoch 6/30 107/107 [==============================] - 16s 147ms/step — val_f1: 0.527518 — val_precision: 0.550191 — val_recall: 0.539028 Epoch 6: val_accuracy improved from 0.49131 to 0.53903, saving model to /content/checkpoints/BiGRUMLP.hdf5 31/31 [==============================] - 77s 3s/step - loss: 1.5178 - accuracy: 0.4528 - val_loss: 1.3347 - val_accuracy: 0.5390 - val_f1: 0.5275 - val_recall: 0.5390 - val_precision: 0.5502 Epoch 7/30 107/107 [==============================] - 16s 150ms/step — val_f1: 0.556103 — val_precision: 0.576500 — val_recall: 0.577909 Epoch 7: val_accuracy improved from 0.53903 to 0.57791, saving model to /content/checkpoints/BiGRUMLP.hdf5 31/31 [==============================] - 79s 3s/step - loss: 1.3964 - accuracy: 0.4936 - val_loss: 1.2127 - val_accuracy: 0.5779 - val_f1: 0.5561 - val_recall: 0.5779 - val_precision: 0.5765 Epoch 8/30 107/107 [==============================] - 17s 161ms/step — val_f1: 0.602474 — val_precision: 0.612069 — val_recall: 0.616200 Epoch 8: val_accuracy improved from 0.57791 to 0.61620, saving model to /content/checkpoints/BiGRUMLP.hdf5 31/31 [==============================] - 79s 3s/step - loss: 1.2914 - accuracy: 0.5324 - val_loss: 1.1069 - val_accuracy: 0.6162 - val_f1: 0.6025 - val_recall: 0.6162 - val_precision: 0.6121 Epoch 9/30 107/107 [==============================] - 17s 159ms/step — val_f1: 0.583099 — val_precision: 0.616101 — val_recall: 0.607364 Epoch 9: val_accuracy did not improve from 0.61620 31/31 [==============================] - 79s 3s/step - loss: 1.1917 - accuracy: 0.5738 - val_loss: 1.1118 - val_accuracy: 0.6074 - val_f1: 0.5831 - val_recall: 0.6074 - val_precision: 0.6161 Epoch 10/30 107/107 [==============================] - 16s 151ms/step — val_f1: 0.637699 — val_precision: 0.649250 — val_recall: 0.643594 Epoch 10: val_accuracy improved from 0.61620 to 0.64359, saving model to /content/checkpoints/BiGRUMLP.hdf5 31/31 [==============================] - 79s 3s/step - loss: 1.1319 - accuracy: 0.5987 - val_loss: 1.0278 - val_accuracy: 0.6436 - val_f1: 0.6377 - val_recall: 0.6436 - val_precision: 0.6493 Epoch 11/30 107/107 [==============================] - 17s 156ms/step — val_f1: 0.677581 — val_precision: 0.682764 — val_recall: 0.682474 Epoch 11: val_accuracy improved from 0.64359 to 0.68247, saving model to /content/checkpoints/BiGRUMLP.hdf5 31/31 [==============================] - 75s 2s/step - loss: 1.0398 - accuracy: 0.6266 - val_loss: 0.9757 - val_accuracy: 0.6825 - val_f1: 0.6776 - val_recall: 0.6825 - val_precision: 0.6828 Epoch 12/30 107/107 [==============================] - 17s 159ms/step — val_f1: 0.670787 — val_precision: 0.676946 — val_recall: 0.681296 Epoch 12: val_accuracy did not improve from 0.68247 31/31 [==============================] - 78s 3s/step - loss: 0.9923 - accuracy: 0.6431 - val_loss: 0.9432 - val_accuracy: 0.6813 - val_f1: 0.6708 - val_recall: 0.6813 - val_precision: 0.6769 Epoch 13/30 107/107 [==============================] - 16s 149ms/step — val_f1: 0.680558 — val_precision: 0.695104 — val_recall: 0.689249 Epoch 13: val_accuracy improved from 0.68247 to 0.68925, saving model to /content/checkpoints/BiGRUMLP.hdf5 31/31 [==============================] - 80s 3s/step - loss: 0.9340 - accuracy: 0.6662 - val_loss: 0.9289 - val_accuracy: 0.6892 - val_f1: 0.6806 - val_recall: 0.6892 - val_precision: 0.6951 Epoch 14/30 107/107 [==============================] - 16s 150ms/step — val_f1: 0.723236 — val_precision: 0.734991 — val_recall: 0.727835 Epoch 14: val_accuracy improved from 0.68925 to 0.72784, saving model to /content/checkpoints/BiGRUMLP.hdf5 31/31 [==============================] - 74s 2s/step - loss: 0.8663 - accuracy: 0.6949 - val_loss: 0.8569 - val_accuracy: 0.7278 - val_f1: 0.7232 - val_recall: 0.7278 - val_precision: 0.7350 Epoch 15/30 107/107 [==============================] - 16s 151ms/step — val_f1: 0.719180 — val_precision: 0.728528 — val_recall: 0.726951 Epoch 15: val_accuracy did not improve from 0.72784 31/31 [==============================] - 79s 3s/step - loss: 0.8322 - accuracy: 0.7104 - val_loss: 0.8276 - val_accuracy: 0.7270 - val_f1: 0.7192 - val_recall: 0.7270 - val_precision: 0.7285 Epoch 16/30 107/107 [==============================] - 17s 154ms/step — val_f1: 0.735873 — val_precision: 0.743401 — val_recall: 0.744624 Epoch 16: val_accuracy improved from 0.72784 to 0.74462, saving model to /content/checkpoints/BiGRUMLP.hdf5 31/31 [==============================] - 77s 3s/step - loss: 0.7928 - accuracy: 0.7218 - val_loss: 0.8120 - val_accuracy: 0.7446 - val_f1: 0.7359 - val_recall: 0.7446 - val_precision: 0.7434 Epoch 17/30 107/107 [==============================] - 16s 148ms/step — val_f1: 0.707631 — val_precision: 0.709445 — val_recall: 0.719293 Epoch 17: val_accuracy did not improve from 0.74462 31/31 [==============================] - 80s 3s/step - loss: 0.9062 - accuracy: 0.7000 - val_loss: 0.8624 - val_accuracy: 0.7193 - val_f1: 0.7076 - val_recall: 0.7193 - val_precision: 0.7094 Epoch 18/30 107/107 [==============================] - 16s 150ms/step — val_f1: 0.729827 — val_precision: 0.743000 — val_recall: 0.728130 Epoch 18: val_accuracy did not improve from 0.74462 31/31 [==============================] - 72s 2s/step - loss: 0.8237 - accuracy: 0.7190 - val_loss: 0.8516 - val_accuracy: 0.7281 - val_f1: 0.7298 - val_recall: 0.7281 - val_precision: 0.7430 Epoch 19/30 107/107 [==============================] - 16s 151ms/step — val_f1: 0.750211 — val_precision: 0.757480 — val_recall: 0.756112 Epoch 19: val_accuracy improved from 0.74462 to 0.75611, saving model to /content/checkpoints/BiGRUMLP.hdf5 31/31 [==============================] - 77s 3s/step - loss: 0.7466 - accuracy: 0.7430 - val_loss: 0.7737 - val_accuracy: 0.7561 - val_f1: 0.7502 - val_recall: 0.7561 - val_precision: 0.7575 Epoch 20/30 107/107 [==============================] - 16s 147ms/step — val_f1: 0.760419 — val_precision: 0.767420 — val_recall: 0.762592 Epoch 20: val_accuracy improved from 0.75611 to 0.76259, saving model to /content/checkpoints/BiGRUMLP.hdf5 31/31 [==============================] - 77s 3s/step - loss: 0.6677 - accuracy: 0.7726 - val_loss: 0.7880 - val_accuracy: 0.7626 - val_f1: 0.7604 - val_recall: 0.7626 - val_precision: 0.7674 Epoch 21/30 107/107 [==============================] - 16s 147ms/step — val_f1: 0.765542 — val_precision: 0.769955 — val_recall: 0.766127 Epoch 21: val_accuracy improved from 0.76259 to 0.76613, saving model to /content/checkpoints/BiGRUMLP.hdf5 31/31 [==============================] - 79s 3s/step - loss: 0.6396 - accuracy: 0.7793 - val_loss: 0.7768 - val_accuracy: 0.7661 - val_f1: 0.7655 - val_recall: 0.7661 - val_precision: 0.7700 Epoch 22/30 107/107 [==============================] - 16s 149ms/step — val_f1: 0.764978 — val_precision: 0.766824 — val_recall: 0.768778 Epoch 22: val_accuracy improved from 0.76613 to 0.76878, saving model to /content/checkpoints/BiGRUMLP.hdf5 31/31 [==============================] - 73s 2s/step - loss: 0.6121 - accuracy: 0.7962 - val_loss: 0.7495 - val_accuracy: 0.7688 - val_f1: 0.7650 - val_recall: 0.7688 - val_precision: 0.7668 Epoch 23/30 107/107 [==============================] - 16s 146ms/step — val_f1: 0.772570 — val_precision: 0.779288 — val_recall: 0.773785 Epoch 23: val_accuracy improved from 0.76878 to 0.77378, saving model to /content/checkpoints/BiGRUMLP.hdf5 31/31 [==============================] - 74s 2s/step - loss: 0.6190 - accuracy: 0.7924 - val_loss: 0.7745 - val_accuracy: 0.7738 - val_f1: 0.7726 - val_recall: 0.7738 - val_precision: 0.7793 Epoch 24/30 107/107 [==============================] - 17s 155ms/step — val_f1: 0.771834 — val_precision: 0.778802 — val_recall: 0.772312 Epoch 24: val_accuracy did not improve from 0.77378 31/31 [==============================] - 77s 3s/step - loss: 0.5728 - accuracy: 0.8058 - val_loss: 0.7703 - val_accuracy: 0.7723 - val_f1: 0.7718 - val_recall: 0.7723 - val_precision: 0.7788 Epoch 25/30 107/107 [==============================] - 16s 147ms/step — val_f1: 0.778215 — val_precision: 0.782727 — val_recall: 0.781443 Epoch 25: val_accuracy improved from 0.77378 to 0.78144, saving model to /content/checkpoints/BiGRUMLP.hdf5 31/31 [==============================] - 73s 2s/step - loss: 0.5385 - accuracy: 0.8177 - val_loss: 0.7427 - val_accuracy: 0.7814 - val_f1: 0.7782 - val_recall: 0.7814 - val_precision: 0.7827 Epoch 26/30 107/107 [==============================] - 16s 147ms/step — val_f1: 0.786600 — val_precision: 0.788118 — val_recall: 0.787629 Epoch 26: val_accuracy improved from 0.78144 to 0.78763, saving model to /content/checkpoints/BiGRUMLP.hdf5 31/31 [==============================] - 78s 3s/step - loss: 0.4997 - accuracy: 0.8339 - val_loss: 0.7299 - val_accuracy: 0.7876 - val_f1: 0.7866 - val_recall: 0.7876 - val_precision: 0.7881 Epoch 27/30 107/107 [==============================] - 16s 149ms/step — val_f1: 0.786458 — val_precision: 0.787760 — val_recall: 0.786745 Epoch 27: val_accuracy did not improve from 0.78763 31/31 [==============================] - 72s 2s/step - loss: 0.4757 - accuracy: 0.8392 - val_loss: 0.7559 - val_accuracy: 0.7867 - val_f1: 0.7865 - val_recall: 0.7867 - val_precision: 0.7878 Epoch 28/30 107/107 [==============================] - 16s 148ms/step — val_f1: 0.781311 — val_precision: 0.784836 — val_recall: 0.782032 Epoch 28: val_accuracy did not improve from 0.78763 31/31 [==============================] - 78s 3s/step - loss: 0.4602 - accuracy: 0.8506 - val_loss: 0.7549 - val_accuracy: 0.7820 - val_f1: 0.7813 - val_recall: 0.7820 - val_precision: 0.7848 Epoch 29/30 107/107 [==============================] - 16s 150ms/step — val_f1: 0.793178 — val_precision: 0.796680 — val_recall: 0.794698 Epoch 29: val_accuracy improved from 0.78763 to 0.79470, saving model to /content/checkpoints/BiGRUMLP.hdf5 31/31 [==============================] - 78s 3s/step - loss: 0.4433 - accuracy: 0.8493 - val_loss: 0.7585 - val_accuracy: 0.7947 - val_f1: 0.7932 - val_recall: 0.7947 - val_precision: 0.7967 Epoch 30/30 107/107 [==============================] - 16s 146ms/step — val_f1: 0.795004 — val_precision: 0.800514 — val_recall: 0.796760 Epoch 30: val_accuracy improved from 0.79470 to 0.79676, saving model to /content/checkpoints/BiGRUMLP.hdf5 31/31 [==============================] - 77s 3s/step - loss: 0.4220 - accuracy: 0.8568 - val_loss: 0.7976 - val_accuracy: 0.7968 - val_f1: 0.7950 - val_recall: 0.7968 - val_precision: 0.8005
%matplotlib inline
import matplotlib.pyplot as plt
# summarize history for accuracy
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'dev'], loc='upper left')
plt.show()
# summarize history for loss
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'dev'], loc='upper right')
plt.show()
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Bidirectional, GRU
from tensorflow.keras.optimizers import Adam
from sklearn.metrics import classification_report
GRU_SIZE = 64
DENSE = 32
with tf.device('/device:GPU:0'):
model = Sequential()
model.add(Embedding(input_dim=MAX_WORDS+2, output_dim=EMBEDDING_DIM, weights=[embedding_matrix],
input_length=MAX_SEQUENCE_LENGTH,mask_zero=True, trainable=False))
model.add(Bidirectional(GRU(GRU_SIZE, return_sequences=False, recurrent_dropout = 0.33)))
model.add(Dense(DENSE, activation='relu' ))
model.add(Dense(len(twenty_train.target_names), activation='softmax'))
# Load weights from the pre-trained model
model.load_weights("/content/checkpoints/BiGRUMLP.hdf5")
# model.load_weights("/content/gdrive/My Drive/checkpoints/BiGRUMLP.hdf5")
predictions = np.argmax(model.predict(val_data), -1)
print(classification_report(y_val, predictions, target_names=twenty_train.target_names))
WARNING:tensorflow:Layer gru_1 will not use cuDNN kernels since it doesn't meet the criteria. It will use a generic GPU kernel as fallback when running on GPU. WARNING:tensorflow:Layer gru_1 will not use cuDNN kernels since it doesn't meet the criteria. It will use a generic GPU kernel as fallback when running on GPU. WARNING:tensorflow:Layer gru_1 will not use cuDNN kernels since it doesn't meet the criteria. It will use a generic GPU kernel as fallback when running on GPU.
107/107 [==============================] - 16s 148ms/step precision recall f1-score support alt.atheism 0.80 0.77 0.79 160 comp.graphics 0.59 0.76 0.66 165 comp.os.ms-windows.misc 0.74 0.77 0.76 189 comp.sys.ibm.pc.hardware 0.66 0.46 0.54 168 comp.sys.mac.hardware 0.64 0.62 0.63 182 comp.windows.x 0.87 0.71 0.78 168 misc.forsale 0.79 0.71 0.75 182 rec.autos 0.87 0.90 0.88 181 rec.motorcycles 0.88 0.85 0.86 184 rec.sport.baseball 0.90 0.90 0.90 169 rec.sport.hockey 0.91 0.92 0.92 175 sci.crypt 0.94 0.92 0.93 177 sci.electronics 0.67 0.82 0.74 173 sci.med 0.89 0.95 0.92 181 sci.space 0.88 0.92 0.90 181 soc.religion.christian 0.74 0.86 0.80 177 talk.politics.guns 0.88 0.88 0.88 177 talk.politics.mideast 0.96 0.79 0.87 170 talk.politics.misc 0.68 0.84 0.75 135 talk.religion.misc 0.58 0.45 0.50 101 accuracy 0.80 3395 macro avg 0.79 0.79 0.79 3395 weighted avg 0.80 0.80 0.80 3395
from sklearn.metrics import accuracy_score
predictions = np.argmax(model.predict(val_data), -1)
print(f'Validation Accuracy: {accuracy_score(y_val, predictions)*100:.2f}%')
predictions = np.argmax(model.predict(test_data), -1)
print(f'Test Accuracy:{accuracy_score(y_test, predictions)*100:.2f}%')
107/107 [==============================] - 16s 149ms/step Validation Accuracy: 79.68% 32/32 [==============================] - 5s 166ms/step Test Accuracy:71.20%
Architecture |
---|
RNN (BiLSTM) layers |
Attention Mechanism |
MLP layer |
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.layers import Dense, Dropout, Bidirectional, LSTM, Embedding, Input
from tensorflow.keras.optimizers import Adam
import tensorflow.keras.backend as K
from tensorflow.keras import Model
# !pip install keras-self-attention
# from keras_self_attention import SeqSelfAttention
LSTM_SIZE = 300
DENSE = 1000
with tf.device('/device:GPU:0'):
inputs = Input((MAX_SEQUENCE_LENGTH,))
# Define the Embedding Layer with fastext weights
embeddings = Embedding(input_dim=MAX_WORDS+2, output_dim=EMBEDDING_DIM, weights=[embedding_matrix],
input_length=MAX_SEQUENCE_LENGTH, mask_zero=True, trainable=False)(inputs)
drop_emb = Dropout(0.33)(embeddings)
# Define a (Biderectional) RNN with LSTM cells
bilstm = Bidirectional(LSTM(units=LSTM_SIZE, return_sequences=True, recurrent_dropout=0.33))(drop_emb)
drop_encodings = Dropout(0.33)(bilstm)
# Pass the encoding through an Attention Layer
x, attn = DeepAttention(return_attention=True)(drop_encodings)
# x, attn = LinearAttention(return_attention=True)(drop_encodings)
# Alternatively use keras package for self-attention
#x, attn = SeqSelfAttention(return_attention=True)(drop_encodings)
# Apply Droupout to the encoding produced by the attention mechanism
drop_x = Dropout(0.33)(x)
# Pass through a Dense Layer
hidden = Dense(units=DENSE, activation="relu")(drop_x)
# Apply Dropout to the output of the Dense Layer
drop_out = Dropout(0.33)(hidden)
# Last pass through a Dense Layer with softmax activation to produce a probability distribution
out = Dense(units=len(twenty_train.target_names), activation="softmax")(drop_out)
# Wrap model --> Remember Functional API
model2 = Model(inputs=inputs, outputs=out)
print(model2.summary())
model2.compile(loss='categorical_crossentropy',
optimizer=Adam(lr=0.001),
metrics=["accuracy"])
if not os.path.exists('/content/checkpoints'):
os.makedirs('/content/checkpoints')
checkpoint = ModelCheckpoint('/content/checkpoints/BiLSTM_attn.hdf5',
monitor='val_accuracy',
mode='max', verbose=2,
save_best_only=True,
save_weights_only=True)
history2 = model2.fit(train_data, y_train_1_hot,
validation_data=(val_data, y_val_1_hot),
batch_size=128,
epochs=30,
shuffle=True,
callbacks=[Metrics(valid_data=(val_data, y_val_1_hot)),
checkpoint])
WARNING:tensorflow:Layer lstm will not use cuDNN kernels since it doesn't meet the criteria. It will use a generic GPU kernel as fallback when running on GPU. WARNING:tensorflow:Layer lstm will not use cuDNN kernels since it doesn't meet the criteria. It will use a generic GPU kernel as fallback when running on GPU. WARNING:tensorflow:Layer lstm will not use cuDNN kernels since it doesn't meet the criteria. It will use a generic GPU kernel as fallback when running on GPU. /usr/local/lib/python3.10/dist-packages/keras/src/initializers/initializers.py:120: UserWarning: The initializer GlorotUniform is unseeded and being called multiple times, which will return identical values each time (even if the initializer is unseeded). Please update your code to provide a seed to the initializer, or avoid using the same initializer instance more than once. warnings.warn(
Model: "model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_1 (InputLayer) [(None, 250)] 0 embedding (Embedding) (None, 250, 300) 30000600 dropout (Dropout) (None, 250, 300) 0 bidirectional (Bidirection (None, 250, 600) 1442400 al) dropout_1 (Dropout) (None, 250, 600) 0 deep_attention (DeepAttent [(None, 600), 361201 ion) (None, 250, 1)] dropout_2 (Dropout) (None, 600) 0 dense (Dense) (None, 1000) 601000 dropout_3 (Dropout) (None, 1000) 0 dense_1 (Dense) (None, 20) 20020 ================================================================= Total params: 32425221 (123.69 MB) Trainable params: 2424621 (9.25 MB) Non-trainable params: 30000600 (114.44 MB) _________________________________________________________________
WARNING:absl:`lr` is deprecated in Keras optimizer, please use `learning_rate` or use the legacy optimizer, e.g.,tf.keras.optimizers.legacy.Adam.
None Epoch 1/30 107/107 [==============================] - 22s 198ms/step — val_f1: 0.521783 — val_precision: 0.606270 — val_recall: 0.548748 Epoch 1: val_accuracy improved from -inf to 0.54875, saving model to /content/checkpoints/BiLSTM_attn.hdf5 62/62 [==============================] - 173s 3s/step - loss: 1.9925 - accuracy: 0.3237 - val_loss: 1.3148 - val_accuracy: 0.5487 - val_f1: 0.5218 - val_recall: 0.5487 - val_precision: 0.6063 Epoch 2/30 107/107 [==============================] - 20s 191ms/step — val_f1: 0.700132 — val_precision: 0.710411 — val_recall: 0.707511 Epoch 2: val_accuracy improved from 0.54875 to 0.70751, saving model to /content/checkpoints/BiLSTM_attn.hdf5 62/62 [==============================] - 175s 3s/step - loss: 1.1947 - accuracy: 0.6001 - val_loss: 0.9568 - val_accuracy: 0.7075 - val_f1: 0.7001 - val_recall: 0.7075 - val_precision: 0.7104 Epoch 3/30 107/107 [==============================] - 20s 187ms/step — val_f1: 0.718428 — val_precision: 0.740916 — val_recall: 0.725479 Epoch 3: val_accuracy improved from 0.70751 to 0.72548, saving model to /content/checkpoints/BiLSTM_attn.hdf5 62/62 [==============================] - 158s 3s/step - loss: 0.9418 - accuracy: 0.6962 - val_loss: 0.8615 - val_accuracy: 0.7255 - val_f1: 0.7184 - val_recall: 0.7255 - val_precision: 0.7409 Epoch 4/30 107/107 [==============================] - 21s 199ms/step — val_f1: 0.748557 — val_precision: 0.769423 — val_recall: 0.751105 Epoch 4: val_accuracy improved from 0.72548 to 0.75110, saving model to /content/checkpoints/BiLSTM_attn.hdf5 62/62 [==============================] - 178s 3s/step - loss: 0.8189 - accuracy: 0.7301 - val_loss: 0.7762 - val_accuracy: 0.7511 - val_f1: 0.7486 - val_recall: 0.7511 - val_precision: 0.7694 Epoch 5/30 107/107 [==============================] - 21s 193ms/step — val_f1: 0.770434 — val_precision: 0.783276 — val_recall: 0.769956 Epoch 5: val_accuracy improved from 0.75110 to 0.76996, saving model to /content/checkpoints/BiLSTM_attn.hdf5 62/62 [==============================] - 160s 3s/step - loss: 0.7321 - accuracy: 0.7586 - val_loss: 0.7425 - val_accuracy: 0.7700 - val_f1: 0.7704 - val_recall: 0.7700 - val_precision: 0.7833 Epoch 6/30 107/107 [==============================] - 21s 199ms/step — val_f1: 0.779297 — val_precision: 0.792297 — val_recall: 0.778498 Epoch 6: val_accuracy improved from 0.76996 to 0.77850, saving model to /content/checkpoints/BiLSTM_attn.hdf5 62/62 [==============================] - 179s 3s/step - loss: 0.6830 - accuracy: 0.7693 - val_loss: 0.7044 - val_accuracy: 0.7785 - val_f1: 0.7793 - val_recall: 0.7785 - val_precision: 0.7923 Epoch 7/30 107/107 [==============================] - 21s 195ms/step — val_f1: 0.785049 — val_precision: 0.808186 — val_recall: 0.786156 Epoch 7: val_accuracy improved from 0.77850 to 0.78616, saving model to /content/checkpoints/BiLSTM_attn.hdf5 62/62 [==============================] - 178s 3s/step - loss: 0.6185 - accuracy: 0.7933 - val_loss: 0.6838 - val_accuracy: 0.7862 - val_f1: 0.7850 - val_recall: 0.7862 - val_precision: 0.8082 Epoch 8/30 107/107 [==============================] - 22s 202ms/step — val_f1: 0.811010 — val_precision: 0.817126 — val_recall: 0.810604 Epoch 8: val_accuracy improved from 0.78616 to 0.81060, saving model to /content/checkpoints/BiLSTM_attn.hdf5 62/62 [==============================] - 159s 3s/step - loss: 0.5771 - accuracy: 0.8092 - val_loss: 0.6231 - val_accuracy: 0.8106 - val_f1: 0.8110 - val_recall: 0.8106 - val_precision: 0.8171 Epoch 9/30 107/107 [==============================] - 20s 189ms/step — val_f1: 0.796695 — val_precision: 0.804522 — val_recall: 0.801178 Epoch 9: val_accuracy did not improve from 0.81060 62/62 [==============================] - 158s 3s/step - loss: 0.5386 - accuracy: 0.8260 - val_loss: 0.6250 - val_accuracy: 0.8012 - val_f1: 0.7967 - val_recall: 0.8012 - val_precision: 0.8045 Epoch 10/30 107/107 [==============================] - 21s 197ms/step — val_f1: 0.821066 — val_precision: 0.827142 — val_recall: 0.819440 Epoch 10: val_accuracy improved from 0.81060 to 0.81944, saving model to /content/checkpoints/BiLSTM_attn.hdf5 62/62 [==============================] - 157s 3s/step - loss: 0.4968 - accuracy: 0.8341 - val_loss: 0.5832 - val_accuracy: 0.8194 - val_f1: 0.8211 - val_recall: 0.8194 - val_precision: 0.8271 Epoch 11/30 107/107 [==============================] - 21s 197ms/step — val_f1: 0.814973 — val_precision: 0.823588 — val_recall: 0.817673 Epoch 11: val_accuracy did not improve from 0.81944 62/62 [==============================] - 155s 3s/step - loss: 0.4499 - accuracy: 0.8516 - val_loss: 0.6115 - val_accuracy: 0.8177 - val_f1: 0.8150 - val_recall: 0.8177 - val_precision: 0.8236 Epoch 12/30 107/107 [==============================] - 20s 187ms/step — val_f1: 0.827455 — val_precision: 0.832135 — val_recall: 0.829455 Epoch 12: val_accuracy improved from 0.81944 to 0.82946, saving model to /content/checkpoints/BiLSTM_attn.hdf5 62/62 [==============================] - 159s 3s/step - loss: 0.4310 - accuracy: 0.8564 - val_loss: 0.5865 - val_accuracy: 0.8295 - val_f1: 0.8275 - val_recall: 0.8295 - val_precision: 0.8321 Epoch 13/30 107/107 [==============================] - 21s 195ms/step — val_f1: 0.827546 — val_precision: 0.831140 — val_recall: 0.829161 Epoch 13: val_accuracy did not improve from 0.82946 62/62 [==============================] - 156s 3s/step - loss: 0.4049 - accuracy: 0.8621 - val_loss: 0.5805 - val_accuracy: 0.8292 - val_f1: 0.8275 - val_recall: 0.8292 - val_precision: 0.8311 Epoch 14/30 107/107 [==============================] - 21s 195ms/step — val_f1: 0.835334 — val_precision: 0.840626 — val_recall: 0.835935 Epoch 14: val_accuracy improved from 0.82946 to 0.83594, saving model to /content/checkpoints/BiLSTM_attn.hdf5 62/62 [==============================] - 154s 2s/step - loss: 0.3655 - accuracy: 0.8752 - val_loss: 0.5747 - val_accuracy: 0.8359 - val_f1: 0.8353 - val_recall: 0.8359 - val_precision: 0.8406 Epoch 15/30 107/107 [==============================] - 21s 192ms/step — val_f1: 0.830022 — val_precision: 0.836516 — val_recall: 0.830044 Epoch 15: val_accuracy did not improve from 0.83594 62/62 [==============================] - 181s 3s/step - loss: 0.3433 - accuracy: 0.8848 - val_loss: 0.5721 - val_accuracy: 0.8300 - val_f1: 0.8300 - val_recall: 0.8300 - val_precision: 0.8365 Epoch 16/30 107/107 [==============================] - 20s 186ms/step — val_f1: 0.832501 — val_precision: 0.838497 — val_recall: 0.832401 Epoch 16: val_accuracy did not improve from 0.83594 62/62 [==============================] - 155s 3s/step - loss: 0.3155 - accuracy: 0.8963 - val_loss: 0.5820 - val_accuracy: 0.8324 - val_f1: 0.8325 - val_recall: 0.8324 - val_precision: 0.8385 Epoch 17/30 107/107 [==============================] - 21s 192ms/step — val_f1: 0.835755 — val_precision: 0.841986 — val_recall: 0.837113 Epoch 17: val_accuracy improved from 0.83594 to 0.83711, saving model to /content/checkpoints/BiLSTM_attn.hdf5 62/62 [==============================] - 177s 3s/step - loss: 0.2877 - accuracy: 0.9076 - val_loss: 0.5864 - val_accuracy: 0.8371 - val_f1: 0.8358 - val_recall: 0.8371 - val_precision: 0.8420 Epoch 18/30 107/107 [==============================] - 21s 197ms/step — val_f1: 0.837591 — val_precision: 0.842126 — val_recall: 0.838586 Epoch 18: val_accuracy improved from 0.83711 to 0.83859, saving model to /content/checkpoints/BiLSTM_attn.hdf5 62/62 [==============================] - 155s 3s/step - loss: 0.2893 - accuracy: 0.9016 - val_loss: 0.5857 - val_accuracy: 0.8386 - val_f1: 0.8376 - val_recall: 0.8386 - val_precision: 0.8421 Epoch 19/30 107/107 [==============================] - 20s 184ms/step — val_f1: 0.850468 — val_precision: 0.853581 — val_recall: 0.849779 Epoch 19: val_accuracy improved from 0.83859 to 0.84978, saving model to /content/checkpoints/BiLSTM_attn.hdf5 62/62 [==============================] - 157s 3s/step - loss: 0.2678 - accuracy: 0.9115 - val_loss: 0.5446 - val_accuracy: 0.8498 - val_f1: 0.8505 - val_recall: 0.8498 - val_precision: 0.8536 Epoch 20/30 107/107 [==============================] - 21s 195ms/step — val_f1: 0.845820 — val_precision: 0.850999 — val_recall: 0.845361 Epoch 20: val_accuracy did not improve from 0.84978 62/62 [==============================] - 156s 3s/step - loss: 0.2422 - accuracy: 0.9202 - val_loss: 0.5782 - val_accuracy: 0.8454 - val_f1: 0.8458 - val_recall: 0.8454 - val_precision: 0.8510 Epoch 21/30 107/107 [==============================] - 21s 197ms/step — val_f1: 0.842152 — val_precision: 0.845579 — val_recall: 0.842710 Epoch 21: val_accuracy did not improve from 0.84978 62/62 [==============================] - 176s 3s/step - loss: 0.2212 - accuracy: 0.9236 - val_loss: 0.6202 - val_accuracy: 0.8427 - val_f1: 0.8422 - val_recall: 0.8427 - val_precision: 0.8456 Epoch 22/30 107/107 [==============================] - 22s 202ms/step — val_f1: 0.846921 — val_precision: 0.853436 — val_recall: 0.847128 Epoch 22: val_accuracy did not improve from 0.84978 62/62 [==============================] - 175s 3s/step - loss: 0.2198 - accuracy: 0.9252 - val_loss: 0.6041 - val_accuracy: 0.8471 - val_f1: 0.8469 - val_recall: 0.8471 - val_precision: 0.8534 Epoch 23/30 107/107 [==============================] - 21s 196ms/step — val_f1: 0.842242 — val_precision: 0.847028 — val_recall: 0.843888 Epoch 23: val_accuracy did not improve from 0.84978 62/62 [==============================] - 174s 3s/step - loss: 0.2055 - accuracy: 0.9342 - val_loss: 0.5751 - val_accuracy: 0.8439 - val_f1: 0.8422 - val_recall: 0.8439 - val_precision: 0.8470 Epoch 24/30 107/107 [==============================] - 21s 195ms/step — val_f1: 0.849301 — val_precision: 0.854896 — val_recall: 0.849485 Epoch 24: val_accuracy did not improve from 0.84978 62/62 [==============================] - 163s 3s/step - loss: 0.1860 - accuracy: 0.9380 - val_loss: 0.5997 - val_accuracy: 0.8495 - val_f1: 0.8493 - val_recall: 0.8495 - val_precision: 0.8549 Epoch 25/30 107/107 [==============================] - 22s 206ms/step — val_f1: 0.849716 — val_precision: 0.855865 — val_recall: 0.848601 Epoch 25: val_accuracy did not improve from 0.84978 62/62 [==============================] - 185s 3s/step - loss: 0.1528 - accuracy: 0.9482 - val_loss: 0.6142 - val_accuracy: 0.8486 - val_f1: 0.8497 - val_recall: 0.8486 - val_precision: 0.8559 Epoch 26/30 107/107 [==============================] - 22s 203ms/step — val_f1: 0.851611 — val_precision: 0.854770 — val_recall: 0.850957 Epoch 26: val_accuracy improved from 0.84978 to 0.85096, saving model to /content/checkpoints/BiLSTM_attn.hdf5 62/62 [==============================] - 181s 3s/step - loss: 0.1583 - accuracy: 0.9472 - val_loss: 0.6040 - val_accuracy: 0.8510 - val_f1: 0.8516 - val_recall: 0.8510 - val_precision: 0.8548 Epoch 27/30 107/107 [==============================] - 22s 203ms/step — val_f1: 0.855303 — val_precision: 0.858370 — val_recall: 0.855081 Epoch 27: val_accuracy improved from 0.85096 to 0.85508, saving model to /content/checkpoints/BiLSTM_attn.hdf5 62/62 [==============================] - 180s 3s/step - loss: 0.1467 - accuracy: 0.9500 - val_loss: 0.6020 - val_accuracy: 0.8551 - val_f1: 0.8553 - val_recall: 0.8551 - val_precision: 0.8584 Epoch 28/30 107/107 [==============================] - 21s 198ms/step — val_f1: 0.844082 — val_precision: 0.852072 — val_recall: 0.844477 Epoch 28: val_accuracy did not improve from 0.85508 62/62 [==============================] - 179s 3s/step - loss: 0.1445 - accuracy: 0.9510 - val_loss: 0.6639 - val_accuracy: 0.8445 - val_f1: 0.8441 - val_recall: 0.8445 - val_precision: 0.8521 Epoch 29/30 107/107 [==============================] - 21s 200ms/step — val_f1: 0.849131 — val_precision: 0.852903 — val_recall: 0.848601 Epoch 29: val_accuracy did not improve from 0.85508 62/62 [==============================] - 178s 3s/step - loss: 0.1379 - accuracy: 0.9523 - val_loss: 0.6117 - val_accuracy: 0.8486 - val_f1: 0.8491 - val_recall: 0.8486 - val_precision: 0.8529 Epoch 30/30 107/107 [==============================] - 21s 199ms/step — val_f1: 0.857764 — val_precision: 0.861298 — val_recall: 0.857732 Epoch 30: val_accuracy improved from 0.85508 to 0.85773, saving model to /content/checkpoints/BiLSTM_attn.hdf5 62/62 [==============================] - 158s 3s/step - loss: 0.1228 - accuracy: 0.9601 - val_loss: 0.6177 - val_accuracy: 0.8577 - val_f1: 0.8578 - val_recall: 0.8577 - val_precision: 0.8613
%matplotlib inline
import matplotlib.pyplot as plt
# summarize history for accuracy
plt.plot(history2.history['accuracy'])
plt.plot(history2.history['val_accuracy'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'dev'], loc='upper left')
plt.show()
# summarize history for loss
plt.plot(history2.history['loss'])
plt.plot(history2.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'dev'], loc='upper right')
plt.show()
Model Name | Val Accuracy | Test Accuracy |
---|---|---|
Logistic Regression + TF-IDF | 83.74% | 76.83% |
MLP + TF-IDF | 86.95% | 77.10% |
MLP + Word2Vec Centroids | 79.73% | 70.61% |
############################ | ##### | ##### |
RNN custom embeddings | 78.76% | 67.70% |
RNN Word2Vec | 72.43% | 67.10% |
RNN Word2Vec + tuning (?) | 79.68% | 71.20% |
RNN custom embeddings + self-attention | 83.33% | 71.70% |
RNN Word2Vec + self-attention | 85.77% | 79.00% |