模型:
krotima1/mbart-ht2a-cs
这个模型是基于 facebook/mbart-large-cc25 在捷克新闻数据集上微调得到的,用于生成捷克语摘要的检查点。
该模型处理的任务是“标题+文本生成摘要”(HT2A),即从捷克新闻文本中生成被视为摘要的多句话。
该模型是在一个大型捷克新闻数据集上进行训练的,该数据集由两个数据集拼接而成,一个是由捷克新闻中心提供的私有CNC数据集,另一个是 SumeCzech 数据集。该数据集包含大约175万个基于捷克新闻的文档,分为标题、摘要和全文三个部分。编码器的截断和填充设置为512个标记,解码器的截断和填充设置为128个标记。
该模型在1x NVIDIA Tesla A100 40GB的GPU上进行了60小时的训练,以及4x NVIDIA Tesla A100 40GB的GPU上进行了40小时的训练。在训练过程中,该模型共看到了12896K个文档,相当于大约8.4个时期。
假设您正在使用提供的Summarizer.ipynb文件。
def summ_config(): cfg = OrderedDict([ # summarization model - checkpoint from website ("model_name", "krotima1/mbart-ht2a-cs"), ("inference_cfg", OrderedDict([ ("num_beams", 4), ("top_k", 40), ("top_p", 0.92), ("do_sample", True), ("temperature", 0.89), ("repetition_penalty", 1.2), ("no_repeat_ngram_size", None), ("early_stopping", True), ("max_length", 128), ("min_length", 10), ])), #texts to summarize ("text", [ "Input your Czech text", ] ), ]) return cfg cfg = summ_config() #load model model = AutoModelForSeq2SeqLM.from_pretrained(cfg["model_name"]) tokenizer = AutoTokenizer.from_pretrained(cfg["model_name"]) # init summarizer summarize = Summarizer(model, tokenizer, cfg["inference_cfg"]) summarize(cfg["text"])