在fast-reid/datasets/目录提供了不同数据集的信息。可以自行下载。这里介绍其中最常用的Market-1501数据集。
Market-1501是用于行人重识别的大规模公共基准数据集。它包含由6个不同的摄像机捕获的1501个行人,以及32,668个行人图像边界框。数据集分为两部分:其中750人的图像用于训练,其余751人的图像用于测试。在官方测试协议中,选择3,368个查询图像作为查询集query,以在包含19,732张参考图像的gallery图像集中找到正确匹配。
1Market-1501
2 ├── bounding_box_test (750人的19732张图像用于测试)
3 ├── -1_c1s1_000401_03.jpg
4 ├── 0071_c6s2_072893_01.jpg
5 ├── 0071_c6s2_072918_02.jpg
6 ├── bounding_box_train (751人的12936张图像用于训练)
7 ├── 0002_c1s1_000451_03.jpg
8 ├── 0002_c1s1_000801_01.jpg
9 ├── 0430_c5s1_109673_01.jpg
10 ├── gt_bbox (25259张图像手动标注)
11 ├── 0001_c1s1_001051_00.jpg
12 ├── 0001_c1s2_041171_00.jpg
13 ├── 0933_c6s2_110943_00.jpg
14 ├── gt_query (matlab格式,用于判断一个query的哪些图片是好的匹配和不好的匹配)
15 ├── 0001_c1s1_001051_00_good.mat
16 ├── 0794_c2s2_086182_00_good.mat
17 ├── 0001_c1s1_001051_00_junk.mat
18 ├── query (750人的3368张图像用于查询)
19 ├── 0001_c1s1_001051_00.jpg
20 ├── 0001_c2s1_000301_00.jpg
21 ├── 0001_c3s1_000551_00.jpg
22 └── readme.txt
图像命名规则
以0071_c6s2_072893_01.jpg 为例
数据集使用
通常都是用度量学习的方式来使用Market-1501数据集。一般使用bounding_box_train,bounding_box_tes和query数据集中的图像进行模型训练和测试。
在fast-reid/MODEL_ZOO.md文件下提供了不同数据集下不同方法得到的sota模型。以最简单的Bot在Market1501中训练ResNet50模型为例。点击Method下的链接会转到模型配置文件路径,点击download会下载对应的预训练模型(大概300MB)。
对于对应的config路径位于fast-reid/configs目录下,所用到的文件有两个:
1configs
2 ├── Market1501
3 ├── bagtricks_R50.yml
4 ├── Base-bagtricks.yml
代码运行时会把Base-bagtricks.yml和bagtricks_R50.yml合并在一起。模型训练测试推理就是靠这两个文件,当然你可以手动把这两个文件并在一起。具体文件修改可以后续看看不同的config文件和官方代码,自己摸索摸索就可以入手。
Base-bagtricks.yml
1MODEL:
2 META_ARCHITECTURE: Baseline
3
4 BACKBONE: # 模型骨干结构
5 NAME: build_resnet_backbone
6 NORM: BN
7 DEPTH: 50x
8 LAST_STRIDE: 1
9 FEAT_DIM: 2048
10 WITH_IBN: False
11 PRETRAIN: True
12
13 HEADS: # 模型头
14 NAME: EmbeddingHead
15 NORM: BN
16 WITH_BNNECK: True
17 POOL_LAYER: GlobalAvgPool
18 NECK_FEAT: before
19 CLS_LAYER: Linear
20
21 LOSSES: # 训练loss
22 NAME: ("CrossEntropyLoss", "TripletLoss",)
23
24 CE:
25 EPSILON: 0.1
26 SCALE: 1.
27
28 TRI:
29 MARGIN: 0.3
30 HARD_MINING: True
31 NORM_FEAT: False
32 SCALE: 1.
33
34INPUT: # 模型输入图像处理方式
35 SIZE_TRAIN: [ 256, 128 ]
36 SIZE_TEST: [ 256, 128 ]
37
38 REA:
39 ENABLED: True
40 PROB: 0.5
41
42 FLIP:
43 ENABLED: True
44
45 PADDING:
46 ENABLED: True
47
48DATALOADER: # 模型读取图像方式
49 SAMPLER_TRAIN: NaiveIdentitySampler
50 NUM_INSTANCE: 4
51 NUM_WORKERS: 8
52
53SOLVER: # 模型训练配置文件
54 AMP:
55 ENABLED: True
56 OPT: Adam
57 MAX_EPOCH: 120
58 BASE_LR: 0.00035
59 WEIGHT_DECAY: 0.0005
60 WEIGHT_DECAY_NORM: 0.0005
61 IMS_PER_BATCH: 64
62
63 SCHED: MultiStepLR
64 STEPS: [ 40, 90 ]
65 GAMMA: 0.1
66
67 WARMUP_FACTOR: 0.1
68 WARMUP_ITERS: 2000
69
70 CHECKPOINT_PERIOD: 30
71
72TEST: # 模型测试配置
73 EVAL_PERIOD: 30
74 IMS_PER_BATCH: 128
75
76CUDNN_BENCHMARK: True
77MODEL:
78 META_ARCHITECTURE: Baseline
79
80 BACKBONE: # 模型骨干结构
81 NAME: build_resnet_backbone
82 NORM: BN
83 DEPTH: 50x
84 LAST_STRIDE: 1
85 FEAT_DIM: 2048
86 WITH_IBN: False
87 PRETRAIN: True
88
89 HEADS: # 模型头
90 NAME: EmbeddingHead
91 NORM: BN
92 WITH_BNNECK: True
93 POOL_LAYER: GlobalAvgPool
94 NECK_FEAT: before
95 CLS_LAYER: Linear
96
97 LOSSES: # 训练loss
98 NAME: ("CrossEntropyLoss", "TripletLoss",)
99
100 CE:
101 EPSILON: 0.1
102 SCALE: 1.
103
104 TRI:
105 MARGIN: 0.3
106 HARD_MINING: True
107 NORM_FEAT: False
108 SCALE: 1.
109
110INPUT: # 模型输入图像处理方式
111 SIZE_TRAIN: [ 256, 128 ]
112 SIZE_TEST: [ 256, 128 ]
113
114 REA:
115 ENABLED: True
116 PROB: 0.5
117
118 FLIP:
119 ENABLED: True
120
121 PADDING:
122 ENABLED: True
123
124DATALOADER: # 模型读取图像方式
125 SAMPLER_TRAIN: NaiveIdentitySampler
126 NUM_INSTANCE: 4
127 NUM_WORKERS: 8
128
129SOLVER: # 模型训练配置文件
130 AMP:
131 ENABLED: True
132 OPT: Adam
133 MAX_EPOCH: 120
134 BASE_LR: 0.00035
135 WEIGHT_DECAY: 0.0005
136 WEIGHT_DECAY_NORM: 0.0005
137 IMS_PER_BATCH: 64
138
139 SCHED: MultiStepLR
140 STEPS: [ 40, 90 ]
141 GAMMA: 0.1
142
143 WARMUP_FACTOR: 0.1
144 WARMUP_ITERS: 2000
145
146 CHECKPOINT_PERIOD: 30
147
148TEST: # 模型测试配置
149 EVAL_PERIOD: 30
150 IMS_PER_BATCH: 128
151
152CUDNN_BENCHMARK: True
bagtricks_R50.yml
注意我加了预训练模型路径。
1_BASE_: ../Base-bagtricks.yml # 链接父目录下的Base-bagtricks.yml
2
3DATASETS:
4 NAMES: ("Market1501",) # 数据集路径
5 TESTS: ("Market1501",) # 测试集路径
6
7OUTPUT_DIR: logs/market1501/bagtricks_R50 # 输出结果路径
8
9MODEL:
10 WEIGHTS: model/market_bot_R50.pth # 预训练模型路径,这句是我自己加的