这里我的示例代码结构如下所示,个人习惯为了方便调试和后续接口使用,和官方仓库不一样。
1 ├── configs(配置文件路径)
2 ├── Market1501
3 ├── bagtricks_R50.yml
4 ├── Base-bagtricks.yml
5 ├── datasets(数据集目录)
6 ├── Market-1501-v15.09.15 (这个数据集名不要改)
7 ├── bounding_box_test (750人的19732张图像用于测试)
8 ├── bounding_box_train (751人的12936张图像用于训练)
9 ├── query (750人的3368张图像用于查询)
10 ├── fastreid
11 ├── model(预训练模型目录),下载好的预训练模型存放在这
12 ├── demo.py(提取图像的特征,并保存),来自原来的demo目录
13 ├── predictor.py (模型加载文件),来自原来的demo目录
14 ├── train_net.py (模型训练与测试封装版代码),来自原来的tools目录
15 ├── visualize_result.py (可视化特征提取结果),来自原来的demo目录
重点关注几个py文件,我直接挪到根目录下了。还有模型文件的保存路径,config预训练模型地址,数据集的名字也要注意的。各个文件具体使用可以看看下面介绍,都有代码注释。
特别注意,py文件为了方便调试,我直接在代码里面设置了*args的参数,实际使用要特别注意。
这个代码就是加载模型(调用predictor.py),提取查询图像的特征,并保存为npy文件。保存在demo_output文件夹下,一张图像对一个npy文件。这些包含特征向量的npy文件可供后续向量检索使用。
1# encoding: utf-8
2"""
3@author: liaoxingyu
4@contact: sherlockliao01@gmail.com
5提取图像的特征,并保存
6"""
7
8import argparse
9import glob
10import os
11import sys
12
13import torch.nn.functional as F
14import cv2
15import numpy as np
16import tqdm
17from torch.backends import cudnn
18
19sys.path.append('.')
20
21from fastreid.config import get_cfg
22from fastreid.utils.logger import setup_logger
23from fastreid.utils.file_io import PathManager
24
25from predictor import FeatureExtractionDemo
26
27# import some modules added in project like this below
28# sys.path.append("projects/PartialReID")
29# from partialreid import *
30
31cudnn.benchmark = True
32setup_logger(name="fastreid")
33
34
35# 读取配置文件
36def setup_cfg(args):
37 # load config from file and command-line arguments
38 cfg = get_cfg()
39 # add_partialreid_config(cfg)
40 cfg.merge_from_file(args.config_file)
41 cfg.merge_from_list(args.opts)
42 cfg.freeze()
43 return cfg
44
45
46def get_parser():
47 parser = argparse.ArgumentParser(description="Feature extraction with reid models")
48 parser.add_argument(
49 "--config-file", # config路径,通常包含模型配置文件
50 metavar="FILE",
51 help="path to config file",
52 )
53 parser.add_argument(
54 "--parallel", # 是否并行
55 action='store_true',
56 help='If use multiprocess for feature extraction.'
57 )
58 parser.add_argument(
59 "--input", # 输入图像路径
60 nargs="+",
61 help="A list of space separated input images; "
62 "or a single glob pattern such as 'directory/*.webp'",
63 )
64 parser.add_argument(
65 "--output", # 输出结果路径
66 default='demo_output',
67 help='path to save features'
68 )
69 parser.add_argument(
70 "--opts",
71 help="Modify config options using the command-line 'KEY VALUE' pairs",
72 default=[],
73 nargs=argparse.REMAINDER,
74 )
75 return parser
76
77
78def postprocess(features):
79 # Normalize feature to compute cosine distance
80 features = F.normalize(features) # 特征归一化
81 features = features.cpu().data.numpy()
82 return features
83
84
85if __name__ == '__main__':
86 args = get_parser().parse_args() # 解析输入参数
87 # 调试使用,使用的时候删除下面代码
88 # ---
89 args.config_file = "./configs/Market1501/bagtricks_R50.yml" # config路径
90 args.input = "./datasets/Market-1501-v15.09.15/query/*.webp" # 图像路径
91 # ---
92
93 cfg = setup_cfg(args) # 读取cfg文件
94 demo = FeatureExtractionDemo(cfg, parallel=args.parallel) # 加载特征提取器,也就是加载模型
95
96 PathManager.mkdirs(args.output) # 创建输出路径
97 if args.input:
98 if PathManager.isdir(args.input[0]): # 判断输入的是否为路径
99 # args.input = glob.glob(os.path.expanduser(args.input[0])) # 原来的代码有问题
100 args.input = glob.glob(os.path.expanduser(args.input)) # 获取输入路径下所有的文件路径
101 assert args.input, "The input path(s) was not found"
102 for path in tqdm.tqdm(args.input): # 逐张处理
103 img = cv2.imread(path)
104 feat = demo.run_on_image(img) # 提取图像特征
105 feat = postprocess(feat) # 后处理主要是特征归一化
106 np.save(os.path.join(args.output, os.path.basename(path).split('.')[0] + '.npy'), feat) # 保存图像对应的特征,以便下次使用
这个代码就是加载模型(调用predictor.py),提取查询图像的特征,计算模型的各个精度指标。输出模型的ROC结果图,以及某张图像的匹配结果图像。输出目录为vis_rank_list。
ROC结果图如下图所示,ROC曲线下的面积AUC越大,表示模型效果越好。top1精度93.37左右。
某张图像的匹配结果图像如下所示。每张图有1张查询图和5张查询结果图,左1为查询图像,其他为查询结果图。蓝色框表示查询结果错误,红色框表示查询结果正确。在查询结果图上有标题,比如0.976/false/cam1,表示当前查询结果图像和查询图像特征距离为0.976,查询结果为false(查询错误),该查询结果来自cam1摄像头。查询图像上的标题,如0.9967/cam2,这里0.9967表示查询图像的查询结果精度指标,cam2表示查询图像来自cam2摄像头。
1# encoding: utf-8
2"""
3@author: xingyu liao
4@contact: sherlockliao01@gmail.com
5可视化特征提取结果
6"""
7
8import argparse
9import logging
10import sys
11
12import numpy as np
13import torch
14import tqdm
15from torch.backends import cudnn
16
17sys.path.append('.')
18
19import torch.nn.functional as F
20from fastreid.evaluation.rank import evaluate_rank
21from fastreid.config import get_cfg
22from fastreid.utils.logger import setup_logger
23from fastreid.data import build_reid_test_loader
24from predictor import FeatureExtractionDemo
25from fastreid.utils.visualizer import Visualizer
26
27# import some modules added in project
28# for example, add partial reid like this below
29# sys.path.append("projects/PartialReID")
30# from partialreid import *
31
32cudnn.benchmark = True
33setup_logger(name="fastreid")
34
35logger = logging.getLogger('fastreid.visualize_result')
36
37
38# 读取配置文件
39def setup_cfg(args):
40 # load config from file and command-line arguments
41 cfg = get_cfg()
42 # add_partialreid_config(cfg)
43 cfg.merge_from_file(args.config_file)
44 cfg.merge_from_list(args.opts)
45 cfg.freeze()
46 return cfg
47
48
49def get_parser():
50 parser = argparse.ArgumentParser(description="Feature extraction with reid models")
51 parser.add_argument(
52 "--config-file", # config路径,通常包含模型配置文件
53 metavar="FILE",
54 help="path to config file",
55 )
56 parser.add_argument(
57 '--parallel', # 是否并行
58 action='store_true',
59 help='if use multiprocess for feature extraction.'
60 )
61 parser.add_argument(
62 "--dataset-name", # 数据集名字
63 help="a test dataset name for visualizing ranking list."
64 )
65 parser.add_argument(
66 "--output", # 输出结果路径
67 default="./vis_rank_list",
68 help="a file or directory to save rankling list result.",
69
70 )
71 parser.add_argument(
72 "--vis-label", # 输出结果是否查看
73 action='store_true',
74 help="if visualize label of query instance"
75 )
76 parser.add_argument(
77 "--num-vis", # 挑选多少张图像用于结果展示
78 default=1000,
79 help="number of query images to be visualized",
80 )
81 parser.add_argument(
82 "--rank-sort", # 结果展示是相似度排序方式,默认从小到大排序
83 default="ascending",
84 help="rank order of visualization images by AP metric",
85 )
86 parser.add_argument(
87 "--label-sort", # label结果展示是相似度排序方式,默认从小到大排序
88 default="ascending",
89 help="label order of visualization images by cosine similarity metric",
90 )
91 parser.add_argument(
92 "--max-rank", # 显示topk的结果,默认显示前10个结果
93 default=5,
94 help="maximum number of rank list to be visualized",
95 )
96 parser.add_argument(
97 "--opts",
98 help="Modify config options using the command-line 'KEY VALUE' pairs",
99 default=[],
100 nargs=argparse.REMAINDER,
101 )
102 return parser
103
104
105if __name__ == '__main__':
106 args = get_parser().parse_args()
107 # 调试使用,使用的时候删除下面代码
108 # ---
109 args.config_file = "./configs/Market1501/bagtricks_R50.yml" # config路径
110 args.dataset_name = 'Market1501' # 数据集名字
111 args.vis_label = False # 是否显示正确label结果
112 args.rank_sort = 'descending' # 从大到小显示关联结果
113 args.label_sort = 'descending' # 从大到小显示关联结果
114 # ---
115
116 cfg = setup_cfg(args)
117 # 可以直接在代码中设置cfg中设置模型路径
118 # cfg["MODEL"]["WEIGHTS"] = './configs/Market1501/bagtricks_R50.yml'
119 test_loader, num_query = build_reid_test_loader(cfg, dataset_name=args.dataset_name) # 创建测试数据集
120 demo = FeatureExtractionDemo(cfg, parallel=args.parallel) # 加载特征提取器,也就是加载模型
121
122 logger.info("Start extracting image features")
123 feats = [] # 图像特征,用于保存每个行人的图像特征
124 pids = [] # 行人id,用于保存每个行人的id
125 camids = [] # 拍摄的摄像头,行人出现的摄像头id
126 # 逐张保存读入行人图像,并保存相关信息
127 for (feat, pid, camid) in tqdm.tqdm(demo.run_on_loader(test_loader), total=len(test_loader)):
128 feats.append(feat)
129 pids.extend(pid)
130 camids.extend(camid)
131
132 feats = torch.cat(feats, dim=0) # 将feats转换为tensor的二维向量,向量维度为[图像数,特征维度]
133 # 这里把query和gallery数据放在一起了,需要切分query和gallery的数据
134 q_feat = feats[:num_query]
135 g_feat = feats[num_query:]
136 q_pids = np.asarray(pids[:num_query])
137 g_pids = np.asarray(pids[num_query:])
138 q_camids = np.asarray(camids[:num_query])
139 g_camids = np.asarray(camids[num_query:])
140
141 # compute cosine distance 计算余弦距离
142 q_feat = F.normalize(q_feat, p=2, dim=1)
143 g_feat = F.normalize(g_feat, p=2, dim=1)
144 distmat = 1 - torch.mm(q_feat, g_feat.t()) # 这里distmat表示两张图像的距离,越小越接近
145 distmat = distmat.numpy()
146
147 # 计算各种评价指标 cmc[0]就是top1精度,应该是93%左右,这里精度会有波动
148 logger.info("Computing APs for all query images ...")
149 cmc, all_ap, all_inp = evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids)
150 logger.info("Finish computing APs for all query images!")
151
152 visualizer = Visualizer(test_loader.dataset) # 创建Visualizer类
153 visualizer.get_model_output(all_ap, distmat, q_pids, g_pids, q_camids, g_camids) # 保存结果
154
155 logger.info("Start saving ROC curve ...") # 保存ROC曲线
156 fpr, tpr, pos, neg = visualizer.vis_roc_curve(args.output)
157 visualizer.save_roc_info(args.output, fpr, tpr, pos, neg)
158 logger.info("Finish saving ROC curve!")
159
160 logger.info("Saving rank list result ...") # 保存部分查询图像的关联结果,按照顺序排列
161 query_indices = visualizer.vis_rank_list(args.output, args.vis_label, args.num_vis,
162 args.rank_sort, args.label_sort, args.max_rank)
163 logger.info("Finish saving rank list results!")
这段代码调用config文件,训练或者测试模型。训练模型设置args.eval_only = False,反之为测试模型。测试模型结果如下图所示。代码封装的很不错,把该有的测试指标都贴上去了。
另外这是封装过多的代码,如果想知道清晰的训练代码查看fast-reid/tools/plain_train_net.py,这个文件提供了详细没有封装过多的训练代码。
1#!/usr/bin/env python
2# encoding: utf-8
3"""
4@author: sherlock
5@contact: sherlockliao01@gmail.com
6模型训练与测试封装版代码
7"""
8
9import sys
10
11sys.path.append('.')
12
13from fastreid.config import get_cfg
14from fastreid.engine import DefaultTrainer, default_argument_parser, default_setup, launch
15from fastreid.utils.checkpoint import Checkpointer
16
17
18# 读取配置文件
19def setup(args):
20 """
21 Create configs and perform basic setups.
22 """
23 cfg = get_cfg()
24 cfg.merge_from_file(args.config_file)
25 cfg.merge_from_list(args.opts)
26 cfg.freeze()
27 default_setup(cfg, args)
28 return cfg
29
30
31def main(args):
32 cfg = setup(args)
33 # 模型测试
34 if args.eval_only:
35 cfg.defrost()
36 cfg.MODEL.BACKBONE.PRETRAIN = False
37 model = DefaultTrainer.build_model(cfg)
38 # 加载预训练模型
39 Checkpointer(model).load(cfg.MODEL.WEIGHTS) # load trained model
40
41 res = DefaultTrainer.test(cfg, model)
42 return res
43 # 模型训练
44 trainer = DefaultTrainer(cfg)
45
46 trainer.resume_or_load(resume=args.resume)
47 return trainer.train()
48
49
50if __name__ == "__main__":
51 args = default_argument_parser().parse_args()
52 # 调试使用,使用的时候删除下面代码
53 # ---
54 args.config_file = "./configs/Market1501/bagtricks_R50.yml" # config路径
55 args.eval_only = True # 是否测试模型,False表示训练模型,True表示测试模型
56 # ---
57
58 print("Command Line Args:", args)
59 launch(
60 main,
61 args.num_gpus,
62 num_machines=args.num_machines,
63 machine_rank=args.machine_rank,
64 dist_url=args.dist_url,
65 args=(args,),
66 )