Point Cloud Segmentation¶
The Task¶
A Point Cloud is a set of data points in space, usually describes by x
, y
and z
coordinates.
PointCloud Segmentation is the task of performing classification at a point-level, meaning each point will associated to a given class. The current integration builds on top Open3D-ML.
Example¶
Let’s look at an example using a data set generated from the KITTI Vision Benchmark.
The data are a tiny subset of the original dataset and contains sequences of point clouds.
The data contains multiple folder, one for each sequence and a meta.yaml file describing the classes and their official associated color map.
A sequence should contain one folder for scans and one folder for labels, plus a pose.txt
to re-align the sequence if required.
Here’s the structure:
data
├── meta.yaml
├── 00
│ ├── scans
| | ├── 00000.bin
| | ├── 00001.bin
| | ...
│ ├── labels
| | ├── 00000.label
| | ├── 00001.label
| | ...
| ├── pose.txt
│ ...
|
└── XX
├── scans
| ├── 00000.bin
| ├── 00001.bin
| ...
├── labels
| ├── 00000.label
| ├── 00001.label
| ...
├── pose.txt
Learn more: http://www.semantic-kitti.org/dataset.html
Once we’ve downloaded the data using download_data()
, we create the PointCloudSegmentationData
.
We select a pre-trained randlanet_semantic_kitti
backbone for our PointCloudSegmentation
task.
We then use the trained PointCloudSegmentation
for inference.
Finally, we save the model.
Here’s the full example:
import torch
import flash
from flash.core.data.utils import download_data
from flash.pointcloud import PointCloudSegmentation, PointCloudSegmentationData
# 1. Create the DataModule
# Dataset Credit: http://www.semantic-kitti.org/
download_data("https://pl-flash-data.s3.amazonaws.com/SemanticKittiTiny.zip", "data/")
datamodule = PointCloudSegmentationData.from_folders(
train_folder="data/SemanticKittiTiny/train",
val_folder="data/SemanticKittiTiny/val",
batch_size=4,
)
# 2. Build the task
model = PointCloudSegmentation(backbone="randlanet_semantic_kitti", num_classes=datamodule.num_classes)
# 3. Create the trainer and finetune the model
trainer = flash.Trainer(
max_epochs=1, limit_train_batches=1, limit_val_batches=1, num_sanity_val_steps=0, gpus=torch.cuda.device_count()
)
trainer.fit(model, datamodule)
# 4. Predict what's within a few PointClouds?
datamodule = PointCloudSegmentationData.from_files(
predict_files=[
"data/SemanticKittiTiny/predict/000000.bin",
"data/SemanticKittiTiny/predict/000001.bin",
],
batch_size=4,
)
predictions = trainer.predict(model, datamodule=datamodule)
print(predictions)
# 5. Save the model!
trainer.save_checkpoint("pointcloud_segmentation_model.pt")
Flash Zero¶
The point cloud segmentation task can be used directly from the command line with zero code using Flash Zero. You can run the above example with:
flash pointcloud_segmentation
To view configuration options and options for running the point cloud segmentation task with your own data, use:
flash pointcloud_segmentation --help