Helps to make error handling prompts and datasets for error free inference and fine-tuning datasets for enforcing self correction property.
Error-handling prompts are crucial for refining model performance, especially in complex tasks like Text-to-SQL generation. The prompts help the model learn how to handle errors by providing additional context and guidance based on past mistakes. By training on these prompts, the model can self-correct during inference, improving the quality of its output.
To create an error-handling dataset or prompt, the data entity must include:
db_id: Identifier for the database.
db_path: Path to the database.
prompt: The original prompt used to generate the initial results.
In this guide, we’ll use the BirdBench dataset as an example, but you can use any text-to-SQL compatible dataset. We’ll define the generators (models) and executors (evaluation components) that play crucial roles in this pipeline.
from premsql.datasets.error_dataset import ErrorDatasetGeneratorfrom premsql.generators.huggingface import Text2SQLGeneratorHFfrom premsql.executors.from_langchain import ExecutorUsingLangChaingenerator = Text2SQLGeneratorHF( model_or_name_or_path="premai-io/prem-1B-SQL", experiment_name="testing_error_gen",type="train",# do not type: 'test' since this will be used during training device="cuda:0")executor = ExecutorUsingLangChain()
We define our existing training dataset using BirdBench. For demonstration, we’re using num_rows=10, but it’s recommended to use the full dataset when training. Typically, the error dataset will be smaller than the training set if you’re using a well-trained model.
from premsql.datasets import BirdDatasetbird_train = BirdDataset( split="train", dataset_folder="/path/to/dataset").setup_dataset(num_rows=10)
Run the error handling generation and save the results. The results are saved in ./experiments/train/<generator-experiment-name>/error_dataset.json. Use force=True to regenerate if needed.
{'db_id':'movie_platform','question':'Name movie titles released in year 1945. Sort the listing by the descending order of movie popularity.','SQL':'SELECT movie_title FROM movies WHERE movie_release_year = 1945 ORDER BY movie_popularity DESC LIMIT 1','prompt':'\n# Follow these instruction:\nYou will be given schemas of tables of a database. Your job is to write correct\nerror free SQL query based on the question asked. Please make sure:\n\n1. Do not add ``` at start / end of the query. It should be a single line query in a single line (string format)\n2. Make sure the column names are correct and exists in the table\n3. For column names which has a space with it, make sure you have put `` in that column name\n4. Think step by step and always check schema and question and the column names before writing the\nquery. \n\n# Database and Table Schema:\nCREATE TABLE "lists"\n(\n user_id INTEGER\n references lists_users (user_id),\n list_id INTEGER not null\n primary key,\n list_title TEXT,\n list_movie_number INTEGER,\n list_update_timestamp_utc TEXT,\n list_creation_timestamp_utc TEXT,\n list_followers INTEGER,\n list_url TEXT,\n list_comments INTEGER,\n list_description TEXT,\n list_cover_image_url TEXT,\n list_first_image_url TEXT,\n list_second_image_url TEXT,\n list_third_image_url TEXT\n)\nCREATE TABLE "movies"\n(\n movie_id INTEGER not null\n primary key,\n movie_title TEXT,\n movie_release_year INTEGER,\n movie_url TEXT,\n movie_title_language TEXT,\n movie_popularity INTEGER,\n movie_image_url TEXT,\n director_id TEXT,\n director_name TEXT,\n director_url TEXT\n)\nCREATE TABLE "ratings_users"\n(\n user_id INTEGER\n references lists_users (user_id),\n rating_date_utc TEXT,\n user_trialist INTEGER,\n user_subscriber INTEGER,\n user_avatar_image_url TEXT,\n user_cover_image_url TEXT,\n user_eligible_for_trial INTEGER,\n user_has_payment_method INTEGER\n)\nCREATE TABLE lists_users\n(\n user_id INTEGER not null ,\n list_id INTEGER not null ,\n list_update_date_utc TEXT,\n list_creation_date_utc TEXT,\n user_trialist INTEGER,\n user_subscriber INTEGER,\n user_avatar_image_url TEXT,\n user_cover_image_url TEXT,\n user_eligible_for_trial TEXT,\n user_has_payment_method TEXT,\n primary key (user_id, list_id),\n foreign key (list_id) references lists(list_id),\n foreign key (user_id) references lists(user_id)\n)\nCREATE TABLE ratings\n(\n movie_id INTEGER,\n rating_id INTEGER,\n rating_url TEXT,\n rating_score INTEGER,\n rating_timestamp_utc TEXT,\n critic TEXT,\n critic_likes INTEGER,\n critic_comments INTEGER,\n user_id INTEGER,\n user_trialist INTEGER,\n user_subscriber INTEGER,\n user_eligible_for_trial INTEGER,\n user_has_payment_method INTEGER,\n foreign key (movie_id) references movies(movie_id),\n foreign key (user_id) references lists_users(user_id),\n foreign key (rating_id) references ratings(rating_id),\n foreign key (user_id) references ratings_users(user_id)\n)\n\n\n\n# Here are some Examples on how to generate SQL statements and use column names:\n\n\n# Question: Name movie titles released in year 1945. Sort the listing by the descending order of movie popularity.\n\n# Generated SQL: SELECT movie_title FROM movies WHERE movie_release_year = 1945 ORDER BY movie_popularity DESC LIMIT 1;\n\n## Error Message\n\nNone\n\nCarefully review the original question and error message, then rewrite the SQL query to address the identified issues. \nEnsure your corrected query uses correct column names, \nfollows proper SQL syntax, and accurately answers the original question \nwithout introducing new errors.\n\n# SQL: \n','db_path':'/root/anindya/text2sql/data/bird/train/train_databases/movie_platform/movie_platform.sqlite'}
Each data point in the error dataset will include an error-handling prompt, which looks like this:
ERROR_HANDLING_PROMPT ="""{existing_prompt}# Generated SQL: {sql}## Error Message{error_msg}Carefully review the original question and error message, then rewrite the SQL query to address the identified issues.Ensure your corrected query uses correct column names,follows proper SQL syntax,and accurately answers the original question without introducing new errors.# SQL: """
{'input_ids': tensor([32013,32013,2042,...,207,16,32021]),'labels': tensor([-100,-100,-100,...,207,16,32021]),'raw':{'db_id':'movie_platform','question':'Name movie titles released in year 1945. Sort the listing by the descending order of movie popularity.','SQL':'SELECT movie_title FROM movies WHERE movie_release_year = 1945 ORDER BY movie_popularity DESC LIMIT 1','prompt':'<|begin▁of▁sentence|>You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\n### Instruction:\n\n# Follow these instruction:\nYou will be given schemas of tables of a database. Your job is to write correct\nerror free SQL query based on the question asked. Please make sure:\n\n1. Do not add ``` at start / end of the query. It should be a single line query in a single line (string format)\n2. Make sure the column names are correct and exists in the table\n3. For column names which has a space with it, make sure you have put `` in that column name\n4. Think step by step and always check schema and question and the column names before writing the\nquery. \n\n# Database and Table Schema:\nCREATE TABLE "lists"\n(\n user_id INTEGER\n references lists_users (user_id),\n list_id INTEGER not null\n primary key,\n list_title TEXT,\n list_movie_number INTEGER,\n list_update_timestamp_utc TEXT,\n list_creation_timestamp_utc TEXT,\n list_followers INTEGER,\n list_url TEXT,\n list_comments INTEGER,\n list_description TEXT,\n list_cover_image_url TEXT,\n list_first_image_url TEXT,\n list_second_image_url TEXT,\n list_third_image_url TEXT\n)\nCREATE TABLE "movies"\n(\n movie_id INTEGER not null\n primary key,\n movie_title TEXT,\n movie_release_year INTEGER,\n movie_url TEXT,\n movie_title_language TEXT,\n movie_popularity INTEGER,\n movie_image_url TEXT,\n director_id TEXT,\n director_name TEXT,\n director_url TEXT\n)\nCREATE TABLE "ratings_users"\n(\n user_id INTEGER\n references lists_users (user_id),\n rating_date_utc TEXT,\n user_trialist INTEGER,\n user_subscriber INTEGER,\n user_avatar_image_url TEXT,\n user_cover_image_url TEXT,\n user_eligible_for_trial INTEGER,\n user_has_payment_method INTEGER\n)\nCREATE TABLE lists_users\n(\n user_id INTEGER not null ,\n list_id INTEGER not null ,\n list_update_date_utc TEXT,\n list_creation_date_utc TEXT,\n user_trialist INTEGER,\n user_subscriber INTEGER,\n user_avatar_image_url TEXT,\n user_cover_image_url TEXT,\n user_eligible_for_trial TEXT,\n user_has_payment_method TEXT,\n primary key (user_id, list_id),\n foreign key (list_id) references lists(list_id),\n foreign key (user_id) references lists(user_id)\n)\nCREATE TABLE ratings\n(\n movie_id INTEGER,\n rating_id INTEGER,\n rating_url TEXT,\n rating_score INTEGER,\n rating_timestamp_utc TEXT,\n critic TEXT,\n critic_likes INTEGER,\n critic_comments INTEGER,\n user_id INTEGER,\n user_trialist INTEGER,\n user_subscriber INTEGER,\n user_eligible_for_trial INTEGER,\n user_has_payment_method INTEGER,\n foreign key (movie_id) references movies(movie_id),\n foreign key (user_id) references lists_users(user_id),\n foreign key (rating_id) references ratings(rating_id),\n foreign key (user_id) references ratings_users(user_id)\n)\n\n\n\n# Here are some Examples on how to generate SQL statements and use column names:\n\n\n# Question: Name movie titles released in year 1945. Sort the listing by the descending order of movie popularity.\n\n# Generated SQL: SELECT movie_title FROM movies WHERE movie_release_year = 1945 ORDER BY movie_popularity DESC LIMIT 1;\n\n## Error Message\n\nNone\n\nCarefully review the original question and error message, then rewrite the SQL query to address the identified issues. \nEnsure your corrected query uses correct column names, \nfollows proper SQL syntax, and accurately answers the original question \nwithout introducing new errors.\n\n# SQL: \n\n','db_path':'/root/anindya/text2sql/data/bird/train/train_databases/movie_platform/movie_platform.sqlite'}}
That’s how you generate an error-handling dataset. This dataset is compatible with other premsql datasets, allowing you to mix and use them collectively for fine-tuning. By incorporating error-handling prompts, you can significantly enhance your model’s capability to handle errors, leading to more robust and reliable performance in production.