模型:

flax-sentence-embeddings/stackoverflow_mpnet-base

英文

stackoverflow_mpnet-base

这是一个在StackOverflow上使用18,562,443个标题和正文对进行训练的microsoft/mpnet-base模型。

SentenceTransformers是一组模型和框架,可以从给定数据中训练和生成句子嵌入。生成的句子嵌入可以用于聚类、语义搜索和其他任务。我们使用了一个预训练的 microsoft/mpnet-base 模型,并使用Siamese网络结构和对比学习目标进行训练。训练数据使用了来自StackOverflow的18,562,443个标题和正文对。对于这个模型,使用了隐藏状态的均值池化作为句子嵌入。请参阅该存储库中的data_config.json和train_script.py以了解模型的训练方式和使用的数据集。

我们在 Hugging Face 组织的 Community week using JAX/Flax for NLP & CV 中开发了这个模型。我们将这个模型作为项目 Train the Best Sentence Embedding Model Ever with 1B Training Pairs 的一部分进行开发。我们受益于高效的硬件基础设施来运行项目:7个 TPU v3-8,并从Google的Flax、JAX和Cloud团队成员那里得到了有关高效深度学习框架的帮助。

感兴趣的用途

我们的模型旨在用作搜索引擎的句子编码器。给定一个输入句子,它输出一个向量,捕捉句子的语义信息。该句向量可用于语义搜索、聚类或句子相似度任务。

如何使用

这是如何使用此模型在 SentenceTransformers 库中获取给定文本的特征的方法:

from sentence_transformers import SentenceTransformer

model = SentenceTransformer('flax-sentence-embeddings/stackoverflow_mpnet-base')
text = "Replace me by any question / answer you'd like."
text_embbedding = model.encode(text)
# array([-0.01559514,  0.04046123,  0.1317083 ,  0.00085931,  0.04585106,
#        -0.05607086,  0.0138078 ,  0.03569756,  0.01420381,  0.04266302 ...],
#        dtype=float32)

训练过程

预训练

我们使用了预训练的 microsoft/mpnet-base 。有关预训练过程的更详细信息,请参阅模型卡。

微调

我们使用对比目标对模型进行微调。具体而言,我们计算批次中每对可能的句子的余弦相似度。然后,通过与真实配对进行比较,应用交叉熵损失。

超参数

我们在 TPU v3-8 上对模型进行了训练。我们在80k步骤中训练模型,使用批量大小为1024(每个TPU核心128个)。我们使用了500次学习率预热。序列长度限制为128个标记。我们使用了AdamW优化器和2e-5的学习率。完整的训练脚本可以在当前存储库中获取。

训练数据

我们使用了来自StackOverflow的18,562,443个标题和正文对作为训练数据。

Dataset Paper Number of training tuples
StackOverflow title body pairs - 18,562,443