stack.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. '''
  2. Author: Sungguk Cha
  3. eMail : navinad@naver.com
  4. It loads several models and stacks the prediction results.
  5. '''
  6. import argparse
  7. import os
  8. from tqdm import tqdm
  9. from mypath import Path
  10. from dataloaders import make_data_loader
  11. from modeling.sync_batchnorm.replicate import patch_replication_callback
  12. from modeling.deeplab import *
  13. from utils.lr_scheduler import LR_Scheduler
  14. from utils.saver import Saver
  15. from utils.summaries import TensorboardSummary
  16. from utils.metrics import Evaluator
  17. from utils.visualize import Visualize as vs
  18. from torchsummary import summary
  19. import numpy as np
  20. import torch.nn as nn
  21. blows = 0
  22. def gn(planes):
  23. return nn.GroupNorm(16, planes)
  24. def blow(image, _class):
  25. '''
  26. post process subfunction
  27. blows '_class' class from an image
  28. '''
  29. global blows
  30. blows += 1
  31. image[image == _class] = 0
  32. return image
  33. def post1(inputs):
  34. '''
  35. Post processing 1.
  36. It blows up classes less than {(0, 1): 1337, (0, 1, 2): [1597, 1304], (0, 2): 2836}
  37. '''
  38. results = []
  39. blowed = False
  40. for result in inputs:
  41. unique, counts = np.unique(result, return_counts=True)
  42. dic = dict(zip(unique, counts))
  43. unique = tuple(unique)
  44. if unique == (0, 1):
  45. if dic[1] < 1337:
  46. result = blow(result, 1)
  47. blowed = True
  48. elif unique == (0, 2):
  49. if dic[2] < 2836:
  50. result = blow(result, 2)
  51. blowed = True
  52. elif unique == (0, 1, 2):
  53. if dic[1] < 1597:
  54. result = blow(result, 1)
  55. blowed = True
  56. if dic[2] < 1304:
  57. result = blow(result, 2)
  58. blowed = True
  59. results.append(result)
  60. return results, blowed
  61. class Stack(object):
  62. def __init__(self, args):
  63. self.args = args
  64. self.vs = vs(args.nice)
  65. #Dataloader
  66. kwargs = {"num_workers": args.workers, 'pin_memory': True}
  67. if self.args.dataset == 'bdd':
  68. _, _, self.test_loader, self.nclass = make_data_loader(args, **kwargs)
  69. else: #self.args.dataset == 'nice':
  70. self.test_loader, self.nclass = make_data_loader(args, **kwargs)
  71. #else:
  72. # raise NotImplementedError
  73. ### Load models
  74. #backs = ["resnet", "resnet152"]
  75. backs = ["resnet", "ibn", "resnet152"]
  76. check = './ckpt'
  77. checks = ["herbrand.pth.tar", "ign85.12.pth.tar", "r152_85.20.pth.tar"]
  78. self.models = []
  79. self.M = len(backs)
  80. # define models
  81. for i in range(self.M):
  82. model = DeepLab(num_classes = self.nclass,
  83. backbone=backs[i],
  84. output_stride=16,
  85. Norm=gn,
  86. freeze_bn=False)
  87. self.models.append(model)
  88. self.models[i] = torch.nn.DataParallel(self.models[i], device_ids=self.args.gpu_ids)
  89. patch_replication_callback(self.models[i])
  90. self.models[i] = self.models[i].cuda()
  91. # load checkpoints
  92. for i in range(self.M):
  93. resume = os.path.join(check, checks[i])
  94. if not os.path.isfile( resume ):
  95. raise RuntimeError("=> no checkpoint found at '{}'".format(resume))
  96. checkpoint = torch.load( resume )
  97. dicts = checkpoint['state_dict']
  98. model_dict = {}
  99. state_dict = self.models[i].module.state_dict()
  100. for k, v in dicts.items():
  101. if k in state_dict:
  102. model_dict[k] = v
  103. state_dict.update(model_dict)
  104. self.models[i].module.load_state_dict(state_dict)
  105. print( "{} loaded successfully".format(checks[i]) )
  106. def predict(self, mode):
  107. for i in range(self.M):
  108. self.models[i].eval()
  109. tbar = tqdm(self.test_loader, desc='\r')
  110. for i, sample in enumerate(tbar):
  111. images = sample['image']
  112. names = sample['name']
  113. images.cuda()
  114. outputs = []
  115. with torch.no_grad():
  116. for i in range(self.M):
  117. output = self.models[i](images)
  118. output = output.data.cpu().numpy()
  119. outputs.append( output )
  120. if mode == "stack":
  121. results = outputs[0]
  122. for output in outputs[1:]:
  123. results += output
  124. results = np.argmax( results, axis=1 )
  125. if self.args.post:
  126. posts, blowed = post1(np.array(results))
  127. if blowed:
  128. images = images.cpu().numpy()
  129. self.vs.predict_color( results, images, names, self.args.savedir )
  130. _names = []
  131. for name in names:
  132. _name = name.split('.')[0] + "-blow.png"
  133. _names.append(_name)
  134. self.vs.predict_color( posts, images, _names, self.args.savedir )
  135. continue # saving blows
  136. if self.args.color:
  137. images = images.cpu().numpy()
  138. self.vs.predict_color( results, images, names, self.args.savedir )
  139. else:
  140. self.vs.predict_id( results, names, self.args.savedir )
  141. if self.args.post:
  142. global blows
  143. print(blows, "blows happened")
  144. def get_args():
  145. parser = argparse.ArgumentParser()
  146. # Dataloader
  147. parser.add_argument('--dataset', default='bdd')
  148. parser.add_argument('--workers', type=int, default=0, metavar='N', help='dataloader threads')
  149. parser.add_argument("--img_list", default=None)
  150. parser.add_argument("--batch_size")
  151. # Model load
  152. parser.add_argument('--gpu_ids', type=str, default='0')
  153. # Prediction save
  154. parser.add_argument('--savedir', type=str, default='./prd')
  155. parser.add_argument('--color', default=False, action='store_true')
  156. parser.add_argument('--nice', default=False, action='store_true', help="Use nice RGB mean & std")
  157. parser.add_argument('--post', default=False, action='store_true', help="Activate post process")
  158. return parser.parse_args()
  159. if __name__ == "__main__":
  160. args = get_args()
  161. args.batch_size = int(args.batch_size)
  162. args.gpu_ids = [int(s) for s in args.gpu_ids.split(',')]
  163. stack = Stack(args)
  164. stack.predict(mode="stack")