Error-handling prompts are crucial for refining model performance, especially in complex tasks like Text-to-SQL generation. The prompts help the model learn how to handle errors by providing additional context and guidance based on past mistakes. By training on these prompts, the model can self-correct during inference, improving the quality of its output.

Requirements

To create an error-handling dataset or prompt, the data entity must include:

  • db_id: Identifier for the database.
  • db_path: Path to the database.
  • prompt: The original prompt used to generate the initial results.

In this guide, we’ll use the BirdBench dataset as an example, but you can use any text-to-SQL compatible dataset. We’ll define the generators (models) and executors (evaluation components) that play crucial roles in this pipeline.

References

If you are unfamiliar with generators, executors, or datasets, you can check out the following tutorials:

Overview

For Training

  1. Start with an existing dataset compatible with premsql datasets.
  2. Use a generator to run on the dataset. The executor gathers errors for incorrect generations.
  3. Use the existing response, initial prompt, and the error to create new data points using an error-handling prompt.

For Inference

PremSQL automatically handles error correction in simple-pipeline and execution-guided decoding, so manual intervention is not needed.

Defining Generators and Executors

Let’s start by defining the necessary components:

from premsql.datasets.error_dataset import ErrorDatasetGenerator
from premsql.generators.huggingface import Text2SQLGeneratorHF
from premsql.executors.from_langchain import ExecutorUsingLangChain


generator = Text2SQLGeneratorHF(
    model_or_name_or_path="premai-io/prem-1B-SQL",
    experiment_name="testing_error_gen",
    type="train", # do not type: 'test' since this will be used during training
    device="cuda:0"
)

executor = ExecutorUsingLangChain()

Setting Up the Dataset

We define our existing training dataset using BirdBench. For demonstration, we’re using num_rows=10, but it’s recommended to use the full dataset when training. Typically, the error dataset will be smaller than the training set if you’re using a well-trained model.

from premsql.datasets import BirdDataset

bird_train = BirdDataset(
    split="train",
    dataset_folder="/path/to/dataset"
).setup_dataset(num_rows=10)

Creating the Error Handling Dataset

Creating the error-handling dataset is simple: feed in the generator and executor of your choice.

error_dataset_gen = ErrorDatasetGenerator(
    generator=generator,
    executor=executor
)

Generating and Saving Error Dataset

Run the error handling generation and save the results. The results are saved in ./experiments/train/<generator-experiment-name>/error_dataset.json. Use force=True to regenerate if needed.

error_dataset = error_dataset_gen.generate_and_save(
    datasets=bird_train,
    force=True
)

Error Prompt Template

Each data point in the error dataset will include an error-handling prompt, which looks like this:

ERROR_HANDLING_PROMPT = """
{existing_prompt}

# Generated SQL: {sql}

## Error Message

{error_msg}

Carefully review the original question and error message, then rewrite the SQL query to address the identified issues. 
Ensure your corrected query uses correct column names, 
follows proper SQL syntax, and accurately answers the original question 
without introducing new errors.

# SQL: 
"""

To use a custom prompt, provide it like this:

error_dataset = error_dataset_gen.generate_and_save(
    datasets=bird_train,
    force=True,
    prompt_template=your_prompt_template
)

Ensure your custom template includes at least the keys used by the default template.

Loading an Existing Error Handling Dataset

You don’t need to rerun the error-handling pipeline once generated. Use from_existing to load the dataset during fine-tuning.

existing_error_dataset = ErrorDatasetGenerator.from_existing(
    experiment_name="testing_error_gen"
)

print(existing_error_dataset[0])

Tokenizing the Dataset

We also support tokenizing the error dataset during loading, particularly useful when integrating with Hugging Face Transformers.

existing_error_dataset = ErrorDatasetGenerator.from_existing(
    experiment_name="testing_error_gen",
    tokenize_model_name_or_path="premai-io/prem-1B-SQL",
)

existing_error_dataset[0]

Example with SQLite Executor

Let’s see another example using a different executor—SQLite Executor.

from premsql.executors import SQLiteExecutor

generator = Text2SQLGeneratorHF(
    model_or_name_or_path="premai-io/prem-1B-SQL",
    experiment_name="testing_error_sqlite",
    type="train",
    device="cuda:0"
)
sqlite_executor = SQLiteExecutor()

error_dataset_gen = ErrorDatasetGenerator(
    generator=generator,
    executor=sqlite_executor
)

Generating a Tokenized Dataset On-the-Fly

You can generate and save a tokenized dataset directly during the error generation process.

error_dataset_from_sqlite = error_dataset_gen.generate_and_save(
    datasets=bird_train,
    tokenize=True,
    force=True
)

error_dataset_from_sqlite[0]

Conclusion

That’s how you generate an error-handling dataset. This dataset is compatible with other premsql datasets, allowing you to mix and use them collectively for fine-tuning. By incorporating error-handling prompts, you can significantly enhance your model’s capability to handle errors, leading to more robust and reliable performance in production.