英文

T5中文对联生成模型(t5-chinese-couplet)

T5中文对联生成模型

t5-chinese-couplet评估对联测试数据:

T5在对联测试中的总体表现:

prefix input_text target_text pred
对联: 春回大地,对对黄莺鸣暖树 日照神州,群群紫燕衔新泥 福至人间,家家紫燕舞和风

在对联测试集上,T5生成的结果满足字数相同、词性对齐、词面对齐和形似要求,但语义对仗工整和平仄合律方面还不够满足。

T5的网络结构(原生T5):

使用方法

本项目在文本生成项目中开源: textgen ,支持T5模型的调用如下:

安装包:

pip install -U textgen
from textgen import T5Model
model = T5Model("t5", "shibing624/t5-chinese-couplet")
r = model.predict(["对联:丹枫江冷人初去"])
print(r) # ['白石矶寒客不归']

使用方法(HuggingFace Transformers)

如果没有 textgen ,可以按照以下方式使用模型:

首先,将输入传递给转换器模型,然后获得生成的句子。

安装包:

pip install transformers 
from transformers import T5ForConditionalGeneration, T5Tokenizer

tokenizer = T5Tokenizer.from_pretrained("shibing624/t5-chinese-couplet")
model = T5ForConditionalGeneration.from_pretrained("shibing624/t5-chinese-couplet")


def batch_generate(input_texts, max_length=64):
    features = tokenizer(input_texts, return_tensors='pt')
    outputs = model.generate(input_ids=features['input_ids'],
                             attention_mask=features['attention_mask'],
                             max_length=max_length)
    return tokenizer.batch_decode(outputs, skip_special_tokens=True)


r = batch_generate(["对联:丹枫江冷人初去"])
print(r)

输出:

['白石矶寒客不归']

模型文件组成:

t5-chinese-couplet
    ├── config.json
    ├── model_args.json
    ├── pytorch_model.bin
    ├── special_tokens_map.json
    ├── tokenizer_config.json
    ├── spiece.model
    └── vocab.txt

训练数据集

中文对联数据集

数据格式:

head -n 1 couplet_files/couplet/train/in.txt
晚 风 摇 树 树 还 挺 

head -n 1 couplet_files/couplet/train/out.txt
晨 露 润 花 花 更 红 

如果需要训练T5模型,请参考 https://github.com/shibing624/textgen/blob/main/docs/%E5%AF%B9%E8%81%94%E7%94%9F%E6%88%90%E6%A8%A1%E5%9E%8B%E5%AF%B9%E6%AF%94.md

引用

@software{textgen,
  author = {Xu Ming},
  title = {textgen: Implementation of Text Generation models},
  year = {2022},
  url = {https://github.com/shibing624/textgen},
}