少样本学习突破:用0.1%数据实现90%文本分类效果

2023年12月28日 由 alex 发表 436 0

现代大型语言模型(LLM)的零样本(zero-shot)能力确实令人鼓舞,让我们感觉到通用人工智能(AGI)非常接近了。然而,它需要庞大的网络和在大量数据上的预训练。但这还不够。你需要特别针对你的业务应用来微调模型。这里的区别在于,你需要多少示例就能获得合理的结果。在我们的团队中,我们开发了一个零样本文本分类模型,只需每个标签8个例子,就可以达到全面微调的模型在数千个例子上训练的90%效果。在这个教程中,我们将向你展示如何使用我们的开源零样本文本分类模型,实现同样的结果。


需求和数据


首先,确保你已经安装了以下库:


pip install datasets transformers accelerate setfit


  • 数据集:一个统一的接口用来管理和访问多样化的机器学习数据集。
  • 变压器:Hugging Face库提供预训练模型和工具,用于自然语言处理任务。
  • 加速器:一个库,通过添加仅四行代码,可以让相同的PyTorch代码跨任何分布式配置运行。
  • SetFit:一个高效且无需提示的框架,用于少量样本微调句子变换器。


好的,我们现在需要下载一个将会用到的数据集,“情绪”数据集,它包含描述文本的6种不同情绪类别。然后我们将数据集分为测试集和训练集,并从训练集中随机选择48个样例,每个标签平均8个样例。


from datasets import load_dataset
#emotion
emotion_dataset = load_dataset("dair-ai/emotion")
test_dataset = emotion_dataset['test']
classes = test_dataset.features["label"].names
N = 8
train_dataset = emotion_dataset['train'].shuffle(seed=41)
                                        .select(range(len(classes)*N))


设定适合


首先,我们将看看使用 SetFit 可以取得什么结果——这是一种使用文本嵌入进行分类的替代性少样本学习方法。然后,我们将看到我们的方法如何比 SetFit 更强大。


from setfit import SetFitModel, Trainer, TrainingArguments
from sklearn.metrics import classification_report

model = SetFitModel.from_pretrained("BAAI/bge-base-en-v1.5")

args = TrainingArguments(
   batch_size=32,
   num_epochs=1,
)

trainer = Trainer(
   model=model,
   args=args,
   train_dataset=train_dataset,
   eval_dataset=test_dataset,
)
trainer.train()

为了测试该模型,我们运行以下命令:
preds = model.predict(test_dataset['text'])
print(classification_report(test_dataset['label'], preds, 
                                target_names=classes, digits=4))


6


我们得到的结果比SetFit在零样本设置下的结果还要糟糕。其原因是数据集的类别分布不平衡,因此我们无法保证在均匀采样时所有类别都会出现,结果就是我们无法正确训练逻辑模型。我们的方法更加通用,模型的微调不需要训练额外的分类头。


领悟法


现在让我们尝试我们的方法,首先你需要初始化模型:


from transformers import AutoTokenizer, AutoModelForSequenceClassification

model_name = 'knowledgator/comprehend_it-base'

tokenizer = AutoTokenizer.from_pretrained(model_name)

model = AutoModelForSequenceClassification.from_pretrained(model_name)


我们的方法基于一个文本分类模型,该模型经过训练,用以区分两个陈述是中性的、相互矛盾的还是包含关系的。


现在,让我们初始化所有数据处理功能:


from transformers import TrainingArguments, Trainer
from transformers import DataCollatorWithPadding
from datasets import Dataset
import random
import torch
import evaluate
import numpy as np
accuracy = evaluate.load("accuracy")
def transform_dataset(dataset, classes, template = '{}'):
   new_dataset = {'sources':[], 'targets': [], 'labels': []}
   texts = dataset['text']
   labels = dataset['label']
   label2count = {}
   for label in labels:
       if label not in label2count:
           label2count[label]=1
       else:
           label2count[label]+=1
   count = len(labels)
   label2prob = {label:lc/count for label, lc in label2count.items()}
   unique_labels = list(label2prob)
   probs = list(label2prob.values())
   ids = list(range(len(labels)))
   for text, label_id in zip(texts, labels):
       label = classes[label_id]
       for i in range(len(classes)-1):
           new_dataset['sources'].append(text)
           new_dataset['targets'].append(template.format(label))
           new_dataset['labels'].append(1.)
       for i in range(len(classes)-1):
           neg_class_ = label
           while neg_class_==label:
               # neg_class_ = random.sample(classes, k=1)[0]
               neg_lbl = np.random.choice(unique_labels, p=probs)
               neg_class_ = classes[neg_lbl]
           new_dataset['sources'].append(text)
           new_dataset['targets'].append(template.format(neg_class_))
           new_dataset['labels'].append(-1.)
   return Dataset.from_dict(new_dataset)

def compute_metrics(eval_pred):
   predictions, labels = eval_pred
   predictions = np.argmax(predictions, axis=1)
   return accuracy.compute(predictions=predictions, references=labels)
def tokenize_and_align_label(example):
   hypothesis = example['targets']
   seq = example["sources"]+hypothesis
   tokenized_input = tokenizer(seq, truncation=True, max_length=512, 
                                                    padding="max_length")
   label = example['labels']
   if label==1.0:
       label = torch.tensor(1)
   elif label==0.0:
       label = torch.tensor(2)
   else:
       label = torch.tensor(0)
   tokenized_input['label'] = label
   return tokenized_input


我们来处理训练数据集并运行训练:


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
dataset = transform_dataset(train_dataset, classes)
tokenized_dataset = dataset.map(tokenize_and_align_label)
tokenized_dataset = tokenized_dataset.train_test_split(test_size=0.1)

training_args = TrainingArguments(
   output_dir='comprehendo',
   learning_rate=3e-5,
   per_device_train_batch_size=8,
   per_device_eval_batch_size=8,
   num_train_epochs=3,
   weight_decay=0.01,
   evaluation_strategy="epoch",
)
trainer = Trainer(
   model=model,
   args=training_args,
   train_dataset=tokenized_dataset["train"],
   eval_dataset=tokenized_dataset['test'],
   tokenizer=tokenizer,
   data_collator=data_collator,
   compute_metrics=compute_metrics,
)
trainer.train()
trainer.save_model('comprehender')


要使用我们的模型进行推理,我们可以利用Hugging Face管道进行零样本分类。


from transformers import pipeline
from sklearn.metrics import classification_report
from tqdm import tqdm
classifier = pipeline("zero-shot-classification",
                     model='comprehendo',tokenizer=tokenizer, device=device)


那么,让我们测试一下这个模型:


preds = []
label2idx = {label: id for id, label in enumerate(classes)}
for example in tqdm(test_dataset):
   pred = classifier(example['text'],classes)['labels'][0]
   idx = label2idx[pred]
   preds.append(idx)
print(classification_report(test_dataset['label'], preds, 
                                        target_names=classes, digits=4))


尽管我们的数据集中并不包括所有的标签,但我们取得了令人印象深刻的结果,并且在微平均F1得分方面的结果比我们在零样本设置中的模型高出了8%。


7


结论


因此,我们的方法显著优于SetFit;然而,重要的是要注意,SetFit的运行速度会根据模型大小和标签数量的不同而有所不同。我们的方法取决于标签的数量,因为它需要在文本和标签之间进行全面注意,因此我们需要运行模型N次,这个次数与标签的数量相等。所以,选择取决于性能、准确性以及你拥有的训练样本数量之间的平衡。

文章来源:https://medium.com/@knowledgrator/achieve-90-results-in-few-shot-text-classification-with-just-0-1-data-6bebdec1e08f
欢迎关注ATYUN官方公众号
商务合作及内容投稿请联系邮箱:bd@atyun.com
评论 登录
写评论取消
回复取消