Recurrent Neural Networks (RNNs) and their variants like LSTM and GRU are the most commonly preferred network for ML tasks that involves some kind of sequence data (time-series data, text data, speech data, etc). The RNNs are very good at capturing the long sequences of the data. Though RNNs are better at capturing sequences, it takes time to train RNN networks. Recent studies have shown that CNNs consisting of 1D convolution layers are also good at capturing sequences. Hence, we can use a network of 1D Convolutions for NLP tasks like text classification, text generation, etc. The main benefit of using CNNs is that it has fewer parameters to train compared to RNNs and gets trained faster.
As a part of this tutorial, we have designed CNNs with 1D convolution layers using Flax for text classification task. Flax is a Python deep learning library built on top of JAX for designing deep neural networks. We have used AG NEWS dataset which has text documents for 4 different categories of news. We have tried different approaches to using Conv1D layers for solving text classification tasks. For encoding text data before giving it to the convolution layer, we have used word embeddings approach. After training networks, we have also evaluated their performance by calculating various ML metrics and explained their predictions using LIME algorithm.
Below, we have listed important sections of Tutorial to give an overview of the material covered.
Below, we have imported the necessary Python libraries that we have used in our tutorial and printed the versions of them as well.
import jax
print("JAX Version : {}".format(jax.__version__))
import flax
print("FLAX Version : {}".format(flax.__version__))
import optax
print("OPTAX Version : {}".format(optax.__version__))
import torchtext
print("Torchtext Version : {}".format(torchtext.__version__))
from tensorflow import keras
print("Keras Version : {}".format(keras.__version__))
import warnings
warnings.filterwarnings("ignore")
In this section, we are preparing data to be given to the neural network. As our raw data is text and neural network works on real-valued data. We need to convert our text data to real-valued data.
In this tutorial, we are going to use word embeddings approach where we break text into a list of tokens (words, punctuation marks, etc) and assign a real-valued vector to each token of text. We'll achieve this in two steps.
We'll complete the first step in this section. The second step will be implemented in the neural network where we include Embedding Layer as the first layer which is responsible for mapping the integer index of tokens to their respective embeddings. These embeddings get updated during training of the network to better understand the meaning of token.
In this section, we have loaded AG NEWS dataset that we are going to use for our case. It has text documents for 4 different news categories (["World", "Sports", "Business", "Sci/Tech"]) which our trained network will try to classify. The dataset is already divided into train and test sets.
import numpy as np
train_dataset, test_dataset = torchtext.datasets.AG_NEWS()
X_train_text, Y_train = [], []
for Y, X in train_dataset:
X_train_text.append(X)
Y_train.append(Y)
X_test_text, Y_test = [], []
for Y, X in test_dataset:
X_test_text.append(X)
Y_test.append(Y)
unique_classes = list(set(Y_train))
target_classes = ["World", "Sports", "Business", "Sci/Tech"]
## Subtracted 1 from labels to bring range from 1-4 to 0-3
Y_train, Y_test = np.array(Y_train) - 1, np.array(Y_test) - 1
len(X_train_text), len(X_test_text)
In this section, we are vectorizing our text data.
First, we have initialized Tokenizer object available from keras. After initializing the tokenizer, we have trained it by calling fit_on_texts() method on it with train and test text examples. The call to this method will populate the vocabulary inside of Tokenizer object which will have all unique tokens (words, punctuation marks, etc). The vocabulary is a simple mapping from a token to a unique integer index. Each token is assigned a unique integer index starting from 1. Position 0 is reserved for unknown tokens encountered in the future.
After vocabulary is populated with tokens, we can call texts_to_sequences() method on the tokenizer object with a list of text documents. It'll tokenize them and return indexes of tokens for each text example. We have called texts_to_sequence() method on train and test text documents. We have also called pad_sequences() method around it. This method is used to bring tokens of all text examples to the same length. The text for different documents can have a different number of tokens. But we have decided to keep a maximum of 50 tokens per text example. All tokens beyond 50 will be truncated. The examples that have less than 50 tokens will be appended with 0s.
After vectorizing text data to indexes, we have also printed the first few examples with their indexes.
Below, we have explained with a simple example how text documents will be mapped to indexes.
text = "Hello, How are you? Where are you planning to go?"
tokens = ['hello', ',', 'how', 'are', 'you', '?', 'where',
'are', 'you', 'planning', 'to', 'go', '?']
vocab = {
'hello': 0,
'bye': 1,
'how': 2,
'the': 3,
'welcome': 4,
'are': 5,
'you': 6,
'to': 7,
'<unk>': 8,
}
vector = [0,8,2,4,6,8,8,5,6,8,7,8,8]
from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences
from jax import numpy as jnp
max_tokens = 50 ## Hyperparameter to tune for better performance
tokenizer = Tokenizer()
tokenizer.fit_on_texts(X_train_text+X_test_text)
## Vectorizing data to keep 50 words per sample.
X_train_vect = pad_sequences(tokenizer.texts_to_sequences(X_train_text), maxlen=max_tokens, padding="post", truncating="post", value=0.)
X_test_vect = pad_sequences(tokenizer.texts_to_sequences(X_test_text), maxlen=max_tokens, padding="post", truncating="post", value=0.)
print(X_train_vect[:3])
X_train_vect, X_test_vect = jnp.array(X_train_vect, dtype=jnp.int32), jnp.array(X_test_vect, dtype=jnp.int32)
Y_train, Y_test = jnp.array(Y_train), jnp.array(Y_test)
X_train_vect.shape, X_test_vect.shape
## What is word 21
print(tokenizer.index_word[21])
## How many times it comes in first text document??
print(X_train_text[0]) ## 2 times
Our first approach tries to perform a text classification task with a simple one convolution layer network. The network consists of an embedding layer, a 1D convolution layer, and a dense layer. Below, we have defined the network, trained it, evaluated the performance of the network, and tried to explain predictions made by the network as well.
In this section, we have defined a network that we'll use for our text classification task in this section. The network consists of 3 layers.
The first layer of the network is the embedding layer. We have created an embedding layer using Embed() constructor available from linen sub-module of Flax. We have asked it to use a length of vocabulary as a number of tokens (first parameter) and embedding length of 128. When we create this layer, it internally will create a weight matrix of shape (vocab_len, embed_len). This matrix has embeddings of length 128 for all tokens of our vocabulary. The embedding layer simply takes a list of indexes of tokens as input and returns embeddings for all those indexes from the weight matrix. The input to embedding layer is of shape (batch_size, max_tokens) = (batch_size, 50) and the output is of shape (batch_size, max_tokens, embed_len) = (batch_size, 50, 128).
The output of the embedding layer is given to Conv1D layer. We have defined Conv1D layer with 32 output channels and a kernel size of 7. We'll be treating the embedding length dimension as the channels dimension in our case. Hence, the input to convolution layer will be of shape (batch_size, max_tokens, embed_len) = (batch_size, 50, 128) and output shape will be (batch_size, max_tokens, conv_output_channels) = (batch_size, 50, 32). We have applied relu activation to the output of the convolution layer.
After applying relu activation, we have applied max() (we can also try average instead of max) function at max_tokens dimension which will return us output of shape (batch_size, 32). This output will then be given to Dense layer which has 4 output units (same as the number of target classes). The output of Dense layer is returned from our network as the prediction.
After defining the network, we initialized it, printed the shape of weights/biases of layers, and performed a forward pass through it using random data for verification purposes.
If you are someone new to Flax and want to learn how to create a neural network using it then we recommend that you go through the below tutorials that cover it in detail. It'll help us better understand Flax.
from flax import linen
embed_len = 128
class Conv1DTextClassifier(linen.Module):
def setup(self):
self.embedding = linen.Embed(len(tokenizer.word_index)+1, embed_len, name="Word Embeddings")
self.conv1 = linen.Conv(32, kernel_size=(7,),name="Conv1")
self.linear1 = linen.Dense(len(unique_classes), name="Dense1")
def __call__(self, X_batch):
x = self.embedding(X_batch)
x = linen.relu(self.conv1(x))
x = x.max(axis=1)
logits = self.linear1(x)
return logits
from jax import numpy as jnp
seed = jax.random.PRNGKey(0)
text_classif = Conv1DTextClassifier()
params = text_classif.init(seed, jax.random.randint(seed, (100, max_tokens), minval=1, maxval=20))
for layer_params in params["params"].items():
print("Layer Name : {}".format(layer_params[0]))
if "Embedding" in layer_params[0]:
weights = layer_params[1]["embedding"]
print("\tLayer Weights : {}".format(weights.shape))
else:
weights, biases = layer_params[1]["kernel"], layer_params[1]["bias"]
print("\tLayer Weights : {}, Biases : {}".format(weights.shape, biases.shape))
preds = text_classif.apply(params, X_train_vect[:5])
preds
Below, we have defined the loss function that we'll be using for our task. We'll be using cross entropy loss for our case. The function takes network parameter, actual data, and their target labels as input. It then makes predictions on data using network parameters and one-hot encodes actual target labels. At last, it calculates cross entropy loss by giving predictions and one-hot encoded target labels to softmax_cross_entropy() function available from Optax library.
def CrossEntropyLoss(params, input_data, actual):
logits_preds = model.apply(params, input_data)
one_hot_actual = jax.nn.one_hot(actual, num_classes=len(unique_classes))
return optax.softmax_cross_entropy(logits=logits_preds, labels=one_hot_actual).sum()
In this section, we are training the neural network that we defined earlier. We have defined a simple training function that will perform training for us. The function takes train data (X, Y), validation data (X_val, Y_val), number of epochs, network parameters, optimizer state, and batch size as input. It then executes a training loop number of epochs time. For each epoch, it loops through training data in batches. For each batch, it performs a forward pass to make predictions, calculates loss, calculates gradients, and updates network parameters using gradients. It also records loss for each batch and prints the average loss of all batches at the end of the epoch. The function also calculates validation data accuracy at the end of each epoch.
from jax import value_and_grad
from tqdm import tqdm
from sklearn.metrics import accuracy_score
def TrainModelInBatches(X, Y, X_val, Y_val, epochs, params, optimizer_state, batch_size=32):
for i in range(1, epochs+1):
batches = jnp.arange((X.shape[0]//batch_size)+1) ### Batch Indices
losses = [] ## Record loss of each batch
for batch in tqdm(batches):
if batch != batches[-1]:
start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
else:
start, end = int(batch*batch_size), None
X_batch, Y_batch = X[start:end], Y[start:end] ## Single batch of data
loss, gradients = value_and_grad(CrossEntropyLoss)(params, X_batch,Y_batch)
## Update Network Parameters
updates, optimizer_state = optimizer.update(gradients, optimizer_state)
params = optax.apply_updates(params, updates)
losses.append(loss) ## Record Loss
print("CrossEntropyLoss : {:.3f}".format(jnp.array(losses).mean()))
Y_val_preds = model.apply(params, X_val)
val_acc = accuracy_score(Y_val, jnp.argmax(Y_val_preds, axis=1))
print("Validation Accuracy : {:.3f}".format(val_acc))
return params
Below, we have actually trained our network using the training routine defined in the previous cell. First, we have initialized batch size to 1024, the number of epochs to 10, and the learning rate to 0.001. Then, we have initialized our text classification network and Adam optimizer. At last, we have called our training routine with the necessary parameters to perform training. We can notice from the loss and accuracy value getting printed after each epoch that our model is doing a good job at the text classification task.
from jax import random
seed = random.PRNGKey(0)
batch_size=1024
epochs=10
learning_rate = jnp.array(1e-3)
model = Conv1DTextClassifier()
params = model.init(seed, X_train_vect[:5])
optimizer = optax.adam(learning_rate=learning_rate)
optimizer_state = optimizer.init(params)
final_weights = TrainModelInBatches(X_train_vect, Y_train, X_test_vect, Y_test, epochs, params, optimizer_state, batch_size=batch_size)
In this section, we have evaluated the performance of our train network by calculating accuracy score, classification report and confusion matrix metrics on test predictions. We can notice from the accuracy score that the network is doing quite a good job at the classification task.
We have calculated ML metrics using functions available from scikit-learn. If you want to learn about various ML metrics available from sklearn then please feel free to check the below link.
We have also created a plot of confusion matrix using Python library scikit-plot. We can notice from the plot that our model is doing a good job at classifying text documents of categories Sports and World compared to categories Business and Sci/Tech.
If you are new to library scikit-plot and want to learn various ML metrics plots it provides then please do check the below link.
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
train_preds = model.apply(final_weights, X_train_vect)
test_preds = model.apply(final_weights, X_test_vect)
print("Train Accuracy : {:.3f}".format(accuracy_score(Y_train, np.argmax(train_preds, axis=1))))
print("Test Accuracy : {:.3f}".format(accuracy_score(Y_test, np.argmax(test_preds, axis=1))))
print("\nClassification Report : ")
print(classification_report(Y_test, np.argmax(test_preds, axis=1), target_names=target_classes))
print("\nConfusion Matrix : ")
print(confusion_matrix(Y_test, np.argmax(test_preds, axis=1)))
from sklearn.metrics import confusion_matrix
import scikitplot as skplt
import matplotlib.pyplot as plt
skplt.metrics.plot_confusion_matrix([target_classes[i] for i in Y_test], [target_classes[i] for i in np.argmax(test_preds, axis=1)],
normalize=True,
title="Confusion Matrix",
cmap="Blues",
hide_zeros=True,
figsize=(5,5)
);
plt.xticks(rotation=90);
In this section, we have tried to explain predictions made by our network using LIME algorithm. We'll be using an implementation of the algorithm available through lime python library. It let us generate visualization showing which words from our text example contributed to predicting a particular target label.
Please feel to check the below links if you are new to LIME and want to learn about it in-depth.
In order to explain predictions using LIME, we first need to create an instance of LimeTextExplainer which we have done below.
from lime import lime_text
explainer = lime_text.LimeTextExplainer(class_names=target_classes, verbose=True)
Now, we have randomly selected a text example from the test dataset and made predictions on it using our trained network. Our network correctly predicts the target label as Sci/Tech for the selected text example.
Apart from this, we have also defined the prediction function in the below cell. The function takes a list of text examples as input and returns their prediction probabilities using our model. The function tokenizes data and then gives it to the network to make predictions. At last, softmax activation is applied to the output of the network to convert it to probabilities and returned from the function.
import numpy as np
def make_predictions(X_batch_text):
X_batch = pad_sequences(tokenizer.texts_to_sequences(X_batch_text), maxlen=max_tokens, padding="post", truncating="post", value=0.)
logits = model.apply(final_weights, jnp.array(X_batch))
preds = linen.softmax(logits)
return preds.to_py()
rng = np.random.RandomState(1234)
idx = rng.randint(1, len(X_test_text))
print("Prediction : ", target_classes[model.apply(final_weights, X_test_vect[idx:idx+1]).argmax(axis=-1)[0]])
print("Actual : ", target_classes[Y_test[idx]])
Below, we have first generated an Explanation object by calling explain_instance() method on LimeTextExplainer object. We have given text examples, prediction function, and the target label to the method. This explanation object has details about words contributing to prediction.
Then, we have called show_in_notebook() method on Explanation object to generate a visualization showing words contributing to predicting the target label as Sci/Tech.
We can notice from the visualization that words like 'RFID', 'privacy', 'frequency', 'identification', etc are used for predicting the target label as Sci/Tech which makes sense as they are commonly used in the field.
explanation = explainer.explain_instance(X_test_text[idx], classifier_fn=make_predictions, labels=Y_test[idx:idx+1].to_py())
explanation.show_in_notebook()
Our approach in this section uses multiple convolution layers. The majority of the code in this section is the same as our code from the previous section except for the definition of neural network which uses two 1d convolution layers this time instead of one.
In this section, we have defined a network that we'll use for our text classification task in this section. The network consists of one embedding layer, two convolution layers, and one dense layer. We have again used the embedding length of 128 in this section as well. The two convolution layers have 32 output channels and a kernel size of 7. The output of the embedding layer is given to the convolution layer. The relu activation is applied to the output of the first convolution layer and then max-pooling operation is performed on it. The output of max-pooling is given to the second convolution layer. Then, we have applied relu activation to the output of the second convolution layer and given it to the dense layer. The output of the dense layer is returned as a prediction of the network.
After defining the network, we initialized it, printed the shape of weights/biases of layers, and performed a forward pass through it using random data for verification purposes.
from flax import linen
embed_len = 128 ## Hyperparameter
class Conv1DTextClassifier(linen.Module):
def setup(self):
self.embedding = linen.Embed(len(tokenizer.word_index)+1, embed_len, name="Word Embeddings")
self.conv1 = linen.Conv(32, kernel_size=(7,),name="Conv1")
self.conv2 = linen.Conv(32, kernel_size=(7,),name="Conv2")
self.linear1 = linen.Dense(len(unique_classes), name="Dense1")
def __call__(self, X_batch):
x = self.embedding(X_batch)
x = linen.relu(self.conv1(x))
x = linen.max_pool(x, window_shape=(5,))
x = linen.relu(self.conv2(x))
x = x.max(axis=1)
logits = self.linear1(x)
return logits
from jax import numpy as jnp
seed = jax.random.PRNGKey(0)
text_classif = Conv1DTextClassifier()
params = text_classif.init(seed, jax.random.randint(seed, (100, max_tokens), minval=1, maxval=20))
for layer_params in params["params"].items():
print("Layer Name : {}".format(layer_params[0]))
if "Embedding" in layer_params[0]:
weights = layer_params[1]["embedding"]
print("\tLayer Weights : {}".format(weights.shape))
else:
weights, biases = layer_params[1]["kernel"], layer_params[1]["bias"]
print("\tLayer Weights : {}, Biases : {}".format(weights.shape, biases.shape))
preds = text_classif.apply(params, X_train_vect[:5])
preds
Here, we have trained our network using exactly the same settings that we have used in the previous section. We can notice from the loss and accuracy values getting printed after each epoch that our model is doing a good job at the text classification task.
from jax import random
seed = random.PRNGKey(0)
batch_size=1024
epochs=10
learning_rate = jnp.array(1e-3)
model = Conv1DTextClassifier()
params = model.init(seed, X_train_vect[:5])
optimizer = optax.adam(learning_rate=learning_rate)
optimizer_state = optimizer.init(params)
final_weights = TrainModelInBatches(X_train_vect, Y_train, X_test_vect, Y_test, epochs, params, optimizer_state, batch_size=batch_size)
In this section, we have evaluated the performance of our trained network by calculating accuracy score, classification report and confusion matrix metrics on test predictions. We can notice from the test accuracy that it's a little low compared to the previous approach which is surprising as we had expected that trying more convolution layers might improve accuracy. We have also plotted the confusion matrix for reference.
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
train_preds = model.apply(final_weights, X_train_vect)
test_preds = model.apply(final_weights, X_test_vect)
print("Train Accuracy : {:.3f}".format(accuracy_score(Y_train, np.argmax(train_preds, axis=1))))
print("Test Accuracy : {:.3f}".format(accuracy_score(Y_test, np.argmax(test_preds, axis=1))))
print("\nClassification Report : ")
print(classification_report(Y_test, np.argmax(test_preds, axis=1), target_names=target_classes))
print("\nConfusion Matrix : ")
print(confusion_matrix(Y_test, np.argmax(test_preds, axis=1)))
from sklearn.metrics import confusion_matrix
import scikitplot as skplt
import matplotlib.pyplot as plt
skplt.metrics.plot_confusion_matrix([target_classes[i] for i in Y_test], [target_classes[i] for i in np.argmax(test_preds, axis=1)],
normalize=True,
title="Confusion Matrix",
cmap="Blues",
hide_zeros=True,
figsize=(5,5)
);
plt.xticks(rotation=90);
In this section, we have again tried to explain predictions made by our network using LIME algorithm. Our network correctly predicts the target label as Sci/Tech for a selected random test example. From the visualization created using LIME, we can notice that words like 'RFID', 'technology', 'frequency', 'identification', 'privacy', etc are contributing to predicting the target label as Sci/Tech.
from lime import lime_text
explainer = lime_text.LimeTextExplainer(class_names=target_classes)
rng = np.random.RandomState(1234)
idx = rng.randint(1, len(X_test_text))
print("Prediction : ", target_classes[model.apply(final_weights, X_test_vect[idx:idx+1]).argmax(axis=-1)[0]])
print("Actual : ", target_classes[Y_test[idx]])
explanation = explainer.explain_instance(X_test_text[idx], classifier_fn=make_predictions, labels=Y_test[idx:idx+1].to_py())
explanation.show_in_notebook()
Below, we have listed a summary of various approaches we tried above.
Approach | Max Tokens | Embedding Length | Conv Output Channels | Test Accuracy (%) |
---|---|---|---|---|
Single Conv1D Layer Network | 50 | 128 | 32 | 91.2 |
Multiple Conv1D Layers Network | 50 | 128 | 32,32 | 89.9 |
This ends our small tutorial explaining how we can design CNNs of 1D Convolution layers using Flax (JAX) framework for solving text classification tasks. Please feel free to let us know your views in the comments section.
If you are more comfortable learning through video tutorials then we would recommend that you subscribe to our YouTube channel.
When going through coding examples, it's quite common to have doubts and errors.
If you have doubts about some code examples or are stuck somewhere when trying our code, send us an email at coderzcolumn07@gmail.com. We'll help you or point you in the direction where you can find a solution to your problem.
You can even send us a mail if you are trying something new and need guidance regarding coding. We'll try to respond as soon as possible.
If you want to