The basic components of a point-based 3D object detector is set abstraction (SA) layer, which downsamples points for better efficiency and enlarges receptive fields. However, existing SA layer only takes the relative locations among points into consideration, e.g. using furthest point sampling, while ignoring point features. Because the points on the objects take small proportion of space, uniform and cascaded SA may don't contain objects' points in the last layer, degrading 3D object detection performances. We are thus motivated to design a new lightweight and effective SA layer named Boundary-Aware Set Abstraction layer (BA-Net) to retain important foreground and boundary points during cascaded down-sampling. Technically, we first embed a lightweight point segmentation model (PSM) to compute the point-wise foreground scores, then propose a Boundary Prediction Model(BPM) to detect points on object boundaries. Finally, point scores are used to twist inter-node distances and furthest point down-sampling is conducted in the twisted distance space (B-FPS). We experiment on KITTI dataset and the results show that BA-Net improves detection performance especially in harder cases. Additionally, BA-Net is easy-to-plug-in point-based module and able to boost various detectors.
Method | Easy | Mod. | Hard | mAP |
---|---|---|---|---|
PointRCNN | 91.57 | 82.24 | 80.45 | 84.75 |
PointRCNN+BA-Net | +0.75 | +0.8 | +1.86 | +1.14 |
3DSSD | 91.54 | 83.46 | 82.18 | 85.73 |
3DSSD+BA-Net | +0.89 | +1.93 | +0.38 | +1.06 |
All the codes are tested in the following environment:
- Linux (tested on 18.04)
- Python 3.6+
- PyTorch 1.3
- CUDA 11.6
- spconv v2.x
NOTE: Please re-install pcdet v0.5 by running python setup.py develop
git clone https://github.com/HuangZhe885/Boundary-Aware-SA.git
cd Boundary-Aware-SA
pip install -r requirements.txt
python setup.py develop
install spconv
git clone https://github.com/traveller59/spconv.git --recursive
cd spconv
python setup.py bdist_wheel
cd ./dist
pip install *
Please download the official KITTI 3D object detection dataset and organize the downloaded files as follows (the road planes could be downloaded from [road plane], which are optional for data augmentation in the training):
- Generate the data infos by running the following command:
python -m pcdet.datasets.kitti.kitti_dataset create_kitti_infos tools/cfgs/dataset_configs/kitti_dataset.yaml
You could optionally add extra command line parameters --batch_size ${BATCH_SIZE} and --epochs ${EPOCHS} to specify your preferred parameters.
python train.py --cfg_file ${CONFIG_FILE}
python test.py --cfg_file ${CONFIG_FILE} --batch_size ${BATCH_SIZE} --eval_all
Visualizing detection results on KITTI val split. The ground truth and predictions are labeled in red and green respectively. Pink points mark the 512 key points sampled in last SA layer.
Harder instances contain fewer LiDAR points and are not likely to be selected, therefore, it is difficult for them to survive in the vanilla FPS down-sampling, and the features for remote (or small) instances cannot be fully transmitted to the next layer of the network, while BA-Net
can keep adequate interior boundary points of different foreground instances. It preserves rich information for regression and classification
Here we present experimental results evaluated on the KITTI validation set.
Snapshots of our 3D detection results on row 1 (left is 3DSSD, right is BA-Net) on the KITTI validation set. The predicted bounding boxes are shown in green, and are project back onto the color images in pink (2th rows) for visualization.
This project is built with OpenPCDet, a powerful toolbox for LiDAR-based 3D object detection. Please refer to OpenPCDet.md and the official github repository for more information.