一个可学习通道转换引发的惨案

最近在做异常检测的算法,主实验是在 BMAD 这套数据集上做 leave one out,BMAD 这套数据集是由MRI、CT、OCT等一共六个数据集组成,这些图像的通道数差异巨大。

为了适应统一的模型输入,我用了一个可学习的 channel_proj 层,它的作用是将不同通道数的输入转换为统一的3通道格式。

1
2
3
4
5
6
7
# 可训练的通道转换层
self.channel_proj = nn.Conv2d(4, self.target_channels, kernel_size=1, bias=False)
# 初始化卷积权重(凯明初始化,适合ReLU类激活)
nn.init.kaiming_normal_(self.channel_proj.weight, mode='fan_out', nonlinearity='relu')
...
# 维度转换
x_out = self.channel_proj(x_flat)

问题

然而,这个看似简单的转换层,在跨域任务中,带来极其灾难的后果。

如果不做跨域,那其实这个可学习的通道转换也许是有利的(因为模型在尽量拟合同域的数据分布),但是问题在于我做的是跨域的,也就是在A上训,B上推。

具体表现就是:虽然源域训练中没有问题,但在目标域上,随着训练的进行,验证准确率发生下降。

因为 channel_proj 层是一个可学习的1×1卷积,它的作用是对输入图像进行通道映射,而他又是在模型的入口(且我的backbone是冻结的),因此在源域数据的训练过程中,这个模块学到了一些对目标域不适用的映射,而在目标域推理时,就会导致对输入图像进行了错误的偏移。

解决

因为BMAD这套数据集里即使有四通道的,第四个通道也是无意义信息,所以直接丢弃即可。

1
x_out = x_flat[:, :3, :, :].contiguous()

后续

改了之后明显跨域性能得到提高,但这个问题导致浪费了很多时间,因为我后面的很多改进都是在受到这个负面影响的指标下在做改进。