Tuner
Helps to fine-tune Open Source models using different readily available datasets or custom datasets for Text to SQL tasks.
premsql tuner
is a module designed to fine-tune models specifically for text-to-SQL tasks. The module offers multiple ways of fine-tuning, providing flexibility based on the needs of your project.
Supported Fine-Tuning Methods
- Full Fine-Tuning: Standard fine-tuning of the model with all its parameters.
- PEFT using LoRA: Parameter-Efficient Fine-Tuning with LoRA (Low-Rank Adaptation) for faster and efficient training.
- PEFT using QLoRA: Another PEFT approach using Quantized LoRA, optimizing resource use during training.
In addition to these methods, you can create custom fine-tuning pipelines using the components and tools provided by premsql.
Prerequisites
This tutorial assumes that you are familiar with the following concepts and resources:
Datasets Tutorial
Learn how to utilize pre-processed datasets for Text-to-SQL tasks. This tutorial covers dataset evaluation, fine-tuning, and creation of custom datasets.
Generators Tutorial
A step-by-step guide on how to use Text-to-SQL generators to create SQL queries from user input and specified database sources.
Executors Tutorial
Learn how to connect to databases and execute SQL queries generated by models. This tutorial covers execution, troubleshooting, and best practices.
Evaluators Tutorial
Understand the evaluation of Text-to-SQL models with metrics like execution accuracy and Valid Efficiency Score (VES).
PremSQL Error Handling
Helps to make error handling prompts and datasets for error free inference and fine-tuning datasets for enforcing self correction property. .
Additionally, knowledge of Hugging Face’s TRL library will help you better understand the internal working of PEFT using LoRA.
Let’s start by importing the required packages:
from premsql.datasets import (
BirdDataset,
SpiderUnifiedDataset,
DomainsDataset,
GretelAIDataset,
)
from premsql.executors.from_sqlite import SQLiteExecutor
from premsql.datasets import Text2SQLDataset
from premsql.tuner.peft import Text2SQLPeftTuner
from premsql.datasets.error_dataset import ErrorDatasetGenerator
Defining Paths and Model
Define the path to the dataset and specify the model that will be fine-tuned:
path = "/root/anindya/text2sql/data"
model_name_or_path = "premai-io/prem-1B-SQL"
Setting Up Datasets
BirdBench Training Dataset
First, let’s set up the BirdBench training dataset. Here, we use a subset for demonstration, but you should use the full dataset during actual fine-tuning.
bird_train = BirdDataset(split="train", dataset_folder=path).setup_dataset(
num_rows=100,
)
Spider Dataset
Next, we load the Spider dataset:
spider_train = SpiderUnifiedDataset(split="train", dataset_folder="./data").setup_dataset(
num_rows=100
)
Domains Dataset
Load the domains dataset:
domains_dataset = DomainsDataset(split="train", dataset_folder="./data").setup_dataset(
num_rows=100,
)
Gretel AI Synthetic Dataset
We also use a synthetic Text-to-SQL dataset from Gretel AI:
gertelai_dataset = GretelAIDataset(split="train", dataset_folder="./data").setup_dataset(
num_rows=100,
)
Error Dataset
Finally, we load an error dataset, which is essential for handling edge cases and improving model robustness. More information on error datasets can be found here.
existing_error_dataset = ErrorDatasetGenerator.from_existing(
experiment_name="testing_error_gen",
)
Merging Datasets
Now, we merge all the loaded datasets into a single entity. This consolidated dataset will be used for fine-tuning.
merged_dataset = [
*spider_train,
*bird_train,
*domains_dataset,
*gertelai_dataset,
*existing_error_dataset
]
Validation Dataset
To validate the fine-tuning process, we initialize the BirdBench validation dataset. Validation for Text-to-SQL differs from typical LLM validation, as it involves executing generated SQL queries against a database and comparing results with ground truth tables.
The premsql module provides a custom callback to facilitate this evaluation process. Simply define your validation datasets, and the callback handles the rest.
bird_dev = Text2SQLDataset(dataset_name="bird", split="validation", dataset_folder=path).setup_dataset(
num_rows=10,
filter_by=("difficulty", "challenging")
)
Initializing the Tuner
The Text2SQLPeftTuner
class manages the fine-tuning process. Initialize it with the model path and experiment name to store the logs.
tuner = Text2SQLPeftTuner(
model_name_or_path=model_name_or_path,
experiment_name="lora_tuning"
)
Training the Model
Use the train
function of the tuner class to start the fine-tuning process. You need to provide the following arguments:
train_datasets
: The merged datasets used for training.output_dir
: Directory to save the fine-tuned model weights.num_train_epochs
: Number of training epochs.per_device_train_batch_size
: Training batch size per device.gradient_accumulation_steps
: Number of steps for gradient accumulation.evaluation_dataset
: The dataset used for validation; set toNone
if no evaluation is required.eval_steps
: Number of steps between evaluations.max_seq_length
: Maximum sequence length permissible by the model.executor
: Use an executor if you have defined an evaluation dataset.filter_eval_results_by
: Filters evaluation results based on specified criteria (e.g., difficulty).
You can also pass additional parameters compatible with transformers TrainingArguments using **kwargs
, which will override any default settings.
Here’s how to initiate the training:
tuner.train(
train_datasets=merged_dataset,
output_dir="./output",
num_train_epochs=1,
per_device_train_batch_size=1,
gradient_accumulation_steps=1,
evaluation_dataset=bird_dev,
eval_steps=100,
max_seq_length=1024,
executor=SQLiteExecutor(),
filter_eval_results_by=("difficulty", "challenging")
)
Conclusion
This command will start the training process.
All model weights will be stored in the ./output
directory, and
logs will be saved in ./experiments/train/<experiment-name>
.
For an end-to-end example, you can check out our
fine-tuning using LoRA script.
With this guide, you should now be able to fine-tune models using premsql efficiently! Also nothing is tightly coupled with our datasets, if you have made the datasets (which in the end is the standard input and output tokens tensor)