
Jerry Liu • 2023-09-06
为任何嵌入模型微调线性适配器
我们在 LlamaIndex 中增加了功能,允许您在由任何模型(sentence_transformers
、OpenAI 等)生成的嵌入之上微调线性适配器。
这允许您将嵌入表示转换为新的潜在空间,该空间针对您的特定数据和查询检索进行了优化。这可以带来检索性能的小幅提升,进而转化为性能更好的 RAG 系统。
一个不错的额外优势是:使用此适配器,您无需重新嵌入文档!只需转换查询即可。
我们提供了完整的端到端指南,展示如何生成合成数据集、微调线性适配器并评估其性能。
背景
微调嵌入模型的概念非常强大。事实上,我们受到启发,不仅添加了完整的示例仓库/博客文章,还在 LlamaIndex 中添加了原生抽象,展示了如何使用我们的SentenceTransformersFinetuneEngine
在任何非结构化文本语料库上微调 sentence_transformers 模型。
然而,这种方法有一些局限性
SentenceTransformersFinetuneEngine
仅限于微调sentence_transformers
模型。- 微调嵌入模型后,您需要重新嵌入文档语料库。
在上周五的微调 + RAG 网络研讨会上,Jo (Vespa) 提到了完全相同的问题:微调嵌入模型需要您重新索引文档。然而,他在 Vespa 的工作中探索了使用基础模型“冻结”文档嵌入的概念,并转而训练查询嵌入的转换。
这启发我们探索一种类似的嵌入微调方法,该方法同时更通用,并且允许我们冻结现有的文档嵌入。
方法
我们全新的 EmbeddingAdapterFinetuneEngine
在任何模型生成的查询嵌入之上微调一个 线性适配器。线性适配器 仅仅是一个线性变换,它专门转换查询嵌入,同时保持文档嵌入不变。
线性适配器可以用于任何现有的嵌入模型之上:SBERT 嵌入、OpenAI 嵌入、Cohere 嵌入等。因此,您可以直接将其添加到您已经在使用的任何嵌入模型之上!

由于文档嵌入未改变,这意味着您始终可以在为文档生成嵌入后微调此线性适配器。您可以选择随意地根据变化的数据分布重新训练此适配器,而无需重新嵌入所有文档。
技术细节
如上所述,线性适配器仅在查询嵌入之上执行线性变换,同时保持文档嵌入不变(使用权重矩阵 W + 偏置项 b)。

就是这样!如果文档嵌入可以表示为 (n x d) 矩阵 D,其中 n 是文档数量,d 是嵌入维度,那么嵌入相似度仅由以下方式衡量:

线性适配器使用类似于 sentence_transformers
中的 MultipleNegativesRankingLoss
函数的损失项进行训练 — 给定一批正向(问题、上下文)示例,该函数在底层使用交叉熵损失来惩罚真实(问题、上下文)对距离太远,并惩罚交换对距离太近。
补充说明: 我们最终使用纯 PyTorch 编写了大部分微调逻辑,但大量借鉴了 sentence_transformers
源代码。我们无法直接使用 sentence_transformers,因为我们接收的是嵌入作为输入,而不是原始文本。您可以在此处查看我们的一些训练代码。
Notebook 演练
在这个 notebook 演练中,我们遵循与我们之前关于嵌入微调的博客文章相似的步骤集:
- 生成用于训练和评估的合成问答-上下文数据集。
- 在现有模型(例如 SBERT)之上微调我们的线性适配器。
- 获取嵌入模型并对其进行评估。
与之前的文章一样,我们使用 UBER 和 LYFT 10K 作为示例数据。我们使用 Lyft 生成训练数据集,使用 Uber 生成评估数据集。
生成用于训练和评估的合成数据集
我们使用辅助抽象 generate_qa_embedding_pairs
来生成训练和评估数据集。此函数接受任意一组文本节点(块),并生成包含(问题,上下文)对的结构化数据集。
from llama_index.finetuning import (
generate_qa_embedding_pairs,
EmbeddingQAFinetuneDataset,
)
# generate
train_dataset = generate_qa_embedding_pairs(train_nodes)
val_dataset = generate_qa_embedding_pairs(val_nodes)
# save
train_dataset.save_json("train_dataset.json")
val_dataset.save_json("val_dataset.json")
# load
train_dataset = EmbeddingQAFinetuneDataset.from_json("train_dataset.json")
val_dataset = EmbeddingQAFinetuneDataset.from_json("val_dataset.json")
微调我们的线性适配器
然后,我们在现有嵌入模型之上微调线性适配器。我们导入新的 EmbeddingAdapterFinetuneEngine
抽象,它接受现有嵌入模型和一组训练参数。
在此示例中,我们使用 bge-small-en
sentence-transformers 模型,但我们也可以使用 LlamaIndex/LangChain 中的任何嵌入模型。
from llama_index.finetuning import EmbeddingAdapterFinetuneEngine
from llama_index.embeddings import resolve_embed_model
import torch
base_embed_model = resolve_embed_model("local:BAAI/bge-small-en")
# alternative: use OpenAI
# from llama_index.embeddings import OpenAIEmbedding
# openai = OpenAIEmbedding()
finetune_engine = EmbeddingAdapterFinetuneEngine(
train_dataset,
base_embed_model,
model_output_path="<model_output_path>",
epochs=4,
verbose=True,
# can optionally pass along any parameters that go into `train_model`
# optimizer_class=torch.optim.SGD,
# optimizer_params={"lr": 0.01}
)
然后我们可以调用 fine-tune 来启动微调任务。训练线性模型相当简单,不需要昂贵的设备 — 这可以在 Macbook 上轻松运行。
finetune_engine.finetune()
获取嵌入模型并对其进行评估
微调任务完成后,我们可以获取我们的嵌入模型。
我们可以直接从 finetune_engine
获取它,或者导入新的 LinearAdapterEmbeddingModel
并以更手动的方式构建它。
选项 1
embed_model = finetune_engine.get_finetuned_model()
选项 2
from llama_index.embeddings import LinearAdapterEmbeddingModel
embed_model = LinearAdapterEmbeddingModel(base_embed_model, "<model_output_path>")
下一步是评估它。我们将微调模型与基础模型以及 text-embedding-ada-002
进行比较。
我们使用两个排名指标进行评估
- 命中率指标: 对于每个(查询,上下文)对,我们使用查询检索排名前 k 的文档。如果结果包含真实上下文,则视为命中。
- 平均倒数排名 (MRR):这是一个稍微更细粒度的排名指标,它考察真实上下文在检索到的排名前 k 集合中的“倒数排名”。倒数排名定义为 1/排名。当然,如果结果不包含上下文,则倒数排名为 0。
一些补充说明
- 我们在 Lyft 文档上运行了 4 个 epochs
- 我们使用 Adam 作为优化器,学习率采用默认值(我们尝试了 SGD,效果不太好)
结果

就命中率而言,基础模型在验证数据集上达到 78.7%,微调模型达到 79.8%。与此同时,text-embedding-ada-002
达到 87.0%。
就 MRR 而言,基础模型达到 64.3%,微调模型达到 66%。text-embedding-ada-002
达到 68.4%。
微调模型带来了一些性能提升,尽管承认幅度很小 — 比直接在最新数据集上微调 sentence_transformers 所获得的性能提升要小。
尽管如此,性能提升仍然是性能提升,而且启动和亲自尝试的成本非常低廉!因此您可以决定这是否对您有意义。
结论
我们在 LlamaIndex 中创建了一个全新的模块,允许您在任何嵌入模型之上微调线性适配器。
它可以帮助您在检索指标上获得一些微小的改进;重要的是,它允许您保持文档嵌入不变,只转换查询。
资源
训练代码(如果您想自己查看):https://github.com/jerryjliu/llama_index/blob/main/llama_index/finetuning/embeddings/adapter_utils.py