Learn how to use and customize premsql datasets for Text-to-SQL tasks, including working with available datasets, creating your own, and extending functionalities.
premsql provides a simple API to use various pre-processed datasets for Text-to-SQL tasks. Text-to-SQL is complex as it requires data dependencies on databases and tables. The premsql datasets help streamline this by providing easy access to datasets and enabling you to create your own datasets with private databases.
To get started, you’ll use the Text2SQLDataset class from the premsql package. Here’s an example using the BirdBench dataset:
from premsql.datasets import Text2SQLDatasetfrom premsql.utils import print_databird_dataset = Text2SQLDataset( dataset_name='bird', split="train", force_download=False, dataset_folder="/path/to/your/data"# change this to the path where you want to store the dataset)
{'db_id':'movie_platform','question':'Name movie titles released in year 1945. Sort the listing by the descending order of movie popularity.','evidence':'released in the year 1945 refers to movie_release_year = 1945;','SQL':'SELECT movie_title FROM movies WHERE movie_release_year = 1945 ORDER BY movie_popularity DESC LIMIT 1'}
You can view the available filters for the dataset:
print(bird_dataset.filters_available)
['db_id']
Filters are those keys in the dataset that can be used to filter the dataset based on specific criteria. In our case of BirdBench, we can only
filter by db_id. PremSQL automatically detects which filters are available for a dataset and provides them as a list.
To load the processed dataset, use the setup_dataset method. This will process and return the dataset object. The method has several optional parameters:
filter_by: Filters the dataset based on the provided filter.
num_rows: Limits the number of rows to return.
num_fewshot: Defines how many few-shot examples to include in the prompt.
model_name_or_path: Applies the model-specific prompt template and tokenizes the dataset if provided.
prompt_template: Custom prompt template; defaults are available in the premsql prompts module.
Example:
# Set up the BirdBench datasetbird_dataset = bird_dataset.setup_dataset( model_name_or_path="premai-io/prem-1B-SQL", num_fewshot=3, num_rows=3)print_data(bird_dataset[0])
{'input_ids': tensor([32013,32013,2042,...,207,16,32021]),'labels': tensor([-100,-100,-100,...,207,16,32021]),'raw':{'db_id':'movie_platform','question':'Name movie titles released in year 1945. Sort the listing by the descending order of movie popularity.','evidence':'released in the year 1945 refers to movie_release_year = 1945;','SQL':'SELECT movie_title FROM movies WHERE movie_release_year = 1945 ORDER BY movie_popularity DESC LIMIT 1','db_path':'/root/anindya/text2sql/data/bird/train/train_databases/movie_platform/movie_platform.sqlite','prompt':'<|begin▁of▁sentence|>You are an AI programming assistant, utilizing the Deepseek Coder model, develo....tles released in year 1945. Sort the listing by the descending order of movie popularity.\n\n# SQL: \n\n'}}
The difference between loading and setting up a dataset is, in loading, we generally
download the dataset or load the raw dataset from some folder. Setting up a dataset means
processing the dataset (which includes inserting schemas, few shot prompts, adding customization, filtering etc)
based on user’s requirements.
2024-09-09 04:27:12,344 - [BIRD-DATASET] - INFO - Loaded Bird Dataset2024-09-09 04:27:12,345 - [BIRD-DATASET] - INFO - Setting up Bird DatasetApplying prompt: 100%|██████████| 3/3 [00:00<00:00, 1908.24it/s]
{'db_id':'movie_platform','question':'Name movie titles released in year 1945. Sort the listing by the descending order of movie popularity.','evidence':'released in the year 1945 refers to movie_release_year = 1945;','SQL':'SELECT movie_title FROM movies WHERE movie_release_year = 1945 ORDER BY movie_popularity DESC LIMIT 1','db_path':'/root/anindya/text2sql/data/bird/train/train_databases/movie_platform/movie_platform.sqlite','prompt':'\n# Follow these instruction:\nYou will be given schemas of tables of a database. Your job is to write....itles released in year 1945. Sort the listing by the descending order of movie popularity.\n\n# SQL: \n'}
We are now loading the BirdBench validation dataset. And as you can see there are two types of filters available. These are: db_id or difficulty
You can filter datasets based on available criteria, such as db_id or difficulty.
Example of filtering by difficulty:
bird_validation = Text2SQLDataset( dataset_name='bird', split="validation", force_download=False, dataset_folder="../data").setup_dataset( model_name_or_path=None, num_fewshot=3, num_rows=100, filter_by=("difficulty","simple"))# Count the number of simple difficulty examplessimple_count =len([example for example in bird_validation if example["difficulty"]=="simple"])print(simple_count)
2024-09-09 04:27:20,270 - [BIRD-DATASET] - INFO - Loaded Bird Dataset2024-09-09 04:27:20,271 - [BIRD-DATASET] - INFO - Setting up Bird DatasetApplying prompt: 100%|██████████| 100/100 [00:00<00:00, 2101.37it/s]100
{'id':5097,'question':'What is the total volume of timber sold by each salesperson, sorted by salesperson?','schema':"CREATE TABLE salesperson (salesperson_id INT, name TEXT, region TEXT); INSERT INTO salesperson (salesperson_id, name, region) VALUES (1, 'John Doe', 'North'), (2, 'Jane Smith', 'South'); CREATE TABLE timber_sales (sales_id INT, salesperson_id INT, volume REAL, sale_date DATE); INSERT INTO timber_sales (sales_id, salesperson_id, volume, sale_date) VALUES (1, 1, 120, '2021-01-01'), (2, 1, 150, '2021-02-01'), (3, 2, 180, '2021-01-01');",'SQL':'SELECT salesperson_id, name, SUM(volume) as total_volume FROM timber_sales JOIN salesperson ON timber_sales.salesperson_id = salesperson.salesperson_id GROUP BY salesperson_id, name ORDER BY total_volume DESC;','context':"CREATE TABLE salesperson (salesperson_id INT, name TEXT, region TEXT); INSERT INTO salesperson (salesperson_id, name, region) VALUES (1, 'John Doe', 'North'), (2, 'Jane Smith', 'South'); CREATE TABLE timber_sales (sales_id INT, salesperson_id INT, volume REAL, sale_date DATE); INSERT INTO timber_sales (sales_id, salesperson_id, volume, sale_date) VALUES (1, 1, 120, '2021-01-01'), (2, 1, 150, '2021-02-01'), (3, 2, 180, '2021-01-01');",'task_type':'analytics and reporting','complexity':'single join','db_id':'forestry','db_path':None,'prompt':'<|begin▁of▁sentence|>You are an AI programming assistant, utilizing the Deepseek Coder model, develo....tion: What is the total volume of timber sold by each salesperson, sorted by salesperson?\n\n# SQL: \n\n'}
premsql datasets support merging, allowing you to pack multiple datasets together for combined use.
This is useful when training on multiple datasets or you want to do validation on a combination of datasets.
# Merge datasetsmerged_dataset =[*bird_dataset,*spider_dataset,*domains,*gretel_dataset]print(f"Length of bird dataset: {len(bird_dataset)}")print(f"Length of spider dataset: {len(spider_dataset)}")print(f"Length of domains dataset: {len(domains)}")print(f"Length of gretel dataset: {len(gretel_dataset)}")print(f"Length of merged dataset: {len(merged_dataset)}")print_data(merged_dataset[0])
Length of bird dataset: 3Length of spider dataset: 3Length of domains dataset: 3Length of gretel dataset: 3Length of merged dataset: 12
{'input_ids': tensor([32013,32013,2042,...,207,16,32021]),'labels': tensor([-100,-100,-100,...,207,16,32021]),'raw':{'db_id':'movie_platform','question':'Name movie titles released in year 1945. Sort the listing by the descending order of movie popularity.','evidence':'released in the year 1945 refers to movie_release_year = 1945;','SQL':'SELECT movie_title FROM movies WHERE movie_release_year = 1945 ORDER BY movie_popularity DESC LIMIT 1','db_path':'/root/anindya/text2sql/data/bird/train/train_databases/movie_platform/movie_platform.sqlite','prompt':'<|begin▁of▁sentence|>You are an AI programming assistant, utilizing the Deepseek Coder model, develo....tles released in year 1945. Sort the listing by the descending order of movie popularity.\n\n# SQL: \n\n'}}
Here’s how a prompt looks when wrapped around a model’s prompt template:
print(gretel_dataset[0]["raw"]["prompt"])
<|begin▁of▁sentence|>You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer### Instruction:# Follow these instruction:You will be given schemas of tables of a database. Your job is to write correcterror free SQL query based on the question asked. Please make sure:1. Do not add ``` at start / end of the query. It should be a single line query in a single line (string format)2. Make sure the column names are correct and exists in the table3. For column names which has a space with it, make sure you have put `` in that column name4. Think step by step and always check schema and question and the column names before writing thequery. # Database and Table Schema:CREATE TABLE salesperson (salesperson_id INT, name TEXT, region TEXT); INSERT INTO salesperson (salesperson_id, name, region) VALUES (1, 'John Doe', 'North'), (2, 'Jane Smith', 'South'); CREATE TABLE timber_sales (sales_id INT, salesperson_id INT, volume REAL, sale_date DATE); INSERT INTO timber_sales (sales_id, salesperson_id, volume, sale_date) VALUES (1, 1, 120, '2021-01-01'), (2, 1, 150, '2021-02-01'), (3, 2, 180, '2021-01-01');# Here are some Examples on how to generate SQL statements and use column names:Question: What is the total volume of timber sold by each salesperson, sorted by salesperson?SQL: SELECT salesperson_id, name, SUM(volume) as total_volume FROM timber_sales JOIN salesperson ON timber_sales.salesperson_id = salesperson.salesperson_id GROUP BY salesperson_id, name ORDER BY total_volume DESC;# Question: What is the total volume of timber sold by each salesperson, sorted by salesperson?# SQL:
This prompts are made dynamically as long as the raw dataset contains the necessary keys: db_id, question and SQL (all case sensitives) and
the dataset is structured in the standardization followed by PremSQL. Let’s understand more about the standardization and how you can create your own dataset.
If you have some annotations for your Text to SQL task on your private data then you can create your own dataset.When it comes to tasks like Text to SQL, we face a lot of dependencies with databases at every step. For example, to make a dataset for Text to SQL,
you need a module that can connect to databases, fetch the schemas, and apply them to the prompt. That’s why we have introduced a more standard approach to providing inputs, which is used all across PremSQL.
If you use it for your private databases, then you can extend and customize it for your data and your use case.
However, you strictly need to organize your files in the following structure:
Each sub-folder in databases should contain a .sqlite file matching the sub-folder name. Your train or validation JSON file should contain a list of dictionaries with these required keys:
db_id: Corresponds to the folder and .sqlite file name.
question: User question.
SQL: The ground truth SQL query.
Example entry:
{"question_id":0,"db_id":"california_schools","question":"What is the highest eligible free rate for K-12 students in the schools in Alameda County?","evidence":"Eligible free rate for K-12 = `Free Meal Count (K-12)` / `Enrollment (K-12)`","SQL":"SELECT `Free Meal Count (K-12)` / `Enrollment (K-12)` FROM frpm WHERE `County Name` = 'Alameda' ORDER BY (CAST(`Free Meal Count (K-12)` AS REAL) / `Enrollment (K-12)`) DESC LIMIT 1","difficulty":"simple"}
You can load your dataset using:
from premsql.datasets import StandardDatasetvalidation_dataset = StandardDataset( split="validation",# it can be either train / validation / test depending on your dataset and the name of the json file dataset_path="/path/to/the/folder/containing/the/files/", database_folder_name="databases",# The same name of the folder json_file_name="validation.json",)train_dataset = StandardDataset( split="train",# it can be either train / validation / test depending on your dataset and the name of the json file dataset_path="/path/to/the/folder/containing/the/files/", database_folder_name="databases",# The same name of the folder json_file_name="train.json",)
So far, all the examples shown are tailored to .sqlite databases. However, you might encounter scenarios where:
You are working with different databases, such as PostgreSQL or other cloud-based database instances.
You want to implement custom logic before generating prompts.
You need to extend the utility of premsql to fit specific requirements.
This section will guide you through achieving these customizations.
Note: For scenario 1, you could also migrate a subset of your dataset to SQLite format. This allows you to annotate the dataset in SQLite and follow the standard approach to create a Text2SQL-compatible dataset for fine-tuning and inference.
If you require full customization beyond this, you can achieve it in three steps. A detailed tutorial will be available in future releases, but here’s a concise overview of the process:
A dataset instance manages operations on individual data points. To create a custom dataset instance, extend the premsql.datasets.base.Text2SQLBaseInstance class. Below is a blueprint of how to define your custom instance:
from premsql.datasets.base import Text2SQLBaseInstanceclassCustomDataInstance(Text2SQLBaseInstance):def__init__(self, dataset:list[dict])->None:super().__init__(dataset=dataset)defschema_prompt(self, db_path:str)->str:# Define your custom schema prompt logic here.# Fetch the schema from your database and format it as needed.# For SQLite databases, you could use:# SELECT sql FROM sqlite_master WHERE type='table' AND name='{table_name}'# For other databases, adapt the query accordingly.return"Custom schema prompt logic here."
In addition to schema_prompt, this class also includes other methods like additional_prompt and apply_prompt. These methods have default implementations that are database-agnostic, but you can customize them as needed.
Once you have your custom instance defined, you can create a custom dataset class by extending the premsql.datasets.base.Text2SQLBaseDataset class. This class handles the overall dataset setup and management. Here’s an example:
from premsql.datasets.base import Text2SQLBaseDatasetfrom typing import Optional, Unionfrom pathlib import Pathimport logginglogger = logging.getLogger(__name__)classCustomText2SQLDataset(Text2SQLBaseDataset):def__init__( self, split:str, dataset_folder: Optional[Union[str, Path]]="./data", hf_token: Optional[str]=None, force_download: Optional[bool]=False,):# Define your custom initialization logic heresuper().__init__(split=split, dataset_folder=dataset_folder)# Additional custom initialization if neededdefsetup_dataset( self, filter_by: Optional[tuple]=None, num_rows: Optional[int]=None, num_fewshot: Optional[int]=None, model_name_or_path: Optional[str]=None, prompt_template: Optional[str]=None,): logger.info("Setting up the custom Text2SQL dataset.")# Custom setup logic can be added here if neededreturnsuper().setup_dataset( filter_by=filter_by, num_rows=num_rows, num_fewshot=num_fewshot, model_name_or_path=model_name_or_path, prompt_template=prompt_template,)
In the __init__ method, you can define any specific initialization logic needed for your dataset. Similarly, the setup_dataset method can be modified to include any custom setup steps or logic.
By following these steps, you can extend premsql to support different databases, implement custom preprocessing logic, or tailor the dataset setup to your specific needs. For a complete understanding, review the Text2SQLBaseDataset and Text2SQLBaseInstance classes in the premsql source code. We will be releasing a detailed tutorial soon on how to create datasets for different types of databases beyond SQLite.