跳转至

SimCLR

SimCLR通过使用对比学习在ImageNet上达到了69.3%的精度。

动机

学习有效的视觉表达而无需人类监督是一个长期存在的问题。主流方法通常可被归类为两种:生成类和判别类。

生成方法学习在输入空间中生成或建模像素。然而,像素级生成非常的计算昂贵,并且对于表达学习来说也可能有些过了。

判别方法使用类似于监督学习中的目标函数来学习表达,区别在于他们使用前置任务来训练网络,而输入和标签均衍生自无标签数据集。许多这些方法都依赖启发式方法以设计前置任务,这可能会限制所学表达的通用性。

近来,基于在隐空间进行对比学习的判别方法表现出了极好的前景,达到了SOTA结果。

贡献

本文提出了一个对比学习的简单的框架。他无需特殊的架构,也不需要内存银行。

  1. 多个数据增强的组合对于定义产生有效表达的预测任务来说至关重要。与监督学习相比,无监督学习从数据增强中获得的收益更大。

  2. 在表达与对比损失之间增加一个科学系的非线性变换可以很大程度上提高学到的表达的质量。

  3. 使用对比交叉熵损失函数的表达学习可以从归一化的嵌入和对温度参数的恰当调整中受益。

  4. 相较于监督学习,对比学习受益于更大的batch size与更长的训练。与监督学习类似,对比学习受益于更深更宽的网络。

方法

框架

本文从隐空间的对比损失来最大化同一数据样本的不同增强视图的一致性以学习表达。本框架由四个主要部分构成。

  1. 一个随机 数据增强 模块,对任一给定数据样本产生两个相关的视图,表示为\(\tilde{x}_i\)\(\tilde{x}_j\),我们将其称为正样本对。在本工作中, 我们顺序的应用三个简单的增强:随即裁剪(镜像)并将其缩放回原始大小、随机色彩失真、随机高斯模糊。 其中,裁剪和颜色失真的组合对性能有至关重要的影响。

  2. 一个神经网络 基础编码器 \(f(\cdot)\),从增强后的数据样本中提取出表达向量。本文的框架允许使用多种网络结构而没有限制。本文出于简单的考量使用了广泛使用的ResNet来获取\(\mathbf{h}_i = f(\tilde{x}_i) = \mathrm{ResNet}(\tilde{x}_i)\),其中\(\mathbf{h}_i \in \mathbb{R}^d\)是平均池化层的输出。

  3. 一个小的神经网络 映射头 \(g(\cdot)\),将表达映射到对比损失被应用的空间上。本文使用一个具有一个隐藏层的MLP来获得\(\mathcal{z}_i = g(\mathbf{h}_i) = W^{(2)} \sigma (W^{(1)} h_i)\),其中\(\sigma\)代表ReLU非线性层。我们发现在\(\mathcal{z}_i\)上定义对比损失比在\(\mathbf{h}_i\)上更好。

  4. 一个 对比损失函数 NT-Xent \(\mathcal{l}_{i, j} = -\log \frac{\exp(\mathrm{sim}(\mathcal{z}_i, \mathcal{z}_j) / \tau)}{\sum^{2N}_{k=1} \mathbf{1}_{[k \neq i]} \exp(\mathrm{sim}(\mathcal{z}_i, \mathcal{z}_k) / \tau)}\),其中\(N\)为batch size,\(\mathrm{sim}(u, v) = u^Tv / \left\Vert u \right\Vert \left\Vert v \right\Vert\)表示\(\mathcal{l}_2\)正则后的\(u\)\(v\)的点积(也即余弦相似度),\(\mathbf{1}_{[k \neq i]}\)为一个指示器函数,当且仅当\(k \neq i\)时为1,\(\tau\)表示温度参数。损失在所有正样本对上被计算,也即\((i, j)\)\((j, i)\)

大batch size训练

标准的SGD/Momentum结合线性学习率调整对于超大规模网络训练来说可能不太稳定。本工作使用LARS优化器以解决这个问题。

由于正样本对通常在同一张卡上,模型可能利用局部信息的泄露而在不提升表达的情况下提升性能。本文通过聚合BN的均值和方差来实现Global BN以解决这个问题。其他方法还包括跨设备随机排布数据、使用layer norm替代BN。

实验设置

本工作在ImageNet ILSVRC-2012上训练,数据增强包括随机剪裁并缩放回原始大小(伴随随机镜像)、色彩失真和高斯模糊。模型使用ResNet-50作为编码器,和一个2层MLP映射头来将表达映射到一个128维的隐空间上。本工作使用NT-Xent损失,通过LARS进行优化。学习率被设置为4.8(\(0.3 \times \text{batch size} / 256\)),weight decay设置为1e-6。Batch size为4096,训练100个周期s。此外,线性预热在前10个周期被应用,学习率按照cosine decay schedule衰减,没有重启。

讨论

  • 数据增强的构成对学习好的表达至关重要
  • 相较于监督学习,对比学习需要更强的数据增强
  • 对比学习从更大的模型中获益(更多)
  • 非线性映射头提升在其之前的表达的质量
  • 正则化的交叉熵损失与可调的温度参数比其他方法效果更好
  • 对比学习从更大的batch sizes和更长的训练中获益(更多)