标注数据存在错误怎么办?MIT&Google开源置信学习新方法,找出错误标注

标注数据存在错误怎么办?MIT&Google开源置信学习新方法,找出错误标注

加入极市专业CV交流群,与 10000+来自港科大、北大、清华、中科院、CMU、腾讯、百度 等名校名企视觉开发者互动交流!

同时提供每月大咖直播分享、真实项目需求对接、干货资讯汇总,行业技术交流。关注 极市平台 公众号 ,回复 加群,立刻申请入群~

作者|JayLou娄杰,https://zhuanlan.zhihu.com/p/146557232
本文已获作者授权,不得二次转载

监督学习虽好,错误标注却能让模型翻车~

在实际工作中,你是否遇到过这样一个问题或痛点:无论是通过哪种方式获取的标注数据,数据标注质量可能不过关,存在一些错误?亦或者是数据标注的标准不统一、存在一些歧义?特别是badcase反馈回来,发现训练集标注的居然和badcase一样?如下图所示,QuickDraw、MNIST和Amazon Reviews数据集中就存在错误标注。

为了快速迭代,大家是不是常常直接人工去清洗这些“脏数据”?(笔者也经常这么干~)。但数据规模上来了咋整?有没有一种方法能够自动找出哪些错误标注的样本呢?基于此,本文尝试提供一种可能的解决方案——置信学习

本文的组织架构是:


1、置信学习的定义

那什么是置信学习呢?这个概念来自一篇由MIT和Google联合提出的paper:《Confident Learning: Estimating Uncertainty in Dataset Labels[1] 》。论文提出的置信学习(confident learning,CL是一种新兴的、具有原则性的框架,以识别标签错误、表征标签噪声并应用于带噪学习(noisy label learning)。

笔者注:笔者乍一听「置信学习」挺陌生的,但回过头来想想,好像干过类似的事情,比如:在某些场景下,对训练集通过交叉验证来找出一些可能存在错误标注的样本,然后交给人工去纠正。此外,神经网络的成功通常建立在大量、干净的数据上,标注错误过多必然会影响性能表现,带噪学习可是一个大的topic,有兴趣可参考这些文献 https://github.com/subeeshvasu/Awesome-Learning-with-Label-Noise。

废话不说,首先给出这种置信学习框架的优势

  • 最大的优势:可以用于发现标注错误的样本!
  • 无需迭代,开源了相应的python包,方便快速使用!在ImageNet中查找训练集的标签错误仅仅需要3分钟!
  • 可直接估计噪声标签与真实标签的联合分布,具有理论合理性
  • 不需要超参数,只需使用交叉验证来获得样本外的预测概率。
  • 不需要做随机均匀的标签噪声的假设(这种假设在实践中通常不现实)。
  • 与模型无关,可以使用任意模型,不像众多带噪学习与模型和训练过程强耦合。

笔者注:置信学习找出的「标注错误的样本」,不一定是真实错误的样本,这是一种基于不确定估计的选择方法。

2、置信学习开源工具:cleanlab

论文最令人惊喜的一点就是作者这个置信学习框架进行了开源,并命名为cleanlab,我们可以pip install cleanlab使用,具体文档说明在这里cleanlab文档说明。


      
      
    
from cleanlab.pruning import get_noise_indices# 输入# s:噪声标签# psx: n x m 的预测概率概率,通过交叉验证获得ordered_label_errors = get_noise_indices(    s=numpy_array_of_noisy_labels,    psx=numpy_array_of_predicted_probabilities,    sorted_index_method='normalized_margin', # Orders label errors )

我们来看看cleanlab在MINIST数据集中找出的错误样本吧,是不是感觉很牛~

标注数据存在错误怎么办?MIT&Google开源置信学习新方法,找出错误标注

如果你不只是想找到错误标注的样本,还想把这些标注噪音clean掉之后重新继续学习,那3行codes也可以搞定,这时候连交叉验证都省了~:


      
      
    
from cleanlab.classification import LearningWithNoisyLabelsfrom sklearn.linear_model import LogisticRegression
# 其实可以封装任意一个你自定义的模型.lnl = LearningWithNoisyLabels(clf=LogisticRegression()) lnl.fit(X=X_train_data, s=train_noisy_labels) # 对真实世界进行验证.predicted_test_labels = lnl.predict(X_test)

笔者注:上面虽然只给出了CV领域的例子,但置信学习也适用于NLP啊~此外,cleanlab可以封装任意一个你自定义的模型,以下机器学习框架都适用:scikit-learn, PyTorch, TensorFlow, FastText。

3、置信学习的3个步骤

置信学习开源工具cleanlab操作起来比较容易,但置信学习背后也是有着充分的理论支持的。事实上,一个完整的置信学习框架,需要完成以下三个步骤(如图1所示):

  1. Count:估计噪声标签和真实标签的联合分布;
  2. Clean:找出并过滤掉错误样本;
  3. Re-Training:过滤错误样本后,重新调整样本类别权重,重新训练;
标注数据存在错误怎么办?MIT&Google开源置信学习新方法,找出错误标注

图1 置信学习框架

下面对上述3个步骤进行详细阐述:

3.1 Count:估计噪声标签和真实标签的联合分布

我们定义噪声标签为 

,即经过初始标注(也许是人工标注)、但可能存在错误的样本;定义真实标签为 

,但事实上我们并不会获得真实标签,通常可通过交叉验证对真实标签进行估计。此外,定义样本总数为 

,类别总数为 

为了估计联合分布,共需要4步:

  • step 1 : 交叉验证:

    • 首先需要通过对数据集集进行交叉验证计算第 

      样本在第 

      个类别下的概率 


    • 然后计算每个人工标定类别 

      下的平均概率 

      作为置信度阈值;

    • 最后对于样本 

      ,其真实标签 

      为 

      个类别中的最大概率 

      ,并且 

      ;

  • step 2: 计算计数矩阵 

    (类似于混淆矩阵),如图1中的

    意味着,人工标记为dog但实际为fox的样本为40个。具体的操作流程如图2所示:

标注数据存在错误怎么办?MIT&Google开源置信学习新方法,找出错误标注

图2 计数矩阵C计算流程

  • step 3 : 标定计数矩阵:目的就是为了让计数总和与人工标记的样本总数相同。计算公式如下面所示,其中 

    为人工标记标签 

    的样本总个数:

公式1

  • step 4 : 估计噪声标签

    和真实标签

    的联合分布

    ,可通过下式求得:

公式2

看到这里,也许你会问为什么要估计这个联合分布呢?其实这主要是为了下一步方便我们去clean噪声数据。此外,这个联合分布其实能充分反映真实世界中噪声(错误)标签和真实标签的分布,随着数据规模的扩大,这种估计方法与真实分布越接近(原论文中有着严谨的证明,由于公式推导繁杂这里不再赘述,有兴趣的同学可以详细阅读原文~,后文的图7也有相关实验进行证明)。

看到这里,也许你还感觉公式好麻烦,那下面我们通过一个具体的例子来展示上述计算过程:

  • step 1 : 通过交叉验证获取第 

    样本在第 

    个类别下的概率 

    ;为说明问题,这里假设共10个样本、2个类别,每个类别有5个样本。经过计算每个人工标签类别 

    下的平均概率 

    分别为: 

    .

标注数据存在错误怎么办?MIT&Google开源置信学习新方法,找出错误标注

图3 P[i][j]和t[j]计算

  • step2: 根据图2的计算流程,我们得到计数矩阵 

    为:

标注数据存在错误怎么办?MIT&Google开源置信学习新方法,找出错误标注

图4 计数矩阵C计算

  • step3: 标定后的计数矩阵 

    为(计数总和与人工标记的样本总数相同),将原来的样本总数进行加权即可,以 

    为例,根据公式1,其计算为 

    )::

标注数据存在错误怎么办?MIT&Google开源置信学习新方法,找出错误标注
  • step4:联合分布 

    为:(根据公式2直接进行概率归一化即可)

标注数据存在错误怎么办?MIT&Google开源置信学习新方法,找出错误标注

图5 联合分布Q计算

3.2 Clean:找出并过滤掉错误样本

在得到噪声标签和真实标签的联合分布 

,论文共提出了5种方法过滤错误样本。

  • Method 1: 

    ,选取 

    的样本进行过滤,即选取 

    最大概率对应的下标 

    与人工标签不一致的样本。

  • Method 2: 

    ,选取构造计数矩阵 

    过程中、进入非对角单元的样本进行过滤。

  • Method 3: Prune by Class (PBC) ,即对于人工标记的每一个类别 

    ,选取 

    个样本过滤,并按照最低概率 

    排序。

  • Method 4: Prune by Noise Rate (PBNR) ,对于计数矩阵 

    的非对角单元,选取 

    个样本进行过滤,并按照最大间隔 

    排序。

  • Method 5: C+NR,同时采用Method 3和Method 4.

我们仍然以图3给出的示例进行说明:

  • Method 1:过滤掉i=2,3,4,8,9共5个样本;
  • Method 2:进入到计数矩阵非对角单元的样本分别为i=3,4,9,将这3个样本过滤;
  • Method 3:对于类别0,选取 

    个样本过滤,按照最低概率排序,选取i=2,3,4;对于类别1,选取 

    个样本过滤,按照最低概率排序选取i=9;综上,共过滤i=2,3,4,9共4个样本;

  • Method 4:对于非对角单元 

    选取i=2,3,4过滤,对 

    选取i=9过滤。

上述这些过滤样本的方法在cleanlab也有提供,我们只要提供2个输入、1行code即可clean错误样本:


      
      
    
import cleanlab# 输入# s:噪声标签# psx: n x m 的预测概率概率,通过交叉验证获得# Method 3:Prune by Class (PBC)baseline_cl_pbc = cleanlab.pruning.get_noise_indices(s, psx, prune_method='prune_by_class',n_jobs=1)# Method 4:Prune by Noise Rate (PBNR)baseline_cl_pbnr = cleanlab.pruning.get_noise_indices(s, psx, prune_method='prune_by_noise_rate',n_jobs=1)# Method 5:C+NRbaseline_cl_both = cleanlab.pruning.get_noise_indices(s, psx, prune_method='both',n_jobs=1)




发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注