模型:
flax-sentence-embeddings/multi-QA_v1-mpnet-asymmetric-A
SentenceTransformers是一组模型和框架,能够对给定的数据进行训练和生成句子嵌入。生成的句子嵌入可以用于聚类、语义搜索和其他任务。我们使用了两个单独的预训练模型,并使用对比学习目标对它们进行了训练。我们使用了来自StackExchange和其他数据集的问题和答案对作为训练数据,以使模型对问题/答案嵌入相似性具有鲁棒性。
我们在Hugging Face组织的比赛中开发了该模型。我们将该模型作为项目的一部分开发: Train the Best Sentence Embedding Model Ever with 1B Training Pairs 。我们获得了高效的硬件基础设施来运行该项目:7个TPU v3-8,以及来自Google的Flax、JAX和Cloud团队成员关于高效深度学习框架的帮助。
这个模型集用作搜索引擎的句子编码器。给定一个输入句子,它输出一个捕捉句子语义信息的向量。这个句子向量可以用于语义搜索、聚类或句子相似性任务。
为了进行语义搜索,应同时使用两个模型。
以下是使用 SentenceTransformers 库得到给定文本特征的方法:
from sentence_transformers import SentenceTransformer model_Q = SentenceTransformer('flax-sentence-embeddings/multi-QA_v1-mpnet-asymmetric-Q') model_A = SentenceTransformer('flax-sentence-embeddings/multi-QA_v1-mpnet-asymmetric-A') question = "Replace me by any question you'd like." question_embbedding = model_Q.encode(text) answer = "Replace me by any answer you'd like." answer_embbedding = model_A.encode(text) answer_likeliness = cosine_similarity(question_embedding, answer_embedding)
我们使用了预训练的 Mpnet-base 。有关预训练过程的详细信息,请参阅模型卡片。
我们使用对比目标对模型进行微调。形式上,我们从每个批次的可能句子对计算余弦相似度。然后,通过与真实句子对进行比较,应用交叉熵损失。
我们在TPU v3-8上训练了模型。使用批处理大小为1024(每个TPU内核为128),在训练期间进行了80k步训练。我们使用了500个学习率预热。序列长度被限制为128个令牌。我们使用了AdamW优化器和2e-5的学习率。完整的训练脚本可在当前存储库中访问。
我们使用多个Stackexchange问题-答案数据集的连接来微调我们的模型。还使用了MSMARCO、NQ和其他问题-答案数据集。
Dataset | Paper | Number of training tuples |
---|---|---|
1238321 | - | 4,750,619 |
1239321 | - | 364,001 |
12310321 | - | 73,346 |
12311321 | 12312321 | 87,599 |
12313321 | - | 103,663 |
12314321 | 12315321 | 325,475 |
12316321 | 12317321 | 64,371,441 |
12318321 | 12319321 | 77,427,422 |
12320321 | 12321321 | 9,144,553 |
12322321 | 12323321 | 3,012,496 |
12324321 Question/Answer | 12325321 | 681,164 |
SearchQA | - | 582,261 |
12326321 | 12327321 | 100,231 |