Datasets
Learn how to use and customize premsql datasets for Text-to-SQL tasks, including working with available datasets, creating your own, and extending functionalities.
Introduction
premsql
provides a simple API to use various pre-processed datasets for Text-to-SQL tasks. Text-to-SQL is complex as it requires data dependencies on databases and tables. The premsql
datasets help streamline this by providing easy access to datasets and enabling you to create your own datasets with private databases.
Available Datasets
Currently, the following datasets are readily available:
Let’s see how to use these datasets with premsql
.
Loading Datasets
To get started, you’ll use the Text2SQLDataset
class from the premsql
package. Here’s an example using the BirdBench dataset:
from premsql.datasets import Text2SQLDataset
from premsql.utils import print_data
bird_dataset = Text2SQLDataset(
dataset_name='bird', split="train", force_download=False,
dataset_folder="/path/to/your/data" # change this to the path where you want to store the dataset
)
BirdBench is one of the very popular and recent Text to SQL Benchmarking dataset. It comes with around 9K train and 1534 validation data composing of complex and diverse queries around real world datasets. They also have a private test data which they use to test different Text to SQL models and approaches by the community.
Accessing Raw Data
The dataset object provides two key methods:
raw_dataset
: Returns a dictionary with raw data from the JSON file.filters_available
: Lists the filters available for the dataset.
Example:
raw_bird_training_dataset = bird_dataset.raw_dataset
print(raw_bird_training_dataset[0])
Available Filters
You can view the available filters for the dataset:
print(bird_dataset.filters_available)
Filters are those keys in the dataset that can be used to filter the dataset based on specific criteria. In our case of BirdBench, we can only
filter by db_id
. PremSQL automatically detects which filters are available for a dataset and provides them as a list.
Setting Up the Dataset
To load the processed dataset, use the setup_dataset
method. This will process and return the dataset object. The method has several optional parameters:
filter_by
: Filters the dataset based on the provided filter.num_rows
: Limits the number of rows to return.num_fewshot
: Defines how many few-shot examples to include in the prompt.model_name_or_path
: Applies the model-specific prompt template and tokenizes the dataset if provided.prompt_template
: Custom prompt template; defaults are available in thepremsql
prompts module.
Example:
# Set up the BirdBench dataset
bird_dataset = bird_dataset.setup_dataset(
model_name_or_path="premai-io/prem-1B-SQL",
num_fewshot=3,
num_rows=3
)
print_data(bird_dataset[0])
The difference between loading
and setting up
a dataset is, in loading, we generally
download the dataset or load the raw dataset from some folder. Setting up a dataset means
processing the dataset (which includes inserting schemas, few shot prompts, adding customization, filtering etc)
based on user’s requirements.
Preview Without Tokenization
Tokenization sometimes can be computationally heavy. To preview the dataset without tokenizing, set model_name_or_path
to None
.
bird_dataset_without_tokenization = Text2SQLDataset(
dataset_name='bird', split="train", force_download=False,
dataset_folder="../data"
).setup_dataset(
model_name_or_path=None, num_fewshot=3, num_rows=3
)
print_data(bird_dataset_without_tokenization[0])
Filtering Datasets
We are now loading the BirdBench validation dataset. And as you can see there are two types of filters available. These are: db_id
or difficulty
You can filter datasets based on available criteria, such as db_id
or difficulty
.
Example of filtering by difficulty:
bird_validation = Text2SQLDataset(
dataset_name='bird', split="validation", force_download=False,
dataset_folder="../data"
).setup_dataset(
model_name_or_path=None,
num_fewshot=3,
num_rows=100,
filter_by=("difficulty", "simple")
)
# Count the number of simple difficulty examples
simple_count = len([example for example in bird_validation if example["difficulty"] == "simple"])
print(simple_count)
Using Other Available Datasets
As mentioned earlier, PremSQL provides access to some best readily available open source datasets like:
Let’s load and see how the other datasets looks like:
# Loading Spider Dataset
spider_dataset = Text2SQLDataset(
dataset_name="spider",
split="train",
dataset_folder="../data",
).setup_dataset(
num_fewshot=3,
num_rows=3,
model_name_or_path="premai-io/prem-1B-SQL",
)
# Loading Domains Dataset
domains = Text2SQLDataset(
dataset_name="domains",
split="train",
dataset_folder="../data",
).setup_dataset(
num_fewshot=3,
num_rows=3,
model_name_or_path="premai-io/prem-1B-SQL",
)
# Loading Gretel AI Dataset (Synthetic)
gretel_dataset = Text2SQLDataset(
dataset_name="gretel",
split="train",
dataset_folder="../data",
).setup_dataset(
num_fewshot=3,
num_rows=3,
model_name_or_path="premai-io/prem-1B-SQL",
)
print_data(gretel_dataset[0]["raw"])
Dataset Merging
premsql
datasets support merging, allowing you to pack multiple datasets together for combined use.
This is useful when training on multiple datasets or you want to do validation on a combination of datasets.
# Merge datasets
merged_dataset = [*bird_dataset, *spider_dataset, *domains, *gretel_dataset]
print(f"Length of bird dataset: {len(bird_dataset)}")
print(f"Length of spider dataset: {len(spider_dataset)}")
print(f"Length of domains dataset: {len(domains)}")
print(f"Length of gretel dataset: {len(gretel_dataset)}")
print(f"Length of merged dataset: {len(merged_dataset)}")
print_data(merged_dataset[0])
Understanding Prompts in PremSQL
Here’s how a prompt looks when wrapped around a model’s prompt template:
print(gretel_dataset[0]["raw"]["prompt"])
This prompts are made dynamically as long as the raw dataset contains the necessary keys: db_id
, question
and SQL
(all case sensitives) and
the dataset is structured in the standardization followed by PremSQL. Let’s understand more about the standardization and how you can create your own dataset.
Creating Your Own Dataset
If you have some annotations for your Text to SQL task on your private data then you can create your own dataset.When it comes to tasks like Text to SQL, we face a lot of dependencies with databases at every step. For example, to make a dataset for Text to SQL, you need a module that can connect to databases, fetch the schemas, and apply them to the prompt. That’s why we have introduced a more standard approach to providing inputs, which is used all across PremSQL. If you use it for your private databases, then you can extend and customize it for your data and your use case. However, you strictly need to organize your files in the following structure:
├── databases
│ ├── california_schools
│ ├── california_schools.sqlite
├── train.json
├── validation.json # Optional
Each sub-folder in databases
should contain a .sqlite
file matching the sub-folder name. Your train
or validation
JSON file should contain a list of dictionaries with these required keys:
db_id
: Corresponds to the folder and.sqlite
file name.question
: User question.SQL
: The ground truth SQL query.
Example entry:
{
"question_id": 0,
"db_id": "california_schools",
"question": "What is the highest eligible free rate for K-12 students in the schools in Alameda County?",
"evidence": "Eligible free rate for K-12 = `Free Meal Count (K-12)` / `Enrollment (K-12)`",
"SQL": "SELECT `Free Meal Count (K-12)` / `Enrollment (K-12)` FROM frpm WHERE `County Name` = 'Alameda' ORDER BY (CAST(`Free Meal Count (K-12)` AS REAL) / `Enrollment (K-12)`) DESC LIMIT 1",
"difficulty": "simple"
}
You can load your dataset using:
from premsql.datasets import StandardDataset
validation_dataset = StandardDataset(
split="validation", # it can be either train / validation / test depending on your dataset and the name of the json file
dataset_path="/path/to/the/folder/containing/the/files/",
database_folder_name="databases", # The same name of the folder
json_file_name="validation.json",
)
train_dataset = StandardDataset(
split="train", # it can be either train / validation / test depending on your dataset and the name of the json file
dataset_path="/path/to/the/folder/containing/the/files/",
database_folder_name="databases", # The same name of the folder
json_file_name="train.json",
)
Advanced Customization for Text-to-SQL Datasets
So far, all the examples shown are tailored to .sqlite
databases. However, you might encounter scenarios where:
- You are working with different databases, such as PostgreSQL or other cloud-based database instances.
- You want to implement custom logic before generating prompts.
- You need to extend the utility of
premsql
to fit specific requirements.
This section will guide you through achieving these customizations.
Note: For scenario 1, you could also migrate a subset of your dataset to SQLite format. This allows you to annotate the dataset in SQLite and follow the standard approach to create a Text2SQL-compatible dataset for fine-tuning and inference.
If you require full customization beyond this, you can achieve it in three steps. A detailed tutorial will be available in future releases, but here’s a concise overview of the process:
Step 1: Define a Custom Dataset Instance
A dataset instance manages operations on individual data points. To create a custom dataset instance, extend the premsql.datasets.base.Text2SQLBaseInstance
class. Below is a blueprint of how to define your custom instance:
from premsql.datasets.base import Text2SQLBaseInstance
class CustomDataInstance(Text2SQLBaseInstance):
def __init__(self, dataset: list[dict]) -> None:
super().__init__(dataset=dataset)
def schema_prompt(self, db_path: str) -> str:
# Define your custom schema prompt logic here.
# Fetch the schema from your database and format it as needed.
# For SQLite databases, you could use:
# SELECT sql FROM sqlite_master WHERE type='table' AND name='{table_name}'
# For other databases, adapt the query accordingly.
return "Custom schema prompt logic here."
In addition to schema_prompt
, this class also includes other methods like additional_prompt
and apply_prompt
. These methods have default implementations that are database-agnostic, but you can customize them as needed.
Step 2: Define a Custom Dataset Class
Once you have your custom instance defined, you can create a custom dataset class by extending the premsql.datasets.base.Text2SQLBaseDataset
class. This class handles the overall dataset setup and management. Here’s an example:
from premsql.datasets.base import Text2SQLBaseDataset
from typing import Optional, Union
from pathlib import Path
import logging
logger = logging.getLogger(__name__)
class CustomText2SQLDataset(Text2SQLBaseDataset):
def __init__(
self,
split: str,
dataset_folder: Optional[Union[str, Path]] = "./data",
hf_token: Optional[str] = None,
force_download: Optional[bool] = False,
):
# Define your custom initialization logic here
super().__init__(split=split, dataset_folder=dataset_folder)
# Additional custom initialization if needed
def setup_dataset(
self,
filter_by: Optional[tuple] = None,
num_rows: Optional[int] = None,
num_fewshot: Optional[int] = None,
model_name_or_path: Optional[str] = None,
prompt_template: Optional[str] = None,
):
logger.info("Setting up the custom Text2SQL dataset.")
# Custom setup logic can be added here if needed
return super().setup_dataset(
filter_by=filter_by,
num_rows=num_rows,
num_fewshot=num_fewshot,
model_name_or_path=model_name_or_path,
prompt_template=prompt_template,
)
In the __init__
method, you can define any specific initialization logic needed for your dataset. Similarly, the setup_dataset
method can be modified to include any custom setup steps or logic.
Summary
By following these steps, you can extend premsql
to support different databases, implement custom preprocessing logic, or tailor the dataset setup to your specific needs. For a complete understanding, review the Text2SQLBaseDataset
and Text2SQLBaseInstance
classes in the premsql
source code. We will be releasing a detailed tutorial soon on how to create datasets for different types of databases beyond SQLite.