BBN:Bilateral-Branch network with cumulative learning for long-tailed visual recognition
1.Abstract
class re-balancing strategies,包括两种,re-weighting and re-sampling,是解决长尾现象的极端类别不平衡有效的方法。we firstly discover that these rebalancing methods achieving satisfactory recognition accuracy owe to that they could significantly promote the classifier learning of deep networks. However, at the same time, they will unexpectedly damage the representative ability of the learned deep features to some extent. 这些方法之所以有效,是因为他们能够显著的促进深层网络的分类器学习能力,然后某种程度上也会破坏所学深度特征的表征能力。因此,作者把处理长尾问题看成了两个阶段,第一个是表征能力,第二个是分类器的判别能力。
2.Introduction
在文献中,处理长尾问题最有效的方式是class re-balancing strategies,re-sampling 和cost-sensitive re-weighting. 这些方法可以调整网络训练,通过重新抽样的例子或重新加权的损失,在小批量的例子,这是在期望更接近测试分布。因此,类重平衡可以直接影响深度网络分类器权值的更新,从而促进分类器的学习。这就是重新平衡可以在长尾数据上获得令人满意的识别精度的原因。
然而,尽管重新平衡方法具有良好的最终预测,但我们认为这些方法仍然具有不利影响,即它们也会在一定程度上意外地损害所学习的深层特征(即表征学习)的表征能力。具体地说,当数据不平衡严重时,重采样有过度拟合尾部数据(通过过度采样)的风险,也有欠拟合整个数据分布(通过欠采样)的风险。对于重新加权,它将通过直接改变甚至反转数据呈现频率来扭曲原始分布。这里很重要,这是作者的理论来源,后面会专门做实验去证明一点,一般常规的class re-balancing策略是因为对分类器的权重进行更新从而有了更好的结论,但是他们对于特征的表征能力实际上是削弱了,能不能把两者的有点结合起来,既要更好的表征能力,又要很好的更新分类器的权重呢?
3.related work
re-sampling的两种方式:1) Over-sampling by simply repeating data for minority classes and 2) under-sampling by abandoning data for dominant classes.
re-weighting:这个应该和pytorch中的 WeightedRandomSampler 是类似的。
mixup
4.how class re-balancing strategies work?
这里是作者对上述猜想的实验?猜想是re-balancing的策略促进了分类的学习,但是损害了特征的表征能力?先介绍几个词,CE:常规的长尾数据训练,RS:re-sampling,RW:re-weighting.作者做了两个阶段,第一阶段,用CE/RS/RW三种方式训练数据,得到了这些学习方式相对应的不同类别的特征提取器。第二阶段,我们固定前一阶段学习到的特征提取器的参数,然后用不同的方式训练分类器。
看上面这个表,纵坐标是分类器视角(对应第二阶段),横坐标是特征表征能力视角(对应第一阶段) ,比如CE-CE,就是用常规方式训练的特征提取器,再用常规方式训练的分类器的错误率是58.62,CE-RW,就是用常规方式训练的特征提取器,用RW训练的分类器的错误率是56.53.从分类器学习的视角看,也就是在纵向上,RS,RW总是有更低的分类错误率,说明,控制了特征提取之后,class re-balancing的方式对分类器的学习起到了很积极的作用,从特征表征能力的视角看,也就是横向上,发现CE总是有更低的错误率,说明在统一的分类器训练方式下,CE的特征表征能力更强。这就和猜想对上了。
5.Methodology
这个网络的提出就顺理成章了, 一条支路是在长尾数据上的常规训练,一条支路是用reversed sampler方式采样的训练方式,让模型更关注尾部数据。采样方式类似于WeightedRandomSampler。除此之外,这里面还有个alpha,用来对两条支路学习的f进行加权。在前向的时候也是经过两条支路的,不过alpha取的0.5。
代码:
def bbn_mix(self, model, criterion, image, label, meta, **kwargs):
image_a, image_b = image.to(self.device), meta["sample_image"].to(self.device)
label_a, label_b = label.to(self.device), meta["sample_label"].to(self.device)
feature_a, feature_b = (
model(image_a, feature_cb=True),
model(image_b, feature_rb=True),
)
l = 1 - ((self.epoch - 1) / self.div_epoch) ** 2 # parabolic decay
# l = 0.5 # fix
# l = math.cos((self.epoch-1) / self.div_epoch * math.pi /2) # cosine decay
# l = 1 - (1 - ((self.epoch - 1) / self.div_epoch) ** 2) * 1 # parabolic increment
# l = 1 - (self.epoch-1) / self.div_epoch # linear decay
# l = np.random.beta(self.alpha, self.alpha) # beta distribution
# l = 1 if self.epoch <= 120 else 0 # seperated stage
mixed_feature = 2 * torch.cat((l * feature_a, (1 - l) * feature_b), dim=1)
output = model(mixed_feature, classifier_flag=True)
loss = l * criterion(output, label_a) + (1 - l) * criterion(output, label_b)
now_result = torch.argmax(self.func(output), 1)
now_acc = (
l * accuracy(now_result.cpu().numpy(), label_a.cpu().numpy())[0]
+ (1 - l) * accuracy(now_result.cpu().numpy(), label_b.cpu().numpy())[0]
)
return loss, now_acc
image_a,image_b是两条支路出来的输入图,不同的采样方式,进入model,model的权重是共享的(resnet),但是主干网出来各有一个basicblock不是共享的,对应图中的红蓝块,得到feature_a,feature_b,用alpha做一下concat。
6. validation experiments of our proposals
这里有个有意思的实验,作者可视化了每个类别的l2范数,不同类别的l2范数可以证明分类器的偏好,即具有最大l2范数的分类器的权重倾向于判断一个示例的类别。可以从上图中看到BBN-CB后面几类的l2范数小,BBN-RB明显后面几类l2大,说明采用方式还是让长尾数据得到了很好的利用。
本文从分类器学习和特征表征学习两个角度出发,进行相对应的实验,建立了两支路,一路走CE,保存特征表征能力,一路走采样,保留分类器学习能力,很合理。
对一个领域问题,比如长尾,类别不均衡,细粒度,噪声样本,还是要有理论上的支撑,不然有的时候不知道从什么方面下手,就在不断的试模型,其实是收益很少的工作,算法工程师是接触实际业务一线的人,要在理论支撑下了解普遍业务场景的共性,举一反三,其实一个现实场景中的问题往往也不是一类问题,想要很好的定义并且去找相应的办法其实也不简单,比如之前做的banner品类分类,大品类是从一些图片网站上爬下来的标签,本身噪声标签就很多,其次,十几二十类的数据也不可能是均衡的,基本都是长尾的,并且有些类别并不是严格可分的,类似双十一618这种标签有些图片就是很难做区分的。
工程类论文之所以好看,还是因为其有一个在实际业务中的问题或者发现作为基础,不会产生空中楼阁那样的理论,比如之前ghostnet,rfbnet这类文章,就编故事,虽然有效,但没那味。这篇论文的代码也是值得学习的,两路其实在工程上也算是比较好的融合特征提取方案,比如很早小视科技开源的活体检测方案也是用了双路,还有我之前做的篡改检测识别模型也是用了双路,一路DCT转换,一路rgb,因此这篇文章的代码以及工程上一些组合,特征mix也是很值得学习的。
更新:在一个27品类,总数为112942,样本分布极不平衡的数据集上未收敛,BBN论文原始参数,同期在resnet上top1为83.8054,top5为96.8726,作者本文支撑理论的实验靠谱?或者说是个普通的现象吗?待议。此外,本文训练时间极慢,本来双支路是会有点慢,但是比我想象中慢太多了,并且显存占用大概是resnet的3倍左右,不像是个理想的工程方案。
一杯白开水儿: 请问测试下来哪个效果更好? 相比RVM和BGMV2怎么样
星空真懒: 优质好文,支持支持。【我也写了一些相关领域的文章,希望能够得到博主的指导,共同进步!】
G鲲鹏展翅Y: 我的oa为啥返回是0?
weixin_45199492: 博主大大你好,请问一下你这个代码里的labels是什么样的,我用这种做labels,text_emb的维度是[batch_size, hidden_size],label_emb的维度是[num_classes, hidden_size] 我的原始labels= ["金融", "房地产", "股票", "教育", "科学", "社会", "政治", "体育", "游戏", "娱乐"],非常期待您的回答,感恩
weixin_47486447: 请问system_message的作用是什么?是作为输入GPT用来生成158k指令跟随数据的Prompt吗? 另外Caption中输入的5句描述是来自于哪?也是GPT生成的吗?求解答