多标签分类:Python的Scikit-Learn简介

2023年08月07日 由 camellia 发表 542 0

学习如何在工作中开发多标签分类器。


在机器学习任务中,分类是一种监督学习方法,用于根据输入数据预测标签。例如,我们想通过使用历史特征预测某人是否对销售产品感兴趣。通过使用可用的训练数据训练机器学习模型,我们可以对输入数据进行分类任务。

我们经常遇到经典的分类任务,例如二分类(两个标签)和多类分类(多于两个标签)。在这种情况下,我们会训练分类器,并且模型会尝试从所有可用标签中预测一个标签。用于分类的数据集类似于下面的图片。

3.1

上图显示目标(销售产品/服务)在二分类中包含两个标签,而在多分类中包含三个标签。模型将从可用的特征中进行训练,然后仅输出一个标签。

多标签分类不同于二分类或多类分类。在多标签分类中,我们不仅尝试预测一个输出标签。相反,多标签分类将尝试预测尽可能多的标签应用于输入数据。输出的标签可以从没有标签到最大数量的可用标签。

多标签分类通常用于文本数据分类任务。例如,下面是一个用于多标签分类的示例数据集。

3.2

在上面的示例中,想象一下文本1到文本5是可以归类为四个类别的句子:事件、运动、流行文化和自然。使用上述训练数据,多标签分类任务会预测哪个标签适用于给定的句子。每个类别之间并不相互对立,因为它们不是互斥的;每个标签可以被视为独立的。

更详细地说,我们可以看到文本1的标签是运动和流行文化,而文本2的标签是流行文化和自然。这表明每个标签都是互斥的,多标签分类的预测输出可以是没有标签,也可以是所有标签同时输出。

有了这个介绍,让我们尝试使用Scikit-Learn构建多标签分类器。

使用Scikit-Learn进行多标签分类


本教程将使用Kaggle上公开可用的生物医学PubMed多标签分类数据集。该数据集将包含各种特征,但我们只使用摘要文本特征和它们的MeSH分类(A:解剖学,B:生物体,C:疾病等)。示例数据如下图所示。

3.3

上述数据集显示每篇论文可以分类为多个类别,这是多标签分类的情况。利用这个数据集,我们可以使用Scikit-Learn构建多标签分类器。在我们训练模型之前准备数据集。

3.4

在上面的代码中,我们将文本数据转换为TF-IDF表示,以便我们的Scikit-Learn模型可以接受训练数据。为了简化教程,我跳过了预处理数据的步骤,如删除停用词。

在数据转换之后,我们将数据集分割为训练集和测试集。

3.5

完成准备工作后,我们将开始训练多标签分类器。在Scikit-Learn中,我们将使用MultiOutputClassifier对象来训练MultiOutput Classifier模型。这个模型的策略是为每个标签训练一个分类器。基本上,每个标签都有自己的分类器。

在这个示例中,我们将使用逻辑回归作为模型,并且MultiOutput Classifier会将其扩展到所有的标签上。

3.6

我们可以改变模型并调整传递给MultiOutput Clasiffier的模型参数,以根据需求进行管理。训练完成后,让我们使用模型对测试数据进行预测。

3.7
3.8

预测结果是每个MeSH类别的标签数组。每行代表一个句子,每列代表一个标签。

最后,我们需要评估我们的多标签分类器。我们可以使用准确率指标对模型进行评估。

3.9

准确率得分:0.145

准确率得分为0.145,这表明模型只能在少于14.5%的情况下预测正确的标签组合。然而,准确率得分对于多标签预测评估存在一定的局限性。准确率得分要求每个句子的所有标签都位于准确确切位置,否则将被认为是错误的。

例如,第一行的预测与测试数据之间只有一个标签的不同。由于标签组合不同,这将被认为是错误的预测,因此我们的模型得分较低。

3.10

为了解决这个问题,我们必须评估标签的预测而不是它们的标签组合。在这种情况下,我们可以依靠汉明损失评估指标。汉明损失是通过将错误预测数量除以总标签数量来计算的。因为汉明损失是一个损失函数,得分越低越好(0表示没有错误预测,1表示所有预测都错误)。

3.11

汉明损失:0.13

我们的多标签分类器的汉明损失得分为0.13,这意味着我们的模型独立预测错误的概率为13%。这意味着每个标签的预测可能有13%的错误率。

结论


多标签分类是一个机器学习任务,其中输出可能是没有标签或给定输入数据的所有可能标签。这与二分类或多类分类不同,其中标签输出是互斥的。

使用Scikit-Learn的MultiOutput Classifier,我们可以开发多标签分类器,其中我们为每个标签训练一个分类器。对于模型评估,最好使用汉明损失指标,因为准确率可能不能正确地给出完整的情况。


文章来源:https://www.kdnuggets.com/2023/08/multilabel-classification-introduction-python-scikitlearn.html
欢迎关注ATYUN官方公众号
商务合作及内容投稿请联系邮箱:bd@atyun.com
评论 登录
热门职位
Maluuba
20000~40000/月
Cisco
25000~30000/月 深圳市
PilotAILabs
30000~60000/年 深圳市
写评论取消
回复取消