宣布 LlamaCloud 全面可用(以及我们的 1900 万美元 A 轮融资)!
LlamaIndex

Jerry Liu 2023-08-17

轻松为您的 Text-to-SQL 应用微调 Llama 2

Llama 2 是开源大型语言模型发展的一个巨大里程碑。最大模型及其微调变体在 Hugging Face 开源 LLM 排行榜上名列前茅。多项基准测试表明,其性能正在接近 GPT-3.5(甚至在某些情况下超越)。所有这一切都意味着开源大型语言模型在从 RAG 系统到智能体的复杂大型语言模型应用中,正成为越来越可行和可靠的选择。

背景:Llama-2–7B 不擅长 Text-to-SQL

然而,最小的 Llama 2 模型(70亿参数)的一个缺点是它不擅长生成 SQL,这使得它不适用于结构化分析用例。例如,我们尝试使用以下提示模板提示 Llama 2 生成正确的 SQL 语句

You are a powerful text-to-SQL model. Your job is to answer questions about a database. You are given a question and context regarding one or more tables. 

You must output the SQL query that answers the question.

### Input:
{input}

### Context:
{context}

### Response:

这里我们使用了来自 sql-create-context 数据集的一个示例条目。

input: In 1981 which team picked overall 148?
context: CREATE TABLE table_name_8 (team VARCHAR, year VARCHAR, overall_pick VARCHAR)

同时,以下是生成的输出与正确输出的对比

Generated output: SELECT * FROM `table_name_8` WHERE '1980' = YEAR AND TEAM = "Boston Celtics" ORDER BY OVERALL_PICK DESC LIMIT 1;

Correct output: SELECT team FROM table_name_8 WHERE year = 1981 AND overall_pick = "148"

这显然不理想。与 ChatGPT 和 GPT-4 不同,Llama 2 不能稳定地生成格式良好且正确的 SQL 输出。

这正是微调发挥作用的地方——给定一个合适的 Text-to-SQL 数据语料库,我们可以训练 Llama 2 更好地从自然语言生成 SQL 输出。从高层次来看,微调涉及在某种程度上修改模型的权重。有不同的微调模型的方法,从更新网络的所有参数,到更新参数的子集,再到只微调额外参数(例如 LoRA 的工作原理)。

模型微调完成后,仍然可以将其集成到下游的大型语言模型应用中。这正是本教程旨在展示的内容。与我们现有的教程相比,它涉及更多步骤,我们现有的教程主要关注“情境学习”和“检索增强”用例——冻结模型本身,但专注于将数据编排到输入提示中。微调的学习曲线可能很高,并且需要大量计算资源。本教程旨在尽可能让入门变得简单。

教程概述

在本教程中,我们将向您展示如何在一个 Text-to-SQL 数据集上微调 Llama 2,然后利用 LlamaIndex 的功能将其用于针对任何 SQL 数据库的结构化分析。

以下是我们使用的技术栈

特别感谢 Anyscale 出色的 Llama 2 教程,它为本项目提供了灵感

所有材料都可以在我们的 Github 仓库中找到:https://github.com/run-llama/modal_finetune_sql(再次强调,这是改编自 doppel-bot)。此外,完整教程可以在我们的 Jupyter Notebook 指南中找到。务必查看!

如上所述,进行微调确实需要不少步骤。我们的目标是使其尽可能简单易懂,并开箱即用。我们不会涵盖 Modal、PEFT、微调过程本身的细枝末节等等,但我们会提供一个粗略的概述。

当然也有更高级的 API 我们可以用来完成这项任务(例如 OpenAI、Lamini)。未来可以有更多后续教程来涵盖这些主题!

步骤 1:加载用于微调 LLaMa 的训练数据

第一步是打开 Jupyter Notebook。该 Notebook 组织成一系列可运行的脚本,每个脚本都执行加载数据所需的步骤。

我们的代码在编排的每个步骤都使用了 Modal,而 Modal 最适合直接在 Python 脚本之上使用。这就是为什么这些单元格中很多本身不包含 Python 代码块。

首先,我们使用 Modal 加载 b-mc2/sql-create-context 数据集。这是一个简单的任务,只需加载数据集并将其格式化为 .jsonl 文件。

modal run src.load_data_sql --data-dir "data_sql"

如我们所见,底层任务相当直接

# Modal stubs allow our function to run remotely
@stub.function(
    retries=Retries(
        max_retries=3,
        initial_delay=5.0,
        backoff_coefficient=2.0,
    ),
    timeout=60 * 60 * 2,
    network_file_systems={VOL_MOUNT_PATH.as_posix(): output_vol},
    cloud="gcp",
)
def load_data_sql(data_dir: str = "data_sql"):
    from datasets import load_dataset

    dataset = load_dataset("b-mc2/sql-create-context")

    dataset_splits = {"train": dataset["train"]}
    out_path = get_data_path(data_dir)

    out_path.parent.mkdir(parents=True, exist_ok=True)

    for key, ds in dataset_splits.items():
        with open(out_path, "w") as f:
            for item in ds:
                newitem = {
                    "input": item["question"],
                    "context": item["context"],
                    "output": item["answer"],
                }
                f.write(json.dumps(newitem) + "\n")

步骤 2:运行微调脚本

下一步是对解析后的数据集运行我们的微调脚本。

modal run src.finetune_sql --data-dir "data_sql" --model-dir "model_sql"

微调脚本执行以下步骤。

将数据集拆分为训练集和验证集

train_val = data["train"].train_test_split(test_size=val_set_size, shuffle=True, seed=42)
train_data = train_val["train"].shuffle().map(generate_and_tokenize_prompt)
val_data = train_val["test"].shuffle().map(generate_and_tokenize_prompt)

将每个拆分格式化为 (输入提示, 标签) 元组:输入查询和上下文被格式化为相同的输入提示。然后对输入提示进行分词,并将标签设置为与输入提示完全相同——这允许模型训练下一个词元预测。

def generate_and_tokenize_prompt(data_point):
  full_prompt = generate_prompt_sql(
      data_point["input"],
      data_point["context"],
      data_point["output"],
  )
  tokenized_full_prompt = tokenize(full_prompt)
  if not train_on_inputs:
      raise NotImplementedError("not implemented yet")
  return tokenized_full_prompt

输入提示与本博客顶部给出的完全相同。

运行微调脚本时,模型保存在由 model_dir 指定的远程云目录中(如果未指定,则使用默认值)。

步骤 3:评估

模型已微调完成,可以从云端提供服务。我们可以使用来自 sql-create-context 的样本数据进行一些基本评估,比较微调模型与基准 Llama 2 模型的性能。

modal run src.eval_sql::main

结果表明微调模型有了巨大改进

Input 1: {'input': 'Which region (year) has Abigail at number 7, Sophia at number 1 and Aaliyah at number 5?', 'context': 'CREATE TABLE table_name_12 (region__year_ VARCHAR, no_5 VARCHAR, no_7 VARCHAR, no_1 VARCHAR)', 'output': 'SELECT region__year_ FROM table_name_12 WHERE no_7 = "abigail" AND no_1 = "sophia" AND
no_5 = "aaliyah"'}
Output 1 (finetuned model): SELECT region__year_ FROM table_name_12 WHERE no_7 = "abigail" AND no_1 = "aaliyah" AND no_5 = "sophia"
Output 1 (base model): SELECT * FROM table_name_12 WHERE region__year = '2018' AND no_5 = 'Abigail' AND no_7 = 'Sophia' AND no_1 = 'Aaliyah';


Input 2: {'input': 'Name the result/games for 54741', 'context': 'CREATE TABLE table_21436373_11 (result_games VARCHAR, attendance VARCHAR)', 'output': 'SELECT result_games FROM table_21436373_11 WHERE attendance = 54741'}
Output 2 (finetuned model): SELECT result_games FROM table_21436373_11 WHERE attendance = "54741"
Output 2 (base model): SELECT * FROM table_21436373_11 WHERE result_games = 'name' AND attendance > 0;

基础模型生成格式错误的输出或不正确的 SQL 语句,

而微调模型能够生成更接近预期输出的结果。

步骤 4:将微调模型与 LlamaIndex 集成

现在我们可以在 LlamaIndex 中使用此模型对任何数据库执行 Text-to-SQL。

我们首先定义一个测试 SQL 数据库,然后可以使用它来测试模型的推理能力。

我们创建一个模拟的 city_stats 表,其中包含城市名称、人口和国家信息,并用一些示例城市填充它。

db_file = "cities.db"
engine = create_engine(f"sqlite:///{db_file}")
metadata_obj = MetaData()
# create city SQL table
table_name = "city_stats"
city_stats_table = Table(
    table_name,
    metadata_obj,
    Column("city_name", String(16), primary_key=True),
    Column("population", Integer),
    Column("country", String(16), nullable=False),
)
metadata_obj.create_all(engine)

这存储在一个 cities.db 文件中。

然后我们可以使用 Modal 将微调模型和此数据库文件加载到 LlamaIndex 的 NLSQLTableQueryEngine 中——这个查询引擎允许用户轻松地对给定数据库执行 Text-to-SQL。

modal run src.inference_sql_llamaindex::main --query "Which city has the highest population?" --sqlite-file-path "nbs/cities.db" --model-dir "model_sql" --use-finetuned-model True

我们得到了如下响应

SQL Query: SELECT MAX(population) FROM city_stats WHERE country = "United States"
Response: [(2679000,)]

结论

基本就是这样!本教程提供了一个非常高层次的方法,帮助您开始微调 Llama 2 模型以生成 SQL 语句,并端到端展示如何通过 LlamaIndex 将其集成到您的 Text-to-SQL 工作流程中。

资源

为了完整起见,我们再次在此链接所有资源。

教程仓库:https://github.com/run-llama/modal_finetune_sql(改编自 doppel-bot)。

Jupyter Notebook 指南.

技术栈

特别鸣谢:来自 Anyscale 的 Llama 2 教程