How to Train a Custom Keypoint Detection Model

With a keypoint detection model it is possible to detect interest points in an image which allows to estimate pose, angles and dimensions of humans and objects. This article covers the process of training a keypoint detection model on a custom data set.

Most popular applications of artificial intelligence and machine learning in computer vision include image classification, object detection and segmentation, However, if you are interested in detecting keypoints or points of interest in an image or video, keypoint detection models are your choice. Typically neural network based models are used that allow better detection accuracy. At the moment, R-CNN based keypoint detection models that use region based convolutional networks are typically the most promising approach to detect points of interest in an image.

Whilst a popular application of keypoint detectors is human pose estimation, the same approach can also be used to detect keypoints of objects and estimate the pose of objects. In a recent article we covered how we are using a keypoint detection model in order to detect and estimate the "pose" of a sailing boats.

Of course there are pretrained keypoint detection models available that can be used for keypoint detection. However, the issue is that most of the pretrained weights are specific for human pose estimation which does not generalize for pose detection of objects. Existing models might also not perform well enough on your data if your data set systematically deviates from the training data that was used to train the model. Commonly the coco data set is used to train keypoint detection models for human pose estimation.

Cost, time and effort to train a keypoint detection model

Before you get into training a keypoint detection model or hiring someone to develop it for you, you might wonder what the time effort and associated costs might be. It is very difficult to give a general answer since complexity varies greatly depending on the use case. As a first step, you might want to develop a keypoint detector prototype which could be used in a minimum viable product (MVP). The prototype would also help you to figure out potential issues that need to be resolved before training a production grade model and integrating in your product.

For a typical project the following steps are required to train a first prototype:

ActivityTaskTime Effort
Data AcquisitionTaking 50 - 100 images1 - 2 days which might be spread over longer period of time to account for variation
LabelingAnnotate the data1 - 2 hours
TrainingConfigure and train the keypoint detector~ 1 day using a default approach with a framework

Obviously, this is just a very rough guideline and for specific use cases time and costs might be very different. However, the good news is that with limited time and budget, it is typically possible to train a custom model that might be just good enough to keep you going. If you would like to get a better estimation, please do not hesitate to contact us.

Data to train a keypoint detection model

As for most machine learning and artificial intelligence models, appropriate data is required to train a keypoint detection model. The better and more representative your data, the more accurate your model will perform in production. There are various approaches in terms of data strategy that you can follow:

Custom data acquisition

The typical approach is that you acquire data specifically to train a keypoint detection model. Both images and annotations are needed for the training of the Machine Learning model.

As a first step, you use a camera to take images of the object with the keypoints. It is important to have enough variation in the images, using different angles, environments and lighting conditions. The question arises how many images are required. The answer depends very much on your use case and the required accuracy. The recommendation is to start with maybe around 50 - 100 images to develop a first prototype and increase the amount of data once other potential issues are resolved.

The next step is to label the data. Typically you need to know the bounding box of the object as well as the position of the individual keypoints. There are various labeling tools available. A good starting point might be the COCO Annotator which has support for annotating keypoints. Be aware that labeling images is a time consuming and repetitive task which is often outsourced to external companies. However, for a first prototype, it might pay off to carry out the annotation process in-house and only outsource once it is fully understood and potential issues are resolved. Accurate labels and a well defined annotation process are crucial for the successful training of a keypoint detector.

Using simulated images

In certain situations data acquisition might be challenging, complex and costly. In such cases, simulated images can be used to train a keypoint detection model. The idea behind using simulated images is to use computer software to render and generate images that visually look almost like real world images.

Obviously this approach requires the development of code that can render images: With some scripting language such as Python it is possible to connect to 3D computer graphics software such as Blender which allows for rendering of objects. In addition to that, data augmentation techniques can be used to allow for sufficient variation in the data set.

A significant advantage of using synthetic data is the fact that you do not need to annotate the data any more. Typically the software already knows the position of the keypoints that are rendered in the image. This approach is also very scalable since thousands of pictures and labels can be generated almost instantly.

Using existing data sets

It is worth the effort to check if there are any publicly available data sets that can be used to train your machine learning model.

Frameworks to train a Keypoint Detector

In this article I will cover some frameworks and pre-defined models that allow you to train or retrain a keypoint detector. I will also walk through the steps that are required to train a keypoint detector on a custom dataset, starting with the preparation of data, configuration of the model and training phase. With this you should be able to address a custom use case for keypoint detection.

There are various frameworks and libraries that you can use to train a keypoint detection model. Popular ones are:

Detectron2 Keypoint Detection

Detectron2 is an AI library developed by facebook that covers functionalities such as object detection, image segmentation as well as keypoint detection. The Detectron2 model zoo also offers pretrained weights for various keypoint R-CNN architectures trained on the coco person data set. If you wish to train your model on a custom data set a detailed instruction is provided on the data format as well as the training process.

Pytorch Keypoint Detection

Pytorch is an open source machine learning library that also offers support for keypoint detection models. There is a R-CNN keypoint detection model implemented that can be trained on a custom data set.

How to train a custom keypoint detector with Pytorch

In the following part I will walk through the steps that are required to train a R-CNN keypoint detector on a custom data set with Pytorch.

Prepare the data

The first step towards your own keypoint detection model is the preparation of the data. The keypoint detection model works with visual data which means that you would need suitable images as input. You could also use videos but would need to extract the relevant frames.

Along with the images annotations are required that will tell the model during the training process where the keypoints are located. Annotations should include the bounding box of the object that you would like to detect along with the position of the keypoints in the image. There are various annotation tools available that you could use to annotate your images.

Data augmentation techniques can help to add variation to your data set, especially if data acquisition is somehow impractical . This process will help your model later on the achieve better generalization.

The use of simulated images in machine learning might help in a scenario where the acquisition of real images is expensive. It might also reduce the effort for labeling as the ground truth data (boxes and position of keypoints) could be automatically generated by the rendering software.

Using the Pytorch dataset and dataloader class

Pytorch provides abstractions for data handling that can be used for the training and inference process of machine learning models.

The first abstraction that I would like to cover is the Pytorch dataset class which allows to iterate over a data set via the __getitem__ method. I was using the following implementation that would simply load an image from disk and convert it to a Pytorch tensor. During the training process the method would also return the annotations as a dictionary.

class KeypointsDataset(Dataset):
    def __init__(self, img_dir: Path, annotations_dir: Optional[Path] = None):
        self.img_dir = img_dir
        self.annotations_dir =annotations_dir
        self.images = sorted(list(self.img_dir.iterdir()))
        self.annotations = None
        if annotations_dir:
            self.annotations = []
            for image in self.images:
                self.annotations.append(self.annotations_dir / f"{image.stem}.json")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = Image.open(self.images[idx])
        image_tensor = ToTensor()(image)
        target = {}
        if self.annotations:
            with open(self.annotations[idx], "r") as f:
                annotations = json.load(f)
            target = {
                "boxes": torch.Tensor([a["box"] for a in annotations]),
                "keypoints": torch.Tensor([a"keypoints" for a in annotations]),
                "labels": torch.Tensor([1 for _ in annotations]).type(torch.int64),
            }
        return image_tensor, target

In many cases the Pytorch default DataLoader class can be used to load the sample data and iterate over the data set. You might need to use a custom collate_fn function to properly load your data and annotations.

Configure the R-CNN Keypoint detector

The implementation of the R-CNN keypoint detector in pytorch is rather flexible and allows customization for various scenarios. When creating the KeypointRCNN model a number of parameters can be adjusted:

First of all, you can configure the backbone architecture that you would like to use. This is the first part of the keypoint detection model which is used for the computation of features (feature extraction). Whilst you can rather freely choose the backbone architecture, only the resnet50 option currently comes with pretrained weights on the coco data. Whilst a more complex backbone might result in better accuracy, it typically costs time at inference stage resulting in slightly lower throughput.

Another important parameter is the number of keypoints per detection. This heavily depends on your personal setup and the amount of interest points you would like to obtain per detection.

Another important setting is the number of classes. It is important to note that the number of classes should include the background class, meaning that it should be set to 2 if you would like to detect a single class (one for the background and one for your class).

In addition you can also configure the number of detections per image. This can be handy if you already know that the number of detections and hence keypoints is limited in your data set as well as during inference. In the end, this setting will help with avoiding too many detections at inference time.

Whilst I already obtained reasonable performance by adjusting the parameters mentioned above, please refer to the documentation of the model for additional customization options.

Training the Keypoint Detector Model

Once the data, data handling and the model is configured, you can focus on the training process of the keypoint detector. A GPU with cuda support will come handy as it allows you to train the model faster compared to using CPU.

Typically the data is split into a training and validation set. The training data is used to train the model whilst performance is evaluated on the validation set. During the training process, a certain number of epochs is used which determines how often you iterate over the training data to allow the model to learn.

In case you encounter an issue with the keypoint loss such as the loss is always zero for instance, you might need to check the configuration of your model (e.g. number of classes which should include the background class). Another common source for this error is that the class labels of your data include 0 which is reserved for the background.

Pytorch offers functionality to serialize models and load models from disk. This is helpful for training as it allows you to save snapshots of your model, do some manual tests if required and resume training from a certain epoch.

Configure the Inference Pipeline for the Keypoint Detector

Once you have trained a keypoint detection model in pytorch you can still configure the model at inference time. Most important are probably the following two options:

  • box_score_thresh
    This is the detection threshold that determines which proposals are returned during inference. In case you see too many detections when trying out your model you might increase the threshold to limit the number of detections that are returned by the model.
  • box_nms_thresh
    The non maximum suppression threshold determines by how much the predicted boxes that contain the keypoints can overlap. It might well be that single instances are returned by the model multiple times with a high overlap. In this case you might reduce this parameter to limit the number of detections.

Once you have successfully trained and configured your keypoint detection model you can then focus on integrating the model into your application. This might include the export the the keypoint detection model to other formats such as onnx for instance or using a certain runtime to improve inference performance.

Please do not hesitate to contact us in case you wish further information or assistance with the training process of a keypoint detection model.