XOR Classification with Keras¶

source: https://dev.to/jbahire/demystifying-the-xor-problem-1blk

Since the XOR problem can’t be separated by a straight line, it’s known as a not linearly separatable problem. Its decision boundary is not linear. Thus, we need a deeper network to solve the problem.

Check Tensorflow version¶

In [1]:
import tensorflow as tf
print(tf.__version__)
2.14.0

MLP for XOR Classification¶

In [2]:
import time
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras import Input, Model


with tf.device('/device:GPU:0'):

  # the four different states of the XOR gate
  training_data = np.array([[0,0],[0,1],[1,0],[1,1]], "float32")

  # the four expected results in the same order
  target_data = np.array([[0],[1],[1],[0]], "float32")

  # ------------------------ Sequential API ------------------------ #
  # model = Sequential()
  # model.add(Dense(32, input_dim=2, activation='relu'))
  # model.add(Dense(64, activation='relu'))
  # model.add(Dense(1, activation='sigmoid'))

  # ------------------------ Fuctional API ------------------------ #
  inputs = Input(shape=(2,))

  x = Dense(32, activation="relu")(inputs)
  x = Dense(64, activation="relu")(x)
  outputs = Dense(1, activation='sigmoid')(x)

  model = Model(inputs=inputs, outputs=outputs, name="xor_model")


  # ------------- Compile and fit methods are the same ------------- #
  model.compile(loss='mean_squared_error',
                optimizer='adam',
                metrics=['binary_accuracy'])

  start_training_time = time.time()
  model.fit(training_data, target_data, batch_size=4, epochs=1500, verbose = 0)
  end_training_time = time.time()

  print(f'\nTraining time: {time.strftime("%H:%M:%S", time.gmtime(end_training_time - start_training_time))} sec\n')

  predictions = model.predict(training_data)

  print('  Data       Preds     True')
  for x, pred, y in zip(training_data, predictions, target_data):
    print('{} --> {:.5f}  |  {}'.format(x, pred[0], int(y[0])))
Training time: 00:00:14 sec

1/1 [==============================] - 0s 72ms/step
  Data       Preds     True
[0. 0.] --> 0.01049  |  0
[0. 1.] --> 0.99346  |  1
[1. 0.] --> 0.99306  |  1
[1. 1.] --> 0.00648  |  0