模型:
flax-sentence-embeddings/multi-qa_v1-distilbert-mean_cos
任务:
句子相似度SentenceTransformers 是一组模型和框架,可以从给定的数据中训练和生成句子嵌入。生成的句子嵌入可以用于聚类、语义搜索和其他任务。我们使用预训练的 distilbert-base-uncased 模型,并使用连体网络设置和对比学习目标进行训练。我们使用来自 StackExchange 的问题和答案对作为训练数据,以使模型对问题/答案嵌入相似性更加稳健。对于该模型,使用隐藏状态的均值池化作为句子嵌入方式。
我们在由Hugging Face组织的 Community week using JAX/Flax for NLP & CV 中开发了这个模型。我们开发这个模型是作为项目 Train the Best Sentence Embedding Model Ever with 1B Training Pairs 的一部分。我们从Google的Flax、JAX和Cloud团队成员那里获得了关于高效深度学习框架的硬件基础设施和帮助,包括7个TPU v3-8。
我们的模型旨在用作搜索引擎的句子编码器。给定一个输入句子,它输出一个捕捉句子语义信息的向量。句向量可以用于语义搜索、聚类或句子相似性任务。
以下是如何使用该模型来获取给定文本的特征的步骤,使用 SentenceTransformers 库:
from sentence_transformers import SentenceTransformer model = SentenceTransformer('flax-sentence-embeddings/multi-qa_v1-distilbert-mean_cos') text = "Replace me by any question / answer you'd like." text_embbedding = model.encode(text) # array([-0.01559514, 0.04046123, 0.1317083 , 0.00085931, 0.04585106, # -0.05607086, 0.0138078 , 0.03569756, 0.01420381, 0.04266302 ...], # dtype=float32)
我们使用了预训练的 distilbert-base-uncased 模型。有关预训练过程的更详细信息,请参阅模型卡片。
我们使用对比目标对模型进行微调。具体而言,我们计算批次中每对可能的句子的余弦相似性。然后,通过与真实对进行比较应用交叉熵损失。
我们在TPU v3-8上进行了模型训练。我们使用批次大小为1024(每个TPU核心为128)进行了80k次步骤的训练。我们使用了500个学习率预热。序列长度限制为128个标记。我们使用了学习率为2e-5的AdamW优化器。完整的训练脚本可以在当前存储库中找到。
我们使用了多个Stackexchange问题-答案数据集的拼接来微调我们的模型。还使用了MSMARCO、NQ和其他问答数据集。
Dataset | Paper | Number of training tuples |
---|---|---|
1236321 | - | 4,750,619 |
1237321 | - | 364,001 |
1238321 | - | 73,346 |
1239321 | 12310321 | 87,599 |
12311321 | - | 103,663 |
12312321 | 12313321 | 325,475 |
12314321 | 12315321 | 64,371,441 |
12316321 | 12317321 | 77,427,422 |
12318321 | 12319321 | 9,144,553 |
12320321 | 12321321 | 3,012,496 |
12322321 Question/Answer | 12323321 | 681,164 |
SearchQA | - | 582,261 |
12324321 | 12325321 | 100,231 |