generate_dataset.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. import os.path as ops
  2. import json
  3. import cv2
  4. import numpy as np
  5. import os
  6. import glob
  7. import shutil
  8. def process_json_file(json_root,image_root,des_root):
  9. color_map = {"drivable": 1, "alternative": 2}
  10. if not os.path.exists(des_root):
  11. os.makedirs(des_root)
  12. images = des_root + '/images'
  13. labels = des_root + '/labels'
  14. if not os.path.exists(images):
  15. os.makedirs(images)
  16. if not os.path.exists(labels):
  17. os.makedirs(labels)
  18. i =0
  19. for js in glob.glob(json_root+'/**/*.json', recursive=True):
  20. path = js.split(json_root)[1][1:-5] # 4-5000\7\5183.png
  21. img_path = image_root+'/'+path
  22. rename = (path.replace('\\','-')).replace('/', '-') # 4-5000-7-5183.png
  23. with open(js, 'r') as f:
  24. dict = json.load(f)
  25. area_list = dict["annotations"]["area"]
  26. map = np.zeros([1080, 1920], dtype=np.uint8)
  27. for area in area_list:
  28. if area["points"]:
  29. a = [area["points"]]
  30. else:
  31. continue
  32. cv2.fillPoly(map, np.array(a).astype(int), color_map[area["category"]])
  33. label_name = rename.replace('.png', '_drivable.png') # 4-5000-7-5183_drivable.png
  34. label_dir = des_root + '/' + 'labels'
  35. if not os.path.exists(label_dir):
  36. os.makedirs(label_dir)
  37. image_dir = des_root + '/' + 'images'
  38. if not os.path.exists(image_dir):
  39. os.makedirs(image_dir)
  40. label_path = label_dir + '/' + label_name
  41. image_path = image_dir + '/' + rename
  42. #resize_map = cv2.resize(map, (1280, 720), interpolation=cv2.INTER_AREA)
  43. cv2.imwrite(label_path, map)
  44. img = cv2.imread(img_path)
  45. #resize_img = cv2.resize(img, (1280,720), interpolation=cv2.INTER_AREA)
  46. cv2.imwrite(image_path,img)
  47. i+=1
  48. if i%500==0:
  49. print('完成'+str(i)+'张')
  50. import random
  51. def generate_train_val_test(des_root):
  52. images_dir=des_root+'/images'
  53. labels_dir = des_root + '/labels'
  54. images_list = os.listdir(images_dir)
  55. #split source file into 0.9, 0.05 , 0.05 for training, validation and testing
  56. random.shuffle(images_list)
  57. size = len(images_list)
  58. train = images_list[:int(size*0.9)]
  59. val = images_list[int(size * 0.9):int(size * 0.95)]
  60. test = images_list[int(size * 0.95):]
  61. random.shuffle(train)
  62. random.shuffle(val)
  63. random.shuffle(test)
  64. train_img_dir = des_root + '/dataset/images/train'
  65. val_img_dir = des_root + '/dataset/images/val'
  66. test_img_dir = des_root + '/dataset/images/test'
  67. train_lbl_dir = des_root + '/dataset/labels/train'
  68. val_lbl_dir = des_root + '/dataset/labels/val'
  69. test_lbl_dir = des_root + '/dataset/labels/test'
  70. if not os.path.exists(train_img_dir):
  71. os.makedirs(train_img_dir)
  72. if not os.path.exists(val_img_dir):
  73. os.makedirs(val_img_dir)
  74. if not os.path.exists(test_img_dir):
  75. os.makedirs(test_img_dir)
  76. if not os.path.exists(train_lbl_dir):
  77. os.makedirs(train_lbl_dir)
  78. if not os.path.exists(val_lbl_dir):
  79. os.makedirs(val_lbl_dir)
  80. if not os.path.exists(test_lbl_dir):
  81. os.makedirs(test_lbl_dir)
  82. for img in train:
  83. src_path = images_dir+'/'+img
  84. des_path = train_img_dir+'/'+img
  85. shutil.copy(src_path, des_path)
  86. lbl = img.replace('.png','_drivable.png')
  87. src_path = labels_dir+'/'+lbl
  88. des_path = train_lbl_dir+'/'+lbl
  89. shutil.copy(src_path, des_path)
  90. print('train done')
  91. for img in val:
  92. src_path = images_dir+'/'+img
  93. des_path = val_img_dir+'/'+img
  94. shutil.copy(src_path, des_path)
  95. lbl = img.replace('.png', '_drivable.png')
  96. src_path = labels_dir+'/'+lbl
  97. des_path = val_lbl_dir+'/'+lbl
  98. shutil.copy(src_path, des_path)
  99. print('val done')
  100. for img in test:
  101. src_path = images_dir+'/'+img
  102. des_path = test_img_dir+'/'+img
  103. shutil.copy(src_path, des_path)
  104. lbl = img.replace('.png', '_drivable.png')
  105. src_path = labels_dir+'/'+lbl
  106. des_path = test_lbl_dir+'/'+lbl
  107. shutil.copy(src_path, des_path)
  108. print('test done')
  109. if __name__ == '__main__':
  110. json_root = '/media/adc/Elements/drivable/drivable_json'
  111. des_root = '/media/adc/Elements/drivable_dataset'
  112. image_root = '/media/adc/Elements/drivable/images'
  113. #合并并生成数据集图片
  114. process_json_file(json_root,image_root,des_root)
  115. #随机分配train\val\test
  116. generate_train_val_test(des_root)