Improving Accuracy of SQL Agents (Text-to-SQL) with Memory Tuning¶

The goal of this notebook is to demonstrate how to improve the accuracy of large language model (LLM) outputs using a new fine-tuning method called memory tuning. Specifically, I focus on the task of converting a user's natural language queries into valid SQL statements, but memory tuning works for any generative task. The process works by training an open-source model like Llama 3.1 with precise facts to improve its factual accuracy (from ~30% to 95%) and reduce hallucinations by 10x, while preserving the model's ability to generalize.

Open In Colab

Why should you read this notebook?¶

You want to:

  • Give non-technical people on your team the ability to ask questions of your SQL database without the need for a developer (AKA "data democratization").
  • You wan to learn how to create SQL Agents (i.e. converting natural language questions into SQL queries)
  • Learn how to create evaluation pipelines for checking and improving the accuracy of your SQL generations.
  • Learn how to generate synthetic data and fine-tune an open-source LLM to improve the accuracy of your data.
  • Learn about a new method of fine-tuning called memory tuning.

Table of Contents¶

  • Introduction
  • Create an SQL Agent
  • Create Evaluations
  • Generate Synthetic Data for Fine-tuning
  • Fine-tune Llama 3.1 with Lamini

Summary of Results¶

  • Baseline Accuracy: Initial manual evaluation of SQL query generation yielded an accuracy of ~30%.

  • Prompt Engineering: Techniques like self-reflection and chain-of-thought reasoning improved model performance to 55%.

  • Evaluation Pipeline: Development of reusable evaluations wrapped in Python classes, with instruction fine-tuning increasing accuracy to ~73%.

  • Memory Tuning: Final implementation of memory tuning techniques achieved ~95% accuracy after 3 iterations of fine-tuning.

Attribution¶

This notebook is based on the DeepLearning.ai course Improving Accuracy of LLM Applications, taught by Sharon Zhou, CEO at Lamini, and Amit Sangani, Sr. Director of Partner Engineering at Meta.

Common LLM Failure Modes¶

  • Hallucination - When an LLM generates plausible, but false, information.
  • Prompt reptition - Repeating all or part of the prompt or restating it in different words.
  • Degeneration - Especially for long contexts or question-answer sessions, LLM responses can start to degrade in quality as they stop focusing on the most relevant parts of the conversation .
  • Catastrophic forgetting - When training or fine-tuning an LLM, there is a risk of overwriting knowledge that it previously acquired, thus permenantly "forgetting" something it previousl knew.
  • Prompt injection/jailbreaking - Tricking the LLM into responding in a way it was not intended to, such as convincing it to generate explicit content.
  • Lack of common sense - Despite their many strengths, sometimes LLM fail to understand simple concepts. For example, a classic failure in common sense is when an LLM thinks 1 kg of feathers weighs less than 1 kg of bricks.

Methods of Improving LLM Accuracy¶

  1. Prompt engineering - With this method, we simply include in the system prompt specific instructions for the LLM, e.g. "Only respond with SQL in your answer."
  2. Self-reflection - A variation of prompting engineering, self-reflection instructs the LLM to critique its past repsonses and identify any errors or areas of improvment.
  3. Retrieval augmented generation (RAG) - Performing semantic search across a vector database of embeddings can be used to suppliment model accuracy by grounding it with an external knowledge base.
  4. Instruction fine-tuning - A supervised fine-tuning method where we give an LLM a set of prompts and corresponding outputs. It's common to instruction fine-tune pretrained models to enforce a Question-Answer format, but this method can all be used for enforcing function-calling or other objectives.
  5. Memory tuning - This method is the focus of this notebook and is discussed in further detail below.

Memory Tuning¶

Memory tuning is new method of embedding facts into LLMs to improve factual accuracy and reduce hallucinations. Memory tuning overcomes the following challenge: how can you enforce deterministic outputs (i.e. memorized facts) into a model while still preserving its ability to generalize well?

Memory tuning works in the following way:

  1. A collection of LoRA adapters (See my QLoRA fine-tuning notebook) are trained with precise facts until the loss is zero, i.e. the adapters are intentionally overfit.

NOTE: Overfitting is the term used to describe a model that has been overtrained to the point that instead of learning a concept, it has simply memorized the data it has seen—like memorizing the answers to a practice test instead of actually learning the material in a textbook. The problem with overfitting to the training data is that the model will not generalize well to unseen samples, in the same way that you'll probably fail the real test because you just memorized the answers to the practice test questions.

  1. After the adapters are tuned, the data used for tuning is then embedded, indexed, and stored in a vector database.

  2. When a user queries the LLM-powered application, the app searches the vector database with the user's prompt, then loads the adapter that corresponds with the index of the retrieved embedding.

Set up¶

In [ ]:
!pip install lamini==3.0.5
In [3]:
# If using Google Colab
import os
from google.colab import userdata
os.environ["LAMINI_API_KEY"] = userdata.get('LAMINI_API_KEY')

Introduction¶

We can use Llama 3.1 to generate SQL.

In [58]:
import lamini
In [59]:
llm = lamini.Lamini(model_name="meta-llama/Meta-Llama-3-8B-Instruct")
In [42]:
# Form prompt util
def make_llama_3_prompt(user, system=""):
    system_prompt = ""
    if system != "":
        system_prompt = (
            f"<|start_header_id|>system<|end_header_id|>\n\n{system}<|eot_id|>"
        )
    return f"<|begin_of_text|>{system_prompt}<|start_header_id|>user<|end_header_id|>\n\n{user}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
In [43]:
question = (
    "Given an arbitrary table named `sql_table`, "
    "write a query to return how many rows are in the table."
    )
prompt = make_llama_3_prompt(question)
print(llm.generate(prompt, max_new_tokens=200))
The query to return the number of rows in a table named `sql_table` is:

```sql
SELECT COUNT(*) 
FROM sql_table;
```

This query uses the `COUNT(*)` function to count the number of rows in the table. The `*` is a wildcard that means "all columns", but in this case, we don't need to specify any specific columns because we're only interested in counting the number of rows.
In [62]:
question = """Given an arbitrary table named `sql_table`,
help me calculate the average `height` where `age` is above 20."""
prompt = make_llama_3_prompt(question)
print(llm.generate(prompt, max_new_tokens=200))
Assuming you are using SQL, you can use the following query to calculate the average `height` where `age` is above 20:

```sql
SELECT AVG(height) 
FROM sql_table 
WHERE age > 20;
```

This query will return the average `height` for all rows in `sql_table` where the `age` is greater than 20.
In [63]:
question = """Given an arbitrary table named `sql_table`,
Can you calculate the p95 `height` where the `age` is above 20?"""
prompt = make_llama_3_prompt(question)
print(llm.generate(prompt, max_new_tokens=200))
Assuming you are using a SQL database, you can use the following query to calculate the 95th percentile of the `height` column where the `age` is above 20:

```sql
SELECT PERCENTILE_CONT(0.95) WITHIN GROUP (ORDER BY height) 
FROM sql_table 
WHERE age > 20;
```

This query uses the `PERCENTILE_CONT` function to calculate the 95th percentile of the `height` column. The `WITHIN GROUP (ORDER BY height)` clause specifies that the percentile should be calculated within the group of rows ordered by the `height` column. The `WHERE age > 20` clause filters the rows to only include those where the `age` is above 20.

Note that the exact syntax may vary depending on the specific database management system you are using. For example, in MySQL, you would use `PERCENTILE_CONT(0.95) WITHIN GROUP (ORDER BY height) OVER
In [64]:
question = ("Given an arbitrary table named `sql_table`, "
            "Can you calculate the p95 `height` "
            "where the `age` is above 20? Use sqlite.")
prompt = make_llama_3_prompt(question)

print(llm.generate(prompt, max_new_tokens=200))
You can use the following SQL query to calculate the 95th percentile of the `height` column where the `age` is above 20:
```
SELECT PERCENTILE(height) WITHIN GROUP (ORDER BY height) AS p95_height
FROM sql_table
WHERE age > 20;
```
This query uses the `PERCENTILE` function to calculate the 95th percentile of the `height` column, and the `WITHIN GROUP` clause to specify that the percentile should be calculated within the group of rows where `age` is greater than 20.

Note that the `PERCENTILE` function is only available in SQLite 3.25 and later versions. If you are using an earlier version of SQLite, you can use the `NTILE` function instead:
```
SELECT NTILE(100, height) AS p95_height
FROM (
  SELECT height
  FROM sql_table
  WHERE age > 20
) AS subquery
ORDER

Create a database of NBA players¶

In [53]:
import sqlite3

# Define sample data
players = [
    {
        "Team": "Toronto Raptors",
        "NAME": "Otto Porter Jr.",
        "Jersey": "0",
        "POS": "PF",
        "AGE": 29,
        "HT": "6' 8\"",
        "WT": "198 lbs",
        "COLLEGE": "Georgetown",
        "SALARY": "$6,000,000"
    },
    {
        "Team": "Golden State Warriors",
        "NAME": "Stephen Curry",
        "Jersey": "30",
        "POS": "PG",
        "AGE": 36,
        "HT": "6' 2\"",
        "WT": "185 lbs",
        "COLLEGE": "Davidson",
        "SALARY": "$53,838,416"
    },
    {
        "Team": "Los Angeles Lakers",
        "NAME": "LeBron James",
        "Jersey": "6",
        "POS": "SF",
        "AGE": 39,
        "HT": "6' 9\"",
        "WT": "250 lbs",
        "COLLEGE": "--",
        "SALARY": "$44,474,988"
    },
    {
        "Team": "Milwaukee Bucks",
        "NAME": "Giannis Antetokounmpo",
        "Jersey": "34",
        "POS": "PF",
        "AGE": 29,
        "HT": "6' 11\"",
        "WT": "242 lbs",
        "COLLEGE": "--",
        "SALARY": "$45,640,084"
    },
    {
        "Team": "Toronto Raptors",
        "NAME": "Scottie Barnes",
        "Jersey": "4",
        "POS": "SF",
        "AGE": 22,
        "HT": "6' 7\"",
        "WT": "225 lbs",
        "COLLEGE": "Florida State",
        "SALARY": "$7,644,600"
    },
    {
        "Team": "Brooklyn Nets",
        "NAME": "Kevin Durant",
        "Jersey": "7",
        "POS": "SF",
        "AGE": 35,
        "HT": "6' 10\"",
        "WT": "240 lbs",
        "COLLEGE": "Texas",
        "SALARY": "$48,554,830"
    },
    {
        "Team": "Dallas Mavericks",
        "NAME": "Luka Dončić",
        "Jersey": "77",
        "POS": "PG",
        "AGE": 25,
        "HT": "6' 7\"",
        "WT": "230 lbs",
        "COLLEGE": "--",
        "SALARY": "$42,492,492"
    },
    {
        "Team": "Denver Nuggets",
        "NAME": "Nikola Jokić",
        "Jersey": "15",
        "POS": "C",
        "AGE": 29,
        "HT": "6' 11\"",
        "WT": "284 lbs",
        "COLLEGE": "--",
        "SALARY": "$47,607,350"
    },
    {
        "Team": "Boston Celtics",
        "NAME": "Jayson Tatum",
        "Jersey": "0",
        "POS": "SF",
        "AGE": 26,
        "HT": "6' 8\"",
        "WT": "210 lbs",
        "COLLEGE": "Duke",
        "SALARY": "$32,600,060"
    },
    {
        "Team": "Philadelphia 76ers",
        "NAME": "Joel Embiid",
        "Jersey": "21",
        "POS": "C",
        "AGE": 30,
        "HT": "7' 0\"",
        "WT": "280 lbs",
        "COLLEGE": "Kansas",
        "SALARY": "$47,607,350"
    }
]

# Connect to a SQLite database (or create it)
conn = sqlite3.connect('nba_roster.db')

# Create a cursor object
cursor = conn.cursor()

# Create the players table
cursor.execute('''CREATE TABLE IF NOT EXISTS nba_roster (
    Team TEXT,
    NAME TEXT,
    Jersey TEXT DEFAULT 'NA',
    POS TEXT,
    AGE INTEGER,
    HT TEXT,
    WT TEXT,
    COLLEGE TEXT DEFAULT '--',
    SALARY TEXT DEFAULT '--'
)''')

# Insert the sample data into the table
for player in players:
    cursor.execute('''INSERT INTO nba_roster (
        Team, NAME, Jersey, POS, AGE, HT, WT, COLLEGE, SALARY
    ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)''',
    (player["Team"], player["NAME"], player["Jersey"], player["POS"], player["AGE"],
     player["HT"], player["WT"], player["COLLEGE"], player["SALARY"]))

# Commit the transaction
conn.commit()

# Close the connection
conn.close()

Utils¶

In [45]:
# Can be saved in a utils.py file

# Schema util
def get_schema():
    return """\
0|Team|TEXT eg. "Toronto Raptors"
1|NAME|TEXT eg. "Otto Porter Jr."
2|Jersey|TEXT eg. "0" and when null has a value "NA"
3|POS|TEXT eg. "PF"
4|AGE|INT eg. "22" in years
5|HT|TEXT eg. `6' 7"` or `6' 10"`
6|WT|TEXT eg. "232 lbs"
7|COLLEGE|TEXT eg. "Michigan" and when null has a value "--"
8|SALARY|TEXT eg. "$9,945,830" and when null has a value "--"
"""

# Formt prompt util
def make_llama_3_prompt(user, system=""):
    system_prompt = ""
    if system != "":
        system_prompt = (
            f"<|start_header_id|>system<|end_header_id|>\n\n{system}<|eot_id|>"
        )
    return f"<|begin_of_text|>{system_prompt}<|start_header_id|>user<|end_header_id|>\n\n{user}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"

# Logging util
import logging

def setup_logging():
    # Remove all handlers associated with the root logger object.
    for handler in logging.root.handlers[:]:
        logging.root.removeHandler(handler)

    logging.basicConfig(
        level=logging.WARNING,
        format="%(asctime)s [%(levelname)s] %(message)s",
        handlers=[logging.StreamHandler()],
    )

def get_default_finetune_args():
    return {
        "learning_rate": 3e-4,
        "max_steps": 3000,
        "early_stopping": False,
        "load_best_model_at_end": False,
        "use_cached_model": False,
        "peft_args": {"r_value": 32},
    }

def get_rubric():
    prompt = (
        "Read this scoring rubric carefully and follow the instructions precisely:\n"
    )
    prompt += (
        "A score of 5 means that model's value is the same as the gold answer's id.\n"
    )
    prompt += "A score of 4 means that the model's answer is the same or a paraphrase of the gold answer, but the value may not be an exact match.\n"
    prompt += "A score of 3 means that the model's answer is similar as the gold answer's description, but the value may be wrong. Both answers may indicate that revenue is increased but the gold says 12 percent and the model say 50 million USD.\n"
    prompt += "A score of 2 means that the model's answer is not similar to the gold answer, but the answer is plausible.\n"
    prompt += "A score of 1 means that the model's answer is not similar to the gold answer, and the answer doesn't make sense.\n"

    prompt += "Assign a 5 for a correct value even if other fields are missing.\n"

    return prompt


import random
import jsonlines

def load_training_data(args, make_question):
    path = f"data/training_data/{args.training_file_name}"

    limit = 1000

    with jsonlines.open(path) as reader:
        for index, obj in enumerate(reversed(list(reader))):
            if index >= limit:
                break

            yield {
                "input": make_llama_3_prompt(**make_question(obj)),
                "output": obj["sql"] + "<|eot_id|>",
            }


def get_dataset(args, make_question):
    dataset = list(load_training_data(args, make_question)) * 10
    random.seed(42)
    random.shuffle(dataset)
    return dataset

from argparse import ArgumentParser


def parse_arguments():
    parser = ArgumentParser()

    # The max number of examples to evaluate
    parser.add_argument(
        "--max-examples",
        type=int,
        default=100,
        help="The max number of examples to evaluate",
        required=False,
    )

    parser.add_argument(
        "--sql-model-name",
        type=str,
        default="meta-llama/Meta-Llama-3-8B-Instruct",
        help="The model to use for text2sql",
        required=False,
    )

    parser.add_argument(
        "--gold-file-name",
        type=str,
        default="gold-test-set.jsonl",
        help="The gold dataset to use as seed",
        required=False,
    )

    parser.add_argument(
        "--training-file-name",
        type=str,
        default="generated_queries.jsonl",
        help="The training dataset",
        required=False,
    )

    return parser.parse_args()

Create an SQL Agent¶

In [46]:
import logging
import sqlite3
import pandas as pd

logger = logging.getLogger(__name__)
engine = sqlite3.connect("nba_roster.db")
setup_logging()
In [108]:
user = """Who is the highest paid NBA player?"""
In [109]:
system = f"""You are an NBA analyst with 15 years of experience writing complex SQL queries. Consider the nba_roster table with the following schema:
{get_schema()}

Write a sqlite query to answer the following question. Follow instructions exactly"""
In [110]:
print(system)
You are an NBA analyst with 15 years of experience writing complex SQL queries. Consider the nba_roster table with the following schema:
0|Team|TEXT eg. "Toronto Raptors"
1|NAME|TEXT eg. "Otto Porter Jr."
2|Jersey|TEXT eg. "0" and when null has a value "NA"
3|POS|TEXT eg. "PF"
4|AGE|INT eg. "22" in years
5|HT|TEXT eg. `6' 7"` or `6' 10"`
6|WT|TEXT eg. "232 lbs" 
7|COLLEGE|TEXT eg. "Michigan" and when null has a value "--"
8|SALARY|TEXT eg. "$9,945,830" and when null has a value "--"


Write a sqlite query to answer the following question. Follow instructions exactly
In [111]:
prompt = make_llama_3_prompt(user, system)
In [112]:
print(llm.generate(prompt, max_new_tokens=200))
To answer this question, we can use the following SQL query:

```sql
SELECT NAME, SALARY
FROM nba_roster
WHERE SALARY!= '--'
ORDER BY CAST(SALARY AS REAL) DESC
LIMIT 1;
```

This query first filters out the rows where the salary is '--' (i.e., the players who don't have a salary listed). Then, it orders the remaining rows by the salary in descending order (highest to lowest). Finally, it returns the top row, which corresponds to the highest paid NBA player.
In [47]:
def get_updated_schema():
    return """\
0|Team|TEXT eg. "Toronto Raptors"
1|NAME|TEXT eg. "Otto Porter Jr."
2|Jersey|TEXT eg. "0" and when null has a value "NA"
3|POS|TEXT eg. "PF"
4|AGE|INT eg. "22" in years
5|HT|TEXT eg. `6' 7"` or `6' 10"`
6|WT|TEXT eg. "232 lbs"
7|COLLEGE|TEXT eg. "Michigan" and when null has a value "--"
8|SALARY|TEXT eg. "$9,945,830" and when null has a value "--"
"""
In [114]:
system = f"""You are an NBA analyst with 15 years of experience writing complex SQL queries. Consider the nba_roster table with the following schema:
{get_updated_schema()}

Write a sqlite query to answer the following question. Follow instructions exactly"""
In [115]:
prompt = make_llama_3_prompt(user, system)
In [118]:
print(prompt)
<|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are an NBA analyst with 15 years of experience writing complex SQL queries. Consider the nba_roster table with the following schema:
0|Team|TEXT eg. "Toronto Raptors"
1|NAME|TEXT eg. "Otto Porter Jr."
2|Jersey|TEXT eg. "0" and when null has a value "NA"
3|POS|TEXT eg. "PF"
4|AGE|INT eg. "22" in years
5|HT|TEXT eg. `6' 7"` or `6' 10"`
6|WT|TEXT eg. "232 lbs" 
7|COLLEGE|TEXT eg. "Michigan" and when null has a value "--"
8|SALARY|TEXT eg. "$9,945,830" and when null has a value "--"


Write a sqlite query to answer the following question. Follow instructions exactly<|eot_id|><|start_header_id|>user<|end_header_id|>

Who is the highest paid NBA player?<|eot_id|><|start_header_id|>assistant<|end_header_id|>


In [119]:
print(llm.generate(prompt, max_new_tokens=200))
To answer this question, we can use the following SQL query:

```sql
SELECT NAME, SALARY
FROM nba_roster
WHERE SALARY!= '--'
ORDER BY CAST(SALARY AS REAL) DESC
LIMIT 1;
```

This query first filters out the rows where the salary is '--' (i.e., the players who don't have a salary listed). Then, it orders the remaining rows by the salary in descending order (highest to lowest). Finally, it returns the top row, which corresponds to the highest paid NBA player.

Structured Output¶

We'd like to be able to get just SQL output so we don't have to parse the query from the model response. For this we can use structured output.

In [48]:
result = llm.generate(prompt, output_type={"sqlite_query": "str"}, max_new_tokens=200)
In [51]:
result
Out[51]:
{'sqlite_query': 'SELECT COUNT(*) FROM sql_table'}
In [ ]:
df = pd.read_sql(result['sqlite_query'], con=engine)
In [123]:
df
Out[123]:
NAME SALARY
0 Otto Porter Jr. $6,000,000

Diagnose Hallucinations¶

The wrong query looks like this:

SELECT NAME, SALARY
FROM nba_roster
WHERE salary != '--'
ORDER BY CAST(SALARY AS REAL) DESC
LIMIT 1;

The correct query is:

SELECT salary, name
FROM nba_roster
WHERE salary != '--'
ORDER BY CAST(REPLACE(REPLACE(salary, '$', ''), ',','') AS INTEGER) DESC
LIMIT 1;
In [54]:
query="""SELECT salary, name
FROM nba_roster
WHERE salary != '--'
ORDER BY CAST(REPLACE(REPLACE(salary, '$', ''), ',','') AS INTEGER) DESC
LIMIT 1;"""
df = pd.read_sql(query, con=engine)
print(df)
        SALARY           NAME
0  $53,838,416  Stephen Curry

Create Evaluations¶

Characteristics of good evals:

  1. Quantifiable
  2. Actionable
  3. Scalable and automatable (use LLMs!)

Tips:

  • Start small (20-100 samples)
  • Quality > quantity. Iterative expansion
  • Focus on small areas of improvement, e.g. hallucinations

Practical steps:

  • Find easiest examples that still fail
  • Use an "adversarial playground" to find the boundaries of accuracy
  • Set next accuracy target of your LLM. Before making your evals harder, first get to +90% accuracy then increase the difficulty of the evals

Use an LLM to score your output¶

  1. Get the LLM to output a numerical score.
  2. Provide the question, generated response, and scoring method through the prompt of your eval LLM.
  3. Use structured output to return the score as an int, float, List[int], etc.

Example:

system_prompt = "Compare the following two dataframes."  
system_prompt += "They are similar if they are almost identical, "
system_prompt += "or if they convey the same information about the nba_roster dataset"
system_prompt += "Respond with valid JSON  {'explanation' : str, 'similar': bool}"
user_prompt = (
  f"============ Dataframe 1 ============\n{str(obj.data..get('df','None')).lower()}\n\n"
)
user_prompt += (
  f"============ Dataframe 2 ============\n{str(obj.data..get('reference_df')).lower()}\n\n"
)
user_prompt += f"Can you tell me if these dataframes are similar?"
In [55]:
%mkdir data
In [56]:
%%writefile data/gold-test-set.jsonl
{"question": "What is the 99th percentile salary in the NBA?", "answer": "46741590", "sql": "SELECT (CAST(REPLACE(REPLACE(SALARY, '$', ''), ',','') AS INTEGER)) as percentile FROM nba_roster WHERE SALARY!= '--' order by percentile limit 1 offset (select count(*) from nba_roster where SALARY != '--')*99/100-1;"}
{"question": "What is the 75th percentile salary in the NBA?", "answer": "13932008", "sql": "SELECT (CAST(REPLACE(REPLACE(SALARY, '$', ''), ',','') AS INTEGER)) as percentile FROM nba_roster WHERE SALARY!= '--' order by percentile limit 1 offset (select count(*) from nba_roster where SALARY != '--')*75/100-1;"}
{"question": "What is the 25th percentile salary in the NBA?", "answer": "2413304", "sql": "SELECT (CAST(REPLACE(REPLACE(SALARY, '$', ''), ',','') AS INTEGER)) as percentile FROM nba_roster WHERE SALARY!= '--' order by percentile limit 1 offset (select count(*) from nba_roster where SALARY != '--')*25/100-1;"}
{"question": "What is the median weight in the NBA?", "answer": "215", "sql": "select CAST(SUBSTR(WT, 1, INSTR(WT,' ')) as INTEGER) as percentile from nba_roster order by percentile limit 1 offset (select count(*) from nba_roster)/2;"}
{"question": "What is the average weight in the NBA?", "answer": "214.98", "sql": "SELECT AVG(CAST(SUBSTR(WT, 1, INSTR(WT,' ')) as INTEGER)) FROM nba_roster;"}
{"question": "What is the median height in the NBA?", "answer": "6.58333333333333", "sql": "select CAST(SUBSTR(HT, 1, INSTR(HT,' ')-1) AS INTEGER)+ CAST(SUBSTR(HT, INSTR(HT,' ')+1) AS FLOAT)/12 as percentile from nba_roster order by percentile limit 1 offset (select count(*) from nba_roster)/2;"}
{"question": "What is the average height in the NBA?", "answer": "6.54986111111111", "sql": "select AVG(CAST(SUBSTR(HT, 1, INSTR(HT,' ')-1) AS INTEGER)+ CAST(SUBSTR(HT, INSTR(HT,' ')+1) AS FLOAT)/12) as height from nba_roster;"}
{"question": "Can you tell me how many players are in the NBA?", "answer": "600", "sql": "select count(*) from nba_roster;"}
{"question": "Would you please let me know what the highest paid players are for each position?", "answer": "The highest paid players are Nikola Jokic (C), Paul George (F), Norman Powell (G), Kevin Durant (PF), Stephen Curry (PG), LeBron James (SF), Bradley Beal (SG).", "sql": "SELECT name, pos, MAX(CAST(REPLACE(REPLACE(SALARY, '$', ''), ',','') AS INTEGER)) as max_salary FROM nba_roster WHERE SALARY!= '--' GROUP BY POS;"}
{"question": "Is Jalen Johnson 23 years old?", "answer": "No, Jalen Johnson is 21 years old", "sql" : "Select name, age from nba_roster where name='Jalen Johnson';"}
{"question": "Who is the oldest player on the Brooklyn Nets?", "answer": "Spencer Dinwiddie, Dorian Finney-Smith, Royce O'Neale", "sql" : "SELECT NAME FROM nba_roster WHERE TEAM = 'Brooklyn Nets' AND AGE = (SELECT MAX(AGE) FROM nba_roster WHERE TEAM = 'Brooklyn Nets');"}
{"question": "Who has the higest salary on the Memphis Grizzlies?", "answer": "Ja Morant", "sql" : "select salary, name from nba_roster where team='Memphis Grizzlies' and SALARY!= '--' ORDER BY CAST(REPLACE(REPLACE(SALARY, '$', ''), ',','') AS INTEGER) DESC LIMIT 1;"}
{"question": "Which player has the higest salary on the Cleveland Cavaliers?", "answer": "Darius Garland", "sql" : "select salary, name from nba_roster where team='Cleveland Cavaliers' and SALARY!= '--' ORDER BY CAST(REPLACE(REPLACE(SALARY, '$', ''), ',','') AS INTEGER) DESC LIMIT 1;"}
{"question": "Who is the highest paid center on the Dallas Mavericks?", "answer": "Dereck Lively II", "sql" : "select salary, name from nba_roster where team='Dallas Mavericks' and POS='C' and SALARY!= '--' ORDER BY CAST(REPLACE(REPLACE(SALARY, '$', ''), ',','') AS INTEGER) DESC LIMIT 1;"}
{"question": "How much is Marcus Smart getting paid?", "answer": "$18,833,712", "sql" : "select salary from nba_roster where name='Marcus Smart';"}
{"question": "What's the average age of the Trail Blazers?", "answer": "24", "sql" : "select avg(age) from nba_roster where team='Portland Trail Blazers';"}
{"question": "What's the median age of the NBA?", "answer": "25", "sql" : "select CAST(AGE as INTEGER) as percentile from nba_roster order by percentile limit 1 offset (select count(*) from nba_roster)/2;"}
{"question": "What's the median age of the Miami Heat?", "answer": "26", "sql" : "select CAST(AGE as INTEGER) as percentile from nba_roster where team='Miami Heat' order by percentile limit 1 offset (select count(*) from nba_roster where team='Miami Heat')/2;"}
{"question": "What are the 5 teams with the oldest average age in the NBA", "answer": "Golden State Warriors, Milwaukee Bucks, Miami Heat, LA Clippers, Phoenix Suns", "sql": "SELECT team, AVG(AGE) AS average_age FROM nba_roster GROUP BY team ORDER BY average_age DESC LIMIT 5;"}
{"question": "What is the average salary of Power Forward players in the NBA", "answer": "$10948045", "sql": "select avg(CAST(REPLACE(REPLACE(SALARY, '$', ''), ',','') AS INTEGER)) as average_salary from nba_roster where POS = 'PF';"}
Writing data/gold-test-set.jsonl
In [57]:
question = "What is the median weight in the NBA?"
In [58]:
import lamini
In [59]:
llm = lamini.Lamini(model_name="meta-llama/Meta-Llama-3-8B-Instruct")
In [60]:
system = f"""You are an NBA analyst with 15 years of experience writing complex SQL queries. Consider the nba_roster table with the following schema:
{get_schema()}

Write a sqlite query to answer the following question. Follow instructions exactly"""
prompt = make_llama_3_prompt(question, system)
In [61]:
generated_query = llm.generate(prompt, output_type={"sqlite_query": "str"}, max_new_tokens=200)
print(generated_query)
{'sqlite_query': "SELECT AVG(CAST(SUBSTR(WT, INSTR(WT,'') + 1) AS INTEGER) FROM nba_roster WHERE WT IS NOT NULL"}
In [62]:
import pandas as pd
import sqlite3
engine = sqlite3.connect("nba_roster.db")

Example of an error¶

The model outputs a query that is invalid

In [ ]:
df = pd.read_sql(generated_query['sqlite_query'], con=engine)
In [65]:
import pandas as pd
import sqlite3
engine = sqlite3.connect("./nba_roster.db")
try:
    df = pd.read_sql(generated_query['sqlite_query'], con=engine)
    print(df)
except Exception as e:
    print(e)
Execution failed on sql 'SELECT AVG(CAST(SUBSTR(WT, INSTR(WT,'') + 1) AS INTEGER) FROM nba_roster WHERE WT IS NOT NULL': near "FROM": syntax error

Agent Reflection to improve the query¶

In [66]:
# Give the model the error we're seeing in our query
reflection = f"Question: {question}. Query: {generated_query['sqlite_query']}. This query is invalid (gets the error Execution failed on sql 'SELECT AVG(CAST(SUBSTR(WT, INSTR(WT,'') + 1) AS INTEGER) FROM nba_roster WHERE WT IS NOT NULL': near \"FROM\": syntax error), so it cannot answer the question. Write a corrected sqlite query."
In [67]:
reflection_prompt = make_llama_3_prompt(reflection, system)
In [68]:
reflection_prompt
Out[68]:
'<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are an NBA analyst with 15 years of experience writing complex SQL queries. Consider the nba_roster table with the following schema:\n0|Team|TEXT eg. "Toronto Raptors"\n1|NAME|TEXT eg. "Otto Porter Jr."\n2|Jersey|TEXT eg. "0" and when null has a value "NA"\n3|POS|TEXT eg. "PF"\n4|AGE|INT eg. "22" in years\n5|HT|TEXT eg. `6\' 7"` or `6\' 10"`\n6|WT|TEXT eg. "232 lbs" \n7|COLLEGE|TEXT eg. "Michigan" and when null has a value "--"\n8|SALARY|TEXT eg. "$9,945,830" and when null has a value "--"\n\n\nWrite a sqlite query to answer the following question. Follow instructions exactly<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nQuestion: What is the median weight in the NBA?. Query: SELECT AVG(CAST(SUBSTR(WT, INSTR(WT,\'\') + 1) AS INTEGER) FROM nba_roster WHERE WT IS NOT NULL. This query is invalid (gets the error Execution failed on sql \'SELECT AVG(CAST(SUBSTR(WT, INSTR(WT,\'\') + 1) AS INTEGER) FROM nba_roster WHERE WT IS NOT NULL\': near "FROM": syntax error), so it cannot answer the question. Write a corrected sqlite query.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n'
In [69]:
reflection_query = llm.generate(reflection_prompt, output_type={"sqlite_query": "str"}, max_new_tokens=200)
In [70]:
reflection_query
Out[70]:
{'sqlite_query': "SELECT AVG(CAST(SUBSTR(WT, INSTR(WT,'') + 1) AS INTEGER) FROM nba_roster WHERE WT IS NOT NULL"}
In [142]:
try:
    df = pd.read_sql(reflection_query['sqlite_query'], con=engine)
    print(df)
except Exception as e:
    print(e)
Execution failed on sql 'SELECT AVG(CAST(SUBSTR(WT, INSTR(WT,'') + 1) AS INTEGER) FROM nba_roster WHERE WT IS NOT NULL': near "FROM": syntax error

Look at right answer¶

In [71]:
correct_sql = "select CAST(SUBSTR(WT, 1, INSTR(WT,' ')) as INTEGER) as percentile from nba_roster order by percentile limit 1 offset (select count(*) from nba_roster)/2;"
In [72]:
correct_sql
Out[72]:
"select CAST(SUBSTR(WT, 1, INSTR(WT,' ')) as INTEGER) as percentile from nba_roster order by percentile limit 1 offset (select count(*) from nba_roster)/2;"
In [73]:
df_corrected = pd.read_sql(correct_sql, con=engine)
print(df_corrected)
   percentile
0         240

Evaluate on a larger dataset¶

In [78]:
%mkdir -p data/training_data/archive
In [145]:
%%writefile data/training_data/archive/generated_queries_large.jsonl
{"question": "What is the average height of NBA players", "sql": "SELECT AVG(CAST(SUBSTRING(HT, 0, INSTR(HT,'')-1) AS INTEGER) + CAST(SUBSTRING(HT, INSTR(HT,'')+1) AS INTEGER)/12) as average_height FROM nba_roster WHERE HT!= 'NA';"}
{"question": "What is the average age of all players in the NBA", "sql": "SELECT AVG(AGE) FROM nba_roster"}
{"question": "What are the oldest players on each team with a roster size of 6 or more", "sql": "SELECT NAME FROM nba_roster WHERE AGE IN (SELECT MAX(AGE) FROM nba_roster WHERE TEAM IN (SELECT TEAM FROM nba_roster GROUP BY TEAM HAVING COUNT(*) > 5))"}
{"question": "What is the average height of the players on the Toronto Raptors", "sql": "SELECT AVG(CAST(SUBSTR(HT, 1, INSTR(HT,' ')-1) AS INTEGER)+ CAST(SUBSTR(HT, INSTR(HT,' ')+1) AS FLOAT)/12) as height FROM nba_roster WHERE team='Toronto Raptors';"}
{"question": "What is the highest-paid Toronto Raptors player who attended college", "sql": "SELECT name, salary FROM nba_roster WHERE team='Toronto Raptors' AND COLLEGE!='--' AND SALARY!='--' ORDER BY CAST(REPLACE(REPLACE(SALARY, '$', ''), ',','') AS INTEGER) DESC LIMIT 1"}
{"question": "What is the most common height among NBA players", "sql": "SELECT HT, COUNT(*) as count FROM nba_roster WHERE HT IS NOT NULL GROUP BY HT ORDER BY count DESC LIMIT 1"}
{"question": "What is the most represented college in the NBA", "sql": "SELECT COLLEGE, COUNT(*) as count FROM nba_roster WHERE COLLEGE IS NOT NULL GROUP BY COLLEGE ORDER BY count DESC LIMIT 1"}
{"question": "What is the average age of all players in the NBA", "sql": "SELECT AVG(AGE) AS average_age FROM nba_roster"}
{"question": "What is the average height of NBA players", "sql": "SELECT AVG(CAST(SUBSTR(HT, 1, INSTR(HT,' ')-1) AS INTEGER) + CAST(SUBSTR(HT, INSTR(HT,' ')+1) AS FLOAT)/12) AS average_height FROM nba_roster"}
{"question": "What is the average age of the players in the NBA", "sql": "SELECT AVG(AGE) FROM nba_roster WHERE AGE IS NOT NULL"}
{"question": "What is the position with the most players in the NBA", "sql": "SELECT POS, COUNT(*) as count FROM nba_roster WHERE SALARY!= '--' GROUP BY POS ORDER BY count DESC LIMIT 1"}
{"question": "What is the average height of players on each NBA team, excluding players with unknown heights", "sql": "SELECT TEAM, AVG(CAST(SUBSTRING(HT, 0, INSTR(HT,'')-1) AS INTEGER)) as avg_height FROM nba_roster WHERE HT!= 'NA' GROUP BY TEAM ORDER BY avg_height DESC"}
{"question": "What are the 5 most common heights among NBA players", "sql": "SELECT HT, COUNT(*) AS count FROM nba_roster GROUP BY HT ORDER BY count DESC LIMIT 5"}
{"question": "What are the top 5 colleges with the most players in the NBA", "sql": "SELECT COLLEGE, COUNT(*) AS count FROM nba_roster WHERE COLLEGE!= '--' GROUP BY COLLEGE ORDER BY count DESC LIMIT 5"}
{"question": "What is the average age of the players in the NBA", "sql": "SELECT AVG(AGE) FROM nba_roster WHERE AGE IS NOT NULL"}
{"question": "Which players in the NBA have attended the most colleges", "sql": "SELECT NAME, COLLEGE, COUNT(*) as num_colleges FROM nba_roster WHERE COLLEGE!= '--' GROUP BY NAME, COLLEGE ORDER BY num_colleges DESC;"}
{"question": "What is the average age of the players in the NBA", "sql": "SELECT AVG(AGE) FROM nba_roster"}
{"question": "Who are the top 5 highest-paid players in the NBA", "sql": "SELECT * FROM nba_roster WHERE SALARY!= '--' ORDER BY CAST(REPLACE(REPLACE(SALARY, '$', ''), ',','') AS INTEGER) DESC LIMIT 5"}
{"question": "What is the average height of players on each NBA team", "sql": "SELECT team, AVG(CAST(SUBSTRING(HT, 1, INSTR(HT,'')-1) AS INTEGER) + CAST(SUBSTRING(HT, INSTR(HT,'')+1) AS INTEGER) / 12.0) as avg_height FROM nba_roster WHERE HT!= 'NA' GROUP BY team"}
{"question": "Who are the top 3 highest-paid players in the NBA", "sql": "SELECT name, SUM(CAST(REPLACE(REPLACE(SALARY, '$', ''), ',','') AS INTEGER)) as total_salary FROM nba_roster WHERE SALARY!= '--' GROUP BY name ORDER BY total_salary DESC LIMIT 3"}
{"question": "Which team has the most players in the NBA", "sql": "SELECT team, COUNT(*) as num_players FROM nba_roster GROUP BY team ORDER BY num_players DESC LIMIT 1"}
{"question": "What is the total salary of all players in the NBA who are 6'8", "sql": "SELECT SUM(CAST(REPLACE(REPLACE(SALARY, '$', ''), ',','') AS INTEGER)) as total_salary FROM nba_roster WHERE CAST(SUBSTR(HT, 1, INSTR(HT,'')-1) AS INTEGER) = 68;"}
{"question": "What is the average age of players on each team in the NBA", "sql": "SELECT team, AVG(AGE) as avg_age FROM nba_roster WHERE SALARY!= '--' GROUP BY team"}
{"question": "How many players in the NBA have a non-null salary and college information, and play one of the five main positions", "sql": "SELECT COUNT(*) as num_players FROM nba_roster WHERE POS IN ('PG', 'SG', 'SF', 'PF', 'C') AND SALARY!= '--' AND COLLEGE!= '--'"}
{"question": "What is the most common position in the NBA", "sql": "SELECT POS, COUNT(*) as count FROM nba_roster GROUP BY POS ORDER BY count DESC LIMIT 1"}
{"question": "What is the average height of NBA players", "sql": "SELECT AVG(CAST(SUBSTR(HT, 1, INSTR(HT,' ')-1) AS INTEGER) + CAST(SUBSTR(HT, INSTR(HT,' ')+1) AS FLOAT)/12) as average_height FROM nba_roster;"}
{"question": "What is the average salary of NBA players who are at least 5 years old", "sql": "SELECT AVG(CAST(REPLACE(REPLACE(SALARY, '$', ''), ',','') AS INTEGER)) as average_salary FROM nba_roster WHERE AGE > 5"}
{"question": "What is the average age of all players in the NBA", "sql": "SELECT AVG(AGE) FROM nba_roster"}
{"question": "What is the most common age range among NBA players", "sql": "SELECT AGE, COUNT(*) AS count FROM nba_roster GROUP BY AGE ORDER BY count DESC LIMIT 1"}
{"question": "Which team has the most players in the NBA", "sql": "SELECT Team, COUNT(*) as num_players FROM nba_roster GROUP BY Team ORDER BY num_players DESC LIMIT 1"}
{"question": "What is the average salary of NBA players", "sql": "SELECT AVG(CAST(SUBSTR(SALARY, 1, INSTR(SALARY, '$')-1) AS INTEGER)) FROM nba_roster WHERE SALARY!= '--';"}
{"question": "How many players in the NBA are 68 inches tall", "sql": "SELECT COUNT(*) FROM nba_roster WHERE CAST(SUBSTR(HT, 1, INSTR(HT,'')-1) AS INTEGER) = 68;"}
{"question": "What is the average salary of Power Forwards in the NBA who are at least 25 years old", "sql": "SELECT AVG(CAST(SUBSTR(SALARY, 1, INSTR(SALARY, '$')-1) AS INTEGER)) AS average_salary FROM nba_roster WHERE AGE >= 25 AND POS = 'PF';"}
{"question": "What is the average age of 6-foot Power Forwards in the NBA", "sql": "SELECT AVG(AGE) FROM nba_roster WHERE CAST(SUBSTR(HT, 1, INSTR(HT,' ')-1) AS INTEGER) = 6 AND POS='PF';"}
{"question": "What is the heaviest Power Forward in the NBA", "sql": "SELECT NAME, AVG(CAST(SUBSTR(WT, 1, INSTR(WT,' ')) AS INTEGER)) AS avg_weight FROM nba_roster WHERE POS='PF' GROUP BY NAME ORDER BY avg_weight DESC LIMIT 1"}
{"question": "What is the number of players on each team in the NBA", "sql": "SELECT Team, COUNT(*) as num_players FROM nba_roster GROUP BY Team"}
{"question": "What is the average height of NBA players who are 25 years old or older", "sql": "SELECT AVG(CAST(SUBSTR(HT, 1, INSTR(HT,' ')-1) AS INTEGER)+ CAST(SUBSTR(HT, INSTR(HT,' ')+1) AS FLOAT)/12) as height FROM nba_roster WHERE age >= 25"}
{"question": "What are the top 3 teams with the highest average salaries in the NBA", "sql": "SELECT team, AVG(CAST(REPLACE(REPLACE(SALARY, '$', ''), ',','') AS INTEGER)) as avg_salary FROM nba_roster WHERE SALARY!= '--' GROUP BY team ORDER BY avg_salary DESC LIMIT 3"}
{"question": "What is the most common position in the NBA", "sql": "SELECT POS, COUNT(*) as count FROM nba_roster GROUP BY POS ORDER BY count DESC LIMIT 1"}
{"question": "What are the names of the players in the NBA who are exactly 6 feet 8 inches tall", "sql": "SELECT NAME, HT FROM nba_roster WHERE CAST(SUBSTRING(HT, 0, INSTR(HT,'')-1) AS INTEGER) = 68 ORDER BY HT ASC;"}
{"question": "What is the college with the most players in the NBA", "sql": "SELECT COLLEGE, COUNT(*) as count FROM nba_roster WHERE COLLEGE!= '--' GROUP BY COLLEGE ORDER BY count DESC LIMIT 1"}
{"question": "What is the average age of all players in the NBA", "sql": "SELECT AVG(AGE) FROM nba_roster"}
{"question": "What is the most represented college in the NBA", "sql": "SELECT COLLEGE, COUNT(*) AS frequency FROM nba_roster WHERE COLLEGE!= '--' GROUP BY COLLEGE ORDER BY frequency DESC LIMIT 1"}
{"question": "What is the average age of the players in the NBA", "sql": "SELECT AVG(AGE) as average_age FROM nba_roster WHERE AGE IS NOT NULL"}
{"question": "What is the average height of NBA players who have a recorded height", "sql": "SELECT AVG(CAST(SUBSTR(HT, 1, INSTR(HT,' ')-1) AS INTEGER) + CAST(SUBSTR(HT, INSTR(HT,' ')+1) AS FLOAT)/12) as average_height FROM nba_roster WHERE HT IS NOT NULL"}
{"question": "What is the average salary of NBA players who are 25 years or older", "sql": "SELECT AVG(CAST(SUBSTR(SALARY, 1, INSTR(SALARY, '$') - 1) as INTEGER)) FROM nba_roster WHERE CAST(AGE as INTEGER) >= 25"}
{"question": "What is the most represented college in the NBA", "sql": "SELECT COLLEGE, COUNT(*) as count FROM nba_roster WHERE COLLEGE!= '--' GROUP BY COLLEGE ORDER BY count DESC LIMIT 1"}
{"question": "What is the number of players on each team in the NBA", "sql": "SELECT Team, COUNT(*) as num_players FROM nba_roster GROUP BY Team"}
{"question": "What is the average salary for each position in the NBA, excluding players with unknown salaries", "sql": "SELECT POS, AVG(CAST(SUBSTR(SALARY, 1, INSTR(SALARY, '$') - 1) as INTEGER)) as avg_salary FROM nba_roster WHERE SALARY!= '--' GROUP BY POS"}
{"question": "What is the most common position in the NBA", "sql": "SELECT POS, COUNT(*) as count FROM nba_roster GROUP BY POS ORDER BY count DESC LIMIT 1"}
{"question": "What is the average age of players on each team in the NBA", "sql": "SELECT team, AVG(AGE) as avg_age FROM nba_roster WHERE SALARY!= '--' GROUP BY team"}
{"question": "What are the top 3 positions with the highest total salary expenditure in the NBA", "sql": "SELECT pos, name, SUM(CAST(REPLACE(REPLACE(SALARY, '$', ''), ',','') AS INTEGER)) as total_salary FROM nba_roster WHERE SALARY!= '--' GROUP BY pos ORDER BY total_salary DESC LIMIT 3"}
{"question": "Which colleges have the most players in the NBA", "sql": "SELECT COLLEGE, COUNT(*) AS num_players FROM nba_roster WHERE COLLEGE!= '--' GROUP BY COLLEGE ORDER BY num_players DESC;"}
{"question": "What is the average salary for each team in the NBA", "sql": "SELECT team, AVG(CAST(REPLACE(REPLACE(SALARY, '$', ''), ',','') AS INTEGER)) as avg_salary FROM nba_roster WHERE SALARY!= '--' GROUP BY team"}
{"question": "What is the age range of players on each team in the NBA", "sql": "SELECT team, MIN(AGE) as youngest_player, MAX(AGE) as oldest_player FROM nba_roster WHERE AGE IS NOT NULL GROUP BY team"}
{"question": "Which team has the most players who are 6'8", "sql": "SELECT team, COUNT(*) as num_players FROM nba_roster WHERE CAST(SUBSTR(HT, 1, INSTR(HT,'')-1) AS INTEGER) = 68 GROUP BY team ORDER BY num_players DESC LIMIT 1"}
{"question": "How many players in the NBA are over the age of 25", "sql": "SELECT COUNT(*) FROM nba_roster WHERE AGE > 25"}
{"question": "What is the average height of NBA players under the age of 25", "sql": "SELECT AVG(CAST(SUBSTR(HT, 1, INSTR(HT,' ')-1) AS INTEGER)+ CAST(SUBSTR(HT, INSTR(HT,' ')+1) AS FLOAT)/12) as average_height FROM nba_roster WHERE AGE <= 25"}
{"question": "What is the total salary of all players in the NBA who are more than 5 years older than the average age of all players", "sql": "SELECT SUM(CAST(REPLACE(REPLACE(SALARY, '$', ''), ',','') AS INTEGER)) as total_salary FROM nba_roster WHERE (AGE - (SELECT AVG(AGE) FROM nba_roster)) > 5"}
{"question": "What is the median weight in the NBA", "sql": "SELECT COLLEGE, COUNT(*) as count FROM nba_roster WHERE COLLEGE!= '--' GROUP BY COLLEGE ORDER BY count DESC LIMIT 1"}
{"question": "What are the top 5 teams with the oldest average age of players", "sql": "SELECT team, AVG(AGE) AS average_age FROM nba_roster GROUP BY team ORDER BY average_age DESC LIMIT 5"}
{"question": "What is the average height of NBA players", "sql": "SELECT AVG(CAST(SUBSTRING(HT, 0, INSTR(HT,'')-1) AS INTEGER)) AS average_height FROM nba_roster WHERE HT!= 'NA';"}
{"question": "What is the average salary of the Los Angeles Lakers players", "sql": "SELECT AVG(CAST(SALARY AS INTEGER) ) AS average_salary FROM nba_roster WHERE team='Los Angeles Lakers';"}
{"question": "What is the college that has produced the most players currently playing for the Boston Celtics", "sql": "SELECT COLLEGE, COUNT(*) AS count FROM nba_roster WHERE team='Boston Celtics' GROUP BY COLLEGE ORDER BY count DESC LIMIT 1"}
{"question": "What college has the most players in the NBA who are 30 years old or older", "sql": "SELECT COLLEGE, COUNT(*) AS count FROM nba_roster WHERE AGE >= 30 GROUP BY COLLEGE ORDER BY count DESC LIMIT 1"}
{"question": "How many players in the NBA are at least 5 years older than the youngest player in the league", "sql": "SELECT COUNT(*) as num_players FROM nba_roster WHERE AGE - (SELECT MIN(AGE) FROM nba_roster) > 5"}
{"question": "What are the 5 colleges that have produced the most players in the NBA", "sql": "SELECT COLLEGE, COUNT(*) as num_players FROM nba_roster WHERE COLLEGE!= '--' GROUP BY COLLEGE ORDER BY num_players DESC LIMIT 5"}
{"question": "What are the most common positions in the NBA", "sql": "SELECT POS, COUNT(*) as count FROM nba_roster WHERE POS!= '--' GROUP BY POS ORDER BY count DESC"}
{"question": "What is the average age of all players in the NBA", "sql": "SELECT AVG(AGE) as average_age FROM nba_roster WHERE AGE IS NOT NULL"}
{"question": "What are the teams with the highest average salaries in the NBA", "sql": "SELECT team, AVG(CAST(REPLACE(REPLACE(SALARY, '$', ''), ',','') AS INTEGER)) as avg_salary FROM nba_roster WHERE SALARY!= '--' GROUP BY team ORDER BY avg_salary DESC"}
{"question": "What is the average height of NBA players", "sql": "SELECT AVG(CAST(SUBSTR(HT, 1, INSTR(HT,' ')-1) AS INTEGER) + CAST(SUBSTR(HT, INSTR(HT,' ')+1) AS FLOAT)/12) as average_height FROM nba_roster"}
{"question": "What is the average salary of all NBA players", "sql": "SELECT AVG(CAST(REPLACE(REPLACE(SALARY, '$', ''), ',','') AS INTEGER)) as average_salary FROM nba_roster"}
{"question": "What is the average age of the players on the Toronto Raptors", "sql": "SELECT AVG(AGE) FROM nba_roster WHERE team='Toronto Raptors';"}
{"question": "Which three teams have the most players from a single college", "sql": "SELECT team, COLLEGE, COUNT(*) AS num_players FROM nba_roster GROUP BY team, COLLEGE ORDER BY num_players DESC LIMIT 3"}
{"question": "What is the average salary of NBA players 25 years or older", "sql": "SELECT AVG(CAST(SUBSTR(SALARY, 1, INSTR(SALARY, '$')-1) AS INTEGER)) FROM nba_roster WHERE AGE >= 25"}
{"question": "What is the total salary of all NBA players", "sql": "SELECT SUM(CAST(SUBSTR(SALARY, 1, INSTR(SALARY, '$')-1) AS INTEGER)*1000000) FROM nba_roster"}
{"question": "What are the most common positions in the NBA", "sql": "SELECT POS, COUNT(*) AS num_players FROM nba_roster GROUP BY POS;"}
{"question": "What is the average salary for each age group in the NBA", "sql": "SELECT AVG(CAST(REPLACE(REPLACE(SALARY, '$', ''), ',','') AS INTEGER)) as average_salary, AGE as age_group FROM nba_roster WHERE SALARY!= '--' GROUP BY AGE ORDER BY age_group"}
{"question": "What are the top 5 colleges that have produced the most NBA players", "sql": "SELECT COLLEGE, COUNT(*) as count FROM nba_roster WHERE COLLEGE!= '--' GROUP BY COLLEGE ORDER BY count DESC LIMIT 5"}
{"question": "What is the most common position for players under the age of 25 in the NBA", "sql": "SELECT POS, COUNT(*) as count FROM nba_roster WHERE AGE <= 25 GROUP BY POS ORDER BY count DESC LIMIT 1"}
{"question": "How many players in the NBA are 5 years or younger than the oldest player in the league", "sql": "SELECT COUNT(*) FROM nba_roster WHERE AGE + 5 <= (SELECT MAX(AGE) FROM nba_roster);"}
{"question": "What are the top 5 colleges that have produced the most NBA players", "sql": "SELECT COLLEGE, COUNT(*) as count FROM nba_roster WHERE COLLEGE!= '--' GROUP BY COLLEGE ORDER BY count DESC LIMIT 5"}
{"question": "What are the most common positions in the NBA", "sql": "SELECT POS, COUNT(*) as count FROM nba_roster GROUP BY POS ORDER BY count DESC"}
{"question": "What is the average age of all players in the NBA", "sql": "SELECT AVG(AGE) FROM nba_roster"}
{"question": "What are the most common heights in the NBA", "sql": "SELECT HT, COUNT(*) AS frequency FROM nba_roster GROUP BY HT ORDER BY frequency DESC LIMIT 5"}
{"question": "What are the most common positions in the NBA", "sql": "SELECT POS, COUNT(*) as count FROM nba_roster GROUP BY POS ORDER BY count DESC"}
{"question": "What is the average salary for each team in the NBA, excluding teams with unknown salaries", "sql": "SELECT TEAM, AVG(CAST(REPLACE(REPLACE(SALARY, '$', ''), ',','') AS INTEGER)) as average_salary FROM nba_roster WHERE SALARY!= '--' GROUP BY TEAM ORDER BY average_salary DESC"}
{"question": "What is the college that has produced the most NBA players", "sql": "SELECT COLLEGE, COUNT(*) as count FROM nba_roster WHERE COLLEGE!= '--' GROUP BY COLLEGE ORDER BY count DESC LIMIT 1"}
{"question": "Who is the highest paid player in the NBA", "sql": "SELECT name, salary FROM nba_roster WHERE salary!= '--' ORDER BY CAST(REPLACE(REPLACE(salary, '$', ''), ',', '') AS INTEGER) DESC LIMIT 1"}
{"question": "What is the average age of players who are 6'8", "sql": "SELECT AVG(AGE) FROM nba_roster WHERE CAST(SUBSTR(HT, 1, INSTR(HT,' ')-1) AS INTEGER) = 68"}
{"question": "What is the average age of the players in the NBA who are more than 5 years older than the average age of all players", "sql": "SELECT AVG(AGE) FROM nba_roster WHERE AGE + (SELECT AVG(AGE) FROM nba_roster) > 5*12"}
{"question": "What is the average age of the players in the NBA who are older than 5 years old", "sql": "SELECT AVG(AGE) FROM nba_roster WHERE AGE > 5*12"}
{"question": "What are the top colleges that produce the most NBA players", "sql": "SELECT COLLEGE, COUNT(*) as num_players FROM nba_roster WHERE COLLEGE!= '--' GROUP BY COLLEGE ORDER BY num_players DESC;"}
{"question": "How many players in the NBA are 6'8", "sql": "SELECT COUNT(*) FROM nba_roster WHERE CAST(SUBSTR(HT, 1, INSTR(HT,'')-1) AS INTEGER) = 68;"}
{"question": "What is the average salary for each team in the NBA", "sql": "SELECT Team, AVG(CAST(REPLACE(REPLACE(SALARY, '$', ''), ',','') AS INTEGER)) as average_salary FROM nba_roster GROUP BY Team"}
{"question": "What are the top colleges represented in the NBA", "sql": "SELECT COLLEGE, COUNT(*) as num_players FROM nba_roster WHERE COLLEGE!= '--' GROUP BY COLLEGE ORDER BY num_players DESC;"}
{"question": "What is the most represented college in the NBA", "sql": "SELECT COLLEGE, COUNT(*) as count FROM nba_roster WHERE COLLEGE!= '--' GROUP BY COLLEGE ORDER BY count DESC LIMIT 1"}
{"question": "What are the 5 teams with the highest average salary in the NBA", "sql": "SELECT team, AVG(CAST(REPLACE(REPLACE(SALARY, '$', ''), ',','') AS INTEGER)) AS average_salary FROM nba_roster WHERE SALARY!= '--' GROUP BY team ORDER BY average_salary DESC"}
{"question": "What is the average age of players in the NBA", "sql": "SELECT AVG(AGE) FROM nba_roster"}
{"question": "What is the most common height in the NBA", "sql": "SELECT SUBSTR(HT, 1, INSTR(HT,'')-1) as height, COUNT(*) as count FROM nba_roster GROUP BY SUBSTR(HT, 1, INSTR(HT,'')-1) ORDER BY count DESC LIMIT 1"}
{"question": "What is the position with the most players in the NBA", "sql": "SELECT POS, COUNT(*) as count FROM nba_roster GROUP BY POS ORDER BY count DESC LIMIT 1"}
{"question": "What is the 75th percentile salary in the NBA", "sql": "SELECT HT, AVG(WT) as avg_weight FROM nba_roster WHERE HT IS NOT NULL AND WT IS NOT NULL GROUP BY HT ORDER BY avg_weight DESC LIMIT 1"}
{"question": "Which college has produced the most NBA players", "sql": "SELECT COLLEGE, COUNT(*) as count FROM nba_roster WHERE COLLEGE!= '--' GROUP BY COLLEGE ORDER BY count DESC LIMIT 1"}
{"question": "What is the average salary of NBA players who are older than 25 years old", "sql": "SELECT AVG(CAST(REPLACE(REPLACE(SALARY, '$', ''), ',','') AS INTEGER)) as average_salary FROM nba_roster WHERE AGE > 25"}
{"question": "What is the average age of the players on the Toronto Raptors", "sql": "SELECT AVG(AGE) FROM nba_roster WHERE TEAM = 'Toronto Raptors';"}
{"question": "What is the average height of the players on the Los Angeles Lakers", "sql": "SELECT AVG(CAST(SUBSTR(HT, 1, INSTR(HT,'')-1) AS INTEGER) + CAST(SUBSTR(HT, INSTR(HT,'')+1) AS FLOAT)/12) AS height FROM nba_roster WHERE TEAM = 'Los Angeles Lakers';"}
{"question": "What is the position with the most players in the NBA", "sql": "SELECT POS, COUNT(*) as count FROM nba_roster GROUP BY POS ORDER BY count DESC LIMIT 1"}
{"question": "What is the average age of all players in the NBA who are older than 5 years old", "sql": "SELECT AVG(AGE) as average_age FROM nba_roster WHERE AGE > 5"}
{"question": "How many players on each team have a height of 6'8", "sql": "SELECT team, COUNT(*) as num_players FROM nba_roster WHERE CAST(SUBSTRING(HT, 1, INSTR(HT,'')-1) AS INTEGER) = 68 GROUP BY team"}
{"question": "What is the most common position in the NBA", "sql": "SELECT POS, COUNT(*) as count FROM nba_roster GROUP BY POS ORDER BY count DESC LIMIT 1"}
{"question": "What is the average height of NBA players", "sql": "SELECT AVG(CAST(SUBSTR(HT, 1, INSTR(HT,' ')-1) AS INTEGER) + CAST(SUBSTR(HT, INSTR(HT,' ')+1) AS FLOAT)/12) as average_height FROM nba_roster;"}
{"question": "What is the average salary of NBA players who are at least 5 years old", "sql": "SELECT AVG(CAST(REPLACE(REPLACE(SALARY, '$', ''), ',','') AS INTEGER)) as average_salary FROM nba_roster WHERE AGE > 5"}
{"question": "What is the average age of all players in the NBA", "sql": "SELECT AVG(AGE) FROM nba_roster"}
{"question": "What is the most common age range among NBA players", "sql": "SELECT AGE, COUNT(*) AS count FROM nba_roster GROUP BY AGE ORDER BY count DESC LIMIT 1"}
{"question": "What is the median weight in the NBA", "sql": "SELECT COLLEGE, COUNT(*) as count FROM nba_roster WHERE COLLEGE!= '--' GROUP BY COLLEGE ORDER BY count DESC LIMIT 1"}
{"question": "How many players in the NBA are at least 5 years older than the youngest player in the league", "sql": "SELECT COUNT(*) as num_players FROM nba_roster WHERE AGE - (SELECT MIN(AGE) FROM nba_roster) > 5"}
{"question": "What are the 5 colleges that have produced the most players in the NBA", "sql": "SELECT COLLEGE, COUNT(*) as num_players FROM nba_roster WHERE COLLEGE!= '--' GROUP BY COLLEGE ORDER BY num_players DESC LIMIT 5"}
{"question": "What are the most common positions in the NBA", "sql": "SELECT POS, COUNT(*) as count FROM nba_roster WHERE POS!= '--' GROUP BY POS ORDER BY count DESC"}
{"question": "What is the average age of all players in the NBA", "sql": "SELECT AVG(AGE) as average_age FROM nba_roster WHERE AGE IS NOT NULL"}
Writing data/training_data/archive/generated_queries_large_filtered.jsonl
In [80]:
import logging
import os
from datetime import datetime
from pprint import pprint
from typing import AsyncIterator, Iterator, Union
import sqlite3
from tqdm import tqdm

import pandas as pd
import jsonlines
from lamini.generation.base_prompt_object import PromptObject
from lamini.generation.generation_node import GenerationNode
from lamini.generation.base_prompt_object import PromptObject
from lamini.generation.generation_pipeline import GenerationPipeline

logger = logging.getLogger(__name__)
engine = sqlite3.connect("nba_roster.db")
setup_logging()

class Args:
    def __init__(self,
                 max_examples=100,
                 sql_model_name="meta-llama/Meta-Llama-3-8B-Instruct",
                 gold_file_name="gold-test-set.jsonl",
                 training_file_name="archive/generated_queries.jsonl",
                 num_to_generate=10):
        self.sql_model_name = sql_model_name
        self.max_examples = max_examples
        self.gold_file_name = gold_file_name
        self.training_file_name = training_file_name
        self.num_to_generate = num_to_generate
In [81]:
def load_gold_dataset(args):
    path = f"data/{args.gold_file_name}"

    with jsonlines.open(path) as reader:
        for index, obj in enumerate(reversed(list(reader))):
            if index >= args.max_examples:
                break
            yield PromptObject(prompt="", data=obj)
In [82]:
path = "data/gold-test-set.jsonl"

with jsonlines.open(path) as reader:
    data = [obj for obj in reader]
In [83]:
datapoint = data[4]
In [84]:
datapoint
Out[84]:
{'question': 'What is the average weight in the NBA?',
 'answer': '214.98',
 'sql': "SELECT AVG(CAST(SUBSTR(WT, 1, INSTR(WT,' ')) as INTEGER)) FROM nba_roster;"}
In [85]:
datapoint = data[7]
In [86]:
datapoint
Out[86]:
{'question': 'Can you tell me how many players are in the NBA?',
 'answer': '600',
 'sql': 'select count(*) from nba_roster;'}
In [87]:
system = "You are an NBA analyst with 15 years of experience writing complex SQL queries.\n"
system += "Consider the nba_roster table with the following schema:\n"
system += get_schema() + "\n"
system += (
    "Write a sqlite SQL query that would help you answer the following question:\n"
)
user = datapoint["question"]
prompt = make_llama_3_prompt(user, system)
generated_sql = llm.generate(prompt, output_type={"sqlite_query": "str"}, max_new_tokens=200)
print(generated_sql)
{'sqlite_query': "SELECT COUNT(*) FROM nba_roster WHERE Jersey!= 'NA';"}
In [88]:
df = pd.read_sql(generated_sql['sqlite_query'], con=engine)
print(df)
   COUNT(*)
0        30

Determine if generated response is valid SQL¶

In [89]:
query_succeeded = False
try:
    df = pd.read_sql(generated_sql['sqlite_query'], con=engine)
    query_succeeded = True
    print("Query is valid")
except Exception as e:
    print(
        f"Failed to run SQL query: {generated_sql}"
    )
Query is valid
In [90]:
reference_sql = datapoint["sql"]
ref_df = pd.read_sql(reference_sql, con=engine)
print(ref_df)
   count(*)
0        30
In [91]:
# Here's how to wrap that all up in a runnable class

class QueryStage(GenerationNode):
    def __init__(self, model_name):
        super().__init__(
            model_name=model_name,
            max_new_tokens=200,
        )

    def generate(
        self,
        prompt: Union[Iterator[PromptObject], AsyncIterator[PromptObject]],
        *args,
        **kwargs,
    ):
        results = super().generate(
            prompt,
            output_type={"sqlite_query": "str"},
            *args,
            **kwargs,
        )
        return results


    def postprocess(self, obj: PromptObject):
        # Run both the generated and reference (Gold Dataset) SQL queries
        # Assessing whether the SQL queries succeeded in hitting the database (not correctness yet!)

        query_succeeded = False

        try:
            logger.error(f"Running SQL query '{obj.response['sqlite_query']}'")
            obj.data["generated_query"] = obj.response["sqlite_query"]
            df = pd.read_sql(obj.response["sqlite_query"], con=engine)
            obj.data['df'] = df
            logger.error(f"Got data: {df}")
            query_succeeded = True

        except Exception as e:
            logger.error(
                f"Failed to run SQL query: {obj.response['sqlite_query']}"
            )

        logger.info(f"Running reference SQL query '{obj.data['sql']}'")
        df = pd.read_sql(obj.data["sql"], con=engine)
        logger.info(f"Got data: {df}")
        obj.data['reference_df'] = df

        logger.info(f"For question: {obj.data['question']}")
        logger.info(f"For query: {obj.response['sqlite_query']}")

        obj.data["query_succeeded"] = query_succeeded

    def preprocess(self, obj: PromptObject):
        new_prompt = make_llama_3_prompt(**self.make_prompt(obj.data))
        obj.prompt = new_prompt

    def make_prompt(self, data: dict):
        system = "You are an NBA analyst with 15 years of experience writing complex SQL queries.\n"
        system += "Consider the nba_roster table with the following schema:\n"
        system += get_schema() + "\n"
        system += (
            "Write a sqlite SQL query that would help you answer the following question:\n"
        )
        user = data["question"]
        return {
            "user": user,
            "system": system,
        }

Compare strings¶

In [92]:
str(df).lower() == str(ref_df).lower()
Out[92]:
True

Use an LLM to compare¶

In [93]:
system_prompt = "Compare the following two dataframes. They are similar if they are almost identical, or if they convey the same information about the nba_roster dataset"
system_prompt += "Respond with valid JSON {'explanation' : str, 'similar' : bool}"
system_prompt
Out[93]:
"Compare the following two dataframes. They are similar if they are almost identical, or if they convey the same information about the nba_roster datasetRespond with valid JSON {'explanation' : str, 'similar' : bool}"
In [94]:
user_prompt = (
    f"========== Dataframe 1 =========\n{str(df).lower()}\n\n"
)
user_prompt += (
    f"========== Dataframe 2 =========\n{str(ref_df).lower()}\n\n"
)
user_prompt += f"Can you tell me if these dataframes are similar?"
In [95]:
llm_similarity_prompt = make_llama_3_prompt(user_prompt, system_prompt)
In [96]:
llm_similarity = llm.generate(llm_similarity_prompt, output_type={"explanation": "str", "similar": "bool"}, max_new_tokens=200)
In [97]:
llm_similarity
Out[97]:
{'explanation': 'Both dataframes have the same count of rows, which is 30. This suggests that they are identical in terms of the number of rows they contain',
 'similar': True}
In [98]:
str(df).lower() == str(ref_df).lower() or llm_similarity["similar"]
Out[98]:
True
In [99]:
# How to wrap it up in a class

class ScoreStage(GenerationNode):
    def __init__(self):
        super().__init__(
            model_name="meta-llama/Meta-Llama-3-8B-Instruct",
            max_new_tokens=150,
        )

    def generate(
        self,
        prompt: Union[Iterator[PromptObject], AsyncIterator[PromptObject]],
        *args,
        **kwargs,
    ):
        logger.debug("ScoreStage Generate")
        results = super().generate(
            prompt,
            output_type={"explanation": "str", "similar": ["true", "false"]},
            *args,
            **kwargs,
        )
        logger.debug(f"ScoreStage Results {results}")

        return results

    def preprocess(self, obj: PromptObject):
        obj.prompt = make_llama_3_prompt(**self.make_prompt(obj))
        logger.info(f"Scoring Stage Prompt:\n{obj.prompt}")

    def postprocess(self, obj: PromptObject):
        logger.info(f"Postprocess")
        obj.data['is_matching'] = self.is_matching(obj.data, obj.response)
        obj.data['explanation'] = obj.response["explanation"]
        obj.data['similar'] = obj.response["similar"] == "true"


    def is_matching(self, data, response):
        return (str(data.get('df',"None")).lower() == str(data['reference_df']).lower()
                or response['similar'] == "true")

    def make_prompt(self, obj: PromptObject):
        # Your evaluation model compares SQL output from the generated and reference SQL queries, using another LLM in the pipeline
        system_prompt = "Compare the following two dataframes. They are similar if they are almost identical, or if they convey the same information about the nba_roster dataset"
        system_prompt += "Respond with valid JSON {'explanation' : str, 'similar' : bool}"
        user_prompt = (
            f"========== Dataframe 1 =========\n{str(obj.data.get('df','None')).lower()}\n\n"
        )
        user_prompt += (
            f"========== Dataframe 2 =========\n{str(obj.data['reference_df']).lower()}\n\n"
        )
        user_prompt += f"Can you tell me if these dataframes are similar?"
        return {
            "system": system_prompt,
            "user": user_prompt
        }
In [100]:
class EvaluationPipeline(GenerationPipeline):
    def __init__(self, args):
        super().__init__()
        self.query_stage = QueryStage(args.sql_model_name)
        self.score_stage = ScoreStage()

    def forward(self, x):
        x = self.query_stage(x)
        x = self.score_stage(x)
        return x
In [101]:
async def run_eval(dataset, args):
    results = await run_evaluation_pipeline(dataset, args)
    print("Total results:", len(results))
    return results

async def run_evaluation_pipeline(dataset, args):
    results = EvaluationPipeline(args).call(dataset)
    result_list = []

    pbar = tqdm(desc="Saving results", unit=" results")
    async for result in results:
        result_list.append(result)
        pbar.update()
    return result_list
In [102]:
def save_eval_results(results, args):
    base_path = "./data/results"
    now = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
    experiment_name = f"nba_sql_pipeline_{now}"
    experiment_dir = os.path.join(base_path, experiment_name)
    os.makedirs(os.path.join(base_path, experiment_name))

    # Write args to file
    args_file_name = f"{experiment_dir}/args.txt"
    with open(args_file_name, "w") as writer:
        pprint(args.__dict__, writer)


    def is_correct(r):
        if (
            (r.data["query_succeeded"] and r.data['is_matching']) or
            r.data["generated_query"] == r.data['sql']
        ):
            return True
        return False

    # Write sql results and errors to file
    results_file_name = f"{experiment_dir}/sql_results.jsonl"
    with jsonlines.open(results_file_name, "w") as writer:
        for result in results:
            if not is_correct(result):
                continue
            writer.write(
                {
                    "question": result.data['question'],
                    "query": result.data["generated_query"],
                    "query_succeeded": result.data["query_succeeded"],
                    "reference_sql": result.data['sql'],
                    "df": str(result.data.get('df', 'None')),
                    "reference_df": str(result.data['reference_df']),
                    'is_matching': result.data['is_matching'],
                    'similar': result.data['similar'],
                }
            )

    results_file_name = f"{experiment_dir}/sql_errors.jsonl"
    with jsonlines.open(results_file_name, "w") as writer:
        for result in results:
            if is_correct(result):
                continue
            writer.write(
                {
                    "question": result.data['question'],
                    "query": result.data["generated_query"],
                    "query_succeeded": result.data["query_succeeded"],
                    "df": str(result.data.get('df', 'None')),
                    "reference_df": str(result.data['reference_df']),
                    'is_matching': result.data['is_matching'],
                    'similar': result.data['similar'],
                }
            )

    # Write statistics to file
    average_sql_succeeded = sum(
        [result.data["query_succeeded"] for result in results]
    ) / len(results)
    average_correct = sum(
        [result.data["query_succeeded"] and result.data['is_matching'] for result in results]
    ) / len(results)

    file_name = f"{experiment_dir}/summary.txt"
    with open(file_name, "w") as writer:
        print(f"Total size of eval dataset: {len(results)}", file=writer)
        print(f"Total size of eval dataset: {len(results)}")
        print(f"Percent Valid SQL Syntax: {average_sql_succeeded*100}", file=writer)
        print(f"Percent Valid SQL Syntax: {average_sql_succeeded*100}")
        print(f"Percent Correct SQL Query: {average_correct*100}", file=writer)
        print(f"Percent Correct SQL Query: {average_correct*100}")
In [103]:
args = Args()
dataset = load_gold_dataset(args)
results = await run_eval(dataset, args)
save_eval_results(results, args)
Saving results: 0 results [00:00, ? results/s]2024-08-28 20:59:43,044 [ERROR] Running SQL query 'SELECT AVG(CAST(SUBSTR(SALARY, 2) AS INTEGER) AS average_salary FROM nba_roster WHERE POS = 'PF' AND SALARY!= '--';'
2024-08-28 20:59:43,047 [ERROR] Failed to run SQL query: SELECT AVG(CAST(SUBSTR(SALARY, 2) AS INTEGER) AS average_salary FROM nba_roster WHERE POS = 'PF' AND SALARY!= '--';
2024-08-28 20:59:43,058 [ERROR] Running SQL query 'SELECT Team, AVG(AGE) AS Average_Age FROM nba_roster GROUP BY Team ORDER BY Average_Age DESC LIMIT 5'
2024-08-28 20:59:43,063 [ERROR] Got data:                     Team  Average_Age
0     Los Angeles Lakers         39.0
1  Golden State Warriors         36.0
2          Brooklyn Nets         35.0
3     Philadelphia 76ers         30.0
4        Milwaukee Bucks         29.0
2024-08-28 20:59:43,070 [ERROR] Running SQL query 'SELECT AVG(AGE) FROM nba_roster WHERE Team = 'Miami Heat';'
2024-08-28 20:59:43,075 [ERROR] Got data:   AVG(AGE)
0     None
2024-08-28 20:59:43,082 [ERROR] Running SQL query 'SELECT AVG(AGE) FROM nba_roster'
2024-08-28 20:59:43,087 [ERROR] Got data:    AVG(AGE)
0      30.0
2024-08-28 20:59:43,093 [ERROR] Running SQL query 'SELECT AVG(AGE) FROM nba_roster WHERE Team = 'Portland Trail Blazers';'
2024-08-28 20:59:43,098 [ERROR] Got data:   AVG(AGE)
0     None
2024-08-28 20:59:48,652 [ERROR] Running SQL query 'SELECT SALARY FROM nba_roster WHERE NAME = 'Marcus Smart';'
2024-08-28 20:59:48,656 [ERROR] Got data: Empty DataFrame
Columns: [SALARY]
Index: []
2024-08-28 20:59:48,662 [ERROR] Running SQL query 'SELECT NAME, SALARY FROM nba_roster WHERE TEAM = 'Dallas Mavericks' AND POS = 'C' AND SALARY!= '--' ORDER BY CAST(SALARY AS REAL) DESC LIMIT 1'
2024-08-28 20:59:48,667 [ERROR] Got data: Empty DataFrame
Columns: [NAME, SALARY]
Index: []
2024-08-28 20:59:48,672 [ERROR] Running SQL query 'SELECT NAME, SALARY FROM nba_roster WHERE Team = 'Cleveland Cavaliers' AND SALARY!= '--' ORDER BY SALARY DESC LIMIT 1'
2024-08-28 20:59:48,677 [ERROR] Got data: Empty DataFrame
Columns: [NAME, SALARY]
Index: []
2024-08-28 20:59:48,683 [ERROR] Running SQL query 'SELECT NAME, SALARY FROM nba_roster WHERE Team = 'Memphis Grizzlies' AND SALARY!= '--' ORDER BY SALARY DESC LIMIT 1'
2024-08-28 20:59:48,689 [ERROR] Got data: Empty DataFrame
Columns: [NAME, SALARY]
Index: []
2024-08-28 20:59:48,695 [ERROR] Running SQL query 'SELECT NAME FROM nba_roster WHERE Team = 'Brooklyn Nets' AND AGE = (SELECT MAX(AGE) FROM nba_roster WHERE Team = 'Brooklyn Nets')'
2024-08-28 20:59:48,700 [ERROR] Got data:            NAME
0  Kevin Durant
1  Kevin Durant
2  Kevin Durant
2024-08-28 21:00:00,634 [ERROR] Running SQL query 'SELECT * FROM nba_roster WHERE NAME = 'Jalen Johnson' AND AGE = 23'
2024-08-28 21:00:00,641 [ERROR] Got data: Empty DataFrame
Columns: [Team, NAME, Jersey, POS, AGE, HT, WT, COLLEGE, SALARY]
Index: []
2024-08-28 21:00:00,647 [ERROR] Running SQL query 'SELECT POS, MAX(CAST(SUBSTR(SALARY, 2) AS INTEGER) AS Salary FROM nba_roster WHERE SALARY!= '--' GROUP BY POS'
2024-08-28 21:00:00,650 [ERROR] Failed to run SQL query: SELECT POS, MAX(CAST(SUBSTR(SALARY, 2) AS INTEGER) AS Salary FROM nba_roster WHERE SALARY!= '--' GROUP BY POS
2024-08-28 21:00:00,657 [ERROR] Running SQL query 'SELECT COUNT(*) FROM nba_roster WHERE Jersey!= 'NA';'
2024-08-28 21:00:00,661 [ERROR] Got data:    COUNT(*)
0        30
2024-08-28 21:00:00,667 [ERROR] Running SQL query 'SELECT AVG(CAST(SUBSTRING(HT, 0, INSTR(HT,'')) AS INTEGER) FROM nba_roster WHERE HT IS NOT NULL'
2024-08-28 21:00:00,669 [ERROR] Failed to run SQL query: SELECT AVG(CAST(SUBSTRING(HT, 0, INSTR(HT,'')) AS INTEGER) FROM nba_roster WHERE HT IS NOT NULL
2024-08-28 21:00:00,676 [ERROR] Running SQL query 'SELECT AVG(CAST(SUBSTR(HT, 0, INSTR(HT,'')-1) AS INTEGER) FROM nba_roster WHERE HT IS NOT NULL'
2024-08-28 21:00:00,677 [ERROR] Failed to run SQL query: SELECT AVG(CAST(SUBSTR(HT, 0, INSTR(HT,'')-1) AS INTEGER) FROM nba_roster WHERE HT IS NOT NULL
2024-08-28 21:00:14,162 [ERROR] Running SQL query 'SELECT AVG(CAST(SUBSTR(WT, INSTR(WT,'') + 1) AS INTEGER) AS weight FROM nba_roster WHERE WT IS NOT NULL'
2024-08-28 21:00:14,165 [ERROR] Failed to run SQL query: SELECT AVG(CAST(SUBSTR(WT, INSTR(WT,'') + 1) AS INTEGER) AS weight FROM nba_roster WHERE WT IS NOT NULL
2024-08-28 21:00:14,171 [ERROR] Running SQL query 'SELECT AVG(CAST(SUBSTR(WT, INSTR(WT,'') + 1) AS INTEGER) FROM nba_roster WHERE WT!= 'NA';'
2024-08-28 21:00:14,173 [ERROR] Failed to run SQL query: SELECT AVG(CAST(SUBSTR(WT, INSTR(WT,'') + 1) AS INTEGER) FROM nba_roster WHERE WT!= 'NA';
2024-08-28 21:00:14,178 [ERROR] Running SQL query 'SELECT PERCENTILE(SALARY, 0.25) FROM nba_roster WHERE SALARY!= '--';'
2024-08-28 21:00:14,180 [ERROR] Failed to run SQL query: SELECT PERCENTILE(SALARY, 0.25) FROM nba_roster WHERE SALARY!= '--';
2024-08-28 21:00:14,185 [ERROR] Running SQL query 'SELECT PERCENTILE(salary, 0.75) FROM (SELECT CAST(SUBSTR(salary, 2) AS INTEGER) AS salary FROM nba_roster WHERE salary!= '--') AS subquery'
2024-08-28 21:00:14,186 [ERROR] Failed to run SQL query: SELECT PERCENTILE(salary, 0.75) FROM (SELECT CAST(SUBSTR(salary, 2) AS INTEGER) AS salary FROM nba_roster WHERE salary!= '--') AS subquery
2024-08-28 21:00:14,192 [ERROR] Running SQL query 'SELECT PERCENTILE(salary, 0.99) FROM nba_roster WHERE salary IS NOT NULL'
2024-08-28 21:00:14,194 [ERROR] Failed to run SQL query: SELECT PERCENTILE(salary, 0.99) FROM nba_roster WHERE salary IS NOT NULL
Saving results: 15 results [02:53, 11.59s/ results]
Total results: 15
Total size of eval dataset: 15
Percent Valid SQL Syntax: 73.33333333333333
Percent Correct SQL Query: 60.0

Memory Fine-tuning¶

We can improve accuracy further with memory fine-tuning.

Why fine-tune?¶

  • Fit more data that what fits in a prompt by storing info in weights of the model
  • Learn from data, rather than just get access to data
  • Deeper control of LLM to achieve what you want it to
  • No accuracy ceiling

Instruction tuning¶

  • Take a pretrained model (like Llama 3) and teach it to respond in a question-answer conversational style (like Llama 3 Instruct).
  • The type of fine-tuning still yields smaller loss than an untrained model with a normal distribution of probabilities in its output (i.e. bell curve)

Memory tuning¶

  • Take a instruction tuned model and further fine-tune it by reducing the loss to 0.0 (i.e. the probability is 100% instead of a normal distribution)

Why not just use prompt engineering and RAG?¶

Myth 1: Prompting and RAG are just a good¶

  • Few-shot prompting examples will get the model to do a little better (as we saw above), but not perfect. The model is still fundamentally outputing a probability distribution and autocompleting the internet.
  • Fact Recall (RAG) - RAG can help shift the probabilities but you're still sampling from a distrubtion of similar, but potentially wrong facts.

Myth 2: Fine-tuning is too expensive¶

  • Reality: Fine-tuning is cheaper than running huge prompts in RAG.
  • Parameter-efficient fine-tuning (PEFT) has reduce cost of fine-tuning by 10,000x while maintaining equivalent accuracy.
  • Mixture of Memory Experts (MoME) turns any LLM into a million-way mixture-of-experts adapters, reducing time by 240x.

Caveat: You must implement fine-tuning correctly or else it can be extremely expensive.

Myth 3: Fine-tuning is too hard¶

  • Rolling your own fine-tuning can be hard
  • Not efficient (takes more compute, can't parallelize efficiently across GPUs, crashing, can't continously fine-tune and inference together in production, LLM doesn't improve since its hard to tune per use case, per model, per dataset)
  • Not easy to use, can't scale (GPU and memory issues)
  • Integrating fine-tuning with inteference is full of bugs (e.g. transfering model weights across different formats isn't bug free).
  • Using wrong tool for the job (e.g. instruction fine-tuning doesn't necessarily solve hallucinations since its still optimizing for average error over all the samples, i.e. it's not brining the loss to 0)

Myth 4: I don't have enough data to fine-tune¶

  • You probably have more data than you think
  • Data quality > data quantity

PEFT and MoME¶

In [104]:
import lamini
In [105]:
import logging
import random
from typing import AsyncIterator, Iterator, Union
import sqlite3
import copy
from tqdm import tqdm

import pandas as pd
import jsonlines
from lamini.generation.base_prompt_object import PromptObject
from lamini.generation.generation_node import GenerationNode
from lamini.generation.base_prompt_object import PromptObject
from lamini.generation.generation_pipeline import GenerationPipeline

logger = logging.getLogger(__name__)
engine = sqlite3.connect("./nba_roster.db")
setup_logging()

class Args:
    def __init__(self,
                 max_examples=100,
                 sql_model_name="meta-llama/Meta-Llama-3-8B-Instruct",
                 gold_file_name="gold-test-set.jsonl",
                 training_file_name="generated_queries.jsonl",
                 num_to_generate=10):
        self.sql_model_name = sql_model_name
        self.max_examples = max_examples
        self.gold_file_name = gold_file_name
        self.training_file_name = training_file_name
        self.num_to_generate = num_to_generate

Generate Synthetic Data for Fine-tuning¶

From schema and example, generate new SQL query¶

In [106]:
system = "You are an NBA analyst with 15 years of experience writing complex SQL queries.\n"
system += (
    "Consider a table called 'nba_roster' with the following schema (columns)\n"
)
system += get_schema()
system += "Consider the following questions, and queries used to answer them:\n"
In [107]:
system
Out[107]:
'You are an NBA analyst with 15 years of experience writing complex SQL queries.\nConsider a table called \'nba_roster\' with the following schema (columns)\n0|Team|TEXT eg. "Toronto Raptors"\n1|NAME|TEXT eg. "Otto Porter Jr."\n2|Jersey|TEXT eg. "0" and when null has a value "NA"\n3|POS|TEXT eg. "PF"\n4|AGE|INT eg. "22" in years\n5|HT|TEXT eg. `6\' 7"` or `6\' 10"`\n6|WT|TEXT eg. "232 lbs" \n7|COLLEGE|TEXT eg. "Michigan" and when null has a value "--"\n8|SALARY|TEXT eg. "$9,945,830" and when null has a value "--"\nConsider the following questions, and queries used to answer them:\n'
In [108]:
question = """What is the median weight in the NBA?"""
sql = "select CAST(SUBSTR(WT, 1, INSTR(WT,' ')) as INTEGER) as percentile from nba_roster order by percentile limit 1 offset (select count(*) from nba_roster)/2;"

system += "Question: " + question + "\n"
system += "Query: " + sql + "\n"
In [109]:
print(system)
You are an NBA analyst with 15 years of experience writing complex SQL queries.
Consider a table called 'nba_roster' with the following schema (columns)
0|Team|TEXT eg. "Toronto Raptors"
1|NAME|TEXT eg. "Otto Porter Jr."
2|Jersey|TEXT eg. "0" and when null has a value "NA"
3|POS|TEXT eg. "PF"
4|AGE|INT eg. "22" in years
5|HT|TEXT eg. `6' 7"` or `6' 10"`
6|WT|TEXT eg. "232 lbs" 
7|COLLEGE|TEXT eg. "Michigan" and when null has a value "--"
8|SALARY|TEXT eg. "$9,945,830" and when null has a value "--"
Consider the following questions, and queries used to answer them:
Question: What is the median weight in the NBA?
Query: select CAST(SUBSTR(WT, 1, INSTR(WT,' ')) as INTEGER) as percentile from nba_roster order by percentile limit 1 offset (select count(*) from nba_roster)/2;

In [110]:
user = "Write two queries that are similar but different to those above.\n"
user += "Format the queries as a JSON object, i.e.\n"
user += '{ "explanation": str, "sql_query_1" : str, "sql_query_2": str }.\n'
In [111]:
print(user)
Write two queries that are similar but different to those above.
Format the queries as a JSON object, i.e.
{ "explanation": str, "sql_query_1" : str, "sql_query_2": str }.

In [112]:
user += "First write an explanation of why you decided to write these new queries in about 3-5 sentences, then write valid sqlite SQL queries for each of the 2 new queries. Make sure each query is complete and ends with a ;\n"
In [113]:
print(user)
Write two queries that are similar but different to those above.
Format the queries as a JSON object, i.e.
{ "explanation": str, "sql_query_1" : str, "sql_query_2": str }.
First write an explanation of why you decided to write these new queries in about 3-5 sentences, then write valid sqlite SQL queries for each of the 2 new queries. Make sure each query is complete and ends with a ;

In [114]:
prompt = make_llama_3_prompt(user, system)
In [115]:
llm = lamini.Lamini(model_name="meta-llama/Meta-Llama-3-8B-Instruct")
result = llm.generate(prompt, output_type={ "explanation": "str", "sql_query_1" : "str", "sql_query_2": "str" }, max_new_tokens=200)
print(result)
{'explanation': "I decided to write these new queries to provide more insights into the NBA players' data. The first query calculates the average height of NBA players, while the second query finds the most common college attended by NBA players. These queries are similar to the original query in that they involve extracting specific", 'sql_query_1': "SELECT AVG(CAST(SUBSTR(HT, 1, INSTR(HT,'')-1) AS INTEGER) as average_height FROM nba_roster", 'sql_query_2': 'SELECT COLLEGE, COUNT(*) as frequency FROM nba_roster GROUP BY COLLEGE ORDER BY frequency DESC LIMIT 1'}
In [116]:
def check_sql_query(query):
    try:
        pd.read_sql(query, con=engine)
    except Exception as e:
        logger.debug(f"Error in SQL query: {e}")
        return False

    logger.info(f"SQL query {query} is valid")

    return True
In [117]:
check_sql_query(result["sql_query_1"])
Out[117]:
False
In [118]:
check_sql_query(result["sql_query_2"])
Out[118]:
True
In [119]:
# Wrap it all up together in a class

class ModelStage(GenerationNode):
    def __init__(self):
        super().__init__(
            model_name="meta-llama/Meta-Llama-3-8B-Instruct",
            max_new_tokens=300,
        )

    def generate(
        self,
        prompt: Union[Iterator[PromptObject], AsyncIterator[PromptObject]],
        *args,
        **kwargs,
    ):
        prompt = self.add_template(prompt)

        results = super().generate(
            prompt,
            output_type={
                "explanation": "str",
                "sql_query_1": "str",
                "sql_query_2": "str",
            },
            *args,
            **kwargs,
        )

        return results

    async def add_template(self, prompts):
        async for prompt in prompts:
            new_prompt = make_llama_3_prompt(**self.make_prompt(prompt.data))
            yield PromptObject(prompt=new_prompt, data=prompt.data)

    async def process_results(self, results):
        async for result in results:
            if result is None:
                continue

            if result.response is None:
                continue

            logger.info("=====================================")
            logger.info(f"Generated query 1: {result.response['sql_query_1']}")
            logger.info(f"Generated query 2: {result.response['sql_query_2']}")
            logger.info("=====================================")

            if self.check_sql_query(result.response["sql_query_1"]):
                new_result = PromptObject(prompt="", data=copy.deepcopy(result.data))
                new_result.data.generated_sql_query = result.response["sql_query_1"]
                yield new_result

            if self.check_sql_query(result.response["sql_query_2"]):
                new_result = PromptObject(prompt="", data=copy.deepcopy(result.data))
                new_result.data.generated_sql_query = result.response["sql_query_2"]
                yield new_result

    def make_prompt(self, data):
        system = "You are an NBA analyst with 15 years of experience writing complex SQL queries.\n"
        system += (
            "Consider a table called 'nba_roster' with the following schema (columns)\n"
        )
        system += get_schema()
        system += "Consider the following questions, and queries used to answer them:\n"
        for example in data.sample:
            system += "Question: " + example["question"] + "\n"
            system += "Query: " + example["sql"] + "\n"

        # Important: generate relevant queries to your reference data
        # Ideally, close to those that are failing so you can show the model examples of how to do it right!
        user = "Write two queries that are similar but different to those above.\n"
        user += "Format the queries as a JSON object, i.e.\n"
        user += '{ "explanation": str, "sql_query_1" : str, "sql_query_2": str }.\n'

        # Next, use Chain of Thought (CoT) and prompt-engineering to help with generating SQL queries
        user += "First write an explanation of why you decided to write these new queries in about 3-5 sentences, then write valid sqlite SQL queries for each of the 2 new queries. Make sure each query is complete and ends with a ;\n"

        return {"system": system, "user": user}

    def check_sql_query(self, query):
        try:
            pd.read_sql(query, con=engine)
        except Exception as e:
            logger.debug(f"Error in SQL query: {e}")
            return False

        logger.info(f"SQL query {query} is valid")

        return True

Now that you have queries, generate questions for those queries¶

In [120]:
system = "You are an NBA analyst with 15 years of experience writing complex SQL queries.\n"
system += (
    "Consider a table called 'nba_roster' with the following schema (columns)\n"
)
system += get_schema() + "\n"
system += "Queries, and questions that they are used to answer:\n"

example_question = """What is the median weight in the NBA?"""
example_sql = "select CAST(SUBSTR(WT, 1, INSTR(WT,' ')) as INTEGER) as percentile from nba_roster order by percentile limit 1 offset (select count(*) from nba_roster)/2;"

system += "Question: " + example_question + "\n"
system += "Query: " + example_sql + "\n"
In [121]:
generated_sql = result["sql_query_2"]
In [122]:
user = "Now consider the following query.\n"
user += "Query: " + generated_sql + "\n"
user += "Write a question that this query could be used to answer.\n"
In [123]:
user += "Format your response as a JSON object, i.e.\n"
user += '{ "explanation": str, "question": str }.\n'

user += "First write an explanation in about 3-5 sentences, then write a one sentence question.\n"
In [124]:
prompt = make_llama_3_prompt(user, system)
result = llm.generate(prompt, output_type={ "explanation": "str", "question" : "str" }, max_new_tokens=200)
print(result)
{'explanation': 'This query groups the NBA players by their college and counts the number of players from each college. It then orders the results in descending order by frequency, which means it shows the college with the most players first. Finally, it limits the results to the top 1, which gives us the college with the highest number of players in the NBA. This query is useful for identifying the most popular college among NBA players', 'question': 'What is the most common college attended by NBA players'}
In [125]:
# Wrap it all up together in a class which generates a question
# given a query

class QuestionStage(GenerationNode):
    def __init__(self):
        super().__init__(
            model_name="meta-llama/Meta-Llama-3-8B-Instruct",
            max_new_tokens=150,
        )

    def generate(
        self,
        prompt: Union[Iterator[PromptObject], AsyncIterator[PromptObject]],
        *args,
        **kwargs,
    ):
        results = super().generate(
            prompt,
            output_type={
                "explanation": "str",
                "question": "str",
            },
            *args,
            **kwargs,
        )
        return results

    def preprocess(self, obj: PromptObject):
        new_prompt = make_llama_3_prompt(**self.make_question_prompt(obj.data))
        obj.prompt = new_prompt

    def make_question_prompt(self, data):
        system = "You are an NBA analyst with 15 years of experience writing complex SQL queries.\n"
        system += (
            "Consider a table called 'nba_roster' with the following schema (columns)\n"
        )
        system += get_schema() + "\n"
        system += "Queries, and questions that they are used to answer:\n"
        for example in data.sample:
            system += "Query: " + example["sql"] + "\n"
            system += "Question: " + example["question"] + "\n"

        user = "Now consider the following query.\n"
        user += "Query: " + data.generated_sql_query + "\n"
        user += "Write a question that this query could be used to answer.\n"

        # Using Chain of Thought (CoT) again
        # This time you can do it programmatically with function calling, so you can easily extract a question out of the JSON object
        user += "Format your response as a JSON object, i.e.\n"
        user += '{ "explanation": str, "question": str }.\n'

        user += "First write an explanation in about 3-5 sentences, then write a one sentence question.\n"

        return {"system": system, "user": user}
In [126]:
class QueryGenPipeline(GenerationPipeline):
    def __init__(self):
        super().__init__()
        self.model_stage = ModelStage()
        self.question_stage = QuestionStage()

    def forward(self, x):
        x = self.model_stage(x)
        x = self.question_stage(x)
        return x
In [127]:
async def run_query_gen_pipeline(gold_queries):
    return QueryGenPipeline().call(gold_queries)
In [128]:
# Generate N samples, for every example in the gold dataset

all_examples = []

async def load_gold_queries(args):
    path = f"data/{args.gold_file_name}"

    with jsonlines.open(path) as reader:
        global all_examples

        all_examples = [obj for obj in reader]

    sample_count = args.num_to_generate
    sample_size = 3

    random.seed(42)

    for i in range(sample_count):
        example_sample = ExampleSample(random.sample(all_examples, sample_size), i)
        yield PromptObject(prompt="", data=example_sample)


class ExampleSample:
    def __init__(self, sample, index):
        self.sample = sample
        self.index = index
In [129]:
async def save_generation_results(results, args):
    path = f"data/training_data/{args.training_file_name}"

    pbar = tqdm(desc="Saving results", unit=" results")
    with jsonlines.open(path, "w") as writer:

        async for result in results:
            writer.write(
                {
                    "question": result.response["question"],
                    "sql": result.data.generated_sql_query,
                }
            )
            pbar.update()

        for example in all_examples:
            writer.write(example)
            pbar.update()
In [130]:
args = Args()
gold_queries = load_gold_queries(args)
results = await run_query_gen_pipeline(gold_queries)
await save_generation_results(results, args)
Saving results: 30 results [01:05,  2.20s/ results]
In [ ]:
# # Display queries just generated
#!cat "data/training_data/generated_queries.jsonl"

Fine-tune Llama 3.1 with Lamini¶

In [36]:
# Utils
def get_default_finetune_args():
    return {
        "learning_rate": 3e-4,
        "max_steps": 3000,
        "early_stopping": False,
        "load_best_model_at_end": False,
        "use_cached_model": False,
        "peft_args": {"r_value": 32},
    }
In [38]:
import logging
import os
from datetime import datetime
from pprint import pprint
from typing import AsyncIterator, Iterator, Union
import sqlite3
from tqdm import tqdm

import pandas as pd
import jsonlines
from lamini.generation.base_prompt_object import PromptObject
from lamini.generation.generation_node import GenerationNode
from lamini.generation.base_prompt_object import PromptObject
from lamini.generation.generation_pipeline import GenerationPipeline


logger = logging.getLogger(__name__)
engine = sqlite3.connect("./nba_roster.db")
setup_logging()

class Args:
    def __init__(self,
                 max_examples=100,
                 sql_model_name="meta-llama/Meta-Llama-3-8B-Instruct",
                 gold_file_name="gold-test-set.jsonl",
                 training_file_name="archive/generated_queries.jsonl",
                 num_to_generate=10):
        self.sql_model_name = sql_model_name
        self.max_examples = max_examples
        self.gold_file_name = gold_file_name
        self.training_file_name = training_file_name
        self.num_to_generate = num_to_generate
In [39]:
# Take questions and queries from the training_file and embed them in the prompt
# below to form the training data
def make_question(obj):
    system = "You are an NBA analyst with 15 years of experience writing complex SQL queries.\n"
    system += "Consider the nba_roster table with the following schema:\n"
    system += get_schema() + "\n"
    system += (
        "Write a sqlite SQL query that would help you answer the following question:\n"
    )
    user = obj["question"]
    return {"system": system, "user": user}
In [40]:
args = Args()
llm = lamini.Lamini(model_name="meta-llama/Meta-Llama-3-8B-Instruct")
In [131]:
dataset = get_dataset(args, make_question)
In [132]:
finetune_args = get_default_finetune_args()
In [133]:
llm.train(
   data_or_dataset_id=dataset,
   finetune_args=finetune_args,
   is_public=True,  # For sharing
)
Uploading data....
Upload to blob completed for data.
Data pairs uploaded to blob.

Your dataset id is: c34fcc36ecfa672f4afb71e9f68bf9f1b78a7abf1764b6ce3457fa58a589d05e . Consider using this in the future to train using the same data. 
Eg: llm.train(data_or_dataset_id='c34fcc36ecfa672f4afb71e9f68bf9f1b78a7abf1764b6ce3457fa58a589d05e')
Tuning job submitted! Check status of job 10791 here: https://api.lamini.ai/train/10791
Out[133]:
{'job_id': 10791,
 'status': 'CREATED',
 'dataset_id': 'c34fcc36ecfa672f4afb71e9f68bf9f1b78a7abf1764b6ce3457fa58a589d05e'}
In [134]:
# Examine precomputed fine-tuning results
llm = lamini.Lamini(model_name="a5ebf1c4879569101f32444afae5adcafbfce9c5a6ed13035fd892147f7d59bc")
In [135]:
question = """Who is the highest paid NBA player?"""
system = f"""You are an NBA analyst with 15 years of experience writing complex SQL queries. Consider the nba_roster table with the following schema:
{get_schema()}

Write a sqlite query to answer the following question. Follow instructions exactly"""
prompt = make_llama_3_prompt(question, system)
print("Question:\n", question)
Question:
 Who is the highest paid NBA player?
In [136]:
print("Answer:")
print(llm.generate(prompt, max_new_tokens=200))
Answer:
select salary, name from nba_roster where SALARY!= '--' ORDER BY CAST(REPLACE(REPLACE(SALARY, '$', ''), ',','') AS INTEGER) DESC LIMIT 1
In [137]:
query="SELECT salary, name FROM nba_roster WHERE salary != '--' ORDER BY CAST(REPLACE(REPLACE(salary, '$', ''), ',','') AS INTEGER) DESC LIMIT 1;"
df = pd.read_sql(query, con=engine)
print(df)
        SALARY           NAME
0  $53,838,416  Stephen Curry

Run eval over evaluation dataset¶

In [138]:
# Collapsible or utils from Lesson 3 Lab for evaluation
class QueryStage(GenerationNode):
    def __init__(self, model_name):
        super().__init__(
            model_name=model_name,
            max_new_tokens=300,
        )

    def generate(
        self,
        prompt: Union[Iterator[PromptObject], AsyncIterator[PromptObject]],
        *args,
        **kwargs,
    ):
        results = super().generate(
            prompt,
            output_type={"sqlite_query": "str"},
            *args,
            **kwargs,
        )
        return results


    def postprocess(self, obj: PromptObject):
        # Run both the generated and reference (Gold Dataset) SQL queries
        # Assessing whether the SQL queries succeeded in hitting the database (not correctness yet!)

        query_succeeded = False

        try:
            logger.info(f"Running SQL query '{obj.response['sqlite_query']}'")
            obj.data["generated_query"] = obj.response["sqlite_query"]
            df = pd.read_sql(obj.response["sqlite_query"], con=engine)
            obj.data['df'] = df
            logger.info(f"Got data: {df}")
            query_succeeded = True

        except Exception as e:
            logger.error(
                f"Failed to run SQL query: {obj.response['sqlite_query']}"
            )

        logger.info(f"Running reference SQL query '{obj.data['sql']}'")
        df = pd.read_sql(obj.data["sql"], con=engine)
        logger.info(f"Got data: {df}")
        obj.data['reference_df'] = df

        logger.info(f"For question: {obj.data['question']}")
        logger.info(f"For query: {obj.response['sqlite_query']}")

        obj.data["query_succeeded"] = query_succeeded

    def preprocess(self, obj: PromptObject):
        new_prompt = make_llama_3_prompt(**self.make_prompt(obj.data))
        obj.prompt = new_prompt

    def make_prompt(self, data: dict):
        system = "You are an NBA analyst with 15 years of experience writing complex SQL queries.\n"
        system += "Consider the nba_roster table with the following schema:\n"
        system += get_schema() + "\n"
        system += (
            "Write a sqlite SQL query that would help you answer the following question. Make sure each query ends with a semicolon:\n"
        )
        user = data["question"]
        return {
            "user": user,
            "system": system,
        }

class ScoreStage(GenerationNode):
    def __init__(self):
        super().__init__(
            model_name="meta-llama/Meta-Llama-3-8B-Instruct",
            max_new_tokens=150,
        )

    def generate(
        self,
        prompt: Union[Iterator[PromptObject], AsyncIterator[PromptObject]],
        *args,
        **kwargs,
    ):
        results = super().generate(
            prompt,
            output_type={"explanation": "str", "similar": ["true", "false"]},
            *args,
            **kwargs,
        )
        return results

    def preprocess(self, obj: PromptObject):
        obj.prompt = make_llama_3_prompt(**self.make_prompt(obj))
        logger.info(f"Scoring Stage Prompt:\n{obj.prompt}")

    def postprocess(self, obj: PromptObject):
        obj.data['is_matching'] = self.is_matching(obj.data, obj.response)
        obj.data['explanation'] = obj.response["explanation"]
        obj.data['similar'] = obj.response["similar"] == "true"

    def is_matching(self, data, response):
        return (str(data.get('df',"None")).lower() == str(data['reference_df']).lower()
                or response['similar'] == "true")

    def make_prompt(self, obj: PromptObject):
        # Your evaluation model compares SQL output from the generated and reference SQL queries, using another LLM in the pipeline
        '''
        Note:
        Prompt tuning is important!
        A previous iteration of this scoring pipeline said `Compare the following two dataframes to see if they are identical`.
        That prompt turned out to be too stringent of criteria.
        '''
        system_prompt = "Compare the following two dataframes. They are similar if they are almost identical, or if they convey the same information about the nba_roster dataset"
        system_prompt += "Respond with valid JSON {'explanation' : str, 'similar' : bool}"
        user_prompt = (
            f"========== Dataframe 1 =========\n{str(obj.data.get('df','None')).lower()}\n\n"
        )
        user_prompt += (
            f"========== Dataframe 2 =========\n{str(obj.data['reference_df']).lower()}\n\n"
        )
        user_prompt += f"Can you tell me if these dataframes are similar?"
        return {
            "system": system_prompt,
            "user": user_prompt
        }

async def run_eval(dataset, args):

    results = await run_evaluation_pipeline(dataset, args)

    print("Total results:", len(results))

    return results


async def run_evaluation_pipeline(dataset, args):
    results = EvaluationPipeline(args).call(dataset)

    result_list = []

    pbar = tqdm(desc="Saving results", unit=" results")
    async for result in results:
        result_list.append(result)
        pbar.update()
    return result_list


class EvaluationPipeline(GenerationPipeline):
    def __init__(self, args):
        super().__init__()
        self.query_stage = QueryStage(args.sql_model_name)
        self.score_stage = ScoreStage()


    def forward(self, x):
        x = self.query_stage(x)
        x = self.score_stage(x)
        return x

def load_gold_dataset(args):
    path = f"data/{args.gold_file_name}"

    with jsonlines.open(path) as reader:
        for index, obj in enumerate(reversed(list(reader))):
            if index >= args.max_examples:
                break
            yield PromptObject(prompt="", data=obj)

def save_eval_results(results, args):
    base_path = "./data/results"
    now = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
    experiment_name = f"nba_sql_pipeline_{now}"
    experiment_dir = os.path.join(base_path, experiment_name)
    os.makedirs(os.path.join(base_path, experiment_name))

    # Write args to file
    args_file_name = f"{experiment_dir}/args.txt"
    with open(args_file_name, "w") as writer:
        pprint(args.__dict__, writer)


    def is_correct(r):
        if (
            (result.data["query_succeeded"] and result.data['is_matching']) or
            result.data["generated_query"] == result.data['sql']
        ):
            return True
        return False

    # Write sql results and errors to file
    results_file_name = f"{experiment_dir}/sql_results.jsonl"
    with jsonlines.open(results_file_name, "w") as writer:
        for result in results:
            if not is_correct(result):
                continue
            writer.write(
                {
                    "question": result.data['question'],
                    "query": result.data["generated_query"],
                    "query_succeeded": result.data["query_succeeded"],
                    "reference_sql": result.data['sql'],
                    "df": str(result.data.get('df', 'None')),
                    "reference_df": str(result.data['reference_df']),
                    'is_matching': result.data['is_matching'],
                    'similar': result.data['similar'],
                }
            )

    results_file_name = f"{experiment_dir}/sql_errors.jsonl"
    with jsonlines.open(results_file_name, "w") as writer:
        for result in results:
            if is_correct(result):
                continue
            writer.write(
                {
                    "question": result.data['question'],
                    "query": result.data["generated_query"],
                    "query_succeeded": result.data["query_succeeded"],
                    "df": str(result.data.get('df', 'None')),
                    "reference_df": str(result.data['reference_df']),
                    'is_matching': result.data['is_matching'],
                    'similar': result.data['similar'],
                }
            )

    # Write statistics to file
    average_sql_succeeded = sum(
        [result.data["query_succeeded"] for result in results]
    ) / len(results)
    average_correct = sum(
        [result.data["query_succeeded"] and result.data['is_matching'] for result in results]
    ) / len(results)

    file_name = f"{experiment_dir}/summary.txt"
    with open(file_name, "w") as writer:
        print(f"Total size of eval dataset: {len(results)}", file=writer)
        print(f"Total size of eval dataset: {len(results)}")
        print(f"Percent Valid SQL Syntax: {average_sql_succeeded*100}", file=writer)
        print(f"Percent Valid SQL Syntax: {average_sql_succeeded*100}")
        print(f"Percent Correct SQL Query: {average_correct*100}", file=writer)
        print(f"Percent Correct SQL Query: {average_correct*100}")
In [139]:
args = Args(sql_model_name="a5ebf1c4879569101f32444afae5adcafbfce9c5a6ed13035fd892147f7d59bc")
dataset = load_gold_dataset(args)
results = await run_eval(dataset, args)
save_eval_results(results, args)
Saving results: 15 results [03:10, 12.69s/ results]
Total results: 15
Total size of eval dataset: 15
Percent Valid SQL Syntax: 100.0
Percent Correct SQL Query: 73.33333333333333

Fine-tune Llama 3.1 with Lamini (Iteration 2)¶

In [ ]:
!cat sql_errors_example.jsonl
In [141]:
!cat "data/training_data/archive/generated_queries.jsonl" | grep "75th percentile"
{"question": "What is the 75th percentile salary in the NBA", "sql": "SELECT HT, AVG(WT) as avg_weight FROM nba_roster WHERE HT IS NOT NULL AND WT IS NOT NULL GROUP BY HT ORDER BY avg_weight DESC LIMIT 1"}
In [142]:
!cat "data/training_data/archive/generated_queries_large.jsonl" | grep "75th percentile"
cat: data/training_data/archive/generated_queries_large.jsonl: No such file or directory

Filter the dataset¶

Manually create functions for filtering the dataset

In [147]:
question_set = set()
sql_set = set()

def is_not_valid_sql(question, sql):
    try:
        df = pd.read_sql(sql, con=engine)
        return False
    except Exception as e:
        return True

def has_null_in_sql_or_question(question, sql):
    return "null" in sql.lower() or "null" in question

def returns_empty_dataframe(question, sql):
    try:
        df = pd.read_sql(sql, con=engine)
        return "Empty" in str(df) or "None" in str(df)
    except Exception as e:
        return False

def uses_avg_on_ht_column(question, sql):
    return "avg(ht)" in sql.lower() or "avg(salary" in sql.lower()

filter_conditions = [is_not_valid_sql, has_null_in_sql_or_question, returns_empty_dataframe, uses_avg_on_ht_column]

def training_semicolon(sql):
    if sql.strip()[-1] != ";":
        return sql.strip() + ";"
    return sql

with jsonlines.open("data/training_data/archive/generated_queries_large.jsonl", "r") as reader:
    with jsonlines.open("data/training_data/generated_queries_large_filtered.jsonl", "w") as writer:
        for r in reader:
            if r["question"] in question_set or r["sql"] in sql_set:
                continue
            question_set.add(r["question"])
            sql_set.add(r["sql"])

            if any(c(r['question'], r['sql']) for c in filter_conditions):
                continue

            sql = training_semicolon(r['sql'])
            writer.write(
                {
                    "question": r["question"],
                    "sql": sql,
                }
            )

Continue improving the dataset, filtering, and fine-tuning

In [149]:
# Model tuned on `archive/generated_queries_large_filtered_cleaned.jsonl`
llm = lamini.Lamini(model_name="63fd73a775daf24216b46c680a1e963a8d1e02b21bca43fcea6c26737d2e887e")
In [150]:
question = """What is the median age of the Chicago Bulls?"""
system = f"""You are an NBA analyst with 15 years of experience writing complex SQL queries. Consider the nba_roster table with the following schema:
{get_schema()}

Write a sqlite query to answer the following question. Follow instructions exactly"""
prompt = make_llama_3_prompt(question, system)
print("Question:\n", question)

print("Answer:")
sql = llm.generate(prompt, max_new_tokens=200)
print(sql)
Question:
 What is the median age of the Chicago Bulls?
Answer:
SELECT CAST(AGE AS INTEGER) AS percentile FROM nba_roster WHERE team='Chicago Bulls' ORDER BY percentile LIMIT 1 OFFSET (SELECT COUNT(*) FROM nba_roster WHERE team='Chicago Bulls')/2;
In [ ]:
df = pd.read_sql(sql, con=engine)
print(df)