Fastreid使用

这里我的示例代码结构如下所示,个人习惯为了方便调试和后续接口使用,和官方仓库不一样。

 1  ├── configs配置文件路径
 2    ├── Market1501
 3      ├── bagtricks_R50.yml
 4    ├── Base-bagtricks.yml
 5  ├── datasets数据集目录
 6      ├── Market-1501-v15.09.15 这个数据集名不要改
 7        ├── bounding_box_test (750人的19732张图像用于测试)
 8        ├── bounding_box_train (751人的12936张图像用于训练)
 9        ├── query (750人的3368张图像用于查询)
10  ├── fastreid
11  ├── model预训练模型目录),下载好的预训练模型存放在这
12  ├── demo.py提取图像的特征并保存),来自原来的demo目录
13  ├── predictor.py 模型加载文件),来自原来的demo目录
14  ├── train_net.py 模型训练与测试封装版代码),来自原来的tools目录
15  ├── visualize_result.py 可视化特征提取结果),来自原来的demo目录

重点关注几个py文件,我直接挪到根目录下了。还有模型文件的保存路径,config预训练模型地址,数据集的名字也要注意的。各个文件具体使用可以看看下面介绍,都有代码注释。

特别注意,py文件为了方便调试,我直接在代码里面设置了*args的参数,实际使用要特别注意。

demo.py

这个代码就是加载模型(调用predictor.py),提取查询图像的特征,并保存为npy文件。保存在demo_output文件夹下,一张图像对一个npy文件。这些包含特征向量的npy文件可供后续向量检索使用。

1t6ha

  1# encoding: utf-8
  2"""
  3@author:  liaoxingyu
  4@contact: sherlockliao01@gmail.com
  5提取图像的特征,并保存
  6"""
  7
  8import argparse
  9import glob
 10import os
 11import sys
 12
 13import torch.nn.functional as F
 14import cv2
 15import numpy as np
 16import tqdm
 17from torch.backends import cudnn
 18
 19sys.path.append('.')
 20
 21from fastreid.config import get_cfg
 22from fastreid.utils.logger import setup_logger
 23from fastreid.utils.file_io import PathManager
 24
 25from predictor import FeatureExtractionDemo
 26
 27# import some modules added in project like this below
 28# sys.path.append("projects/PartialReID")
 29# from partialreid import *
 30
 31cudnn.benchmark = True
 32setup_logger(name="fastreid")
 33
 34
 35# 读取配置文件
 36def setup_cfg(args):
 37    # load config from file and command-line arguments
 38    cfg = get_cfg()
 39    # add_partialreid_config(cfg)
 40    cfg.merge_from_file(args.config_file)
 41    cfg.merge_from_list(args.opts)
 42    cfg.freeze()
 43    return cfg
 44
 45
 46def get_parser():
 47    parser = argparse.ArgumentParser(description="Feature extraction with reid models")
 48    parser.add_argument(
 49        "--config-file",  # config路径,通常包含模型配置文件
 50        metavar="FILE",
 51        help="path to config file",
 52    )
 53    parser.add_argument(
 54        "--parallel",  # 是否并行
 55        action='store_true',
 56        help='If use multiprocess for feature extraction.'
 57    )
 58    parser.add_argument(
 59        "--input",  # 输入图像路径
 60        nargs="+",
 61        help="A list of space separated input images; "
 62             "or a single glob pattern such as 'directory/*.webp'",
 63    )
 64    parser.add_argument(
 65        "--output",  # 输出结果路径
 66        default='demo_output',
 67        help='path to save features'
 68    )
 69    parser.add_argument(
 70        "--opts",
 71        help="Modify config options using the command-line 'KEY VALUE' pairs",
 72        default=[],
 73        nargs=argparse.REMAINDER,
 74    )
 75    return parser
 76
 77
 78def postprocess(features):
 79    # Normalize feature to compute cosine distance
 80    features = F.normalize(features)  # 特征归一化
 81    features = features.cpu().data.numpy()
 82    return features
 83
 84
 85if __name__ == '__main__':
 86    args = get_parser().parse_args()  # 解析输入参数
 87    # 调试使用,使用的时候删除下面代码
 88    # ---
 89    args.config_file = "./configs/Market1501/bagtricks_R50.yml"  # config路径
 90    args.input = "./datasets/Market-1501-v15.09.15/query/*.webp"  # 图像路径
 91    # ---
 92
 93    cfg = setup_cfg(args)  # 读取cfg文件
 94    demo = FeatureExtractionDemo(cfg, parallel=args.parallel)  # 加载特征提取器,也就是加载模型
 95
 96    PathManager.mkdirs(args.output)  # 创建输出路径
 97    if args.input:
 98        if PathManager.isdir(args.input[0]):  # 判断输入的是否为路径
 99            # args.input = glob.glob(os.path.expanduser(args.input[0])) # 原来的代码有问题
100            args.input = glob.glob(os.path.expanduser(args.input))  # 获取输入路径下所有的文件路径
101            assert args.input, "The input path(s) was not found"
102        for path in tqdm.tqdm(args.input):  # 逐张处理
103            img = cv2.imread(path)
104            feat = demo.run_on_image(img)  # 提取图像特征
105            feat = postprocess(feat)  # 后处理主要是特征归一化
106            np.save(os.path.join(args.output, os.path.basename(path).split('.')[0] + '.npy'), feat)  # 保存图像对应的特征,以便下次使用

visualize_result.py

这个代码就是加载模型(调用predictor.py),提取查询图像的特征,计算模型的各个精度指标。输出模型的ROC结果图,以及某张图像的匹配结果图像。输出目录为vis_rank_list。

ROC结果图如下图所示,ROC曲线下的面积AUC越大,表示模型效果越好。top1精度93.37左右。

gbkkm

某张图像的匹配结果图像如下所示。每张图有1张查询图和5张查询结果图,左1为查询图像,其他为查询结果图。蓝色框表示查询结果错误,红色框表示查询结果正确。在查询结果图上有标题,比如0.976/false/cam1,表示当前查询结果图像和查询图像特征距离为0.976,查询结果为false(查询错误),该查询结果来自cam1摄像头。查询图像上的标题,如0.9967/cam2,这里0.9967表示查询图像的查询结果精度指标,cam2表示查询图像来自cam2摄像头。

h2666

tg288

  1# encoding: utf-8
  2"""
  3@author:  xingyu liao
  4@contact: sherlockliao01@gmail.com
  5可视化特征提取结果
  6"""
  7
  8import argparse
  9import logging
 10import sys
 11
 12import numpy as np
 13import torch
 14import tqdm
 15from torch.backends import cudnn
 16
 17sys.path.append('.')
 18
 19import torch.nn.functional as F
 20from fastreid.evaluation.rank import evaluate_rank
 21from fastreid.config import get_cfg
 22from fastreid.utils.logger import setup_logger
 23from fastreid.data import build_reid_test_loader
 24from predictor import FeatureExtractionDemo
 25from fastreid.utils.visualizer import Visualizer
 26
 27# import some modules added in project
 28# for example, add partial reid like this below
 29# sys.path.append("projects/PartialReID")
 30# from partialreid import *
 31
 32cudnn.benchmark = True
 33setup_logger(name="fastreid")
 34
 35logger = logging.getLogger('fastreid.visualize_result')
 36
 37
 38# 读取配置文件
 39def setup_cfg(args):
 40    # load config from file and command-line arguments
 41    cfg = get_cfg()
 42    # add_partialreid_config(cfg)
 43    cfg.merge_from_file(args.config_file)
 44    cfg.merge_from_list(args.opts)
 45    cfg.freeze()
 46    return cfg
 47
 48
 49def get_parser():
 50    parser = argparse.ArgumentParser(description="Feature extraction with reid models")
 51    parser.add_argument(
 52        "--config-file",  # config路径,通常包含模型配置文件
 53        metavar="FILE",
 54        help="path to config file",
 55    )
 56    parser.add_argument(
 57        '--parallel',  # 是否并行
 58        action='store_true',
 59        help='if use multiprocess for feature extraction.'
 60    )
 61    parser.add_argument(
 62        "--dataset-name",  # 数据集名字
 63        help="a test dataset name for visualizing ranking list."
 64    )
 65    parser.add_argument(
 66        "--output",  # 输出结果路径
 67        default="./vis_rank_list",
 68        help="a file or directory to save rankling list result.",
 69
 70    )
 71    parser.add_argument(
 72        "--vis-label",  # 输出结果是否查看
 73        action='store_true',
 74        help="if visualize label of query instance"
 75    )
 76    parser.add_argument(
 77        "--num-vis",  # 挑选多少张图像用于结果展示
 78        default=1000,
 79        help="number of query images to be visualized",
 80    )
 81    parser.add_argument(
 82        "--rank-sort",  # 结果展示是相似度排序方式,默认从小到大排序
 83        default="ascending",
 84        help="rank order of visualization images by AP metric",
 85    )
 86    parser.add_argument(
 87        "--label-sort",  # label结果展示是相似度排序方式,默认从小到大排序
 88        default="ascending",
 89        help="label order of visualization images by cosine similarity metric",
 90    )
 91    parser.add_argument(
 92        "--max-rank",  # 显示topk的结果,默认显示前10个结果
 93        default=5,
 94        help="maximum number of rank list to be visualized",
 95    )
 96    parser.add_argument(
 97        "--opts",
 98        help="Modify config options using the command-line 'KEY VALUE' pairs",
 99        default=[],
100        nargs=argparse.REMAINDER,
101    )
102    return parser
103
104
105if __name__ == '__main__':
106    args = get_parser().parse_args()
107    # 调试使用,使用的时候删除下面代码
108    # ---
109    args.config_file = "./configs/Market1501/bagtricks_R50.yml"  # config路径
110    args.dataset_name = 'Market1501'  # 数据集名字
111    args.vis_label = False  # 是否显示正确label结果
112    args.rank_sort = 'descending'  # 从大到小显示关联结果
113    args.label_sort = 'descending'  # 从大到小显示关联结果
114    # ---
115
116    cfg = setup_cfg(args)
117    # 可以直接在代码中设置cfg中设置模型路径
118    # cfg["MODEL"]["WEIGHTS"] = './configs/Market1501/bagtricks_R50.yml'
119    test_loader, num_query = build_reid_test_loader(cfg, dataset_name=args.dataset_name)  # 创建测试数据集
120    demo = FeatureExtractionDemo(cfg, parallel=args.parallel)  # 加载特征提取器,也就是加载模型
121
122    logger.info("Start extracting image features")
123    feats = []  # 图像特征,用于保存每个行人的图像特征
124    pids = []  # 行人id,用于保存每个行人的id
125    camids = []  # 拍摄的摄像头,行人出现的摄像头id
126    # 逐张保存读入行人图像,并保存相关信息
127    for (feat, pid, camid) in tqdm.tqdm(demo.run_on_loader(test_loader), total=len(test_loader)):
128        feats.append(feat)
129        pids.extend(pid)
130        camids.extend(camid)
131
132    feats = torch.cat(feats, dim=0)  # 将feats转换为tensor的二维向量,向量维度为[图像数,特征维度]
133    # 这里把query和gallery数据放在一起了,需要切分query和gallery的数据
134    q_feat = feats[:num_query]
135    g_feat = feats[num_query:]
136    q_pids = np.asarray(pids[:num_query])
137    g_pids = np.asarray(pids[num_query:])
138    q_camids = np.asarray(camids[:num_query])
139    g_camids = np.asarray(camids[num_query:])
140
141    # compute cosine distance 计算余弦距离
142    q_feat = F.normalize(q_feat, p=2, dim=1)
143    g_feat = F.normalize(g_feat, p=2, dim=1)
144    distmat = 1 - torch.mm(q_feat, g_feat.t())  # 这里distmat表示两张图像的距离,越小越接近
145    distmat = distmat.numpy()
146
147    # 计算各种评价指标 cmc[0]就是top1精度,应该是93%左右,这里精度会有波动
148    logger.info("Computing APs for all query images ...")
149    cmc, all_ap, all_inp = evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids)
150    logger.info("Finish computing APs for all query images!")
151
152    visualizer = Visualizer(test_loader.dataset)  # 创建Visualizer类
153    visualizer.get_model_output(all_ap, distmat, q_pids, g_pids, q_camids, g_camids)  # 保存结果
154
155    logger.info("Start saving ROC curve ...")  # 保存ROC曲线
156    fpr, tpr, pos, neg = visualizer.vis_roc_curve(args.output)
157    visualizer.save_roc_info(args.output, fpr, tpr, pos, neg)
158    logger.info("Finish saving ROC curve!")
159
160    logger.info("Saving rank list result ...")  # 保存部分查询图像的关联结果,按照顺序排列
161    query_indices = visualizer.vis_rank_list(args.output, args.vis_label, args.num_vis,
162                                             args.rank_sort, args.label_sort, args.max_rank)
163    logger.info("Finish saving rank list results!")

train_net.py

这段代码调用config文件,训练或者测试模型。训练模型设置args.eval_only = False,反之为测试模型。测试模型结果如下图所示。代码封装的很不错,把该有的测试指标都贴上去了。

gspfq

另外这是封装过多的代码,如果想知道清晰的训练代码查看fast-reid/tools/plain_train_net.py,这个文件提供了详细没有封装过多的训练代码。

 1#!/usr/bin/env python
 2# encoding: utf-8
 3"""
 4@author:  sherlock
 5@contact: sherlockliao01@gmail.com
 6模型训练与测试封装版代码
 7"""
 8
 9import sys
10
11sys.path.append('.')
12
13from fastreid.config import get_cfg
14from fastreid.engine import DefaultTrainer, default_argument_parser, default_setup, launch
15from fastreid.utils.checkpoint import Checkpointer
16
17
18# 读取配置文件
19def setup(args):
20    """
21    Create configs and perform basic setups.
22    """
23    cfg = get_cfg()
24    cfg.merge_from_file(args.config_file)
25    cfg.merge_from_list(args.opts)
26    cfg.freeze()
27    default_setup(cfg, args)
28    return cfg
29
30
31def main(args):
32    cfg = setup(args)
33    # 模型测试
34    if args.eval_only:
35        cfg.defrost()
36        cfg.MODEL.BACKBONE.PRETRAIN = False
37        model = DefaultTrainer.build_model(cfg)
38        # 加载预训练模型
39        Checkpointer(model).load(cfg.MODEL.WEIGHTS)  # load trained model
40
41        res = DefaultTrainer.test(cfg, model)
42        return res
43    # 模型训练
44    trainer = DefaultTrainer(cfg)
45
46    trainer.resume_or_load(resume=args.resume)
47    return trainer.train()
48
49
50if __name__ == "__main__":
51    args = default_argument_parser().parse_args()
52    # 调试使用,使用的时候删除下面代码
53    # ---
54    args.config_file = "./configs/Market1501/bagtricks_R50.yml"  # config路径
55    args.eval_only = True  # 是否测试模型,False表示训练模型,True表示测试模型
56    # ---
57
58    print("Command Line Args:", args)
59    launch(
60        main,
61        args.num_gpus,
62        num_machines=args.num_machines,
63        machine_rank=args.machine_rank,
64        dist_url=args.dist_url,
65        args=(args,),
66    )