premsql tuner is a module designed to fine-tune models specifically for text-to-SQL tasks. The module offers multiple ways of fine-tuning, providing flexibility based on the needs of your project.

Supported Fine-Tuning Methods

  1. Full Fine-Tuning: Standard fine-tuning of the model with all its parameters.
  2. PEFT using LoRA: Parameter-Efficient Fine-Tuning with LoRA (Low-Rank Adaptation) for faster and efficient training.
  3. PEFT using QLoRA: Another PEFT approach using Quantized LoRA, optimizing resource use during training.

In addition to these methods, you can create custom fine-tuning pipelines using the components and tools provided by premsql.

Prerequisites

This tutorial assumes that you are familiar with the following concepts and resources:

Additionally, knowledge of Hugging Face’s TRL library will help you better understand the internal working of PEFT using LoRA.

Let’s start by importing the required packages:

from premsql.datasets import (
    BirdDataset,
    SpiderUnifiedDataset,
    DomainsDataset,
    GretelAIDataset,
)

from premsql.executors.from_sqlite import SQLiteExecutor
from premsql.datasets import Text2SQLDataset
from premsql.tuner.peft import Text2SQLPeftTuner
from premsql.datasets.error_dataset import ErrorDatasetGenerator

Defining Paths and Model

Define the path to the dataset and specify the model that will be fine-tuned:

path = "/root/anindya/text2sql/data"
model_name_or_path = "premai-io/prem-1B-SQL"

Setting Up Datasets

BirdBench Training Dataset

First, let’s set up the BirdBench training dataset. Here, we use a subset for demonstration, but you should use the full dataset during actual fine-tuning.

bird_train = BirdDataset(split="train", dataset_folder=path).setup_dataset(
    num_rows=100,
)

Spider Dataset

Next, we load the Spider dataset:

spider_train = SpiderUnifiedDataset(split="train", dataset_folder="./data").setup_dataset(
    num_rows=100
)

Domains Dataset

Load the domains dataset:

domains_dataset = DomainsDataset(split="train", dataset_folder="./data").setup_dataset(
    num_rows=100,
)

Gretel AI Synthetic Dataset

We also use a synthetic Text-to-SQL dataset from Gretel AI:

gertelai_dataset = GretelAIDataset(split="train", dataset_folder="./data").setup_dataset(
    num_rows=100,
)

Error Dataset

Finally, we load an error dataset, which is essential for handling edge cases and improving model robustness. More information on error datasets can be found here.

existing_error_dataset = ErrorDatasetGenerator.from_existing(
    experiment_name="testing_error_gen",
)

Merging Datasets

Now, we merge all the loaded datasets into a single entity. This consolidated dataset will be used for fine-tuning.

merged_dataset = [
    *spider_train,
    *bird_train,
    *domains_dataset,
    *gertelai_dataset,
    *existing_error_dataset
]

Validation Dataset

To validate the fine-tuning process, we initialize the BirdBench validation dataset. Validation for Text-to-SQL differs from typical LLM validation, as it involves executing generated SQL queries against a database and comparing results with ground truth tables.

The premsql module provides a custom callback to facilitate this evaluation process. Simply define your validation datasets, and the callback handles the rest.

bird_dev = Text2SQLDataset(dataset_name="bird", split="validation", dataset_folder=path).setup_dataset(
    num_rows=10,
    filter_by=("difficulty", "challenging")
)

Initializing the Tuner

The Text2SQLPeftTuner class manages the fine-tuning process. Initialize it with the model path and experiment name to store the logs.

tuner = Text2SQLPeftTuner(
    model_name_or_path=model_name_or_path,
    experiment_name="lora_tuning"
)

Training the Model

Use the train function of the tuner class to start the fine-tuning process. You need to provide the following arguments:

  • train_datasets: The merged datasets used for training.
  • output_dir: Directory to save the fine-tuned model weights.
  • num_train_epochs: Number of training epochs.
  • per_device_train_batch_size: Training batch size per device.
  • gradient_accumulation_steps: Number of steps for gradient accumulation.
  • evaluation_dataset: The dataset used for validation; set to None if no evaluation is required.
  • eval_steps: Number of steps between evaluations.
  • max_seq_length: Maximum sequence length permissible by the model.
  • executor: Use an executor if you have defined an evaluation dataset.
  • filter_eval_results_by: Filters evaluation results based on specified criteria (e.g., difficulty).

You can also pass additional parameters compatible with transformers TrainingArguments using **kwargs, which will override any default settings.

Here’s how to initiate the training:

tuner.train(
    train_datasets=merged_dataset,
    output_dir="./output",
    num_train_epochs=1,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=1,
    evaluation_dataset=bird_dev,
    eval_steps=100,
    max_seq_length=1024,
    executor=SQLiteExecutor(),
    filter_eval_results_by=("difficulty", "challenging")
)

Conclusion

This command will start the training process. All model weights will be stored in the ./output directory, and logs will be saved in ./experiments/train/<experiment-name>. For an end-to-end example, you can check out our fine-tuning using LoRA script.

With this guide, you should now be able to fine-tune models using premsql efficiently! Also nothing is tightly coupled with our datasets, if you have made the datasets (which in the end is the standard input and output tokens tensor)