**Haiku** is a deep learning framework designed by **deepmind** on top of **JAX** to accelerate their deep learning experiments. **Haiku** provides built-in implementations for multi-layer perceptrons, convolutional nets, etc. As **Haiku** is designed on the top of **JAX**, a little bit of background with **JAX** can help better understand **Haiku**.

**Haiku** is designed to let developers use object-oriented programming approach to design neural network and at the same time use **JAX's** function transformations (**grad(), jit(), vmap(), pmap(), etc** ) on the object-oriented code. All **JAX's** important functions are designed to work with functions rather than class and when we design neural networks, we generally define it using class (**OOP**). **Haiku** makes it possible to use OOP-designed modules/classes like pure JAX functions.

As a part of this tutorial, we'll be explaining how we can create simple multi-layer perceptrons using **Haiku**. The main aim of the tutorial is to introduce individuals to **Haiku** framework. We have created two small examples explaining how we can use **Haiku** with small toy datasets to solve regression and classification tasks.

As we have used **JAX** in this tutorial for some tasks, if you want to learn about **JAX** then please feel free to check our tutorial on it.

**Installation**

**pip install -U dm-haiku**

Below we have highlighted important sections of the tutorial to give an overview of the material covered.

**Regression**- Load Data
- Normalize Data
- Define Neural Network
- Define Loss Function
- Train Neural Network
- Make Predictions
- Evaluate Model Performance
- Train Model in Batches of Data
- Make Predictions in Batches
- Evaluate Model Performance

**Classification**

Below we have imported **haiku** and **JAX** libraries which we'll use in our tutorial. We have also printed the versions of both libraries that we'll use in our tutorial.

```
import haiku as hk
print("Haiku Version :{}".format(hk.__version__))
```

```
import jax
print("JAX Version : {}".format(jax.__version__))
```

```
from jax import numpy as jnp
```

In this section, we'll explain how we can create a simple multi-layer perceptron using **Haiku** to solve simple regression tasks. We'll be using a small dataset available from scikit-learn for our example.

In this section, we have loaded the Boston housing dataset available from scikit-learn. We have loaded dataset features in variable **X** and target values in variable **Y**. The target values are median house prices in 1000 dollars which is continuous hence our problem is regression. After loading the dataset, we have divided it into the train (80%) and test (20%) sets.

Scikit-learn loads dataset as numpy arrays. We have also converted all our datasets from numpy to jax arrays.

```
from sklearn import datasets
from sklearn.model_selection import train_test_split
X, Y = datasets.load_boston(return_X_y=True)
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, train_size=0.8, random_state=123)
X_train, X_test, Y_train, Y_test = jnp.array(X_train, dtype=jnp.float32),\
jnp.array(X_test, dtype=jnp.float32),\
jnp.array(Y_train, dtype=jnp.float32),\
jnp.array(Y_test, dtype=jnp.float32),\
samples, features = X_train.shape
X_train.shape, X_test.shape, Y_train.shape, Y_test.shape
```

```
samples, features
```

In this section, we have normalized our datasets. Normalization is generally performed to bring all features values to the same scales so that optimization algorithms of our neural network converge faster. If values are on a different scale which varies by big amount then it can make the optimization process harder and it'll take more time for the algorithm to converge due to variance in feature values.

To normalize datasets, we have first calculated the mean and standard deviation of the training dataset for each feature. We'll have the mean and standard deviation of each feature of data. We'll then subtract this mean from both train and test datasets. Then, we'll divide subtracted values by standard deviation.

```
mean = X_train.mean(axis=0)
std = X_train.std(axis=0)
X_train = (X_train - mean) / std
X_test = (X_test - mean) / std
```

In this section, we have created a neural network using **Haiku** that we'll use for our regression task. We'll be using a class available from **Haiku** named **MLP** to create our neural network. The **MLP** class has a constructor with the same name which can be used to initialize a neural network. The **MLP** will create a simple multi-layer perceptron based on input layer sizes. Below we have included the definition of **MLP()** constructor.

**MLP(output_size=None,w_init=None,b_init=None,with_bias=True,activation=jax.nn.relu,activate_final=False, name=None)**- This constructor takes as input layer sizes and create instance of class**MLP**which we can later train and use to make prediction.- The
**layer_size**takes a list of numbers specifying the number of units per layer of neural network. It should include the number of units for the output layer as well. - The
**w_init**and**b_init**function takes callable as input to initialize weights and biases. - The
**with_bias**argument accepts boolean value specifying whether to use bias or not. By default, biases are added. - The
**activation**function accepts activation that will be applied to layers of neural networks. The default is**jax.nn.relu**. - The
**activate_final**function accepts boolean value specifying whether to activate final layer of neural network with activation specified using**activation**argument. By default, the last layer is not activated.

- The

The majority of neural networks that we create with **Haiku** are a subclass of **haiku.Module** class. We need to transform all class-based models to function based and we can do so by using **hk.transform()** function. It takes as another function and transforms function to pure jax functions.

We have first defined our **MLP** with layer sizes **[5,10,15,1]** inside of another function. We are providing input data to that function and calling our **MLP** instance with input data and returning it. This function basically creates a neural network and performs forward pass through input data. It returns predictions at last. We have transformed this method using **hk.transform()** method. It returns a transformed object which has two methods.

**init(rng, data)**- This method takes seed for random numbers and a few samples as input. It then initialized the weights of the neural network and returns it. It returns an instance of**FlatMap**which is a dictionary-like object which has weights and biases of all layers of the neural network.**apply(weights, rng, data)**- This method takes weights, seed, and data features as input. It then performs forward pass-through data with weights and returns predictions.

Below we have first created our neural network and stored it in a variable named **model**. We have then called **init()** method on it to retrieve weights of the neural network. We have also printed the shape of weights and biased for all layers for verification. Then, in the next cell, we have called **apply()** method on our neural network with weights, the seed for random numbers, and train data. It returns predictions made on data. We have printed the first few predictions. These are predictions with initial weights, we have not trained the neural network and updated weights yet.

```
def FeedForward(x):
mlp = hk.nets.MLP(output_sizes=[5,10,15,1])
return mlp(x)
```

```
model = hk.transform(FeedForward)
```

```
rng = jax.random.PRNGKey(42)
params = model.init(rng, X_train[:5])
print("Weights Type : {}\n".format(type(params)))
for layer_name, weights in params.items():
print(layer_name)
print("Weights : {}, Biases : {}\n".format(params[layer_name]["w"].shape,params[layer_name]["b"].shape))
```

```
preds = model.apply(params, rng, X_train)
preds[:5]
```

In this section, we have defined the loss function that we'll be using for our task. We'll be using **Mean Squared Error Loss (MSE)** function as our loss. The function takes weights, input data, and actual target values as input. It then uses **apply()** method of the model to make predictions on input data using weights. Then we calculate **MSE** loss with predictions and actual target values. We calculate **MSE** by first subtracting predictions from actual target values. Then, we square differences and take the mean of all squared differences.

`MSE(actuals, predictions) = 1/n * (actuals - predictions)^2`

```
def MeanSquaredErrorLoss(weights, input_data, actual):
preds = model.apply(weights, rng, input_data)
preds = preds.squeeze()
return jnp.power(actual - preds, 2).mean()
```

In this section, we are training our neural network.

First, we have defined a simple function that takes weights and gradients as input. It then subtracts learning rate times gradients from weights. We'll be using this function to iteratively update the weights of all layers.

```
def UpdateWeights(weights,gradients):
return weights - learning_rate * gradients
```

Below, we have included logic to train our neural network. We have first initialized the weights of our neural network using **init()** method of our model by providing seed for random numbers and a few data samples. We have then initialized the number of epochs (**1000**) and learning rate (**0.001**).

We then perform a training loop number of epoch times. The logic inside the loop is pretty simple and straightforward. We have first called **value_and_grad()** function with our loss function. The **value_and_grad()** function takes as input a function and returns another function. We can then call returned function with parameter values, it'll return two values as output. The first value in output will be the actual value of that wrapped function with input parameter values and the second value will be gradients of input function with respect to the first input parameter.

In our case first output will be **MSE** loss value and second value will be gradients of **MSE** with respect to weights (first parameter of **MeanSquaredErrorLoss()**).

Then on the next line, we have logic to update the weights of our neural network. We have used **jax.tree_map()** function for updating weights by subtracting learning rate times gradients. The **jax.treemap()** function takes function followed by arguments of function as input. The input arguments have a tree-like structure. It then performs a given function on each leaf of input arguments. This process of updating weights based on learning rate and gradients is generally referred to as **gradient descent**

In our case, weights are the tree-like structure that we had printed when we defined the neural network earlier. The gradients will be in the same structure. We are then applying the update weights function which we declared in our previous cell to each leaf of our weights and gradients. This function will return a new tree-like data structure with weights updates based on gradients and learning rate.

We are also printing **MSE** at every 100 epochs. We can notice from the **MSE** getting printed at every 100 epochs that our model seems to be doing a good job.

```
from jax import value_and_grad
rng = jax.random.PRNGKey(42) ## Reproducibility ## Initializes model with same weights each time.
params = model.init(rng, X_train[:5])
epochs = 1000
learning_rate = jnp.array(0.001)
for i in range(1, epochs+1):
loss, param_grads = value_and_grad(MeanSquaredErrorLoss)(params, X_train, Y_train)
params = jax.tree_map(UpdateWeights, params, param_grads)
if i%100 == 0:
print("MSE : {:.2f}".format(loss))
```

In this section, we have made predictions on our train and test datasets using **apply()** method of our model.

```
train_preds = model.apply(params, rng, X_train)
train_preds[:5]
```

```
test_preds = model.apply(params, rng, X_test)
test_preds[:5]
```

In this section, we have evaluated the performance of our model by evaluating **MSE** loss and **R^2 score** on our train and test predictions. The **R^2 score** returns value in the range **[0,1]** and values near 1 are considered good model. We can notice from **R^2 score** on our train and test predictions that our model is doing a decent job.

If you are interested in learning in detail about **R^2** score and other metrics available from scikit-learn for different kinds of tasks then please feel free to check our tutorial which covers the majority of metrics in detail with examples.

```
print("Test MSE Score : {:.2f}".format(MeanSquaredErrorLoss(params, X_test, Y_test)))
print("Train MSE Score : {:.2f}".format(MeanSquaredErrorLoss(params, X_train, Y_train)))
```

```
from sklearn.metrics import r2_score
print("Test R^2 Score : {:.2f}".format(r2_score(test_preds.squeeze(), Y_test)))
print("Train R^2 Score : {:.2f}".format(r2_score(train_preds.squeeze(), Y_train)))
```

In real-life, datasets are generally large and many times datasets might not fit into the main memory of the computer. To handle datasets that do not fit into the main memory of the computer, we follow an approach where we bring a small batch of samples into the main memory and train the model with a small batch of data. We cover the whole data by training the model on a small batch of data at a time. We update the weights of the neural network for each batch of data. This algorithm of updating weights based on a small batch of data is referred to as **stochastic gradient descent**.

Our current dataset is quite small and easily fits into the main memory of the computer but we'll treat it as a big dataset that does not fit into the main memory of the computer. We'll divide the dataset into batches of data and train the model on a small batch of data at a time. Below, we have included logic to train data in small batches.

We have initialized model weights first using the seed of random numbers. We have then initialized a number of epochs (**500**), batch size (**32**), and learning rate (**0.001**). We have then executed the training loop number of epochs time. Each time, we have generated indexes for batches of data. We are then looping through the whole data in batches. For each batch, we calculate loss and gradients and then update model weights using gradients. We are updating the weights of the model for each batch of data.

We can notice from the **MSE** loss getting printed at every 100 epochs that our model seems to be doing a good job.

```
from jax import value_and_grad
rng = jax.random.PRNGKey(42) ## Reproducibility ## Initializes model with same weights each time.
params = model.init(rng, X_train[:5])
epochs = 500
batch_size = 32
learning_rate = jnp.array(0.001)
for i in range(1, epochs+1):
batches = jnp.arange((X_train.shape[0]//batch_size)+1) ### Batch Indices
losses = [] ## Record loss of each batch
for batch in batches:
if batch != batches[-1]:
start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
else:
start, end = int(batch*batch_size), None
X_batch, Y_batch = X_train[start:end], Y_train[start:end] ## Single batch of data
loss, param_grads = value_and_grad(MeanSquaredErrorLoss)(params, X_batch, Y_batch)
params = jax.tree_map(UpdateWeights, params, param_grads) ## Update Params
losses.append(loss) ## Record Loss
if i % 100 == 0: ## Print MSE every 100 epochs
print("MSE : {:.2f}".format(jnp.array(losses).mean()))
```

As we can not fit whole data into the main memory of the computer, we need to do predictions also on a batch of data. Below we have defined a function that takes updated model weights and input data. It then loops through data in batches making predictions on a batch of data at a time. It then combines predictions of all batches and returns them. It uses the same logic to create batch indexes that were used during the training section to create batches.

We have then used the function to make predictions on train and test datasets.

```
def MakePredictions(weights, input_data, batch_size=32):
batches = jnp.arange((input_data.shape[0]//batch_size)+1) ### Batch Indices
preds = []
for batch in batches:
if batch != batches[-1]:
start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
else:
start, end = int(batch*batch_size), None
X_batch = input_data[start:end]
preds.append(model.apply(weights, rng, X_batch))
return preds
```

```
train_preds = MakePredictions(params, X_train, 32)
train_preds = jnp.concatenate(train_preds).squeeze()
test_preds = MakePredictions(params, X_test, 32)
test_preds = jnp.concatenate(test_preds).squeeze()
```

In this section, we have evaluated the performance of the model by calculating **R^2 score** on train and test predictions.

```
print("Test MSE Score : {:.2f}".format(MeanSquaredErrorLoss(params, X_test, Y_test)))
print("Train MSE Score : {:.2f}".format(MeanSquaredErrorLoss(params, X_train, Y_train)))
```

```
from sklearn.metrics import r2_score
print("Test R^2 Score : {:.2f}".format(r2_score(test_preds.squeeze(), Y_test)))
print("Train R^2 Score : {:.2f}".format(r2_score(train_preds.squeeze(), Y_train)))
```

In this section, we'll explain how to create a simple multi-layer perceptron using **MLP** class of **Haiku** to solve classification tasks. We'll be using a small toy dataset available from scikit-learn for explanation purposes. We have reused the majority of the code from the previous regression section in this section hence we don't have included a detailed description of code sections that are repeated here. If you don't find a detailed description of some code sections then please feel free to check the same section in the regression section as it'll have a detailed description.

In this section, we have loaded the breast cancer dataset available from scikit-learn. We have loaded data features in variable **X** and target values in variable **Y**. The target values are either **1 (malignant tumor)** or **0 (benign tumor)**. As our target values have only two classes, this will be a binary classification problem.

After loading the dataset, we have divided it into the train (80%) and test (20%) sets. We have also converted datasets held in numpy arrays to jax arrays.

```
from sklearn import datasets
from sklearn.model_selection import train_test_split
X, Y = datasets.load_breast_cancer(return_X_y=True)
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, train_size=0.8, stratify=Y, random_state=123)
X_train, X_test, Y_train, Y_test = jnp.array(X_train, dtype=jnp.float32),\
jnp.array(X_test, dtype=jnp.float32),\
jnp.array(Y_train, dtype=jnp.float32),\
jnp.array(Y_test, dtype=jnp.float32)
samples, features = X_train.shape
classes = jnp.unique(Y)
X_train.shape, X_test.shape, Y_train.shape, Y_test.shape
```

```
samples, features, classes
```

In this section, we have normalized our train and test datasets using the mean and standard deviation of features calculated on the training dataset.

```
mean = X_train.mean(axis=0)
std = X_train.std(axis=0)
X_train = (X_train - mean) / std
X_test = (X_test - mean) / std
```

In this section, we have designed a neural network that we'll be using for our binary classification task. The network design is exactly the same as that of the regression section. As **Haiku** **MLP()** constructor does not let us explicitly specify activation function for the last layer, we have separately added a **sigmoid()** activation function in our loss function to be applied to the output of the neural network. We'll apply **sigmoid()** function available from **jax.nn** module when making predictions as well.

```
def FeedForward(x):
mlp = hk.nets.MLP(output_sizes=[5,10,15,1])
return mlp(x)
```

```
model = hk.transform(FeedForward)
```

```
rng = jax.random.PRNGKey(42)
params = model.init(rng, X_train[:5])
for layer_name, weights in params.items():
print(layer_name)
print("Weights : {}, Biases : {}\n".format(params[layer_name]["w"].shape,params[layer_name]["b"].shape))
```

In this section, we have defined a loss function that we'll be using our binary classification task. We'll be using log loss for our task.

`log_loss(actuals, predictions) = 1/n * ( - actuals * log(predictions) - (1-actuals) * log(1-predictions))`

The function takes weights, features data, and actual target values as input. It then uses **apply()** method of the model to make predictions. After making predictions, we have applied sigmoid activation to the output of the last layer using **jax.nn.sigmoid()** function. Then we have calculated loss using predictions and actual target values.

```
def NegLogLoss(weights, input_data, actual):
preds = model.apply(weights, rng, input_data)
preds = preds.squeeze()
preds = jax.nn.sigmoid(preds)
return (- actual * jnp.log(preds) - (1 - actual) * jnp.log(1 - preds)).mean()
```

In this section, we have included the logic to train the neural network. Our logic to train a neural network is the same as the logic we used in the regression section. Only our parameter settings and loss function are different. We have set the number of epochs to **1500** and the learning rate to **0.001**. We can notice from the loss value getting printed every 100 epochs that our model seems to be doing a good job.

```
def UpdateWeights(weights,gradients):
return weights - learning_rate * gradients
```

```
from jax import value_and_grad
rng = jax.random.PRNGKey(42) ## Reproducibility ## Initializes model with same weights each time.
params = model.init(rng, X_train[:5])
epochs = 1500
learning_rate = jnp.array(0.001)
for i in range(1,epochs+1):
loss, param_grads = value_and_grad(NegLogLoss)(params, X_train, Y_train)
params = jax.tree_map(UpdateWeights, params, param_grads)
if i%100 == 0:
print("NegLogLoss : {:.2f}".format(loss))
```

In this section, we have made predictions using the latest weights on train and test datasets. After making predictions using **apply()** method, we have applied **jax.nn.sigmoid()** function to outputs to bring all output values in the range **[0-1]**. We have set the threshold at 0.5 to predict the target class. The values less than 0.5 will be predicted as 0 (benign tumor) and values greater than 0.5 will be predicted as 1 (malignant tumor).

```
train_preds = model.apply(params, rng, X_train)
train_preds = jax.nn.sigmoid(train_preds.squeeze())
train_preds = (train_preds > 0.5).astype(jnp.float32)
test_preds = model.apply(params, rng, X_test)
test_preds = jax.nn.sigmoid(test_preds.squeeze())
test_preds = (test_preds > 0.5).astype(jnp.float32)
```

In this section, we have evaluated the performance of our model by calculating the accuracy of train and test predictions. We can notice from the results that our model seems to be doing a good job.

```
print("Test NegLogLoss Score : {:.2f}".format(NegLogLoss(params, X_test, Y_test)))
print("Train NegLogLoss Score : {:.2f}".format(NegLogLoss(params, X_train, Y_train)))
```

```
from sklearn.metrics import accuracy_score
print("Train Accuracy : {:.2f}".format(accuracy_score(Y_train, train_preds)))
print("Test Accuracy : {:.2f}".format(accuracy_score(Y_test, test_preds)))
```

In this section, we have included logic to train a neural network in batches. The code for this section is exactly the same as the code from the regression section with only changes in parameter values. We have set epochs to **500**, batch size to **32**, and learning rate to **0.001**. We can notice from the loss value getting printed every 100 epochs that our model is doing a good job.

```
from jax import value_and_grad
rng = jax.random.PRNGKey(42) ## Reproducibility ## Initializes model with same weights each time.
params = model.init(rng, X_train[:5])
epochs = 500
batch_size = 32
learning_rate = jnp.array(0.001)
for i in range(1,epochs+1):
batches = jnp.arange((X_train.shape[0]//batch_size)+1) ### Batch Indices
losses = [] ## Record loss of each batch
for batch in batches:
if batch != batches[-1]:
start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
else:
start, end = int(batch*batch_size), None
X_batch, Y_batch = X_train[start:end], Y_train[start:end] ## Single batch of data
loss, param_grads = value_and_grad(NegLogLoss)(params, X_batch, Y_batch)
params = jax.tree_map(UpdateWeights, params, param_grads) ## Update Params
losses.append(loss) ## Record Loss
if i % 100 == 0: ## Print NegLogLoss every 100 epochs
print("NegLogLoss : {:.2f}".format(jnp.array(losses).mean()))
```

In this section, we have made predictions on train and test datasets in batches. We have used the same function we had defined in the regression section to make predictions on data in batches. The main difference here is that after making a prediction, we have applied the sigmoid function to the output and then predicted target classes by setting the threshold at 0.5.

```
def MakePredictions(weights, input_data, batch_size=32):
batches = jnp.arange((input_data.shape[0]//batch_size)+1) ### Batch Indices
preds = []
for batch in batches:
if batch != batches[-1]:
start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
else:
start, end = int(batch*batch_size), None
X_batch = input_data[start:end]
preds.append(model.apply(weights, rng, X_batch))
return preds
```

```
train_preds = MakePredictions(params, X_train, 32)
train_preds = jnp.concatenate(train_preds).squeeze()
train_preds = jax.nn.sigmoid(train_preds)
train_preds = (train_preds > 0.5).astype(jnp.float32)
train_preds[:5], Y_train[:5]
```

```
test_preds = MakePredictions(params, X_test, 32)
test_preds = jnp.concatenate(test_preds).squeeze()
test_preds = jax.nn.sigmoid(test_preds)
test_preds = (test_preds > 0.5).astype(jnp.float32)
test_preds[:5], Y_test[:5]
```

In this section, we have evaluated the performance of our model by calculating the accuracy of train and test predictions.

```
print("Test NegLogLoss Score : {:.2f}".format(NegLogLoss(params, X_test, Y_test)))
print("Train NegLogLoss Score : {:.2f}".format(NegLogLoss(params, X_train, Y_train)))
```

```
from sklearn.metrics import accuracy_score
print("Train Accuracy : {:.2f}".format(accuracy_score(Y_train, train_preds)))
print("Test Accuracy : {:.2f}".format(accuracy_score(Y_test, test_preds)))
```

This ends our small tutorial explaining how we can use **haiku** to create simple multi-layer perceptrons. Please feel free to let us know your views in the comments section.

If you are more comfortable learning through video tutorials then we would recommend that you subscribe to our **YouTube** channel.

When going through coding examples, it's quite common to have doubts and errors.

If you have doubts about some code examples or are stuck somewhere when trying our code, send us an email at **coderzcolumn07@gmail.com**. We'll help you or point you in the direction where you can find a solution to your problem.

You can even send us a mail if you are trying something new and need guidance regarding coding. We'll try to respond as soon as possible.

If you want to

- provide some suggestions on topic
- share your views
- include some details in tutorial
- suggest some new topics on which we should create tutorials/blogs