Generate SQL from Text
Welcome to the Prem AI cookbook section. In this recipe, we will explore how we can use PremAI SDK and DSPy to generate SQL from text. This example shows one of the ways we can extend development with Prem AI. So, without further ado, let’s get started. You can find the complete code here.
Objective
This tutorial aims to introduce you to DSPy and show you how to use it with the Prem SDK. Finally, in the later part of the tutorial, we will demonstrate how to use the Prem Platform to debug, see what DSPy has been optimizing, and how it can be further improved. In this tutorial, we are going to:
- Write simple prompts using
dspy.Signature
to give instructions to the LLM. - Use
dspy.Module
to write a simple Text2SQL pipeline. - Use
dspy.teleprompt
to automatically optimize the prompt for better results. - Use
dspy.Evaluate
to evaluate the results.
Note
For those unfamiliar with DSPy, it is an LLM orchestration tool whose API is very similar to PyTorch. The main focus of DSPy is to help developers write clean and modular code, rather than writing very big prompts. You can learn more about it in their documentation or get a quick overview in our documentation. If you are new here, no problem. We have described what each part of the code does. All you need to be familiar with is Python.
Setting up the project
Let’s start by creating a virtual environment and installing dependencies.
Before getting started, make sure you have an account at Prem AI Platform and a valid project id and an API Key. Additionally, you also need to have at least one repository-id as a last requirement.
python3 -m venv .venv
source .venv/bin/activate
Up next, we need to install some dependencies. You can checkout all the dependencies in this requirements file. Once that’s done, let’s import the required packages.
import dspy
from dspy import PremAI
from dspy.evaluate import Evaluate
from dspy.datasets import DataLoader
After this, we define some constants and instantiate DSPy PremAI LM. We will be using CodeLLama as our base LLM to generate SQL queries. Here is how our launchpad looks:
If you are unfamiliar with setting up a project on Prem, we recommend taking a quick look at this guide. It is super intuitive, and you can even get started for free.
Prem AI offers a variety of models (see the list here) , so you can experiment with all the models. We conducted some initial experiments in Prem Playground, and we found that CodeLlama was performing consistently better than other models like Claude 3, GPT-4, and Mistral. Therefore, we are using Code Llama Instruct 70B by Meta AI for this cookbook experiment.
PREMAI_API_KEY = "G91-xxxxx-xxxxxx"
PROJECT_ID = "1234"
generation_kwargs = {
"temperature": 0.1,
"max_tokens": 1000
}
Additionally, we set the temperature to 0.1 and the max_tokens to 500. You can copy these configurations to your experiment to reproduce the results or start your implementation.
Now, we instantiate our lm
object and test it before moving forward.
lm = PremAI(
project_id=PROJECT_ID, api_key=PREMAI_API_KEY, **generation_kwargs
)
dspy.configure(lm=lm)
lm("hello")
Loading Datasets
We start by loading a dataset. We will use the gretelai/synthetic_text_to_sql dataset for our example. We will also split the dataset into validation and test splits. The code below shows how we load and split the dataset using DSPy.
data_loader = DataLoader()
# Load the dataset from huggingface
trainset = data_loader.from_huggingface(
dataset_name="gretelai/synthetic_text_to_sql",
fields=("sql_prompt", "sql_context", "sql"),
input_keys=("sql_prompt", "sql_context"),
split="train"
)
testset = data_loader.from_huggingface(
dataset_name="gretelai/synthetic_text_to_sql",
fields=("sql_prompt", "sql_context", "sql"),
input_keys=("sql_prompt", "sql_context"),
split="test"
)
We sampled a minimal amount of data because we would not do any kind of fine-tuning here. DSPy will optimize our prompt by tweaking it in several ways (for example, by adding optimal few-shot examples). However, DSPy also provides options for fine-tuning the weights, but that is outside the scope of this tutorial.
trainset = data_loader.sample(dataset=trainset, n=100)
testset = data_loader.sample(dataset=testset, n=75)
_trainval = data_loader.train_test_split(
dataset=trainset,
test_size=0.25,
random_state=1399
)
trainset, valset = _trainval["train"], _trainval["test"]
Let’s also take one sample which will help us to do initial checks for our implementations.
sample = data_loader.sample(dataset=trainset, n=1)[0]
for k, v in sample.items():
print(f"\n{k.upper()}:\n")
print(v)
Creating a DSPy Signature
Most LLM orchestration frameworks, like LangChain and llama-index, advise you to prompt the language models explicitly. Eventually, we need to tweak or optimize the prompts several times to achieve better results. However, writing and managing large prompts can be pretty messy.
The primary purpose of DSPy is to shift the prompting process to a more programmatic paradigm. Signatures in DSPy allow you to specify how the language model’s input and output behaviour should be.
The Field names inside a Signature have a semantic significance. Each field that represents a role (for example: question
or answer
or sql_query
) defines that prompt variable and what the variable is about. More explained below in an example.
For this example, we will try to understand class-based DSPy Signatures. You can learn more about Signatures in the official documentation.
Class-based DSPy Signatures
You define a class in class-based signatures where:
-
The class docstring is used to express the nature of the overall task.
-
You can provide different input and output variables using
InputField
andOutputField
, which also describe the nature of those variables.
Here is a simple example of how we define a simple class-based DsPY signature.
class Emotion(dspy.Signature):
# Define your overall task descripition here
"""Classify emotion among sandness, joy, love, anger, fear, surprise."""
sentence = dspy.InputField(
desc="A sentence which needs to be classified"
)
sentiment = dspy.OutputField(
desc="classify in either of one class: sandness / joy / love / anger / fear /surprise. Do not write anything else, just the class. Sentiment:"
)
classify = dspy.Predict(Emotion)
classify(sentence="The day is super gloomy today ughhh")
To inspect how the actual prompt was passed, you can check using the lm.inspect
method. Here is how it looks like for the Emotion
class.
lm.inspect_history(n=1)
The output here shows the actual prompt that DSPy passed from the signature object to our LLM.
Now, lets make a signature for our SQL Generator. In this signature, we will
- Define the overall task in the class docstring.
- Provide input and output prompt variables.
- Describe Input and Output variables to tell what is expected here.
You can also think of making a prompt very similar to the one shown below but in a more programable way:
prompt = """
Transform a natural language query into an SQL query.
Do not output anything other than SQL Query. You will be given the
sql_prompt which will tell what you need to do and sql_context
which will give some additional context to generate the right SQL. Here are the inputs:
Natural language query: {sql_prompt}
The context of the query: {sql_context}
Write the SQL here:
"""
Here is the class-based signature for Text to SQL generator:
class Text2SQLSignature(dspy.Signature):
"""Transform a natural language query into a SQL query.
You will be given the sql_prompt which will tell what you need to do
and a sql_context which will give some additional context to generate the right SQL.
Only generate the SQL query nothing else. You should give one correct answer.
starting and ending with ```
"""
sql_prompt = dspy.InputField(desc="Natural language query")
sql_context = dspy.InputField(desc="Context for the query")
sql = dspy.OutputField(desc="SQL Query")
Let’s generate a single sample with the signature we just wrote. Please note that the signature is just a blueprint of our goal. You are not required to tweak the prompt anymore. It is the job of DSPy to optimize the process to make the best prompt.
generate_sql_from_query = dspy.Predict(signature=Text2SQLSignature)
result = generate_sql_from_query(
sql_prompt=sample["sql_prompt"],
sql_context=sample["sql_context"]
)
for k, v in result.items():
print(f"\n{k.upper()}:\n")
print(v)
Define a DSPy Module
Now we know how signatures work in DSPy. Multiple signatures or prompting techniques could work as a function. You can compose those techniques into a single module or a program.
If you come from a deep learning background, you have heard about torch.nn.Module
, which helps compose multiple layers to a single program. Similarly, in DSPy, various modules can be formed into more extensive modules (programs). Now, let’s create a simple module for our Text2SQL generator.
class Text2SQLProgram(dspy.Module):
def __init__(self, signature: dspy.Signature):
super().__init__()
self.program = dspy.Predict(signature=signature)
def forward(self, sql_prompt, sql_context):
return self.program(
sql_prompt=sql_prompt,
sql_context=sql_context
)
You might wonder why we are doing the same thing, but now we are wrapping it inside another class. You might have multiple such signatures, each with a different purpose. You can chain each of them individually to create a single output. However, that is outside the scope of the current tutorial. Now, let’s run this module for a sanity check.
text2sql = Text2SQLProgram(signature=Text2SQLSignature)
result = text2sql(
sql_prompt=sample["sql_prompt"],
sql_context=sample["sql_context"]
)
for k, v in result.items():
print(f"\n{k.upper()}:\n")
print(v)
Awesome, now that things are working, let’s create simple metrics to quantify how well the SQL queries are being generated.
Metrics in DSPy
A metric is a function that quantifies how ground truth is related to the prediction and how sound the predicted output is. A simple example is accuracy. So, a straightforward metric could be a direct string match on whether our predicted SQL string is identical to the actual SQL string. Let’s create a custom metric where we:
- First, normalize the SQL string. (This means removing all the ``` and “\n” characters to make them identical).
- Compare both strings and return 0 if they do not match and one if they do.
- Run the above two steps for all validation examples to calculate an average score.
import re
import sqlparse
from tqdm import tqdm
def normalise_sql_string(sql_string):
normalized = re.sub(r'```', '', sql_string).strip()
return re.sub(r'\s+', ' ', normalized)
def compare_sqls(ground_truth, prediction, trace=None):
ground_truth = ground_truth.sql
prediction = prediction.sql
ground_truth = normalise_sql_string(sql_string=ground_truth)
prediction = normalise_sql_string(sql_string=prediction)
ground_truth_parsed = sqlparse.format(
ground_truth,
reindent=True,
keyword_case="upper"
).strip()
prediction_parsed = sqlparse.format(
prediction,
reindent=True,
keyword_case="upper"
).strip()
return ground_truth_parsed == prediction_parsed
# Now use this function for the validation set
scores = []
for x in tqdm(valset, total=len(valset)):
prediction = text2sql(
sql_prompt=x.sql_prompt, sql_context=x.sql_context
)
ground_truth = x
score = compare_sqls(ground_truth=ground_truth, prediction=prediction)
scores.append(score)
print("Score: {score}/{total}".format(score=sum(scores), total=len(scores)))
Output:
100%|██████████| 25/25 [01:17<00:00, 3.09s/it]
Score: 9 / 25
It is important to note that, sometimes, metrics can be very unreliable. For example, a metric like compare_sqls
compares two SQL statements through string manipulation. However, there are better approaches to determining the quality of LLM. A better metric is to run the generated SQL statements inside a database table and see whether the results match the ground truth. However, that is out of the scope of this tutorial.
Optimizing a DSPy program
An optimizer (previously known as teleprompter) in DSPy optimizes the overall prompt workflow by tuning the prompt and the LM weights to maximize the target metrics, such as accuracy.
Optimizers take three things in the input:
-
The DSPy Program: This may be a single module (e.g., dspy.Predict) or a complex multi-module program. It is already defined above.
-
Metric Function: This function evaluates the program and assigns a score at the end. It is already defined above.
-
Few training inputs: This may be very small (i.e., only 5 or 10 examples) and incomplete (only inputs to your program, without any labels).
You can learn more about DSPy Optimizers in their official documentation. For this program, we are NOT optimizing any weights; hence, we are not fine-tuning here.
To keep things simple, we use the LabeledFewShot
optimizer. It simply constructs a few shot examples (which we call as demos) from provided labelled input and output data points. Let’s see it in action.
from dspy.teleprompt import LabeledFewShot
# k = number of few shot examples
optimizer = LabeledFewShot(k=4)
optimized_text2sql = optimizer.compile(
student=text2sql,
trainset=trainset
)
Awesome, finally we use the dspy.Evaluate
to evaluate our overall system. Here is how we do it:
evaluate = Evaluate(
devset=valset,
metric=compare_sqls,
num_threads=3,
display_progress=True,
display_table=0
)
evaluate(optimized_text2sql)
Output:
Average Metric: 11 / 25 (44.0): 100%|██████████| 25/25 [00:22<00:00, 1.09it/s]
There are several other types of optimizers, and all of them follow the same flow. This allows you to plug each in and out to test the best quickly. You can learn more about DSPy optimizers here.
Conclusion
In this tutorial, we created a straightforward example of how to generate SQL from text. Forty-four percent is a fair accuracy for starting out. In this example, we used CodeLlama. However, we can use different models and see which works best for us.
Additionally when you use Prem, you can actually see all the traces and runs of the model being captured in the traces section. Inside Traces, you can monitor each LLM run and see how DSPy optimized your prompt. Here is an example of our case:
From the above picture, as you can see, DSPy added some in-context examples to optimize the initial prompt. You can similarly do this for different settings (like using a different LLM or using a different Optimizer and all of those will be captured here)
Additionally, if you are using Business plans, you can use these traces to further fine-tune your model to work even better.
Congratulations if you have made till here. Checkout our cookbook for the full code and signup to Prem AI if you have’t yet.