cnn_segmentation.h 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. #ifndef CNN_SEGMENTATION_H
  2. #define CNN_SEGMENTATION_H
  3. #include <chrono>
  4. //#include <ros/ros.h>
  5. #include "caffe/caffe.hpp"
  6. // #include <pcl_conversions/pcl_conversions.h>
  7. #include <pcl/point_types.h>
  8. #include <pcl/point_cloud.h>
  9. #include <pcl/PointIndices.h>
  10. //#include <pcl_ros/point_cloud.h>
  11. //#include <rockauto_msgs/DetectedObjectArray.h>
  12. //#include <visualization_msgs/MarkerArray.h>
  13. //#include <jsk_recognition_msgs/BoundingBoxArray.h>
  14. //#include <std_msgs/Header.h>
  15. #include "cluster2d.h"
  16. #include "feature_generator.h"
  17. // #include "pcl_types.h"
  18. // #include "modules/perception/obstacle/lidar/segmentation/cnnseg/cnn_segmentation.h"
  19. #define __APP_NAME__ "lidar_cnn_seg_detect"
  20. struct CellStat
  21. {
  22. CellStat() : point_num(0), valid_point_num(0)
  23. {
  24. }
  25. int point_num;
  26. int valid_point_num;
  27. };
  28. class CNNSegmentation
  29. {
  30. public:
  31. CNNSegmentation();
  32. void run();
  33. void test_run();
  34. private:
  35. double range_, score_threshold_;
  36. int width_;
  37. int height_;
  38. bool use_constant_feature_;
  39. std::string topic_src_;
  40. // std_msgs::Header message_header_;
  41. int gpu_device_id_;
  42. bool use_gpu_;
  43. std::shared_ptr<caffe::Net<float>> caffe_net_;
  44. // center offset prediction
  45. boost::shared_ptr<caffe::Blob<float>> instance_pt_blob_;
  46. // objectness prediction
  47. boost::shared_ptr<caffe::Blob<float>> category_pt_blob_;
  48. // fg probability prediction
  49. boost::shared_ptr<caffe::Blob<float>> confidence_pt_blob_;
  50. // object height prediction
  51. boost::shared_ptr<caffe::Blob<float>> height_pt_blob_;
  52. // raw features to be input into network
  53. boost::shared_ptr<caffe::Blob<float>> feature_blob_;
  54. // class prediction
  55. boost::shared_ptr<caffe::Blob<float>> class_pt_blob_;
  56. // clustering model for post-processing
  57. std::shared_ptr<Cluster2D> cluster2d_;
  58. // bird-view raw feature generator
  59. std::shared_ptr<FeatureGenerator> feature_generator_;
  60. public:
  61. bool init( std::string proto_file,std::string weight_file,double rangemax,double score,
  62. int width,int height,bool use_const,bool usegpu,int gpudevid);
  63. bool segment(const pcl::PointCloud<pcl::PointXYZI>::Ptr &pc_ptr,
  64. const pcl::PointIndices &valid_idx,
  65. std::vector<Obstacle> & objvec);
  66. // bool segment(const pcl::PointCloud<pcl::PointXYZI>::Ptr &pc_ptr,
  67. // const pcl::PointIndices &valid_idx,
  68. // rockauto_msgs::DetectedObjectArray &objects);
  69. // bool segment(const pcl::PointCloud<pcl::PointXYZI>::Ptr &pc_ptr,
  70. // const pcl::PointIndices &valid_idx);
  71. // void pointsCallback(const sensor_msgs::PointCloud2 &msg);
  72. // void pubColoredPoints(const rockauto_msgs::DetectedObjectArray &objects);
  73. };
  74. #endif //CNN_SEGMENTATION_H