How to Train a Custom Keypoint Detection Model

With a keypoint detection model it is possible to detect interest points in an image which allows to estimate pose, angles and dimensions of humans and objects. This article covers the process of training a keypoint detection model on a custom data set.

Most popular applications of artificial intelligence and machine learning in computer vision include image classification, object detection and segmentation, However, if you are interested in detecting keypoints or points of interest in an image or video, keypoint detection models are your choice. Typically neural network based models are used that allow better detection accuracy. At the moment, R-CNN based keypoint detection models that use region based convolutional networks are typically the most promising approach to detect points of interest in an image.

Whilst a popular application of keypoint detectors is human pose estimation, the same approach can also be used to detect keypoints of objects and estimate the pose of objects. In a recent article we covered how we are using a keypoint detection model in order to detect and estimate the "pose" of a sailing boats.

Of course there are pretrained keypoint detection models available that can be used for keypoint detection. However, the issue is that most of the pretrained weights are specific for human pose estimation which does not generalize for pose detection of objects. Existing models might also not perform well enough on your data if your data set systematically deviates from the training data that was used to train the model. Commonly the coco data is used to train keypoint detection models for human pose estimation.

In this article I will cover some frameworks and pre-defined models that allow you to train or retrain a keypoint detector. I will also walk through the steps that are required to train a keypoint detector on a custom dataset, starting with the preparation of data, configuration of the model and training phase. With this you should be able to address a custom use case for keypoint detection.

Frameworks to train a Keypoint Detector

There are various frameworks and libraries that you can use to train a keypoint detection model. Popular ones are:

Detectron2 Keypoint Detection

Detectron2 is an AI library developed by facebook that covers functionalities such as object detection, image segmentation as well as keypoint detection. The Detectron2 model zoo also offers pretrained weights for various keypoint R-CNN architectures trained on the coco person data set. If you wish to train your model on a custom data set a detailed instruction is provided on the data format as well as the training process.

Pytorch Keypoint Detection

Pytorch is an open source machine learning library that also offers support for keypoint detection models. There is a R-CNN keypoint detection model implemented that can be trained on a custom data set.

How to train a custom keypoint detector with Pytorch

In the following part I will walk through the steps that are required to train a R-CNN keypoint detector on a custom data set with Pytorch.

Prepare the data

The first step towards your own keypoint detection model is the preparation of the data. The keypoint detection model works with visual data which means that you would need suitable images as input. You could also use videos but would need to extract the relevant frames.

Along with the images annotations are required that will tell the model during the training process where the keypoints are located. Annotations should include the bounding box of the object that you would like to detect along with the position of the keypoints in the image. There are various annotation tools available that you could use to annotate your images.

Data augmentation techniques can help to add variation to your data set, especially if data acquisition is somehow impractical . This process will help your model later on the achieve better generalization.

The use of simulated images in machine learning might help in a scenario where the acquisition of real images is expensive. It might also reduce the effort for labeling as the ground truth data (boxes and position of keypoints) could be automatically generated by the rendering software.

Using the Pytorch dataset and dataloader class

Pytorch provides abstractions for data handling that can be used for the training and inference process of machine learning models.

The first abstraction that I would like to cover is the Pytorch dataset class which allows to iterate over a data set via the __getitem__ method. I was using the following implementation that would simply load an image from disk and convert it to a Pytorch tensor. During the training process the method would also return the annotations as a dictionary.

class KeypointsDataset(Dataset):
    def __init__(self, img_dir: Path, annotations_dir: Optional[Path] = None):
        self.img_dir = img_dir
        self.annotations_dir =annotations_dir
        self.images = sorted(list(self.img_dir.iterdir()))
        self.annotations = None
        if annotations_dir:
            self.annotations = []
            for image in self.images:
                self.annotations.append(self.annotations_dir / f"{image.stem}.json")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image =[idx])
        image_tensor = ToTensor()(image)
        target = {}
        if self.annotations:
            with open(self.annotations[idx], "r") as f:
                annotations = json.load(f)
            target = {
                "boxes": torch.Tensor([a["box"] for a in annotations]),
                "keypoints": torch.Tensor([a"keypoints" for a in annotations]),
                "labels": torch.Tensor([1 for _ in annotations]).type(torch.int64),
        return image_tensor, target

In many cases the Pytorch default DataLoader class can be used to load the sample data and iterate over the data set. You might need to use a custom collate_fn function to properly load your data and annotations.

Configure the R-CNN Keypoint detector

The implementation of the R-CNN keypoint detector in pytorch is rather flexible and allows customization for various scenarios. When creating the KeypointRCNN model a number of parameters can be adjusted:

First of all, you can configure the backbone architecture that you would like to use. This is the first part of the keypoint detection model which is used for the computation of features (feature extraction). Whilst you can rather freely choose the backbone architecture, only the resnet50 option currently comes with pretrained weights on the coco data. Whilst a more complex backbone might result in better accuracy, it typically costs time at inference stage resulting in slightly lower throughput.

Another important parameter is the number of keypoints per detection. This heavily depends on your personal setup and the amount of interest points you would like to obtain per detection.

Another important setting is the number of classes. It is important to note that the number of classes should include the background class, meaning that it should be set to 2 if you would like to detect a single class (one for the background and one for your class).

In addition you can also configure the number of detections per image. This can be handy if you already know that the number of detections and hence keypoints is limited in your data set as well as during inference. In the end, this setting will help with avoiding too many detections at inference time.

Whilst I already obtained reasonable performance by adjusting the parameters mentioned above, please refer to the documentation of the model for additional customization options.

Training the Keypoint Detector Model

Once the data, data handling and the model is configured, you can focus on the training process of the keypoint detector. A GPU with cuda support will come handy as it allows you to train the model faster compared to using CPU.

Typically the data is split into a training and validation set. The training data is used to train the model whilst performance is evaluated on the validation set. During the training process, a certain number of epochs is used which determines how often you iterate over the training data to allow the model to learn.

In case you encounter an issue with the keypoint loss such as the loss is always zero for instance, you might need to check the configuration of your model (e.g. number of classes which should include the background class). Another common source for this error is that the class labels of your data include 0 which is reserved for the background.

Pytorch offers functionality to serialize models and load models from disk. This is helpful for training as it allows you to save snapshots of your model, do some manual tests if required and resume training from a certain epoch.

Configure the Inference Pipeline for the Keypoint Detector

Once you have trained a keypoint detection model in pytorch you can still configure the model at inference time. Most important are probably the following two options:

  • box_score_thresh
    This is the detection threshold that determines which proposals are returned during inference. In case you see too many detections when trying out your model you might increase the threshold to limit the number of detections that are returned by the model.
  • box_nms_thresh
    The non maximum suppression threshold determines by how much the predicted boxes that contain the keypoints can overlap. It might well be that single instances are returned by the model multiple times with a high overlap. In this case you might reduce this parameter to limit the number of detections.

Once you have successfully trained and configured your keypoint detection model you can then focus on integrating the model into your application. This might include the export the the keypoint detection model to other formats such as onnx for instance or using a certain runtime to improve inference performance.

Please do not hesitate to contact us in case you wish further information or assistance with the training process of a keypoint detection model.