Error Handling
Helps to make error handling prompts and datasets for error free inference and fine-tuning datasets for enforcing self correction property.
Error-handling prompts are crucial for refining model performance, especially in complex tasks like Text-to-SQL generation. The prompts help the model learn how to handle errors by providing additional context and guidance based on past mistakes. By training on these prompts, the model can self-correct during inference, improving the quality of its output.
Requirements
To create an error-handling dataset or prompt, the data entity must include:
db_id
: Identifier for the database.db_path
: Path to the database.prompt
: The original prompt used to generate the initial results.
In this guide, we’ll use the BirdBench dataset as an example, but you can use any text-to-SQL compatible dataset. We’ll define the generators (models) and executors (evaluation components) that play crucial roles in this pipeline.
References
If you are unfamiliar with generators, executors, or datasets, you can check out the following tutorials:
Datasets Tutorial
Learn how to utilize pre-processed datasets for Text-to-SQL tasks. This tutorial covers dataset evaluation, fine-tuning, and creation of custom datasets.
Generators Tutorial
A step-by-step guide on how to use Text-to-SQL generators to create SQL queries from user input and specified database sources.
Executors Tutorial
Learn how to connect to databases and execute SQL queries generated by models. This tutorial covers execution, troubleshooting, and best practices.
Evaluators Tutorial
Understand the evaluation of Text-to-SQL models with metrics like execution accuracy and Valid Efficiency Score (VES).
Overview
For Training
- Start with an existing dataset compatible with premsql datasets.
- Use a generator to run on the dataset. The executor gathers errors for incorrect generations.
- Use the existing response, initial prompt, and the error to create new data points using an error-handling prompt.
For Inference
PremSQL automatically handles error correction in simple-pipeline and execution-guided decoding, so manual intervention is not needed.
Defining Generators and Executors
Let’s start by defining the necessary components:
from premsql.datasets.error_dataset import ErrorDatasetGenerator
from premsql.generators.huggingface import Text2SQLGeneratorHF
from premsql.executors.from_langchain import ExecutorUsingLangChain
generator = Text2SQLGeneratorHF(
model_or_name_or_path="premai-io/prem-1B-SQL",
experiment_name="testing_error_gen",
type="train", # do not type: 'test' since this will be used during training
device="cuda:0"
)
executor = ExecutorUsingLangChain()
Setting Up the Dataset
We define our existing training dataset using BirdBench. For demonstration, we’re using num_rows=10
, but it’s recommended to use the full dataset when training. Typically, the error dataset will be smaller than the training set if you’re using a well-trained model.
from premsql.datasets import BirdDataset
bird_train = BirdDataset(
split="train",
dataset_folder="/path/to/dataset"
).setup_dataset(num_rows=10)
Creating the Error Handling Dataset
Creating the error-handling dataset is simple: feed in the generator and executor of your choice.
error_dataset_gen = ErrorDatasetGenerator(
generator=generator,
executor=executor
)
Generating and Saving Error Dataset
Run the error handling generation and save the results. The results are saved in ./experiments/train/<generator-experiment-name>/error_dataset.json
. Use force=True
to regenerate if needed.
error_dataset = error_dataset_gen.generate_and_save(
datasets=bird_train,
force=True
)
Error Prompt Template
Each data point in the error dataset will include an error-handling prompt, which looks like this:
ERROR_HANDLING_PROMPT = """
{existing_prompt}
# Generated SQL: {sql}
## Error Message
{error_msg}
Carefully review the original question and error message, then rewrite the SQL query to address the identified issues.
Ensure your corrected query uses correct column names,
follows proper SQL syntax, and accurately answers the original question
without introducing new errors.
# SQL:
"""
To use a custom prompt, provide it like this:
error_dataset = error_dataset_gen.generate_and_save(
datasets=bird_train,
force=True,
prompt_template=your_prompt_template
)
Ensure your custom template includes at least the keys used by the default template.
Loading an Existing Error Handling Dataset
You don’t need to rerun the error-handling pipeline once generated. Use from_existing
to load the dataset during fine-tuning.
existing_error_dataset = ErrorDatasetGenerator.from_existing(
experiment_name="testing_error_gen"
)
print(existing_error_dataset[0])
Tokenizing the Dataset
We also support tokenizing the error dataset during loading, particularly useful when integrating with Hugging Face Transformers.
existing_error_dataset = ErrorDatasetGenerator.from_existing(
experiment_name="testing_error_gen",
tokenize_model_name_or_path="premai-io/prem-1B-SQL",
)
existing_error_dataset[0]
Example with SQLite Executor
Let’s see another example using a different executor—SQLite Executor.
from premsql.executors import SQLiteExecutor
generator = Text2SQLGeneratorHF(
model_or_name_or_path="premai-io/prem-1B-SQL",
experiment_name="testing_error_sqlite",
type="train",
device="cuda:0"
)
sqlite_executor = SQLiteExecutor()
error_dataset_gen = ErrorDatasetGenerator(
generator=generator,
executor=sqlite_executor
)
Generating a Tokenized Dataset On-the-Fly
You can generate and save a tokenized dataset directly during the error generation process.
error_dataset_from_sqlite = error_dataset_gen.generate_and_save(
datasets=bird_train,
tokenize=True,
force=True
)
error_dataset_from_sqlite[0]
Conclusion
That’s how you generate an error-handling dataset. This dataset is compatible with other premsql datasets, allowing you to mix and use them collectively for fine-tuning. By incorporating error-handling prompts, you can significantly enhance your model’s capability to handle errors, leading to more robust and reliable performance in production.