跳转至

SimCLRv2

SimCLRv2提出了一个三步的半监督算法,首先使用SimCLRv2无监督预训练一个大模型,然后使用一小部分标注数据监督fine-tune,最后通过无标签样本蒸馏来refine和迁移任务相关知识。对于ImageNet,在使用1%的标签时,本工作达到了73.9%的top-1准确率,并在使用10%的标签时达到了77.5%的top-1准确率。

动机

使用无标签数据进行任务无关的预训练,然后使用标签数据进行任务相关的fine-tune在自然语言处理中被广泛应用,但在计算机视觉中还非主流。在计算机视觉当中更常见的另一个方法时在监督学习中直接利用无标签数据,作为一种规则化方法。这种方法在任务特定方式中使用无标签数据以鼓励在不同模型之间或者不同数据增强下对无标签数据进行类标签预测的一致性。

贡献

本文提出了一个半监督学习的框架,该框架由三个部分组成:1. 无监督或自监督预训练,2. 监督 fine-tune,3. 使用无标签数据蒸馏。

  1. 对于半监督学习(通过任务无关的使用无标签数据),标签越少,越有可能受益于更大的模型。更大的自监督模型拥有更高的标签效率,在很少的标签样本上fine-tune之后有显著更好的表现,尽管他们的容量更大因而更可能过拟合。

  2. 尽管大模型对于学习通用(视觉)表达而言十分重要,但是额外的容量对于特定目标任务而言可能是不必须的。因此,通过使用任务特定的使用无标签数据,模型的预测性能可以被进一步提升并被迁移到更小的网络中。

  3. 更深的映射头不但能提升通过线性评估测得的表达质量,而且还能提升从映射头的一个 中间层 进行微调时的半监督性能。

方法

框架

本文从隐空间的对比损失来最大化同一数据样本的不同增强视图的一致性以学习表达。本框架由四个主要部分构成。

  1. 一个随机 数据增强 模块,对任一给定数据样本产生两个相关的视图,表示为\(\tilde{x}_i\)\(\tilde{x}_j\),我们将其称为正样本对。在本工作中, 我们顺序的应用三个简单的增强:随即裁剪(镜像)并将其缩放回原始大小、随机色彩失真、随机高斯模糊。 其中,裁剪和颜色失真的组合对性能有至关重要的影响。

  2. 一个神经网络 基础编码器 \(f(\cdot)\),从增强后的数据样本中提取出表达向量。本文的框架允许使用多种网络结构而没有限制。本文出于简单的考量使用了广泛使用的ResNet来获取\(\mathbf{h}_i = f(\tilde{x}_i) = \mathrm{ResNet}(\tilde{x}_i)\),其中\(\mathbf{h}_i \in \mathbb{R}^d\)是平均池化层的输出。

  3. 一个小的神经网络 预测头 \(g(\cdot)\),将表达映射到对比损失被应用的空间上。本文使用一个具有一个隐藏层的MLP来获得\(\mathcal{z}_i = g(\mathbf{h}_i) = W^{(2)} \sigma (W^{(1)} h_i)\),其中\(\sigma\)代表ReLU非线性层。我们发现在\(\mathcal{z}_i\)上定义对比损失比在\(\mathbf{h}_i\)上更好。

  4. 一个 对比损失函数 NT-Xent \(\mathcal{l}_{i, j} = -\log \frac{\exp(\mathrm{sim}(\mathcal{z}_i, \mathcal{z}_j) / \tau)}{\sum^{2N}_{k=1} \mathbf{1}_{[k \neq i]} \exp(\mathrm{sim}(\mathcal{z}_i, \mathcal{z}_k) / \tau)}\),其中\(N\)为batch size,\(\mathrm{sim}(u, v) = u^Tv / \left\Vert u \right\Vert \left\Vert v \right\Vert\)表示\(\mathcal{l}_2\)正则后的\(u\)\(v\)的点积(也即余弦相似度),\(\mathbf{1}_{[k \neq i]}\)为一个指示器函数,当且仅当\(k \neq i\)时为1,\(\tau\)表示温度参数。损失在所有正样本对上被计算,也即\((i, j)\)\((j, i)\)

大batch size训练

标准的SGD/Momentum结合线性学习率调整对于超大规模网络训练来说可能不太稳定。本工作使用LARS优化器以解决这个问题。

由于正样本对通常在同一张卡上,模型可能利用局部信息的泄露而在不提升表达的情况下提升性能。本文通过聚合BN的均值和方差来实现Global BN以解决这个问题。其他方法还包括跨设备随机排布数据、使用layer norm替代BN。

实验设置

本工作在ImageNet ILSVRC-2012上训练,数据增强包括随机剪裁并缩放回原始大小(伴随随机镜像)、色彩失真和高斯模糊。模型使用ResNet-50作为编码器,和一个2层MLP映射头来将表达映射到一个128维的隐空间上。本工作使用NT-Xent损失,通过LARS进行优化。学习率被设置为4.8(\(0.3 \times \text{batch size} / 256\)),weight decay设置为1e-6。Batch size为4096,训练100个周期s。此外,线性预热在前10个周期被应用,学习率按照cosine decay schedule衰减,没有重启。

讨论

  • 数据增强的构成对学习好的表达至关重要
  • 相较于监督学习,对比学习需要更强的数据增强
  • 对比学习从更大的模型中获益(更多)
  • 非线性映射头提升在其之前的表达的质量
  • 正则化的交叉熵损失与可调的温度参数比其他方法效果更好
  • 对比学习从更大的batch sizes和更长的训练中获益(更多)