英文

mBART fine-tuned model for Czech abstractive summarization (HT2A-CS)

这个模型是基于 facebook/mbart-large-cc25 在捷克新闻数据集上微调得到的,用于生成捷克语摘要的检查点。

任务

该模型处理的任务是“标题+文本生成摘要”(HT2A),即从捷克新闻文本中生成被视为摘要的多句话。

数据集

该模型是在一个大型捷克新闻数据集上进行训练的,该数据集由两个数据集拼接而成,一个是由捷克新闻中心提供的私有CNC数据集,另一个是 SumeCzech 数据集。该数据集包含大约175万个基于捷克新闻的文档,分为标题、摘要和全文三个部分。编码器的截断和填充设置为512个标记,解码器的截断和填充设置为128个标记。

训练

该模型在1x NVIDIA Tesla A100 40GB的GPU上进行了60小时的训练,以及4x NVIDIA Tesla A100 40GB的GPU上进行了40小时的训练。在训练过程中,该模型共看到了12896K个文档,相当于大约8.4个时期。

使用

假设您正在使用提供的Summarizer.ipynb文件。

def summ_config():
    cfg = OrderedDict([
        # summarization model - checkpoint from website
        ("model_name", "krotima1/mbart-ht2a-cs"),
        ("inference_cfg", OrderedDict([
            ("num_beams", 4),
            ("top_k", 40),
            ("top_p", 0.92),
            ("do_sample", True),
            ("temperature", 0.89),
            ("repetition_penalty", 1.2),
            ("no_repeat_ngram_size", None),
            ("early_stopping", True),
            ("max_length", 128),
            ("min_length", 10),
        ])),
        #texts to summarize
        ("text",
            [
                "Input your Czech text",
            ]
        ),
    ])
    return cfg
cfg = summ_config()
#load model    
model = AutoModelForSeq2SeqLM.from_pretrained(cfg["model_name"])
tokenizer = AutoTokenizer.from_pretrained(cfg["model_name"])
# init summarizer
summarize = Summarizer(model, tokenizer, cfg["inference_cfg"])
summarize(cfg["text"])