What is Image Segmentation?
Image segmentation is the process of segmenting images into segments (also referred to as objects). We detect objects present in images and color them to separate them from each other. It mainly concentrates on detecting boundaries of objects hence they can be easily separated. Many times, we even label each segment/object detected.
Applications of Image Segmentation
Types of Image Segmentation
How to perform Image Segmentation on Image?
Over the years many approaches have been developed for solving image segmentation tasks. Some of them use machine learning (deep learning) whereas others use non-ML solutions. The python library scikit-image has an implementation of majority of non-ML methods. We have listed below few famous approaches to solving image segmentation tasks using non-ML approaches.
Majority of ML approaches involve use of deep neural networks consisting of layers like convolution, dense, etc. Below, we have listed some of the famous neural networks that solve image segmentation tasks.
What Can You Learn From This Tutorial?
As a part of this tutorial, we have explained how to use pre-trained PyTorch models available from torchvision module for image segmentation tasks. Torchvision is a computer vision toolkit of PyTorch and provides pre-trained models for many computer vision tasks like image classification, object detection, image segmentation, etc.
We have downloaded few images from the internet and tried pre-trained models on them. We have explained usage of both instance and semantic segmentation models. Torchvision provides models that are trained on datasets COCO and Pascal VOC. It provides an implementation of majority of deep learning models we have listed above.
Below, we have listed essential sections of tutorial to give an overview of material covered.
Below, we have imported necessary Python libraries that we have used in our tutorial. We have also printed the version that we have used.
import torch
print("PyTorch Version : {}".format(torch.__version__))
import torchvision
print("TorchVision Version : {}".format(torchvision.__version__))
import gc
In this section, we have downloaded three images from the internet that we'll use for our tutorial. We'll try image segmentation algorithms on these images.
The images have objects like people, dogs, toys, etc that we'll try to detect using image segmentation algorithms.
!wget https://www.luxurytravelmagazine.com/files/593/2/80152/luxury-travel-instagram_bu.jpg
!wget https://www.akc.org/wp-content/uploads/2020/12/training-behavior.jpg
!wget https://images.squarespace-cdn.com/content/v1/519bd105e4b0c8ea540e7b36/1555002210238-V3YQS9DEYD2QLV6UODKL/The-Benefits-Of-Playing-Outside-For-Children.jpg
%ls
After downloading images, we loaded them in memory using the Python Library pillow.
from PIL import Image
holiday = Image.open("luxury-travel-instagram_bu.jpg")
holiday
kids_playing = Image.open("The-Benefits-Of-Playing-Outside-For-Children.jpg")
dog_kid_playing = Image.open("training-behavior.jpg")
Below, we have converted all our pillow images to Pytorch tensors using pil_to_tensor() function available from torchvision module. All Pytorch models require input to be tensors.
from torchvision.transforms.functional import pil_to_tensor
holiday_tensor_int = pil_to_tensor(holiday)
kids_playing_tensor_int = pil_to_tensor(kids_playing)
dog_kid_playing_tensor_int = pil_to_tensor(dog_kid_playing)
holiday_tensor_int.shape, kids_playing_tensor_int.shape, dog_kid_playing_tensor_int.shape
holiday_tensor_int.dtype
In this section, we'll load image segmentation models in memory that we'll use on our images. We have loaded one model for explaining semantic segmentation and one for instance segmentation.
Below, we have loaded FCN (with RestNet50 backbone) deep neural network model. The model is available through method fcn_resnet50() from segmentation sub-module of torchvision module.
We need to provide weights parameter to load model with default weights. The segmentation module has an attribute named FCN_ResNet50_Weights that let us specify which weights to use. We have asked to load the model with COCO_WITH_VOC_LABELS_V1 weights. These weights are from model trained on COCO dataset.
Currently, the only weight option available with this model is COCO_WITH_VOC_LABELS_V1 which we have used in our tutorial. There are can be different weight options available if a model is trained with different datasets.
After loading model, we set it in evaluation mode by calling eval() method on it.
from torchvision.models.segmentation import fcn_resnet50, FCN_ResNet50_Weights
fcn_resnet = fcn_resnet50(weights=FCN_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1)
fcn_resnet.eval();
Below, we have loaded Mask R-CNN model which is available from detection sub-module of torchvision module. The model is loaded with COCO_V1 weights which are retrieved from model trained on COCO dataset.
from torchvision.models.detection import maskrcnn_resnet50_fpn, MaskRCNN_ResNet50_FPN_Weights
maskrcnn_resnet = maskrcnn_resnet50_fpn(weights=MaskRCNN_ResNet50_FPN_Weights.COCO_V1)
maskrcnn_resnet.eval();
In this section, we'll make predictions on our images using pre-trained models that we loaded in previous section. We'll need to preprocess images before making predictions on them.
The weights objects have a method named transforms() which can be used to prepare images for network. We have saved references to that method.
By default ,method resizes image to size 520, rescale in the range [0.0, 1.0] and normalize using mean [0.485, 0.456, 0.406] & standard deviation [0.229, 0.224, 0.225].
In our case, we have prevented resizing by setting that parameter value to None. This is done because we want to overlay segmented image on original image.
preprocess_img = FCN_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1.transforms(resize_size=None)
preprocess_img(holiday_tensor_int).unsqueeze(dim=0).shape
Below, we have made predictions using our trained model on the images that we loaded earlier. We are giving network preprocessed images. We have also introduced batch dimension using unsqueeze() because model works on batch of images.
The output of model is a dictionary with two keys.
The out key value has segmented image which we'll visualize in next section.
holiday_preds1 = fcn_resnet(preprocess_img(holiday_tensor_int).unsqueeze(dim=0))
gc.collect();
holiday_preds1.keys()
kids_playing_preds1 = fcn_resnet(preprocess_img(kids_playing_tensor_int).unsqueeze(dim=0))
gc.collect();
dog_kid_playing_preds1 = fcn_resnet(preprocess_img(dog_kid_playing_tensor_int).unsqueeze(dim=0))
gc.collect();
Below, we have first retrieved image preprocessing function from weights object.
Then, we made predictions on our images using Mask R-CNN model that we loaded earlier. We have given network preprocessed images for making predictions.
The prediction of network is a dictionary with 4 keys.
As this network is part of detection module, it returns bounding boxes around detected objects in images.
preprocess_img = MaskRCNN_ResNet50_FPN_Weights.COCO_V1.transforms()
preprocess_img(holiday_tensor_int).shape
holiday_preds2 = maskrcnn_resnet(preprocess_img(holiday_tensor_int).unsqueeze(dim=0))
gc.collect();
holiday_preds2[0].keys()
kids_playing_preds2 = maskrcnn_resnet(preprocess_img(kids_playing_tensor_int).unsqueeze(dim=0))
gc.collect();
dog_kid_playing_preds2 = maskrcnn_resnet(preprocess_img(dog_kid_playing_tensor_int).unsqueeze(dim=0))
gc.collect();
In this section, we'll visualize predictions made by our image segmentation models
In this section, we'll visualize predictions made by semantic segmentation model.
In order to do that, we have first created a dictionary that maps object names to their index. We'll be using this mapping to retrieve segmentation results for a particular object.
class_to_idx = {cls: idx for (idx, cls) in enumerate(FCN_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1.meta["categories"])}
class_to_idx
Below, we have first retrieved predictions made on our holiday image in a variable. The prediction is present in out key of prediction as we mentioned earlier.
The shape of prediction present in out key is (#class, height, width) where '#class' is a number of objects that model is trained to detect. In our case, it is 20 which we can see from dictionary loaded in previous cell.
After loading prediction, we have normalized masks using softmax() function.
Then, we have retrieved mask present at person index and visualized it. We can notice from the image that it is able to detect both persons present in an image.
from torchvision.transforms.functional import to_pil_image
prediction = holiday_preds1['out']
normalized_masks = prediction.softmax(dim=1)[0]
to_pil_image(normalized_masks[class_to_idx['person']])
The torchvision module provides a method named draw_segmentation_masks() that let us overlay detected objects on original image.
First, we have created a boolean mask from normalized mask by setting condition which sets probabilities above 0.7 to True.
Then, we have called draw_segmentation_masks() method with an original image and person mask to overlay detected person objects on original image.
from torchvision.utils import draw_segmentation_masks
masks = normalized_masks > 0.7
out = draw_segmentation_masks(holiday_tensor_int, masks[class_to_idx['person']])
to_pil_image(out)
Below, we have overlaid background on an original image using same process as previous cell.
from torchvision.utils import draw_segmentation_masks
masks = normalized_masks > 0.7
background = draw_segmentation_masks(holiday_tensor_int, masks[class_to_idx['__background__']])
to_pil_image(background)
Below, we have retrieved persons object present in our section image and visualized it. We can notice that model is correctly identifying all kids present in an image.
from torchvision.transforms.functional import to_pil_image
prediction = kids_playing_preds1['out']
normalized_masks = prediction.softmax(dim=1)[0]
to_pil_image(normalized_masks[class_to_idx['person']])
Below, we have overlaid detected kids on original image.
from torchvision.utils import draw_segmentation_masks
masks = normalized_masks > 0.7
out = draw_segmentation_masks(kids_playing_tensor_int, masks[class_to_idx['person']])
to_pil_image(out)
Below, we have retrieved person object for third image and visualized it.
In the next cell, we have retrieved a dog object from same image and visualized it as well.
On later cells, we have overlaid kid and dog object on original image.
from torchvision.transforms.functional import to_pil_image
prediction = dog_kid_playing_preds1['out']
normalized_masks = prediction.softmax(dim=1)[0]
to_pil_image(normalized_masks[class_to_idx['person']])
to_pil_image(normalized_masks[class_to_idx['dog']])
from torchvision.utils import draw_segmentation_masks
masks = normalized_masks > 0.7
person = draw_segmentation_masks(dog_kid_playing_tensor_int, masks[class_to_idx['person']])
to_pil_image(person)
from torchvision.utils import draw_segmentation_masks
masks = normalized_masks > 0.1
dog = draw_segmentation_masks(dog_kid_playing_tensor_int, masks=masks[class_to_idx['dog']])
to_pil_image(dog)
In this section, we'll visualize predictions made by our Mask R-CNN instance segmentation model.
Below, we have first retrieved mapping for labels present through MaskRCNN_ResNet50_FPN_Weights.COCO_V1 object. The dictionary has a mapping from label index to category name. We'll use this dictionary to convert labels predicted by model to category names.
class_to_idx = {cls: idx for (idx, cls) in enumerate(MaskRCNN_ResNet50_FPN_Weights.COCO_V1.meta["categories"])}
len(class_to_idx)
Below, we have retrieved predicted masks and their labels from prediction dictionary. We have then retrieved category names from predicted label indexes. We have also predicted object categories.
from torchvision.utils import draw_segmentation_masks
mapping = MaskRCNN_ResNet50_FPN_Weights.COCO_V1.meta["categories"]
masks = holiday_preds2[0]['masks'].squeeze()
labels = holiday_preds2[0]['labels']
categories = [mapping[label] for label in labels[:15]]
print("Detected Objects : {}".format(categories))
print("Unique Objects : {}".format(list(set(categories))))
Below, we have overlaid predicted objects on an original image using draw_segmentation_masks() method. We have provided method original image and predicted masks. We have not given all masks to method but only first 15 as we want to highlight few important objects.
We have also given mapping to colors parameter to color objects.
We can notice from the results how model is detecting objects like persons, backpack, etc.
from torchvision.utils import draw_segmentation_masks
color_mapping = {"person": "tomato", "kite": "dodgerblue", "backpack": "yellow", "sports ball": "green", "dog": "orange"}
colors = [color_mapping[mapping[label]] for label in labels[:15]]
output = draw_segmentation_masks(holiday_tensor_int, masks=masks[:15].to(torch.bool), colors=colors)
to_pil_image(output)
Below, we have performed same process to detect and visualize objects present in our second image involving kids.
from torchvision.utils import draw_segmentation_masks
mapping = MaskRCNN_ResNet50_FPN_Weights.COCO_V1.meta["categories"]
masks = kids_playing_preds2[0]['masks'].squeeze()
labels = kids_playing_preds2[0]['labels']
categories = [mapping[label] for label in labels[:5]]
print("Detected Objects : {}".format(categories))
print("Unique Objects : {}".format(list(set(categories))))
from torchvision.utils import draw_segmentation_masks
color_mapping = {"person": "tomato", "kite": "dodgerblue", "backpack": "yellow", "sports ball": "green", "dog": "orange", "frisbee": "pink"}
colors = [color_mapping[mapping[label]] for label in labels[:5]]
output = draw_segmentation_masks(kids_playing_tensor_int, masks=masks[:5].to(torch.bool), colors=colors)
to_pil_image(output)
Below, we have performed same process to detect and visualize objects present in our third image where a kid and a dog are playing.
from torchvision.utils import draw_segmentation_masks
mapping = MaskRCNN_ResNet50_FPN_Weights.COCO_V1.meta["categories"]
masks = dog_kid_playing_preds2[0]['masks'].squeeze()
labels = dog_kid_playing_preds2[0]['labels']
categories = [mapping[label] for label in labels]
print("Detected Objects : {}".format(categories))
print("Unique Objects : {}".format(list(set(categories))))
from torchvision.utils import draw_segmentation_masks
color_mapping = {"person": "tomato", "kite": "dodgerblue", "backpack": "yellow", "sports ball": "green", "dog": "orange", "frisbee": "pink", "baseball glove": "grey"}
colors = [color_mapping[mapping[label]] for label in labels[:3]]
output = draw_segmentation_masks(dog_kid_playing_tensor_int, masks=masks[:3].to(torch.bool), colors=colors)
to_pil_image(output)
The torchvision module has other pre-trained models that are available for image segmentation tasks which can be tried to check how they perform. Below, we have listed them.
This ends our small tutorial explaining how we can use pre-trained pytorch models for image segmentation tasks.
If you are more comfortable learning through video tutorials then we would recommend that you subscribe to our YouTube channel.
When going through coding examples, it's quite common to have doubts and errors.
If you have doubts about some code examples or are stuck somewhere when trying our code, send us an email at coderzcolumn07@gmail.com. We'll help you or point you in the direction where you can find a solution to your problem.
You can even send us a mail if you are trying something new and need guidance regarding coding. We'll try to respond as soon as possible.
If you want to