pointpillars.h 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. /******************************************************************************
  2. * Copyright 2020 The Apollo Authors. All Rights Reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. *****************************************************************************/
  16. /*
  17. * Copyright 2018-2019 Autoware Foundation. All rights reserved.
  18. *
  19. * Licensed under the Apache License, Version 2.0 (the "License");
  20. * you may not use this file except in compliance with the License.
  21. * You may obtain a copy of the License at
  22. *
  23. * http://www.apache.org/licenses/LICENSE-2.0
  24. *
  25. * Unless required by applicable law or agreed to in writing, software
  26. * distributed under the License is distributed on an "AS IS" BASIS,
  27. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  28. * See the License for the specific language governing permissions and
  29. * limitations under the License.
  30. */
  31. /**
  32. * @author Kosuke Murakami
  33. * @date 2019/02/26
  34. */
  35. /**
  36. * @author Yan haixu
  37. * Contact: just github.com/hova88
  38. * @date 2021/04/30
  39. */
  40. #pragma once
  41. // headers in STL
  42. #include <algorithm>
  43. #include <cmath>
  44. #include <iomanip>
  45. #include <limits>
  46. #include <map>
  47. #include <memory>
  48. #include <string>
  49. #include <vector>
  50. #include <iostream>
  51. #include <sstream>
  52. #include <fstream>
  53. // headers in TensorRT
  54. #include "NvInfer.h"
  55. #include "NvOnnxParser.h"
  56. // headers in local files
  57. // #include "params.h"
  58. #include "common.h"
  59. #include <yaml-cpp/yaml.h>
  60. #include "preprocess.h"
  61. #include "scatter.h"
  62. #include "postprocess.h"
  63. using namespace std;
  64. // Logger for TensorRT info/warning/errors
  65. class Logger : public nvinfer1::ILogger {
  66. public:
  67. explicit Logger(Severity severity = Severity::kWARNING)
  68. : reportable_severity(severity) {}
  69. void log(Severity severity, const char* msg) override {
  70. // suppress messages with severity enum value greater than the reportable
  71. if (severity > reportable_severity) return;
  72. switch (severity) {
  73. case Severity::kINTERNAL_ERROR:
  74. std::cerr << "INTERNAL_ERROR: ";
  75. break;
  76. case Severity::kERROR:
  77. std::cerr << "ERROR: ";
  78. break;
  79. case Severity::kWARNING:
  80. std::cerr << "WARNING: ";
  81. break;
  82. case Severity::kINFO:
  83. std::cerr << "INFO: ";
  84. break;
  85. default:
  86. std::cerr << "UNKNOWN: ";
  87. break;
  88. }
  89. std::cerr << msg << std::endl;
  90. }
  91. Severity reportable_severity;
  92. };
  93. class PointPillars {
  94. private:
  95. // initialize in initializer list
  96. const float score_threshold_;
  97. const float nms_overlap_threshold_;
  98. const bool use_onnx_;
  99. const std::string pfe_file_;
  100. const std::string backbone_file_;
  101. const std::string pp_config_;
  102. // end initializer list
  103. // voxel size
  104. float kPillarXSize;
  105. float kPillarYSize;
  106. float kPillarZSize;
  107. // point cloud range
  108. float kMinXRange;
  109. float kMinYRange;
  110. float kMinZRange;
  111. float kMaxXRange;
  112. float kMaxYRange;
  113. float kMaxZRange;
  114. // hyper parameters
  115. int kNumClass;
  116. int kMaxNumPillars;
  117. int kMaxNumPointsPerPillar;
  118. int kNumPointFeature;
  119. int kNumGatherPointFeature = 11;
  120. int kGridXSize;
  121. int kGridYSize;
  122. int kGridZSize;
  123. int kNumAnchorXinds;
  124. int kNumAnchorYinds;
  125. int kRpnInputSize;
  126. int kNumAnchor;
  127. int kNumInputBoxFeature;
  128. int kNumOutputBoxFeature;
  129. int kRpnBoxOutputSize;
  130. int kRpnClsOutputSize;
  131. int kRpnDirOutputSize;
  132. int kBatchSize;
  133. int kNumIndsForScan;
  134. int kNumThreads;
  135. // if you change kNumThreads, need to modify NUM_THREADS_MACRO in
  136. // common.h
  137. int kNumBoxCorners;
  138. int kNmsPreMaxsize;
  139. int kNmsPostMaxsize;
  140. //params for initialize anchors
  141. //Adapt to OpenPCDet
  142. int kAnchorStrides;
  143. std::vector<string> kAnchorNames;
  144. std::vector<float> kAnchorDxSizes;
  145. std::vector<float> kAnchorDySizes;
  146. std::vector<float> kAnchorDzSizes;
  147. std::vector<float> kAnchorBottom;
  148. std::vector<std::vector<int>> kMultiheadLabelMapping;
  149. int kNumAnchorPerCls;
  150. int host_pillar_count_[1];
  151. int* dev_x_coors_;
  152. int* dev_y_coors_;
  153. float* dev_num_points_per_pillar_;
  154. int* dev_sparse_pillar_map_;
  155. int* dev_cumsum_along_x_;
  156. int* dev_cumsum_along_y_;
  157. float* dev_pillar_point_feature_;
  158. float* dev_pillar_coors_;
  159. float* dev_points_mean_;
  160. float* dev_pfe_gather_feature_;
  161. void* pfe_buffers_[2];
  162. //variable for doPostprocessCudaMultiHead
  163. void* rpn_buffers_[8];
  164. std::vector<float*> rpn_box_output_;
  165. std::vector<float*> rpn_cls_output_;
  166. float* dev_scattered_feature_;
  167. float* dev_filtered_box_;
  168. float* dev_filtered_score_;
  169. int* dev_filtered_label_;
  170. int* dev_filtered_dir_;
  171. float* dev_box_for_nms_;
  172. int* dev_filter_count_;
  173. std::unique_ptr<PreprocessPointsCuda> preprocess_points_cuda_ptr_;
  174. std::unique_ptr<ScatterCuda> scatter_cuda_ptr_;
  175. std::unique_ptr<PostprocessCuda> postprocess_cuda_ptr_;
  176. Logger g_logger_;
  177. nvinfer1::ICudaEngine* pfe_engine_;
  178. nvinfer1::ICudaEngine* backbone_engine_;
  179. nvinfer1::IExecutionContext* pfe_context_;
  180. nvinfer1::IExecutionContext* backbone_context_;
  181. /**
  182. * @brief Memory allocation for device memory
  183. * @details Called in the constructor
  184. */
  185. void DeviceMemoryMalloc();
  186. /**
  187. * @brief Memory set to 0 for device memory
  188. * @details Called in the DoInference
  189. */
  190. void SetDeviceMemoryToZero();
  191. /**
  192. * @brief Initializing paraments from pointpillars.yaml
  193. * @details Called in the constructor
  194. */
  195. void InitParams();
  196. /**
  197. * @brief Initializing TensorRT instances
  198. * @param[in] usr_onnx_ if true, parse ONNX
  199. * @details Called in the constructor
  200. */
  201. void InitTRT(const bool use_onnx);
  202. void SaveEngine(const nvinfer1::ICudaEngine* engine, const std::string& engine_filepath);
  203. /**
  204. * @brief Convert ONNX to TensorRT model
  205. * @param[in] model_file ONNX model file path
  206. * @param[out] engine_ptr TensorRT model engine made out of ONNX model
  207. * @details Load ONNX model, and convert it to TensorRT model
  208. */
  209. void OnnxToTRTModel(const std::string& model_file,
  210. nvinfer1::ICudaEngine** engine_ptr);
  211. /**
  212. * @brief Convert Engine to TensorRT model
  213. * @param[in] model_file Engine(TensorRT) model file path
  214. * @param[out] engine_ptr TensorRT model engine made
  215. * @details Load Engine model, and convert it to TensorRT model
  216. */
  217. void EngineToTRTModel(const std::string &engine_file ,
  218. nvinfer1::ICudaEngine** engine_ptr) ;
  219. /**
  220. * @brief Preproces points
  221. * @param[in] in_points_array Point cloud array
  222. * @param[in] in_num_points Number of points
  223. * @details Call CPU or GPU preprocess
  224. */
  225. void Preprocess(const float* in_points_array, const int in_num_points);
  226. public:
  227. /**
  228. * @brief Constructor
  229. * @param[in] score_threshold Score threshold for filtering output
  230. * @param[in] nms_overlap_threshold IOU threshold for NMS
  231. * @param[in] use_onnx if true,using onnx file ,else using engine file
  232. * @param[in] pfe_file Pillar Feature Extractor ONNX file path
  233. * @param[in] rpn_file Region Proposal Network ONNX file path
  234. * @details Variables could be changed through point_pillars_detection
  235. */
  236. PointPillars(const float score_threshold,
  237. const float nms_overlap_threshold,
  238. const bool use_onnx,
  239. const std::string pfe_file,
  240. const std::string rpn_file,
  241. const std::string pp_config);
  242. ~PointPillars();
  243. /**
  244. * @brief Call PointPillars for the inference
  245. * @param[in] in_points_array Point cloud array
  246. * @param[in] in_num_points Number of points
  247. * @param[out] out_detections Network output bounding box
  248. * @param[out] out_labels Network output object's label
  249. * @details This is an interface for the algorithm
  250. */
  251. void DoInference(const float* in_points_array,
  252. const int in_num_points,
  253. std::vector<float>* out_detections,
  254. std::vector<int>* out_labels,
  255. std::vector<float>* out_scores);
  256. };