**FLAX** is a high-level framework designed on top of **JAX** to make the process of neural networks development easier and faster. It is designed to speed up the process hence researchers and developers can perform more experiments in less time. We have already covered the tutorial where we have explained how we can create neural networks using **Flax**. Please feel free to check it if you are looking for fully connected networks.

As a part of this tutorial, we'll explain how we can create convolutional neural networks (CNNs) using **Flax**. We'll explain how to use convolution layers to build simple CNNs. **Flax** earlier used to have a sub-module for optimizers but they have decommissioned it for **optax** package which implements the majority of optimizers. The **Flax** team recommends that we use optimizers from **optax** library.

The tutorial assumes that the reader has a background on **JAX** and neural network terms like optimization, loss function, activations, etc. If you want to learn about **JAX** or want to create neural networks using high-level API of **JAX** then please feel free to check the below links. It'll help you with this tutorial as well.

- JAX - (Numpy + Automatic Gradients) on Accelerators (GPUs/TPUs)
- Guide to Create Neural Networks using High-level JAX API
- JAX: Guide to Create CNNs

Below, we have highlighted important sections of the tutorial to give an overview of the material covered.

- Simple Convolutional Neural Network (CNN)
- Load Dataset
- Create Convolutional Neural Network (CNN)
- Define Loss Function
- Train CNN (SGD)
- Make Predictions
- Evaluate Model Performance
- Train CNN (Adam Optimizer)
- Make Predictions
- Evaluate Model Performance

- Channels First vs Channels Last

```
import jax
print("JAX Version : {}".format(jax.__version__))
```

```
import jax.numpy as jnp
```

```
import flax
print("Flax Version : {}".format(flax.__version__))
```

```
import optax
print("Optax Version : {}".format(optax.__version__))
```

In this section, we'll explain how we can create a simple CNN using convolution layers to solve classification tasks. We'll be using the Fashion MNIST dataset available from keras for our purpose.

In this section, we have loaded **Fashion MNIST** dataset available from keras. The dataset has 60k train images and 10k test images. There are 10 different types of fashion items present in the dataset. The dataset is already divided into train and test sets when we load it from keras. After loading it, we convert datasets from numpy array to **JAX** arrays as all **Flax** models work on **JAX** arrays. Then, we have reshaped datasets and added one extra dimension at the end. This is the channel dimension required by convolution layers. The convolution layers will transform this channel dimension. The color or RGB images already have channel dimension as it has 3 channels. Our dataset has grayscale images with only one channel hence we included an extra dimension to show that channel. We have later divided datasets by float value 255 to bring all values in the array in the range **[0,1]**. By default, array has values in the range **[0,255]**.

```
from tensorflow import keras
from sklearn.model_selection import train_test_split
(X_train, Y_train), (X_test, Y_test) = keras.datasets.fashion_mnist.load_data()
X_train, X_test, Y_train, Y_test = jnp.array(X_train, dtype=jnp.float32),\
jnp.array(X_test, dtype=jnp.float32),\
jnp.array(Y_train, dtype=jnp.float32),\
jnp.array(Y_test, dtype=jnp.float32)
X_train, X_test = X_train.reshape(-1,28,28,1), X_test.reshape(-1,28,28,1)
X_train, X_test = X_train/255.0, X_test/255.0
classes = jnp.unique(Y_train)
X_train.shape, X_test.shape, Y_train.shape, Y_test.shape
```

In this section, we have defined our CNN. We have created a CNN by creating a class that extends **linen.Module** class of **Flax**. The layers are available through **flax.linen** module. We have implemented two methods in our **CNN** class.

**setup()**- In this method, we have initialized the layers of our CNN.- In this method, we have performed forward pass-through data using layers defined in**call**()**setup()**method. This method returns predictions.

Our CNN uses two convolution layers. The first layer has **32** output channels and a kernel size of **(3,3)**. The second convolution layer has **16** output channels and a kernel size of **(3,3)**. Both convolution layers have padding specified as **'SAME'** which indicates that the height and width of the image should be the same after the application of convolution kernels. We apply convolution layers to input data one by one in our ** call()** method. We have applied

Our input data has shape **(n_samples, 28,28,1)**. The first convolution layer will transform shape from **(n_samples,28,28,1)** to **(n_samples,28,28,32)**. The second convolution layer will transform shape from **(n_samples,28,28,32)** to **(n_samples,28,28,16)**. Then flatten operation will transform shape from **(n_samples,28,28,16)** to **(n_samples,28 x 28 x 16) = (n_samples,12544)**. The linear layer will transform shape from **(n_samples,12544)** to **(n_samples,10)** which will be our output shape. Later on, we'll include logic to guess actual class per sample from these 10 values per sample by taking class which has highest probability.

After defining CNN, we have initialized it in the next cell. The initialized model has two important methods.

**init(seed, sample_input_data)**- This method takes**PRNG**seed and sample data as input to initialize model weights. It returns model weights as the dictionary-like object.**apply(params, input_data)**- This method performs forward pass through the network on given input data using given parameters.

We have initialized the weights of our CNN by calling **init()** method. We have then printed the shape of weights of various layers as well for information purposes.

Then, in the next cell, we have performed a forward pass through CNN using **apply()** method for verification purposes. We have given a few data samples as input to make predictions.

```
from flax import linen
from jax import random
class CNN(linen.Module):
def setup(self):
self.conv1 = linen.Conv(features=32, kernel_size=(3,3), padding="SAME", name="CONV1")
self.conv2 = linen.Conv(features=16, kernel_size=(3,3), padding="SAME", name="CONV2")
self.linear1 = linen.Dense(len(classes), name="DENSE")
def __call__(self, inputs):
x = linen.relu(self.conv1(inputs))
x = linen.relu(self.conv2(x))
x = x.reshape((x.shape[0], -1))
x = self.linear1(x)
return linen.softmax(x)
```

```
seed = jax.random.PRNGKey(0)
model = CNN()
params = model.init(seed, X_train[:5])
for layer_params in params["params"].items():
print("Layer Name : {}".format(layer_params[0]))
weights, biases = layer_params[1]["kernel"], layer_params[1]["bias"]
print("\tLayer Weights : {}, Biases : {}".format(weights.shape, biases.shape))
```

```
preds = model.apply(params, X_train[:5])
preds
```

In this section, we have defined the loss function for our multi-class classification task. We'll be using cross entropy loss for our task. The function takes model parameters, input data, and actual target values as input. It then makes predictions using **apply()** method by giving model parameters and input data to it. We then convert actual target values to one-hot encoded values. Then, we take the log of model predictions. At last, we multiply one-hot encoded actual target to the log of predictions. We then return the sum of all values of the returned array.

```
def CrossEntropyLoss(weights, input_data, actual):
preds = model.apply(weights, input_data)
one_hot_actual = jax.nn.one_hot(actual, num_classes=len(classes))
log_preds = jnp.log(preds)
return - jnp.sum(one_hot_actual * log_preds)
```

In this section, we are training our CNN. We'll be using a gradient descent optimizer available from **optax** library. We have designed a small function that has logic to train our CNN. We'll be calling this function to perform training in the next cell.

The function takes data features, actual target values, number of epochs, model parameters, optimizer state object, and batch size as input. It then executes the training loop number of epochs time.

For each epoch, it calculates the start and end indexes of batches of data. It then loops through data in batches using these batch indexes. For each batch, it performs forward pass through networks to make predictions and calculates loss using these predictions and actual target values. Then, it calculates the gradients of loss with respect to model parameters. It does all these steps using **value_and_grad()** method of **JAX**. This method takes as input any function that operates on **JAX** arrays and returns another function. We can call this returned function using the same parameters as our main function. It'll return two values. The first value will be the actual value of that function and the second value will be a gradient of the output of that function with respect to the first input parameter of the method.

In our case, we have given our loss function to method **value_and_grad()**. It returns loss value and gradients of loss with respect to model parameters when called using input values. Then, we call **update()** method on the optimizer object using gradients and optimizer state. It'll return updates to be made to model parameters and a new optimizer state. We update model weights using **apply_updates()** method of **optax** by giving model weights and updates to it. It returns updated model parameters/weights. We also record the loss of each batch. Once all epochs are completed, we return the final updated weights from the function.

```
from jax import value_and_grad
def TrainModelInBatches(X, Y, epochs, weights, optimizer_state, batch_size=32):
for i in range(1, epochs+1):
batches = jnp.arange((X.shape[0]//batch_size)+1) ### Batch Indices
losses = [] ## Record loss of each batch
for batch in batches:
if batch != batches[-1]:
start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
else:
start, end = int(batch*batch_size), None
X_batch, Y_batch = X[start:end], Y[start:end] ## Single batch of data
loss, gradients = value_and_grad(CrossEntropyLoss)(weights, X_batch,Y_batch)
## Update Weights
updates, optimizer_state = optimizer.update(gradients, optimizer_state)
weights = optax.apply_updates(weights, updates)
losses.append(loss) ## Record Loss
print("CrossEntropyLoss : {:.3f}".format(jnp.array(losses).mean()))
return weights
```

In the below cell, we have first initialized necessary things to train our CNN and then called function from the previous cell to train our CNN.

We have initialized batch size to **256**, a number of epochs to **15**, and learning rate to **0.0001**. Then, we have initialized our model and its weights/parameters. After initializing the model, we have initialized the optimizer for our case. We have initialized **SGD()** optimizer available from **optax** by giving a learning rate to it. The optimizer object has two important methods.

**init(params)**- This method takes model parameters and returns**OptimizerState**object. It has model weights/parameters wrapped inside it.**update(params,optimizer_state)**- This method takes model parameters and optimizer state as input. It then returns updates to be applied to model parameters and the new optimizer state.

We have initialized **OptimizerState** by calling **init()** method of the optimizer object by giving model parameters to it. At last, we have called our function from the previous cell to train our CNN. We have provided all the necessary parameters for the method. We can notice from the loss value getting printed after every epoch that our model seems to be doing a good job.

```
seed = random.PRNGKey(0)
batch_size=256
epochs=15
learning_rate = jnp.array(1/1e4)
model = CNN()
weights = model.init(seed, X_train[:5])
optimizer = optax.sgd(learning_rate=learning_rate) ## Initialize SGD Optimizer
optimizer_state = optimizer.init(weights)
final_weights = TrainModelInBatches(X_train, Y_train, epochs, weights, optimizer_state, batch_size=batch_size)
```

In this section, we are making predictions using the trained model on train and test datasets. We have designed a small function that loops through input data in batches and makes predictions. Then, we have combined the predictions of all batches. The function takes final updated model parameters, input data, and batch size as input. It then returns combined predictions.

Our predictions though CNN has 10 values per sample as we had discussed earlier. We have applied **softmax** activation function to the output of CNN hence the sum of these 10 values per sample will be 1. They are probabilities. To convert these probabilities to the actual target class, we have retrieved the index of the highest probability per sample, and that index value will be our target class prediction.

```
def MakePredictions(weights, input_data, batch_size=32):
batches = jnp.arange((input_data.shape[0]//batch_size)+1) ### Batch Indices
preds = []
for batch in batches:
if batch != batches[-1]:
start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
else:
start, end = int(batch*batch_size), None
X_batch = input_data[start:end]
if X_batch.shape[0] != 0:
preds.append(model.apply(weights, X_batch))
return preds
```

```
test_preds = MakePredictions(final_weights, X_test, batch_size=256)
test_preds = jnp.concatenate(test_preds).squeeze() ## Combine predictions of all batches
test_preds = jnp.argmax(test_preds, axis=1)
train_preds = MakePredictions(final_weights, X_train, batch_size=256)
train_preds = jnp.concatenate(train_preds).squeeze() ## Combine predictions of all batches
train_preds = jnp.argmax(train_preds, axis=1)
test_preds[:5], train_preds[:5]
```

In this section, we have evaluated the performance of our CNN by calculating the accuracy of train and test predictions. We have also calculated a classification report on test predictions which has information like precision, recall, and f1-score per target class. We have calculated accuracy and classification report using functions available through scikit-learn.

If you want to learn about various machine learning metrics calculation functions available through scikit-learn then please feel free to check our tutorial that covers the majority of them in detail.

```
from sklearn.metrics import accuracy_score
print("Train Accuracy : {:.3f}".format(accuracy_score(Y_train, train_preds)))
print("Test Accuracy : {:.3f}".format(accuracy_score(Y_test, test_preds)))
```

```
from sklearn.metrics import classification_report
print("Test Classification Report ")
print(classification_report(Y_test, test_preds))
```

In this section, we have trained our CNN again but this time using **Adam** optimizer. It'll be useful to compare performance with **SGD** optimizer. All other parameter settings are the same as our **SGD** training.

```
seed = random.PRNGKey(0)
batch_size=256
epochs=15
learning_rate = jnp.array(1/1e4)
model = CNN()
weights = model.init(seed, X_train[:5])
optimizer = optax.adam(learning_rate=learning_rate) ## Initialize Adam Optimizer
optimizer_state = optimizer.init(weights)
final_weights = TrainModelInBatches(X_train, Y_train, epochs, weights, optimizer_state, batch_size=batch_size)
```

In this section, we have made predictions on train and test sets using our CNN trained with **Adam** optimizer.

```
test_preds = MakePredictions(final_weights, X_test, batch_size=256)
test_preds = jnp.concatenate(test_preds).squeeze() ## Combine predictions of all batches
test_preds = jnp.argmax(test_preds, axis=1)
train_preds = MakePredictions(final_weights, X_train, batch_size=256)
train_preds = jnp.concatenate(train_preds).squeeze() ## Combine predictions of all batches
train_preds = jnp.argmax(train_preds, axis=1)
test_preds[:5], train_preds[:5]
```

In this section, we have evaluated the performance of our CNN by calculating the accuracy of train and test predictions. We have also calculated the classification report for test predictions.

```
from sklearn.metrics import accuracy_score
print("Train Accuracy : {:.3f}".format(accuracy_score(Y_train, train_preds)))
print("Test Accuracy : {:.3f}".format(accuracy_score(Y_test, test_preds)))
```

```
from sklearn.metrics import classification_report
print("Test Classification Report ")
print(classification_report(Y_test, test_preds))
```

In our example, we had used grayscale images. The grayscale images generally do not have channels as there is only one channel. But as channel dimension is required by the convolution layer, we have introduced that dimension in our images. Generally, color or RGB images have 3 channels. We can represent these channel details in our multi-dimensional array of images in two different ways.

**Channels First**- Here, we represent color image of**(28,28)**pixels as**(3,28,28)**array.**Channels Last**- Here, we represent color image of**(28,28)**pixels as**(28,28,3)**array.

By default, the convolution layers available through **Flax** require channel details to be present at last in **Channels Last** format. It currently can not handle **Channels First** format.

Below, we have first explained an example where we have created 2 convolution layers and have applied them to input data (**Channels Last** format). We have printed the output shape from both layers as well as weights shape to give an idea about how convolution operations as applied.

Then, in the next cell, we have again created two convolution layers and have applied them to input data which has channel details present at the beginning (**Channels First**). We can notice from the output shapes how calculations are going wrong with **Channels First** format data.

```
conv_layer1 = flax.linen.Conv(16, (3,3))
conv_layer2 = flax.linen.Conv(32, (3,3))
seed = jax.random.PRNGKey(123)
params1 = conv_layer1.init(seed, jax.random.uniform(seed, (50,28,28,1)))
preds1 = conv_layer1.apply(params1, jax.random.uniform(seed,(50,28,28,1)))
params2 = conv_layer2.init(seed, jax.random.uniform(seed, preds1.shape))
preds2 = conv_layer2.apply(params2, jax.random.uniform(seed,preds1.shape))
print("Weights of First Conv Layer : {}".format(params1["params"]["kernel"].shape))
print("Weights of Second Conv Layer : {}".format(params2["params"]["kernel"].shape))
print("\nInput Shape : {}".format((50,28,28,1)))
print("Conv Layer 1 Output Shape : {}".format(preds1.shape))
print("Conv Layer 2 Output Shape : {}".format(preds2.shape))
```

```
conv_layer1 = flax.linen.Conv(16, (3,3))
conv_layer2 = flax.linen.Conv(32, (3,3))
seed = jax.random.PRNGKey(123)
params1 = conv_layer1.init(seed, jax.random.uniform(seed, (50,1,28,28)))
preds1 = conv_layer1.apply(params1, jax.random.uniform(seed,(50,1,28,28)))
params2 = conv_layer2.init(seed, jax.random.uniform(seed, preds1.shape))
preds2 = conv_layer2.apply(params2, jax.random.uniform(seed,preds1.shape))
print("Weights of First Conv Layer : {}".format(params1["params"]["kernel"].shape))
print("Weights of Second Conv Layer : {}".format(params2["params"]["kernel"].shape))
print("\nInput Shape : {}".format((50,1,28,28)))
print("Conv Layer 1 Output Shape : {}".format(preds1.shape))
print("Conv Layer 2 Output Shape : {}".format(preds2.shape))
```

This ends our small tutorial explaining how we can design a convolutional neural network (CNN) using a high-level **Flax** framework designed on top of **JAX**. We had used **optax** library for optimizers as **Flax** has deprecated its module with optimizers. Please feel free to let us know your views in the comments section.

- JAX - (Numpy + Automatic Gradients) on Accelerators (GPUs/TPUs)
- Flax: Framework to Create Neural Networks using JAX
- Guide to Create Neural Networks using High-level JAX API
- PyTorch - Convolutional Neural Networks
- JAX: Guide to Create Convolutional Neural Networks
- MXNet: Convolutional Neural Networks (CNN)

If you are more comfortable learning through video tutorials then we would recommend that you subscribe to our **YouTube** channel.

When going through coding examples, it's quite common to have doubts and errors.

If you have doubts about some code examples or are stuck somewhere when trying our code, send us an email at **coderzcolumn07@gmail.com**. We'll help you or point you in the direction where you can find a solution to your problem.

You can even send us a mail if you are trying something new and need guidance regarding coding. We'll try to respond as soon as possible.

If you want to

- provide some suggestions on topic
- share your views
- include some details in tutorial
- suggest some new topics on which we should create tutorials/blogs