Updated On : Jan-03,2022 Time Investment : ~30 mins

Flax: Framework to Create Neural Networks using JAX

Flax is a deep learning framework designed on the top of JAX. It let us create a neural network easily using its high-level API. JAX is a library that provides numpy like arrays (functions to work on arrays) on CPUs/GPUs/TPUs and automatic differentiation of functions working with arrays. The neural networks created using Flax are faster as it utilizes various optimization functionalities provided by JAX like JIT (just-in-time compilation), vmap (vectorization), pmap (parallelization), etc. As Flax is designed on the top of JAX, a little bit of JAX background is required. If you want to learn about JAX then please feel free to check our small tutorial on it. We recommend that you go through our JAX tutorial and other JAX tutorials mentioned in the reference section to understand JAX better as it'll help with this tutorial.

As a part of this tutorial, we'll explain how we can create simple neural networks using Flax to solve simple regression and classification tasks. We'll be using small toy datasets available from scikit-learn for our purposes. The Flax earlier used to have its own implementation of optimizers but they deprecated that sub-module in favor of a new library named Optax that provides an implementation of the majority of optimizers. Hence, we'll be using optimizers from Optax as a part of our tutorial to optimize our loss functions.

The tutorial requires that the reader has little background on neural networks and it's parts like optimizers, layers, activation functions, loss functions, etc. Because the main aim of the tutorial is to get individuals started designing neural networks using Flax API and not to explain how neural networks work in-depth.

Below we have highlighted important sections of the tutorial to give an overview of the material covered.

Important Sections of Tutorial

  1. Regression
    • Load Dataset
    • Normalize Data
    • Define Neural Network
    • Define Loss Function
    • Train Model
    • Make Predictions
    • Evaluate Model Performance
  2. Classification

Installation

  • pip install --upgrade jax jaxlib
  • pip install flax

Below we have imported the necessary libraries that we'll use in our tutorial and printed the version of them as well.

import flax

print("Flax Version : {}".format(flax.__version__))
Flax Version : 0.3.6
import jax

print("Jax Version : {}".format(jax.__version__))
Jax Version : 0.2.26
import optax

print("Optax Version : {}".format(optax.__version__))
Optax Version : 0.0.9

1. Regression

In this section, we'll explain how we can create simple neural networks to solve regression tasks. We'll be using the Boston housing dataset available from scikit-learn for our purposes.

Load Dataset

In this section, we have loaded the Boston housing dataset available from scikit-learn. We have loaded data features (independent variables) into variable X and target values into variable Y. The target values are median house prices in 1000 dollars. The features are various features related to the house and adjoining area.

After loading the dataset, we have divided it into the train (80%) and test (20%) sets. We have then converted datasets from numpy to JAX arrays.

from sklearn import datasets
from sklearn.model_selection import train_test_split
from jax import numpy as jnp

X, Y = datasets.load_boston(return_X_y=True)

X_train, X_test, Y_train, Y_test = train_test_split(X, Y, train_size=0.8, random_state=123)

X_train, X_test, Y_train, Y_test = jnp.array(X_train, dtype=jnp.float32),\
                                   jnp.array(X_test, dtype=jnp.float32),\
                                   jnp.array(Y_train, dtype=jnp.float32),\
                                   jnp.array(Y_test, dtype=jnp.float32)


samples, features = X_train.shape

X_train.shape, X_test.shape, Y_train.shape, Y_test.shape
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
((404, 13), (102, 13), (404,), (102,))
samples, features
(404, 13)

Normalize Data

In this section, we have normalized our train and test datasets. We normalize datasets so that features that are on different scales and vary a lot in their values, come on the same scale. This will help optimization algorithms like gradient descent to converge faster.

In order to normalize datasets, we have first calculated the mean and standard deviation features of train data. We have then subtracted the mean from the train and test dataset followed by dividing subtracted results with standard deviation.

mean = X_train.mean(axis=0)
std = X_train.std(axis=0)

X_train = (X_train - mean) / std
X_test = (X_test - mean) / std

Define Neural Network

In this section, we have created a neural network using Flax. Flax provides a module named linen which has all necessary layers that are required to create neural networks.

In order to create a neural network using Flax, we need to create a class that extends linen.Module class. We then need to define setup() and call() methods. Inside setup() method, we declare various layers and few default variables. Inside call() method, we actually implement forward pass logic through our network using layers defined inside setup() method. The actual logic of how layers will be applied is kept in call() method. The call() method takes data as input to which forward pass of the neural network will be applied.

Below, we have first created a class representing our neural network by extending linen.Module class. We have then declared a features class variable that holds layer sizes details. We have then created list of linear/dense layers using that features variable inside setup() method. Inside call() method, we have looped through layers initialized in setup() method and executed it on input data one by one. We have applied Relu (Rectified Linear Units) activation function to the output of each layer using linen.relu() function. At last, we return predictions from call() method.

After defining our neural network, we have created an instance of our neural network. We can initialize the weights of the neural network by calling init() method on it. In order to initialize weights, we need to provide pseudo-random number seed and same data as input to init() method. It returns a dictionary-like object which has the parameters/weights of a neural network. The weights are kept in 'params' key of the dictionary. We have printed the shape of weights of the neural network by looping through the weights/parameters dictionary.

Then, in the next cell, we have performed a forward pass through the neural network by calling apply() method on it giving sample data to it. We have then also printed predictions to verify that network is working as expected.

from typing import Sequence, Tuple
from jax import random
import jax.numpy as jnp
from flax import linen

class MultiLayerPerceptronRegressor(linen.Module):
    features: Sequence[int] = (5,10,15,1)

    def setup(self):
        self.layers = [linen.Dense(feat) for feat in self.features]

    def __call__(self, inputs):
        x = inputs
        for i, lyr in enumerate(self.layers):
            x = lyr(x)
            if i != len(self.layers) - 1:
                x = linen.relu(x)
        return x

seed = random.PRNGKey(0)

model = MultiLayerPerceptronRegressor()
params = model.init(seed, X_train[:5])

for layer_params in params["params"].items():
    print("Layer Name : {}".format(layer_params[0]))
    weights, biases = layer_params[1]["kernel"], layer_params[1]["bias"]
    print("\tLayer Weights : {}, Biases : {}".format(weights.shape, biases.shape))
Layer Name : layers_0
	Layer Weights : (13, 5), Biases : (5,)
Layer Name : layers_1
	Layer Weights : (5, 10), Biases : (10,)
Layer Name : layers_2
	Layer Weights : (10, 15), Biases : (15,)
Layer Name : layers_3
	Layer Weights : (15, 1), Biases : (1,)
preds = model.apply(params, X_train[:5])

preds
DeviceArray([[-0.78734845],
             [-0.28225327],
             [-0.21475722],
             [-0.15480304],
             [-0.118424  ]], dtype=float32)

Define Loss Function

In this section, we have defined the loss function for our neural network which is the mean squared error loss function. We'll be calculating gradients with respect to this function. The method takes weights, input data, and actual predictions as input. It first makes predictions using apply() method of neural network. Then, it calculates loss using predictions and actual target values.

The mean squared error loss is simply the average of the squared difference between actual target values and predictions. We return the scalar loss value from the function.

mean square error(predictions, actuals) = 1/n * (actuals - predictions)^2
def MeanSquaredErrorLoss(weights, input_data, actual):
    preds = model.apply(weights, input_data)
    return jnp.power(actual - preds.squeeze(), 2).mean()

Train Model

In this section, we have trained our neural network. We have first initialized a number of epochs to 1000 and also initialized pseudo-random number seed that will be used to initialize model weights. We have then created an instance of neural network and initialized model weights by calling init() method on network instance by giving seed and sample random data.

Then, we have created a gradient descent optimizer using optax library. The optimizer that we'll be using for our training is sgd() optimizer. We have initialized it with a learning rate of 0.001. We have then initialized the optimizer state by calling init() method on the optimizer instance giving neural network weights to it. This optimizer state will have information about model weights which we'll update when performing the training loop.

Then, in the next line, we have created another function by wrapping our loss function inside of value_and_grad() method of JAX. This method returns a function that takes the same parameters as our original loss function but when called with those parameter values, it returns two values as output. The first value will be the loss value returned after the actual execution of the loss function with weights, input data, and actual target values. The second value will be gradients of loss with respect to the first parameter of loss function which is the weights of the neural network. We'll be using this function during the training loop to calculate gradients.

After all initializations, we are looping our training loop number of epochs time. Each time, we first calculate loss values and gradients using the function we initialized earlier. This function also performs a forward pass through the network when calculating predictions inside of the loss function. Then, we call update() method on optimizer instance giving gradients and optimizer state to it. It returns updates and a new optimizer state. At last, we update model weights by calling apply_updates() method of optax to update model weights.

We are printing loss value every 100 epochs. We can notice from loss value getting printed that our model is doing a decent job.

We recommend that readers go through our other tutorials guiding how to create neural networks using JAX as it'll help them better understand JAX and frameworks based on it.

seed = random.PRNGKey(0)
epochs=1000

model = MultiLayerPerceptronRegressor() ## Define Model
random_arr = jax.random.normal(key=seed, shape=(5, features))
params = model.init(seed, random_arr) ## Initialize Model Parameters

optimizer = optax.sgd(learning_rate=1/1e3) ## Initialize SGD Optimizer using OPTAX

optimizer_state = optimizer.init(params)
loss_grad = jax.value_and_grad(MeanSquaredErrorLoss)

for i in range(1,epochs+1):
    loss_val, gradients = loss_grad(params, X_train, Y_train) ## Calculate Loss and Gradients
    updates, optimizer_state = optimizer.update(gradients, optimizer_state)
    params = optax.apply_updates(params, updates) ## Update weights
    if i % 100 == 0:
        print('MSE After {} Epochs : {:.2f}'.format(i, loss_val))
MSE After 0 Epochs : 594.80
MSE After 100 Epochs : 15.83
MSE After 200 Epochs : 12.64
MSE After 300 Epochs : 11.74
MSE After 400 Epochs : 11.14
MSE After 500 Epochs : 10.62
MSE After 600 Epochs : 10.14
MSE After 700 Epochs : 9.78
MSE After 800 Epochs : 9.46
MSE After 900 Epochs : 9.15

Make Predictions

In this section, we are making predictions on train and test sets. We are calling apply() method on the network object by giving updated model weights and input data to it to make predictions.

test_preds = model.apply(params, X_test) ## Make Predictions on test dataset

test_preds = test_preds.ravel()

train_preds = model.apply(params, X_train) ## Make Predictions on train dataset

train_preds = train_preds.ravel()

Evaluate Model Performance

In this section, we have evaluated the performance of our network by calculating r^2 score on train and test predictions. The r^2 score generally returns float value in the range [0,1] for good models where values near 1 are considered good model. We have calculated r^2 score using r2_score() method available from scikit-learn. We can notice from the score calculated on our train and test predictions that our model seems to be doing a good job.

If you want to learn about r^2 score and other metrics provided for ML tasks by scikit-learn then please feel free to check our below tutorial that covers the majority of metrics in detail.

from sklearn.metrics import r2_score

print("Train R^2 Score : {:.2f}".format(r2_score(train_preds.to_py(), Y_train.to_py())))
print("Test  R^2 Score : {:.2f}".format(r2_score(test_preds.to_py(), Y_test.to_py())))
Train R^2 Score : 0.88
Test  R^2 Score : 0.71

2. Classification

In this section, we have explained how we can create simple neural networks for classification tasks. We'll be using the breast cancer dataset available from the scikit-learn for our explanation purposes. We'll be reusing the majority of the code that we have created in the regression section. Hence, we haven't included a detailed description of repeated code parts. Please feel free to check their description in the regression section if you have started directly from the classification section.

Load Dataset

In this section, we have loaded the breast cancer dataset available from scikit-learn. The features of the dataset are various measurements of tumor and the target value is binary (0 - benign tumor, 1 - malignant tumor). As our target values are binary, this will be a binary classification task.

After loading the dataset, we have divided it into the train (80%) and test (20%) sets.

from sklearn import datasets
from sklearn.model_selection import train_test_split

X, Y = datasets.load_breast_cancer(return_X_y=True)

X_train, X_test, Y_train, Y_test = train_test_split(X, Y, train_size=0.8, stratify=Y, random_state=123)

X_train, X_test, Y_train, Y_test = jnp.array(X_train, dtype=jnp.float32),\
                                   jnp.array(X_test, dtype=jnp.float32),\
                                   jnp.array(Y_train, dtype=jnp.float32),\
                                   jnp.array(Y_test, dtype=jnp.float32)

samples, features = X_train.shape
classes = jnp.unique(Y_test)

X_train.shape, X_test.shape, Y_train.shape, Y_test.shape
((455, 30), (114, 30), (455,), (114,))
samples, features, classes
(455, 30, DeviceArray([0., 1.], dtype=float32))

Normalize Datasets

In this section, we have normalized our train and test sets by using the mean and standard deviation calculated on the train set. As we had explained earlier, it helps with faster convergence.

mean = X_train.mean(axis=0)
std = X_train.std(axis=0)

X_train = (X_train - mean) / std
X_test = (X_test - mean) / std

Design Neural Network

In this section, we have defined a neural network to perform the binary classification task. The code for this section is almost exactly the same as our code from the regression section with one minor change. We have designed the same neural network with layer size [5,10,15,1] as earlier. The only difference is that we have applied sigmoid activation to the output of the last layer. The sigmoid function maps input values to the floats in the range [0,1]. Hence, the output of our neural networks will be float in the range [0,1]. We'll be later converting these floats to actual prediction classes (0 - benign, 1 - malignant).

from typing import Sequence, Tuple
from jax import random
import jax.numpy as jnp
from flax import linen

class MultiLayerPerceptronClassifier(linen.Module):
    features: Sequence[int] = (5,10,15,1)

    def setup(self):
        self.layers = [linen.Dense(feat) for feat in self.features]

    def __call__(self, inputs):
        x = inputs
        for i, lyr in enumerate(self.layers):
            x = lyr(x)
            if i != len(self.layers) - 1:
                x = linen.relu(x)
        return linen.sigmoid(x)

seed = random.PRNGKey(0)

model = MultiLayerPerceptronClassifier()
params = model.init(seed, X_train[:5])

for layer_params in params["params"].items():
    print("Layer Name : {}".format(layer_params[0]))
    weights, biases = layer_params[1]["kernel"], layer_params[1]["bias"]
    print("\tLayer Weights : {}, Biases : {}".format(weights.shape, biases.shape))
Layer Name : layers_0
	Layer Weights : (30, 5), Biases : (5,)
Layer Name : layers_1
	Layer Weights : (5, 10), Biases : (10,)
Layer Name : layers_2
	Layer Weights : (10, 15), Biases : (15,)
Layer Name : layers_3
	Layer Weights : (15, 1), Biases : (1,)
preds = model.apply(params, X_train[:5])

preds
DeviceArray([[0.48131582],
             [0.276364  ],
             [0.50851697],
             [0.48755074],
             [0.3897525 ]], dtype=float32)

Define Loss Function

In this section, we have defined the loss function for our neural network. We'll be using the negative log loss function for our network. The function takes weights, input data, and actual target values as input. It then makes predictions using weights and input data. Then, calculates loss using predictions and actual target values.

log_loss(predictions, actuals) = 1/n * (- actuals * log(predictions) - (1 - actual) * log(1 - predictions))
def NegLogLoss(weights, input_data, actual):
    preds = model.apply(weights, input_data)
    preds = preds.squeeze()
    return (- actual * jnp.log(preds) - (1 - actual) * jnp.log(1 - preds)).mean()

Train Neural Network

In this section, we have included code to train the neural network. The code is almost exactly the same as the code from the regression section with the only change that we are using the log loss function as our loss. We are a training network for 1000 epochs with a learning rate of 0.01. We can notice from the loss value getting printed every 100 epochs that our model seems to be doing a good job.

seed = random.PRNGKey(0)
epochs=1000

model = MultiLayerPerceptronClassifier() ## Define Model
random_arr = jax.random.normal(key=seed, shape=(5, features))
params = model.init(seed, random_arr) ## Initialize Model Parameters

optimizer = optax.sgd(learning_rate=1/1e2) ## Initialize SGD Optimizer

optimizer_state = optimizer.init(params)
loss_grad = jax.value_and_grad(NegLogLoss)

for i in range(1, epochs+1):
    loss_val, gradients = loss_grad(params, X_train, Y_train)
    updates, optimizer_state = optimizer.update(gradients, optimizer_state)
    params = optax.apply_updates(params, updates)
    if i % 100 == 0:
        print('NegLogLoss After {} Epochs : {:.2f}'.format(i, loss_val))
NegLogLoss After 100.00 Epochs :  0.6340824
NegLogLoss After 200.00 Epochs :  0.48894864
NegLogLoss After 300.00 Epochs :  0.32399088
NegLogLoss After 400.00 Epochs :  0.2098891
NegLogLoss After 500.00 Epochs :  0.15224986
NegLogLoss After 600.00 Epochs :  0.12405949
NegLogLoss After 700.00 Epochs :  0.10811356
NegLogLoss After 800.00 Epochs :  0.097638115
NegLogLoss After 900.00 Epochs :  0.08902239
NegLogLoss After 1000.00 Epochs :  0.08253476

Make Predictions

In this section, we have made predictions on train and test sets. We have called apply() method on the model instance to make predictions. We have given updated model weights and input data to the method to make predictions. As the output of our model is float values in the range [0,1] due to sigmoid activation function, we need to convert these floats to the actual prediction class. To do this, we have set the threshold at 0.5, predicting class 0 for values that are less than 0.5 and class 1 for values greater than 0.5.

test_preds = model.apply(params, X_test) ## Make Predictions on test dataset

test_preds = test_preds.ravel()

test_preds = (test_preds > 0.5).astype(jnp.float32)

test_preds[:5], Y_test[:5]
(DeviceArray([0., 0., 1., 1., 1.], dtype=float32),
 DeviceArray([0., 0., 1., 1., 1.], dtype=float32))
train_preds = model.apply(params, X_train) ## Make Predictions on train dataset

train_preds = train_preds.ravel()

train_preds = (train_preds > 0.5).astype(jnp.float32)

train_preds[:5], Y_train[:5]
(DeviceArray([1., 1., 0., 1., 1.], dtype=float32),
 DeviceArray([1., 1., 0., 0., 1.], dtype=float32))

Evaluate Model Performance

In this section, we have evaluated the performance of our model by calculating the accuracy of train and test predictions. We have also calculated classification report on test predictions that has information like precision, recall, and f1-score. We can notice from metrics results that our model seems to be doing a decent job.

from sklearn.metrics import accuracy_score

print("Train Accuracy : {:.2f}".format(accuracy_score(Y_train, train_preds)))
print("Test  Accuracy : {:.2f}".format(accuracy_score(Y_test, test_preds)))
Train Accuracy : 0.98
Test  Accuracy : 0.97
from sklearn.metrics import classification_report

print("Test Data Classification Report : ")
print(classification_report(Y_test, test_preds))
Test Data Classification Report :
              precision    recall  f1-score   support

         0.0       0.93      1.00      0.97        42
         1.0       1.00      0.96      0.98        72

    accuracy                           0.97       114
   macro avg       0.97      0.98      0.97       114
weighted avg       0.98      0.97      0.97       114

Sunny Solanki  Sunny Solanki

YouTube Subscribe Comfortable Learning through Video Tutorials?

If you are more comfortable learning through video tutorials then we would recommend that you subscribe to our YouTube channel.

Need Help Stuck Somewhere? Need Help with Coding? Have Doubts About the Topic/Code?

When going through coding examples, it's quite common to have doubts and errors.

If you have doubts about some code examples or are stuck somewhere when trying our code, send us an email at coderzcolumn07@gmail.com. We'll help you or point you in the direction where you can find a solution to your problem.

You can even send us a mail if you are trying something new and need guidance regarding coding. We'll try to respond as soon as possible.

Share Views Want to Share Your Views? Have Any Suggestions?

If you want to

  • provide some suggestions on topic
  • share your views
  • include some details in tutorial
  • suggest some new topics on which we should create tutorials/blogs
Please feel free to contact us at coderzcolumn07@gmail.com. We appreciate and value your feedbacks. You can also support us with a small contribution by clicking DONATE.