该项目旨在使用自监督对比学习目标在非常大的句子级数据集上训练句子嵌入模型。我们使用预训练的 'MiniLM-L6-H384-uncased' 模型,并在10亿个句子对的数据集上进行了微调。我们使用对比学习目标:给定一对句子中的一句,模型应该预测在我们的数据集中与之配对的一组其他随机抽样句子中的哪一个。
我们在由Hugging Face组织的 Community week using JAX/Flax for NLP & CV 中开发了这个模型。我们开发这个模型是作为 Train the Best Sentence Embedding Model Ever with 1B Training Pairs 项目的一部分。我们从谷歌的Flax、JAX和Cloud团队成员那里得到了关于高效深度学习框架的指导,并且获得了有效的硬件基础设施来运行该项目,包括7个TPU v3-8。
我们的模型旨在用作句子编码器。给定一个输入句子,它会输出一个向量,其中包含句子的语义信息。句向量可以用于信息检索、聚类或句子相似性任务。
使用 SentenceTransformers 库来获取给定文本的特征的方法如下:
from sentence_transformers import SentenceTransformer model = SentenceTransformer('flax-sentence-embeddings/all_datasets_v4_MiniLM-L6') text = "Replace me by any text you'd like." text_embbedding = model.encode(text) # array([-0.01559514, 0.04046123, 0.1317083 , 0.00085931, 0.04585106, # -0.05607086, 0.0138078 , 0.03569756, 0.01420381, 0.04266302 ...], # dtype=float32)
我们使用预训练的 'MiniLM-L6-H384-uncased' 模型,它是 'microsoft/MiniLM-L12-H384-uncased' 的6层版本,只保留了每第二层。有关预训练过程的更多详细信息,请参阅模型卡。
我们使用对比目标对模型进行微调。形式上,我们计算批次中每个可能的句子对的余弦相似度。然后,通过与真实配对进行比较,应用交叉熵损失。
我们在 TPU v3-8 上训练了我们的模型。我们使用批量大小为1024(每个TPU核心为128)进行了540k步的训练。我们使用了500的学习率预热。序列长度限制为128个标记。我们使用了AdamW优化器和2e-5的学习率。完整的训练脚本可在当前存储库中获得。
我们使用多个数据集的串联来微调我们的模型。句子对的总数超过10亿个句子。我们根据数据配置文件( data_config.json )中详细的配置,对每个数据集进行了加权概率抽样。
Dataset | Paper | Number of training tuples |
---|---|---|
1237321 | 1238321 | 3,012,496 |
1239321 | - | 364,001 |
12310321 | 12311321 | 317,695 |
[COCO 2020](COCO 2020) | 12312321 | 828,395 |
12313321 | - | 1,151,414 |
12314321 | - | 73,346 |
12315321 | 12316321 | 87,599 |
12317321 | 12318321 | 100,231 |
12319321 | 12320321 | 102,225 |
12321321 | - | 103,663 |
12322321 | 12323321 | 112,696 |
12324321 | 12325321 | 128,542 |
12326321 | 12327321 | 180,000 |
AllNLI ( 12328321 and 12329321 | 12330321 , 12331321 | 277,230 |
12332321 | 12333321 | 325,475 |
12334321 | 12335321 | 684,100 |
12336321 Title/Abstract | 12337321 | 41,769,185 |
12336321 Citation/Citation | 12337321 | 52,603,982 |
12336321 Citation/Abstract | 12337321 | 116,288,806 |
12342321 | 12343321 | 64,371,441 |
12344321 | 12345321 | 77,427,422 |
SearchQA | - | 582,261 |
12346321 Title/Answer | 12347321 | 1,198,260 |
12346321 Title/Question | 12347321 | 659,896 |
12346321 Question/Answer | 12347321 | 681,164 |
12352321 | 12353321 | 9,144,553 |
12354321 | 12355321 | 726,484,430 |
total | 1,097,953,922 |