Introduction

premsql provides a simple API to use various pre-processed datasets for Text-to-SQL tasks. Text-to-SQL is complex as it requires data dependencies on databases and tables. The premsql datasets help streamline this by providing easy access to datasets and enabling you to create your own datasets with private databases.

Available Datasets

Currently, the following datasets are readily available:

  1. BirdBench Dataset
  2. Spider Unified Datasets
  3. Domains Dataset
  4. Gretel AI Dataset

Let’s see how to use these datasets with premsql.

Loading Datasets

To get started, you’ll use the Text2SQLDataset class from the premsql package. Here’s an example using the BirdBench dataset:

from premsql.datasets import Text2SQLDataset
from premsql.utils import print_data

bird_dataset = Text2SQLDataset(
    dataset_name='bird', split="train", force_download=False,
    dataset_folder="/path/to/your/data" # change this to the path where you want to store the dataset
)

BirdBench is one of the very popular and recent Text to SQL Benchmarking dataset. It comes with around 9K train and 1534 validation data composing of complex and diverse queries around real world datasets. They also have a private test data which they use to test different Text to SQL models and approaches by the community.

Accessing Raw Data

The dataset object provides two key methods:

  1. raw_dataset: Returns a dictionary with raw data from the JSON file.
  2. filters_available: Lists the filters available for the dataset.

Example:

raw_bird_training_dataset = bird_dataset.raw_dataset
print(raw_bird_training_dataset[0])

Available Filters

You can view the available filters for the dataset:

print(bird_dataset.filters_available)

Filters are those keys in the dataset that can be used to filter the dataset based on specific criteria. In our case of BirdBench, we can only filter by db_id. PremSQL automatically detects which filters are available for a dataset and provides them as a list.

Setting Up the Dataset

To load the processed dataset, use the setup_dataset method. This will process and return the dataset object. The method has several optional parameters:

  • filter_by: Filters the dataset based on the provided filter.
  • num_rows: Limits the number of rows to return.
  • num_fewshot: Defines how many few-shot examples to include in the prompt.
  • model_name_or_path: Applies the model-specific prompt template and tokenizes the dataset if provided.
  • prompt_template: Custom prompt template; defaults are available in the premsql prompts module.

Example:

# Set up the BirdBench dataset
bird_dataset = bird_dataset.setup_dataset(
    model_name_or_path="premai-io/prem-1B-SQL",
    num_fewshot=3,
    num_rows=3
)

print_data(bird_dataset[0])

The difference between loading and setting up a dataset is, in loading, we generally download the dataset or load the raw dataset from some folder. Setting up a dataset means processing the dataset (which includes inserting schemas, few shot prompts, adding customization, filtering etc) based on user’s requirements.

Preview Without Tokenization

Tokenization sometimes can be computationally heavy. To preview the dataset without tokenizing, set model_name_or_path to None.

bird_dataset_without_tokenization = Text2SQLDataset(
    dataset_name='bird', split="train", force_download=False,
    dataset_folder="../data"
).setup_dataset(
    model_name_or_path=None, num_fewshot=3, num_rows=3
)

print_data(bird_dataset_without_tokenization[0])

Filtering Datasets

We are now loading the BirdBench validation dataset. And as you can see there are two types of filters available. These are: db_id or difficulty You can filter datasets based on available criteria, such as db_id or difficulty.

Example of filtering by difficulty:

bird_validation = Text2SQLDataset(
    dataset_name='bird', split="validation", force_download=False,
    dataset_folder="../data"
).setup_dataset(
    model_name_or_path=None,
    num_fewshot=3,
    num_rows=100,
    filter_by=("difficulty", "simple")
)

# Count the number of simple difficulty examples
simple_count = len([example for example in bird_validation if example["difficulty"] == "simple"])
print(simple_count)

Using Other Available Datasets

As mentioned earlier, PremSQL provides access to some best readily available open source datasets like:

  1. BirdBench Dataset
  2. Spider Unified Datasets
  3. Domains Dataset
  4. Gretel AI Dataset

Let’s load and see how the other datasets looks like:

# Loading Spider Dataset
spider_dataset = Text2SQLDataset(
    dataset_name="spider",
    split="train",
    dataset_folder="../data",
).setup_dataset(
    num_fewshot=3,
    num_rows=3,
    model_name_or_path="premai-io/prem-1B-SQL",
)

# Loading Domains Dataset
domains = Text2SQLDataset(
    dataset_name="domains",
    split="train",
    dataset_folder="../data",
).setup_dataset(
    num_fewshot=3,
    num_rows=3,
    model_name_or_path="premai-io/prem-1B-SQL",
)

# Loading Gretel AI Dataset (Synthetic)
gretel_dataset = Text2SQLDataset(
    dataset_name="gretel",
    split="train",
    dataset_folder="../data",
).setup_dataset(
    num_fewshot=3,
    num_rows=3,
    model_name_or_path="premai-io/prem-1B-SQL",
)

print_data(gretel_dataset[0]["raw"])

Dataset Merging

premsql datasets support merging, allowing you to pack multiple datasets together for combined use. This is useful when training on multiple datasets or you want to do validation on a combination of datasets.

# Merge datasets
merged_dataset = [*bird_dataset, *spider_dataset, *domains, *gretel_dataset]

print(f"Length of bird dataset: {len(bird_dataset)}")
print(f"Length of spider dataset: {len(spider_dataset)}")
print(f"Length of domains dataset: {len(domains)}")
print(f"Length of gretel dataset: {len(gretel_dataset)}")
print(f"Length of merged dataset: {len(merged_dataset)}")

print_data(merged_dataset[0])

Understanding Prompts in PremSQL

Here’s how a prompt looks when wrapped around a model’s prompt template:

print(gretel_dataset[0]["raw"]["prompt"])

This prompts are made dynamically as long as the raw dataset contains the necessary keys: db_id, question and SQL (all case sensitives) and the dataset is structured in the standardization followed by PremSQL. Let’s understand more about the standardization and how you can create your own dataset.

Creating Your Own Dataset

If you have some annotations for your Text to SQL task on your private data then you can create your own dataset.When it comes to tasks like Text to SQL, we face a lot of dependencies with databases at every step. For example, to make a dataset for Text to SQL, you need a module that can connect to databases, fetch the schemas, and apply them to the prompt. That’s why we have introduced a more standard approach to providing inputs, which is used all across PremSQL. If you use it for your private databases, then you can extend and customize it for your data and your use case. However, you strictly need to organize your files in the following structure:

├── databases
│   ├── california_schools
│       ├── california_schools.sqlite
├── train.json  
├── validation.json # Optional 

Each sub-folder in databases should contain a .sqlite file matching the sub-folder name. Your train or validation JSON file should contain a list of dictionaries with these required keys:

  • db_id: Corresponds to the folder and .sqlite file name.
  • question: User question.
  • SQL: The ground truth SQL query.

Example entry:

{
  "question_id": 0,
  "db_id": "california_schools",
  "question": "What is the highest eligible free rate for K-12 students in the schools in Alameda County?",
  "evidence": "Eligible free rate for K-12 = `Free Meal Count (K-12)` / `Enrollment (K-12)`",
  "SQL": "SELECT `Free Meal Count (K-12)` / `Enrollment (K-12)` FROM frpm WHERE `County Name` = 'Alameda' ORDER BY (CAST(`Free Meal Count (K-12)` AS REAL) / `Enrollment (K-12)`) DESC LIMIT 1",
  "difficulty": "simple"
}

You can load your dataset using:

from premsql.datasets import StandardDataset

validation_dataset = StandardDataset(
    split="validation",    # it can be either train / validation / test depending on your dataset and the name of the json file
    dataset_path="/path/to/the/folder/containing/the/files/",
    database_folder_name="databases", # The same name of the folder 
    json_file_name="validation.json",
)

train_dataset = StandardDataset(
    split="train",    # it can be either train / validation / test depending on your dataset and the name of the json file
    dataset_path="/path/to/the/folder/containing/the/files/",
    database_folder_name="databases", # The same name of the folder 
    json_file_name="train.json",
)

Advanced Customization for Text-to-SQL Datasets

So far, all the examples shown are tailored to .sqlite databases. However, you might encounter scenarios where:

  1. You are working with different databases, such as PostgreSQL or other cloud-based database instances.
  2. You want to implement custom logic before generating prompts.
  3. You need to extend the utility of premsql to fit specific requirements.

This section will guide you through achieving these customizations.

Note: For scenario 1, you could also migrate a subset of your dataset to SQLite format. This allows you to annotate the dataset in SQLite and follow the standard approach to create a Text2SQL-compatible dataset for fine-tuning and inference.

If you require full customization beyond this, you can achieve it in three steps. A detailed tutorial will be available in future releases, but here’s a concise overview of the process:

Step 1: Define a Custom Dataset Instance

A dataset instance manages operations on individual data points. To create a custom dataset instance, extend the premsql.datasets.base.Text2SQLBaseInstance class. Below is a blueprint of how to define your custom instance:

from premsql.datasets.base import Text2SQLBaseInstance

class CustomDataInstance(Text2SQLBaseInstance):
    def __init__(self, dataset: list[dict]) -> None:
        super().__init__(dataset=dataset)

    def schema_prompt(self, db_path: str) -> str:
        # Define your custom schema prompt logic here.
        # Fetch the schema from your database and format it as needed.
        # For SQLite databases, you could use:
        # SELECT sql FROM sqlite_master WHERE type='table' AND name='{table_name}'
        # For other databases, adapt the query accordingly.
        return "Custom schema prompt logic here."

In addition to schema_prompt, this class also includes other methods like additional_prompt and apply_prompt. These methods have default implementations that are database-agnostic, but you can customize them as needed.

Step 2: Define a Custom Dataset Class

Once you have your custom instance defined, you can create a custom dataset class by extending the premsql.datasets.base.Text2SQLBaseDataset class. This class handles the overall dataset setup and management. Here’s an example:

from premsql.datasets.base import Text2SQLBaseDataset
from typing import Optional, Union
from pathlib import Path
import logging

logger = logging.getLogger(__name__)

class CustomText2SQLDataset(Text2SQLBaseDataset):
    def __init__(
        self,
        split: str,
        dataset_folder: Optional[Union[str, Path]] = "./data",
        hf_token: Optional[str] = None,
        force_download: Optional[bool] = False,
    ):
        # Define your custom initialization logic here
        super().__init__(split=split, dataset_folder=dataset_folder)
        # Additional custom initialization if needed

    def setup_dataset(
        self,
        filter_by: Optional[tuple] = None,
        num_rows: Optional[int] = None,
        num_fewshot: Optional[int] = None,
        model_name_or_path: Optional[str] = None,
        prompt_template: Optional[str] = None,
    ):
        logger.info("Setting up the custom Text2SQL dataset.")
        # Custom setup logic can be added here if needed
        return super().setup_dataset(
            filter_by=filter_by,
            num_rows=num_rows,
            num_fewshot=num_fewshot,
            model_name_or_path=model_name_or_path,
            prompt_template=prompt_template,
        )

In the __init__ method, you can define any specific initialization logic needed for your dataset. Similarly, the setup_dataset method can be modified to include any custom setup steps or logic.

Summary

By following these steps, you can extend premsql to support different databases, implement custom preprocessing logic, or tailor the dataset setup to your specific needs. For a complete understanding, review the Text2SQLBaseDataset and Text2SQLBaseInstance classes in the premsql source code. We will be releasing a detailed tutorial soon on how to create datasets for different types of databases beyond SQLite.