RNNs for Text Classification with Keras¶

Word2Vec Embeddings¶

In [ ]:
import gensim.downloader as api
word2vec = api.load('word2vec-google-news-300')
[==================================================] 100.0% 1662.8/1662.8MB downloaded

Create and train a BiGRU + MLP model with custom end2end embeddings randomly initialized¶

In [ ]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Bidirectional, GRU, Embedding
from tensorflow.keras.optimizers import Adam

BATCH_SIZE=256
EPOCHS=30
GRU_SIZE = 64
DENSE = 32

MAX_WORDS = 100000
EMBEDDING_DIM = 300
MAX_SEQUENCE_LENGTH = 250

# create an empty sequential model
model = Sequential()

# Αdd an embedding layer
# Randomly Initialized custom embedding layer trained end2end with the model
# model.add(Embedding(input_dim=MAX_WORDS+2, output_dim=EMBEDDING_DIM,
#                     input_length=MAX_SEQUENCE_LENGTH,mask_zero=True, trainable=True))
# Word2vec initialized embedding layer, not trainable
model.add(Embedding(input_dim=MAX_WORDS+2, output_dim=EMBEDDING_DIM, weights=[embedding_matrix],
                    input_length=MAX_SEQUENCE_LENGTH,mask_zero=True, trainable=False))
# Could we try to train the Word2Vec embedding further??? What should we do to do this successfully?
# Hint: learning rate

# Αdd a bidirectional gru layer with 0.33 variational (recurrent) dropout
model.add(Bidirectional(GRU(GRU_SIZE, input_shape=(MAX_SEQUENCE_LENGTH, EMBEDDING_DIM), return_sequences=False, recurrent_dropout = 0.33)))
# return_sequences=False: Whether to return the last output in the output sequence, or the full sequence.

# Αdd a hidden MLP layer
model.add(Dropout(0.33))
model.add(Dense(DENSE, activation='relu' ))

# Αdd the output MLP layer
model.add(Dropout(0.33))
model.add(Dense(len(twenty_train.target_names), activation='softmax'))
# Multi-class classification -> Use softmax over all possible classes

# model.build((None, EMBEDDING_DIM, VECTOR_DIMENSION))

print(model.summary())
model.compile(loss='categorical_crossentropy',
              optimizer=Adam(learning_rate=0.002),     # 0.001),
              metrics=["accuracy"])

# Save model weights after each epoch with ModelCheckpoint
# IF I WANTED TO USE GDRIVE
# '/content/gdrive/My Drive/checkpoints'
if not os.path.exists('/content/checkpoints'):
  os.makedirs('/content/checkpoints')

# '/content/gdrive/My Drive/checkpoints/BiGRUMLP.hdf5'


checkpoint = ModelCheckpoint('/content/checkpoints/BiGRUMLP.hdf5',
                              monitor='val_accuracy',
                              mode='max', verbose=2,
                              save_best_only=True,
                              save_weights_only=True)

history = model.fit(train_data,
                    y_train_1_hot,
                    validation_data=(val_data, y_val_1_hot),
                    batch_size=BATCH_SIZE,
                    epochs=EPOCHS,
                    shuffle=True,
                    callbacks=[Metrics(valid_data=(val_data, y_val_1_hot)),
                    checkpoint])
WARNING:tensorflow:Layer gru will not use cuDNN kernels since it doesn't meet the criteria. It will use a generic GPU kernel as fallback when running on GPU.
WARNING:tensorflow:Layer gru will not use cuDNN kernels since it doesn't meet the criteria. It will use a generic GPU kernel as fallback when running on GPU.
WARNING:tensorflow:Layer gru will not use cuDNN kernels since it doesn't meet the criteria. It will use a generic GPU kernel as fallback when running on GPU.
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 embedding_2 (Embedding)     (None, 250, 300)          30000600  
                                                                 
 bidirectional_2 (Bidirecti  (None, 128)               140544    
 onal)                                                           
                                                                 
 dropout_4 (Dropout)         (None, 128)               0         
                                                                 
 dense_4 (Dense)             (None, 32)                4128      
                                                                 
 dropout_5 (Dropout)         (None, 32)                0         
                                                                 
 dense_5 (Dense)             (None, 20)                660       
                                                                 
=================================================================
Total params: 30145932 (115.00 MB)
Trainable params: 145332 (567.70 KB)
Non-trainable params: 30000600 (114.44 MB)
_________________________________________________________________
None
Epoch 1/30
107/107 [==============================] - 18s 162ms/step
 — val_f1: 0.084604 — val_precision: 0.170068 — val_recall: 0.121060

Epoch 1: val_accuracy improved from -inf to 0.12106, saving model to /content/checkpoints/BiGRUMLP.hdf5
/usr/local/lib/python3.10/dist-packages/sklearn/metrics/_classification.py:1344: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
  _warn_prf(average, modifier, msg_start, len(result))
31/31 [==============================] - 86s 3s/step - loss: 2.9734 - accuracy: 0.0744 - val_loss: 2.9132 - val_accuracy: 0.1211 - val_f1: 0.0846 - val_recall: 0.1211 - val_precision: 0.1701
Epoch 2/30
107/107 [==============================] - 17s 156ms/step
 — val_f1: 0.201904 — val_precision: 0.296786 — val_recall: 0.228571

Epoch 2: val_accuracy improved from 0.12106 to 0.22857, saving model to /content/checkpoints/BiGRUMLP.hdf5
31/31 [==============================] - 84s 3s/step - loss: 2.7441 - accuracy: 0.1327 - val_loss: 2.4504 - val_accuracy: 0.2286 - val_f1: 0.2019 - val_recall: 0.2286 - val_precision: 0.2968
Epoch 3/30
107/107 [==============================] - 16s 154ms/step
 — val_f1: 0.269124 — val_precision: 0.353126 — val_recall: 0.313697

Epoch 3: val_accuracy improved from 0.22857 to 0.31370, saving model to /content/checkpoints/BiGRUMLP.hdf5
31/31 [==============================] - 83s 3s/step - loss: 2.3429 - accuracy: 0.2199 - val_loss: 2.0197 - val_accuracy: 0.3137 - val_f1: 0.2691 - val_recall: 0.3137 - val_precision: 0.3531
Epoch 4/30
107/107 [==============================] - 16s 149ms/step
 — val_f1: 0.392721 — val_precision: 0.461777 — val_recall: 0.420619

Epoch 4: val_accuracy improved from 0.31370 to 0.42062, saving model to /content/checkpoints/BiGRUMLP.hdf5
31/31 [==============================] - 81s 3s/step - loss: 2.0135 - accuracy: 0.3185 - val_loss: 1.7367 - val_accuracy: 0.4206 - val_f1: 0.3927 - val_recall: 0.4206 - val_precision: 0.4618
Epoch 5/30
107/107 [==============================] - 16s 153ms/step
 — val_f1: 0.460487 — val_precision: 0.494728 — val_recall: 0.491311

Epoch 5: val_accuracy improved from 0.42062 to 0.49131, saving model to /content/checkpoints/BiGRUMLP.hdf5
31/31 [==============================] - 77s 3s/step - loss: 1.7313 - accuracy: 0.3855 - val_loss: 1.4422 - val_accuracy: 0.4913 - val_f1: 0.4605 - val_recall: 0.4913 - val_precision: 0.4947
Epoch 6/30
107/107 [==============================] - 16s 147ms/step
 — val_f1: 0.527518 — val_precision: 0.550191 — val_recall: 0.539028

Epoch 6: val_accuracy improved from 0.49131 to 0.53903, saving model to /content/checkpoints/BiGRUMLP.hdf5
31/31 [==============================] - 77s 3s/step - loss: 1.5178 - accuracy: 0.4528 - val_loss: 1.3347 - val_accuracy: 0.5390 - val_f1: 0.5275 - val_recall: 0.5390 - val_precision: 0.5502
Epoch 7/30
107/107 [==============================] - 16s 150ms/step
 — val_f1: 0.556103 — val_precision: 0.576500 — val_recall: 0.577909

Epoch 7: val_accuracy improved from 0.53903 to 0.57791, saving model to /content/checkpoints/BiGRUMLP.hdf5
31/31 [==============================] - 79s 3s/step - loss: 1.3964 - accuracy: 0.4936 - val_loss: 1.2127 - val_accuracy: 0.5779 - val_f1: 0.5561 - val_recall: 0.5779 - val_precision: 0.5765
Epoch 8/30
107/107 [==============================] - 17s 161ms/step
 — val_f1: 0.602474 — val_precision: 0.612069 — val_recall: 0.616200

Epoch 8: val_accuracy improved from 0.57791 to 0.61620, saving model to /content/checkpoints/BiGRUMLP.hdf5
31/31 [==============================] - 79s 3s/step - loss: 1.2914 - accuracy: 0.5324 - val_loss: 1.1069 - val_accuracy: 0.6162 - val_f1: 0.6025 - val_recall: 0.6162 - val_precision: 0.6121
Epoch 9/30
107/107 [==============================] - 17s 159ms/step
 — val_f1: 0.583099 — val_precision: 0.616101 — val_recall: 0.607364

Epoch 9: val_accuracy did not improve from 0.61620
31/31 [==============================] - 79s 3s/step - loss: 1.1917 - accuracy: 0.5738 - val_loss: 1.1118 - val_accuracy: 0.6074 - val_f1: 0.5831 - val_recall: 0.6074 - val_precision: 0.6161
Epoch 10/30
107/107 [==============================] - 16s 151ms/step
 — val_f1: 0.637699 — val_precision: 0.649250 — val_recall: 0.643594

Epoch 10: val_accuracy improved from 0.61620 to 0.64359, saving model to /content/checkpoints/BiGRUMLP.hdf5
31/31 [==============================] - 79s 3s/step - loss: 1.1319 - accuracy: 0.5987 - val_loss: 1.0278 - val_accuracy: 0.6436 - val_f1: 0.6377 - val_recall: 0.6436 - val_precision: 0.6493
Epoch 11/30
107/107 [==============================] - 17s 156ms/step
 — val_f1: 0.677581 — val_precision: 0.682764 — val_recall: 0.682474

Epoch 11: val_accuracy improved from 0.64359 to 0.68247, saving model to /content/checkpoints/BiGRUMLP.hdf5
31/31 [==============================] - 75s 2s/step - loss: 1.0398 - accuracy: 0.6266 - val_loss: 0.9757 - val_accuracy: 0.6825 - val_f1: 0.6776 - val_recall: 0.6825 - val_precision: 0.6828
Epoch 12/30
107/107 [==============================] - 17s 159ms/step
 — val_f1: 0.670787 — val_precision: 0.676946 — val_recall: 0.681296

Epoch 12: val_accuracy did not improve from 0.68247
31/31 [==============================] - 78s 3s/step - loss: 0.9923 - accuracy: 0.6431 - val_loss: 0.9432 - val_accuracy: 0.6813 - val_f1: 0.6708 - val_recall: 0.6813 - val_precision: 0.6769
Epoch 13/30
107/107 [==============================] - 16s 149ms/step
 — val_f1: 0.680558 — val_precision: 0.695104 — val_recall: 0.689249

Epoch 13: val_accuracy improved from 0.68247 to 0.68925, saving model to /content/checkpoints/BiGRUMLP.hdf5
31/31 [==============================] - 80s 3s/step - loss: 0.9340 - accuracy: 0.6662 - val_loss: 0.9289 - val_accuracy: 0.6892 - val_f1: 0.6806 - val_recall: 0.6892 - val_precision: 0.6951
Epoch 14/30
107/107 [==============================] - 16s 150ms/step
 — val_f1: 0.723236 — val_precision: 0.734991 — val_recall: 0.727835

Epoch 14: val_accuracy improved from 0.68925 to 0.72784, saving model to /content/checkpoints/BiGRUMLP.hdf5
31/31 [==============================] - 74s 2s/step - loss: 0.8663 - accuracy: 0.6949 - val_loss: 0.8569 - val_accuracy: 0.7278 - val_f1: 0.7232 - val_recall: 0.7278 - val_precision: 0.7350
Epoch 15/30
107/107 [==============================] - 16s 151ms/step
 — val_f1: 0.719180 — val_precision: 0.728528 — val_recall: 0.726951

Epoch 15: val_accuracy did not improve from 0.72784
31/31 [==============================] - 79s 3s/step - loss: 0.8322 - accuracy: 0.7104 - val_loss: 0.8276 - val_accuracy: 0.7270 - val_f1: 0.7192 - val_recall: 0.7270 - val_precision: 0.7285
Epoch 16/30
107/107 [==============================] - 17s 154ms/step
 — val_f1: 0.735873 — val_precision: 0.743401 — val_recall: 0.744624

Epoch 16: val_accuracy improved from 0.72784 to 0.74462, saving model to /content/checkpoints/BiGRUMLP.hdf5
31/31 [==============================] - 77s 3s/step - loss: 0.7928 - accuracy: 0.7218 - val_loss: 0.8120 - val_accuracy: 0.7446 - val_f1: 0.7359 - val_recall: 0.7446 - val_precision: 0.7434
Epoch 17/30
107/107 [==============================] - 16s 148ms/step
 — val_f1: 0.707631 — val_precision: 0.709445 — val_recall: 0.719293

Epoch 17: val_accuracy did not improve from 0.74462
31/31 [==============================] - 80s 3s/step - loss: 0.9062 - accuracy: 0.7000 - val_loss: 0.8624 - val_accuracy: 0.7193 - val_f1: 0.7076 - val_recall: 0.7193 - val_precision: 0.7094
Epoch 18/30
107/107 [==============================] - 16s 150ms/step
 — val_f1: 0.729827 — val_precision: 0.743000 — val_recall: 0.728130

Epoch 18: val_accuracy did not improve from 0.74462
31/31 [==============================] - 72s 2s/step - loss: 0.8237 - accuracy: 0.7190 - val_loss: 0.8516 - val_accuracy: 0.7281 - val_f1: 0.7298 - val_recall: 0.7281 - val_precision: 0.7430
Epoch 19/30
107/107 [==============================] - 16s 151ms/step
 — val_f1: 0.750211 — val_precision: 0.757480 — val_recall: 0.756112

Epoch 19: val_accuracy improved from 0.74462 to 0.75611, saving model to /content/checkpoints/BiGRUMLP.hdf5
31/31 [==============================] - 77s 3s/step - loss: 0.7466 - accuracy: 0.7430 - val_loss: 0.7737 - val_accuracy: 0.7561 - val_f1: 0.7502 - val_recall: 0.7561 - val_precision: 0.7575
Epoch 20/30
107/107 [==============================] - 16s 147ms/step
 — val_f1: 0.760419 — val_precision: 0.767420 — val_recall: 0.762592

Epoch 20: val_accuracy improved from 0.75611 to 0.76259, saving model to /content/checkpoints/BiGRUMLP.hdf5
31/31 [==============================] - 77s 3s/step - loss: 0.6677 - accuracy: 0.7726 - val_loss: 0.7880 - val_accuracy: 0.7626 - val_f1: 0.7604 - val_recall: 0.7626 - val_precision: 0.7674
Epoch 21/30
107/107 [==============================] - 16s 147ms/step
 — val_f1: 0.765542 — val_precision: 0.769955 — val_recall: 0.766127

Epoch 21: val_accuracy improved from 0.76259 to 0.76613, saving model to /content/checkpoints/BiGRUMLP.hdf5
31/31 [==============================] - 79s 3s/step - loss: 0.6396 - accuracy: 0.7793 - val_loss: 0.7768 - val_accuracy: 0.7661 - val_f1: 0.7655 - val_recall: 0.7661 - val_precision: 0.7700
Epoch 22/30
107/107 [==============================] - 16s 149ms/step
 — val_f1: 0.764978 — val_precision: 0.766824 — val_recall: 0.768778

Epoch 22: val_accuracy improved from 0.76613 to 0.76878, saving model to /content/checkpoints/BiGRUMLP.hdf5
31/31 [==============================] - 73s 2s/step - loss: 0.6121 - accuracy: 0.7962 - val_loss: 0.7495 - val_accuracy: 0.7688 - val_f1: 0.7650 - val_recall: 0.7688 - val_precision: 0.7668
Epoch 23/30
107/107 [==============================] - 16s 146ms/step
 — val_f1: 0.772570 — val_precision: 0.779288 — val_recall: 0.773785

Epoch 23: val_accuracy improved from 0.76878 to 0.77378, saving model to /content/checkpoints/BiGRUMLP.hdf5
31/31 [==============================] - 74s 2s/step - loss: 0.6190 - accuracy: 0.7924 - val_loss: 0.7745 - val_accuracy: 0.7738 - val_f1: 0.7726 - val_recall: 0.7738 - val_precision: 0.7793
Epoch 24/30
107/107 [==============================] - 17s 155ms/step
 — val_f1: 0.771834 — val_precision: 0.778802 — val_recall: 0.772312

Epoch 24: val_accuracy did not improve from 0.77378
31/31 [==============================] - 77s 3s/step - loss: 0.5728 - accuracy: 0.8058 - val_loss: 0.7703 - val_accuracy: 0.7723 - val_f1: 0.7718 - val_recall: 0.7723 - val_precision: 0.7788
Epoch 25/30
107/107 [==============================] - 16s 147ms/step
 — val_f1: 0.778215 — val_precision: 0.782727 — val_recall: 0.781443

Epoch 25: val_accuracy improved from 0.77378 to 0.78144, saving model to /content/checkpoints/BiGRUMLP.hdf5
31/31 [==============================] - 73s 2s/step - loss: 0.5385 - accuracy: 0.8177 - val_loss: 0.7427 - val_accuracy: 0.7814 - val_f1: 0.7782 - val_recall: 0.7814 - val_precision: 0.7827
Epoch 26/30
107/107 [==============================] - 16s 147ms/step
 — val_f1: 0.786600 — val_precision: 0.788118 — val_recall: 0.787629

Epoch 26: val_accuracy improved from 0.78144 to 0.78763, saving model to /content/checkpoints/BiGRUMLP.hdf5
31/31 [==============================] - 78s 3s/step - loss: 0.4997 - accuracy: 0.8339 - val_loss: 0.7299 - val_accuracy: 0.7876 - val_f1: 0.7866 - val_recall: 0.7876 - val_precision: 0.7881
Epoch 27/30
107/107 [==============================] - 16s 149ms/step
 — val_f1: 0.786458 — val_precision: 0.787760 — val_recall: 0.786745

Epoch 27: val_accuracy did not improve from 0.78763
31/31 [==============================] - 72s 2s/step - loss: 0.4757 - accuracy: 0.8392 - val_loss: 0.7559 - val_accuracy: 0.7867 - val_f1: 0.7865 - val_recall: 0.7867 - val_precision: 0.7878
Epoch 28/30
107/107 [==============================] - 16s 148ms/step
 — val_f1: 0.781311 — val_precision: 0.784836 — val_recall: 0.782032

Epoch 28: val_accuracy did not improve from 0.78763
31/31 [==============================] - 78s 3s/step - loss: 0.4602 - accuracy: 0.8506 - val_loss: 0.7549 - val_accuracy: 0.7820 - val_f1: 0.7813 - val_recall: 0.7820 - val_precision: 0.7848
Epoch 29/30
107/107 [==============================] - 16s 150ms/step
 — val_f1: 0.793178 — val_precision: 0.796680 — val_recall: 0.794698

Epoch 29: val_accuracy improved from 0.78763 to 0.79470, saving model to /content/checkpoints/BiGRUMLP.hdf5
31/31 [==============================] - 78s 3s/step - loss: 0.4433 - accuracy: 0.8493 - val_loss: 0.7585 - val_accuracy: 0.7947 - val_f1: 0.7932 - val_recall: 0.7947 - val_precision: 0.7967
Epoch 30/30
107/107 [==============================] - 16s 146ms/step
 — val_f1: 0.795004 — val_precision: 0.800514 — val_recall: 0.796760

Epoch 30: val_accuracy improved from 0.79470 to 0.79676, saving model to /content/checkpoints/BiGRUMLP.hdf5
31/31 [==============================] - 77s 3s/step - loss: 0.4220 - accuracy: 0.8568 - val_loss: 0.7976 - val_accuracy: 0.7968 - val_f1: 0.7950 - val_recall: 0.7968 - val_precision: 0.8005

Visualize Model's Training History¶

In [ ]:
%matplotlib inline
import matplotlib.pyplot as plt

# summarize history for accuracy
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'dev'], loc='upper left')
plt.show()

# summarize history for loss
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'dev'], loc='upper right')
plt.show()

Evaluate performance of BiGRU + MLP model on dev data¶

In [ ]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Bidirectional, GRU
from tensorflow.keras.optimizers import Adam
from sklearn.metrics import classification_report


GRU_SIZE = 64
DENSE = 32

with tf.device('/device:GPU:0'):

  model = Sequential()

  model.add(Embedding(input_dim=MAX_WORDS+2, output_dim=EMBEDDING_DIM, weights=[embedding_matrix],
                    input_length=MAX_SEQUENCE_LENGTH,mask_zero=True, trainable=False))
  model.add(Bidirectional(GRU(GRU_SIZE, return_sequences=False, recurrent_dropout = 0.33)))
  model.add(Dense(DENSE, activation='relu' ))
  model.add(Dense(len(twenty_train.target_names), activation='softmax'))

  # Load weights from the pre-trained model
  model.load_weights("/content/checkpoints/BiGRUMLP.hdf5")
  # model.load_weights("/content/gdrive/My Drive/checkpoints/BiGRUMLP.hdf5")

  predictions = np.argmax(model.predict(val_data), -1)
  print(classification_report(y_val, predictions, target_names=twenty_train.target_names))
WARNING:tensorflow:Layer gru_1 will not use cuDNN kernels since it doesn't meet the criteria. It will use a generic GPU kernel as fallback when running on GPU.
WARNING:tensorflow:Layer gru_1 will not use cuDNN kernels since it doesn't meet the criteria. It will use a generic GPU kernel as fallback when running on GPU.
WARNING:tensorflow:Layer gru_1 will not use cuDNN kernels since it doesn't meet the criteria. It will use a generic GPU kernel as fallback when running on GPU.
107/107 [==============================] - 16s 148ms/step
                          precision    recall  f1-score   support

             alt.atheism       0.80      0.77      0.79       160
           comp.graphics       0.59      0.76      0.66       165
 comp.os.ms-windows.misc       0.74      0.77      0.76       189
comp.sys.ibm.pc.hardware       0.66      0.46      0.54       168
   comp.sys.mac.hardware       0.64      0.62      0.63       182
          comp.windows.x       0.87      0.71      0.78       168
            misc.forsale       0.79      0.71      0.75       182
               rec.autos       0.87      0.90      0.88       181
         rec.motorcycles       0.88      0.85      0.86       184
      rec.sport.baseball       0.90      0.90      0.90       169
        rec.sport.hockey       0.91      0.92      0.92       175
               sci.crypt       0.94      0.92      0.93       177
         sci.electronics       0.67      0.82      0.74       173
                 sci.med       0.89      0.95      0.92       181
               sci.space       0.88      0.92      0.90       181
  soc.religion.christian       0.74      0.86      0.80       177
      talk.politics.guns       0.88      0.88      0.88       177
   talk.politics.mideast       0.96      0.79      0.87       170
      talk.politics.misc       0.68      0.84      0.75       135
      talk.religion.misc       0.58      0.45      0.50       101

                accuracy                           0.80      3395
               macro avg       0.79      0.79      0.79      3395
            weighted avg       0.80      0.80      0.80      3395

In [ ]:
from sklearn.metrics import accuracy_score
predictions = np.argmax(model.predict(val_data), -1)
print(f'Validation Accuracy: {accuracy_score(y_val, predictions)*100:.2f}%')

predictions = np.argmax(model.predict(test_data), -1)
print(f'Test Accuracy:{accuracy_score(y_test, predictions)*100:.2f}%')
107/107 [==============================] - 16s 149ms/step
Validation Accuracy: 79.68%
32/32 [==============================] - 5s 166ms/step
Test Accuracy:71.20%

Create and train a BiLSTM + deep self-attention + MLP model¶

Architecture
RNN (BiLSTM) layers
Attention Mechanism
MLP layer
In [ ]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.layers import Dense, Dropout, Bidirectional, LSTM, Embedding, Input
from tensorflow.keras.optimizers import Adam
import tensorflow.keras.backend as K
from tensorflow.keras import Model

# !pip install keras-self-attention
# from keras_self_attention import SeqSelfAttention

LSTM_SIZE = 300
DENSE = 1000

with tf.device('/device:GPU:0'):

  inputs = Input((MAX_SEQUENCE_LENGTH,))

  # Define the Embedding Layer with fastext weights
  embeddings = Embedding(input_dim=MAX_WORDS+2, output_dim=EMBEDDING_DIM, weights=[embedding_matrix],
                         input_length=MAX_SEQUENCE_LENGTH, mask_zero=True, trainable=False)(inputs)
  drop_emb = Dropout(0.33)(embeddings)

  # Define a (Biderectional) RNN with LSTM cells
  bilstm = Bidirectional(LSTM(units=LSTM_SIZE, return_sequences=True, recurrent_dropout=0.33))(drop_emb)
  drop_encodings = Dropout(0.33)(bilstm)

  # Pass the encoding through an Attention Layer
  x, attn = DeepAttention(return_attention=True)(drop_encodings)
  # x, attn = LinearAttention(return_attention=True)(drop_encodings)


  # Alternatively use keras package for self-attention
  #x, attn = SeqSelfAttention(return_attention=True)(drop_encodings)

  # Apply Droupout to the encoding produced by the attention mechanism
  drop_x = Dropout(0.33)(x)

  # Pass through a Dense Layer
  hidden = Dense(units=DENSE, activation="relu")(drop_x)

  # Apply Dropout to the output of the Dense Layer
  drop_out = Dropout(0.33)(hidden)

  # Last pass through a Dense Layer with softmax activation to produce a probability distribution
  out = Dense(units=len(twenty_train.target_names), activation="softmax")(drop_out)

  # Wrap model --> Remember Functional API
  model2 = Model(inputs=inputs, outputs=out)
  print(model2.summary())

  model2.compile(loss='categorical_crossentropy',
                 optimizer=Adam(lr=0.001),
                 metrics=["accuracy"])

  if not os.path.exists('/content/checkpoints'):
    os.makedirs('/content/checkpoints')

  checkpoint = ModelCheckpoint('/content/checkpoints/BiLSTM_attn.hdf5',
                               monitor='val_accuracy',
                               mode='max', verbose=2,
                               save_best_only=True,
                               save_weights_only=True)

  history2 = model2.fit(train_data, y_train_1_hot,
                        validation_data=(val_data, y_val_1_hot),
                        batch_size=128,
                        epochs=30,
                        shuffle=True,
                        callbacks=[Metrics(valid_data=(val_data, y_val_1_hot)),
                        checkpoint])
WARNING:tensorflow:Layer lstm will not use cuDNN kernels since it doesn't meet the criteria. It will use a generic GPU kernel as fallback when running on GPU.
WARNING:tensorflow:Layer lstm will not use cuDNN kernels since it doesn't meet the criteria. It will use a generic GPU kernel as fallback when running on GPU.
WARNING:tensorflow:Layer lstm will not use cuDNN kernels since it doesn't meet the criteria. It will use a generic GPU kernel as fallback when running on GPU.
/usr/local/lib/python3.10/dist-packages/keras/src/initializers/initializers.py:120: UserWarning: The initializer GlorotUniform is unseeded and being called multiple times, which will return identical values each time (even if the initializer is unseeded). Please update your code to provide a seed to the initializer, or avoid using the same initializer instance more than once.
  warnings.warn(
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 250)]             0         
                                                                 
 embedding (Embedding)       (None, 250, 300)          30000600  
                                                                 
 dropout (Dropout)           (None, 250, 300)          0         
                                                                 
 bidirectional (Bidirection  (None, 250, 600)          1442400   
 al)                                                             
                                                                 
 dropout_1 (Dropout)         (None, 250, 600)          0         
                                                                 
 deep_attention (DeepAttent  [(None, 600),             361201    
 ion)                         (None, 250, 1)]                    
                                                                 
 dropout_2 (Dropout)         (None, 600)               0         
                                                                 
 dense (Dense)               (None, 1000)              601000    
                                                                 
 dropout_3 (Dropout)         (None, 1000)              0         
                                                                 
 dense_1 (Dense)             (None, 20)                20020     
                                                                 
=================================================================
Total params: 32425221 (123.69 MB)
Trainable params: 2424621 (9.25 MB)
Non-trainable params: 30000600 (114.44 MB)
_________________________________________________________________
WARNING:absl:`lr` is deprecated in Keras optimizer, please use `learning_rate` or use the legacy optimizer, e.g.,tf.keras.optimizers.legacy.Adam.
None
Epoch 1/30
107/107 [==============================] - 22s 198ms/step
 — val_f1: 0.521783 — val_precision: 0.606270 — val_recall: 0.548748

Epoch 1: val_accuracy improved from -inf to 0.54875, saving model to /content/checkpoints/BiLSTM_attn.hdf5
62/62 [==============================] - 173s 3s/step - loss: 1.9925 - accuracy: 0.3237 - val_loss: 1.3148 - val_accuracy: 0.5487 - val_f1: 0.5218 - val_recall: 0.5487 - val_precision: 0.6063
Epoch 2/30
107/107 [==============================] - 20s 191ms/step
 — val_f1: 0.700132 — val_precision: 0.710411 — val_recall: 0.707511

Epoch 2: val_accuracy improved from 0.54875 to 0.70751, saving model to /content/checkpoints/BiLSTM_attn.hdf5
62/62 [==============================] - 175s 3s/step - loss: 1.1947 - accuracy: 0.6001 - val_loss: 0.9568 - val_accuracy: 0.7075 - val_f1: 0.7001 - val_recall: 0.7075 - val_precision: 0.7104
Epoch 3/30
107/107 [==============================] - 20s 187ms/step
 — val_f1: 0.718428 — val_precision: 0.740916 — val_recall: 0.725479

Epoch 3: val_accuracy improved from 0.70751 to 0.72548, saving model to /content/checkpoints/BiLSTM_attn.hdf5
62/62 [==============================] - 158s 3s/step - loss: 0.9418 - accuracy: 0.6962 - val_loss: 0.8615 - val_accuracy: 0.7255 - val_f1: 0.7184 - val_recall: 0.7255 - val_precision: 0.7409
Epoch 4/30
107/107 [==============================] - 21s 199ms/step
 — val_f1: 0.748557 — val_precision: 0.769423 — val_recall: 0.751105

Epoch 4: val_accuracy improved from 0.72548 to 0.75110, saving model to /content/checkpoints/BiLSTM_attn.hdf5
62/62 [==============================] - 178s 3s/step - loss: 0.8189 - accuracy: 0.7301 - val_loss: 0.7762 - val_accuracy: 0.7511 - val_f1: 0.7486 - val_recall: 0.7511 - val_precision: 0.7694
Epoch 5/30
107/107 [==============================] - 21s 193ms/step
 — val_f1: 0.770434 — val_precision: 0.783276 — val_recall: 0.769956

Epoch 5: val_accuracy improved from 0.75110 to 0.76996, saving model to /content/checkpoints/BiLSTM_attn.hdf5
62/62 [==============================] - 160s 3s/step - loss: 0.7321 - accuracy: 0.7586 - val_loss: 0.7425 - val_accuracy: 0.7700 - val_f1: 0.7704 - val_recall: 0.7700 - val_precision: 0.7833
Epoch 6/30
107/107 [==============================] - 21s 199ms/step
 — val_f1: 0.779297 — val_precision: 0.792297 — val_recall: 0.778498

Epoch 6: val_accuracy improved from 0.76996 to 0.77850, saving model to /content/checkpoints/BiLSTM_attn.hdf5
62/62 [==============================] - 179s 3s/step - loss: 0.6830 - accuracy: 0.7693 - val_loss: 0.7044 - val_accuracy: 0.7785 - val_f1: 0.7793 - val_recall: 0.7785 - val_precision: 0.7923
Epoch 7/30
107/107 [==============================] - 21s 195ms/step
 — val_f1: 0.785049 — val_precision: 0.808186 — val_recall: 0.786156

Epoch 7: val_accuracy improved from 0.77850 to 0.78616, saving model to /content/checkpoints/BiLSTM_attn.hdf5
62/62 [==============================] - 178s 3s/step - loss: 0.6185 - accuracy: 0.7933 - val_loss: 0.6838 - val_accuracy: 0.7862 - val_f1: 0.7850 - val_recall: 0.7862 - val_precision: 0.8082
Epoch 8/30
107/107 [==============================] - 22s 202ms/step
 — val_f1: 0.811010 — val_precision: 0.817126 — val_recall: 0.810604

Epoch 8: val_accuracy improved from 0.78616 to 0.81060, saving model to /content/checkpoints/BiLSTM_attn.hdf5
62/62 [==============================] - 159s 3s/step - loss: 0.5771 - accuracy: 0.8092 - val_loss: 0.6231 - val_accuracy: 0.8106 - val_f1: 0.8110 - val_recall: 0.8106 - val_precision: 0.8171
Epoch 9/30
107/107 [==============================] - 20s 189ms/step
 — val_f1: 0.796695 — val_precision: 0.804522 — val_recall: 0.801178

Epoch 9: val_accuracy did not improve from 0.81060
62/62 [==============================] - 158s 3s/step - loss: 0.5386 - accuracy: 0.8260 - val_loss: 0.6250 - val_accuracy: 0.8012 - val_f1: 0.7967 - val_recall: 0.8012 - val_precision: 0.8045
Epoch 10/30
107/107 [==============================] - 21s 197ms/step
 — val_f1: 0.821066 — val_precision: 0.827142 — val_recall: 0.819440

Epoch 10: val_accuracy improved from 0.81060 to 0.81944, saving model to /content/checkpoints/BiLSTM_attn.hdf5
62/62 [==============================] - 157s 3s/step - loss: 0.4968 - accuracy: 0.8341 - val_loss: 0.5832 - val_accuracy: 0.8194 - val_f1: 0.8211 - val_recall: 0.8194 - val_precision: 0.8271
Epoch 11/30
107/107 [==============================] - 21s 197ms/step
 — val_f1: 0.814973 — val_precision: 0.823588 — val_recall: 0.817673

Epoch 11: val_accuracy did not improve from 0.81944
62/62 [==============================] - 155s 3s/step - loss: 0.4499 - accuracy: 0.8516 - val_loss: 0.6115 - val_accuracy: 0.8177 - val_f1: 0.8150 - val_recall: 0.8177 - val_precision: 0.8236
Epoch 12/30
107/107 [==============================] - 20s 187ms/step
 — val_f1: 0.827455 — val_precision: 0.832135 — val_recall: 0.829455

Epoch 12: val_accuracy improved from 0.81944 to 0.82946, saving model to /content/checkpoints/BiLSTM_attn.hdf5
62/62 [==============================] - 159s 3s/step - loss: 0.4310 - accuracy: 0.8564 - val_loss: 0.5865 - val_accuracy: 0.8295 - val_f1: 0.8275 - val_recall: 0.8295 - val_precision: 0.8321
Epoch 13/30
107/107 [==============================] - 21s 195ms/step
 — val_f1: 0.827546 — val_precision: 0.831140 — val_recall: 0.829161

Epoch 13: val_accuracy did not improve from 0.82946
62/62 [==============================] - 156s 3s/step - loss: 0.4049 - accuracy: 0.8621 - val_loss: 0.5805 - val_accuracy: 0.8292 - val_f1: 0.8275 - val_recall: 0.8292 - val_precision: 0.8311
Epoch 14/30
107/107 [==============================] - 21s 195ms/step
 — val_f1: 0.835334 — val_precision: 0.840626 — val_recall: 0.835935

Epoch 14: val_accuracy improved from 0.82946 to 0.83594, saving model to /content/checkpoints/BiLSTM_attn.hdf5
62/62 [==============================] - 154s 2s/step - loss: 0.3655 - accuracy: 0.8752 - val_loss: 0.5747 - val_accuracy: 0.8359 - val_f1: 0.8353 - val_recall: 0.8359 - val_precision: 0.8406
Epoch 15/30
107/107 [==============================] - 21s 192ms/step
 — val_f1: 0.830022 — val_precision: 0.836516 — val_recall: 0.830044

Epoch 15: val_accuracy did not improve from 0.83594
62/62 [==============================] - 181s 3s/step - loss: 0.3433 - accuracy: 0.8848 - val_loss: 0.5721 - val_accuracy: 0.8300 - val_f1: 0.8300 - val_recall: 0.8300 - val_precision: 0.8365
Epoch 16/30
107/107 [==============================] - 20s 186ms/step
 — val_f1: 0.832501 — val_precision: 0.838497 — val_recall: 0.832401

Epoch 16: val_accuracy did not improve from 0.83594
62/62 [==============================] - 155s 3s/step - loss: 0.3155 - accuracy: 0.8963 - val_loss: 0.5820 - val_accuracy: 0.8324 - val_f1: 0.8325 - val_recall: 0.8324 - val_precision: 0.8385
Epoch 17/30
107/107 [==============================] - 21s 192ms/step
 — val_f1: 0.835755 — val_precision: 0.841986 — val_recall: 0.837113

Epoch 17: val_accuracy improved from 0.83594 to 0.83711, saving model to /content/checkpoints/BiLSTM_attn.hdf5
62/62 [==============================] - 177s 3s/step - loss: 0.2877 - accuracy: 0.9076 - val_loss: 0.5864 - val_accuracy: 0.8371 - val_f1: 0.8358 - val_recall: 0.8371 - val_precision: 0.8420
Epoch 18/30
107/107 [==============================] - 21s 197ms/step
 — val_f1: 0.837591 — val_precision: 0.842126 — val_recall: 0.838586

Epoch 18: val_accuracy improved from 0.83711 to 0.83859, saving model to /content/checkpoints/BiLSTM_attn.hdf5
62/62 [==============================] - 155s 3s/step - loss: 0.2893 - accuracy: 0.9016 - val_loss: 0.5857 - val_accuracy: 0.8386 - val_f1: 0.8376 - val_recall: 0.8386 - val_precision: 0.8421
Epoch 19/30
107/107 [==============================] - 20s 184ms/step
 — val_f1: 0.850468 — val_precision: 0.853581 — val_recall: 0.849779

Epoch 19: val_accuracy improved from 0.83859 to 0.84978, saving model to /content/checkpoints/BiLSTM_attn.hdf5
62/62 [==============================] - 157s 3s/step - loss: 0.2678 - accuracy: 0.9115 - val_loss: 0.5446 - val_accuracy: 0.8498 - val_f1: 0.8505 - val_recall: 0.8498 - val_precision: 0.8536
Epoch 20/30
107/107 [==============================] - 21s 195ms/step
 — val_f1: 0.845820 — val_precision: 0.850999 — val_recall: 0.845361

Epoch 20: val_accuracy did not improve from 0.84978
62/62 [==============================] - 156s 3s/step - loss: 0.2422 - accuracy: 0.9202 - val_loss: 0.5782 - val_accuracy: 0.8454 - val_f1: 0.8458 - val_recall: 0.8454 - val_precision: 0.8510
Epoch 21/30
107/107 [==============================] - 21s 197ms/step
 — val_f1: 0.842152 — val_precision: 0.845579 — val_recall: 0.842710

Epoch 21: val_accuracy did not improve from 0.84978
62/62 [==============================] - 176s 3s/step - loss: 0.2212 - accuracy: 0.9236 - val_loss: 0.6202 - val_accuracy: 0.8427 - val_f1: 0.8422 - val_recall: 0.8427 - val_precision: 0.8456
Epoch 22/30
107/107 [==============================] - 22s 202ms/step
 — val_f1: 0.846921 — val_precision: 0.853436 — val_recall: 0.847128

Epoch 22: val_accuracy did not improve from 0.84978
62/62 [==============================] - 175s 3s/step - loss: 0.2198 - accuracy: 0.9252 - val_loss: 0.6041 - val_accuracy: 0.8471 - val_f1: 0.8469 - val_recall: 0.8471 - val_precision: 0.8534
Epoch 23/30
107/107 [==============================] - 21s 196ms/step
 — val_f1: 0.842242 — val_precision: 0.847028 — val_recall: 0.843888

Epoch 23: val_accuracy did not improve from 0.84978
62/62 [==============================] - 174s 3s/step - loss: 0.2055 - accuracy: 0.9342 - val_loss: 0.5751 - val_accuracy: 0.8439 - val_f1: 0.8422 - val_recall: 0.8439 - val_precision: 0.8470
Epoch 24/30
107/107 [==============================] - 21s 195ms/step
 — val_f1: 0.849301 — val_precision: 0.854896 — val_recall: 0.849485

Epoch 24: val_accuracy did not improve from 0.84978
62/62 [==============================] - 163s 3s/step - loss: 0.1860 - accuracy: 0.9380 - val_loss: 0.5997 - val_accuracy: 0.8495 - val_f1: 0.8493 - val_recall: 0.8495 - val_precision: 0.8549
Epoch 25/30
107/107 [==============================] - 22s 206ms/step
 — val_f1: 0.849716 — val_precision: 0.855865 — val_recall: 0.848601

Epoch 25: val_accuracy did not improve from 0.84978
62/62 [==============================] - 185s 3s/step - loss: 0.1528 - accuracy: 0.9482 - val_loss: 0.6142 - val_accuracy: 0.8486 - val_f1: 0.8497 - val_recall: 0.8486 - val_precision: 0.8559
Epoch 26/30
107/107 [==============================] - 22s 203ms/step
 — val_f1: 0.851611 — val_precision: 0.854770 — val_recall: 0.850957

Epoch 26: val_accuracy improved from 0.84978 to 0.85096, saving model to /content/checkpoints/BiLSTM_attn.hdf5
62/62 [==============================] - 181s 3s/step - loss: 0.1583 - accuracy: 0.9472 - val_loss: 0.6040 - val_accuracy: 0.8510 - val_f1: 0.8516 - val_recall: 0.8510 - val_precision: 0.8548
Epoch 27/30
107/107 [==============================] - 22s 203ms/step
 — val_f1: 0.855303 — val_precision: 0.858370 — val_recall: 0.855081

Epoch 27: val_accuracy improved from 0.85096 to 0.85508, saving model to /content/checkpoints/BiLSTM_attn.hdf5
62/62 [==============================] - 180s 3s/step - loss: 0.1467 - accuracy: 0.9500 - val_loss: 0.6020 - val_accuracy: 0.8551 - val_f1: 0.8553 - val_recall: 0.8551 - val_precision: 0.8584
Epoch 28/30
107/107 [==============================] - 21s 198ms/step
 — val_f1: 0.844082 — val_precision: 0.852072 — val_recall: 0.844477

Epoch 28: val_accuracy did not improve from 0.85508
62/62 [==============================] - 179s 3s/step - loss: 0.1445 - accuracy: 0.9510 - val_loss: 0.6639 - val_accuracy: 0.8445 - val_f1: 0.8441 - val_recall: 0.8445 - val_precision: 0.8521
Epoch 29/30
107/107 [==============================] - 21s 200ms/step
 — val_f1: 0.849131 — val_precision: 0.852903 — val_recall: 0.848601

Epoch 29: val_accuracy did not improve from 0.85508
62/62 [==============================] - 178s 3s/step - loss: 0.1379 - accuracy: 0.9523 - val_loss: 0.6117 - val_accuracy: 0.8486 - val_f1: 0.8491 - val_recall: 0.8486 - val_precision: 0.8529
Epoch 30/30
107/107 [==============================] - 21s 199ms/step
 — val_f1: 0.857764 — val_precision: 0.861298 — val_recall: 0.857732

Epoch 30: val_accuracy improved from 0.85508 to 0.85773, saving model to /content/checkpoints/BiLSTM_attn.hdf5
62/62 [==============================] - 158s 3s/step - loss: 0.1228 - accuracy: 0.9601 - val_loss: 0.6177 - val_accuracy: 0.8577 - val_f1: 0.8578 - val_recall: 0.8577 - val_precision: 0.8613

Visualize Model's Training History¶

In [ ]:
%matplotlib inline
import matplotlib.pyplot as plt

# summarize history for accuracy
plt.plot(history2.history['accuracy'])
plt.plot(history2.history['val_accuracy'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'dev'], loc='upper left')
plt.show()

# summarize history for loss
plt.plot(history2.history['loss'])
plt.plot(history2.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'dev'], loc='upper right')
plt.show()

Final Results in 20newsgroups¶

Model Name Val Accuracy Test Accuracy
Logistic Regression + TF-IDF 83.74% 76.83%
MLP + TF-IDF 86.95% 77.10%
MLP + Word2Vec Centroids 79.73% 70.61%
############################ ##### #####
RNN custom embeddings 78.76% 67.70%
RNN Word2Vec 72.43% 67.10%
RNN Word2Vec + tuning (?) 79.68% 71.20%
RNN custom embeddings + self-attention 83.33% 71.70%
RNN Word2Vec + self-attention 85.77% 79.00%

Resources¶

  • https://github.com/keras-team/keras/tree/master/examples
  • https://keras.io/