Tutorial: Classification Fine-tuning¶
Let's walk through a quick example of fine-tuning the LM weights within a DSPy program. We'll apply to a simple 77-way classification task.
Our finetuned program will use a tiny Llama-3.2-1B
language model, hosted locally on your GPU. To make this more interesting, we'll assume that (i) we don't have any training labels but (ii) we have 500 unlabeled training examples.
Install dependencies and download data¶
Install the latest DSPy via pip install -U --pre dspy
and follow along. This tutorial depends on DSPy 2.6.0 (pre-release).
This tutorial requires a local GPU at the moment for inference, though we plan to support ollama serving for finetuned models as well.
You will also need the following dependencies:
> pip install "sglang[all]"; pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/
> pip install -U torch transformers accelerate trl peft
Dataset¶
For this tutorial, we will use the Banking77 dataset.
import dspy
import random
from dspy.datasets import DataLoader
from datasets import load_dataset
# Load the Banking77 dataset.
CLASSES = load_dataset("PolyAI/banking77", split="train", trust_remote_code=True).features['label'].names
kwargs = dict(fields=("text", "label"), input_keys=("text",), split="train", trust_remote_code=True)
# Load the first 2000 examples from the dataset, and assign a hint to each *training* example.
raw_data = [
dspy.Example(x, label=CLASSES[x.label]).with_inputs("text")
for x in DataLoader().from_huggingface(dataset_name="PolyAI/banking77", **kwargs)[:1000]
]
random.Random(0).shuffle(raw_data)
This dataset has 77 different categories for classification. Let's review some of them.
len(CLASSES), CLASSES[:10]
(77, ['activate_my_card', 'age_limit', 'apple_pay_or_google_pay', 'atm_support', 'automatic_top_up', 'balance_not_updated_after_bank_transfer', 'balance_not_updated_after_cheque_or_cash_deposit', 'beneficiary_not_allowed', 'cancel_transfer', 'card_about_to_expire'])
Let us sample 500 (unlabeled) queries from Banking77. We'll use these for our bootstrapped finetuning.
unlabeled_trainset = [dspy.Example(text=x.text).with_inputs("text") for x in raw_data[:500]]
unlabeled_trainset[0]
Example({'text': 'What if there is an error on the exchange rate?'}) (input_keys={'text'})
DSPy program¶
Let's say that we want a program that takes the text
and reasons step by step and then selects one of the classes from Banking77.
Note that this is meant mainly for illustration, or for cases where you want to inspect the model's reasoning, e.g. for a small degree of explainability. In other words, this type of task is not necessarily likely to benefit very much from explicit reasoning.
from typing import Literal
classify = dspy.ChainOfThought(f"text -> label: Literal{CLASSES}")
Bootstrapped finetuning¶
There are many ways to go about this, e.g. allowing the model to teach itself or using inference-time compute (e.g., ensembling) to identify cases of high confidence without labels.
Perhaps the simplest is to use a model that we'd expect can do a reasonable job at this task as a teacher of reasoning and classification, and to distill that to our small model. All of these patterns can be expressed in a handful of lines of code.
Let's set up the tiny Llama-3.2-1B-Instruct
as a student LM. We'll use GPT-4o-mini as a teacher LM.
from dspy.clients.lm_local import LocalProvider
student_lm_name = "meta-llama/Llama-3.2-1B-Instruct"
student_lm = dspy.LM(model=f"openai/local:{student_lm_name}", provider=LocalProvider(), max_tokens=2000)
teacher_lm = dspy.LM('openai/gpt-4o-mini', max_tokens=3000)
Now, let's assign classifiers to our LMs.
student_classify = classify.deepcopy()
student_classify.set_lm(student_lm)
teacher_classify = classify.deepcopy()
teacher_classify.set_lm(teacher_lm)
Let's now launch the bootstrapped finetuning. The word "bootstrapped" here means that the program itself will be invoked on the training inputs and the resulting traces seen over all modules will be recorded and used for finetuning. This is the weight-optimizing variant of the various BootstrapFewShot methods in DSPy.
On every question in the (unlabeled) training set, this will invoke the teacher program, which will produce reasoning and select a class. This will be traced and then constitute a training set for all modules (in this case, just the one CoT module) in the student program.
Note: If you have labels, you can pass metric
to the constructor of BootstrapFinetune
. If you want to apply this in practice, you can pass train_kwargs
to the constructor to control local LM training settings: device
, use_peft
, num_train_epochs
, per_device_train_batch_size
, gradient_accumulation_steps
, learning_rate
, max_seq_length
, packing
, bf16
, and output_dir
.
optimizer = dspy.BootstrapFinetune(num_threads=16) # if you *do* have labels, pass metric=your_metric here!
classify_ft = optimizer.compile(student_classify, teacher=teacher_classify, trainset=unlabeled_trainset)
Since this is a local model, we need to explicitly launch it.
classify_ft.get_lm().launch()
Validating the finetuned program¶
Let's now figure out if this was successful. We can ask the system one question and inspect its behavior.
classify_ft(text="I didn't receive my money earlier and it says the transaction is still in progress. Can you fix it?")
Prediction( reasoning='The user is inquiring about a specific issue, which they did not receive and is still showing as a pending transaction. This situation typically indicates a problem with the cash withdrawal process, as the user is not receiving the money they attempted to withdraw. The appropriate label for this scenario is "pending_cash_withdrawal," as it directly relates to the status of the cash withdrawal transaction.', label='pending_cash_withdrawal' )
We could also get a small set of gold labels and see if the system can generalize to unseen queries.
devset = raw_data[500:600]
devset[0]
Example({'text': 'Which fiat currencies do you currently support? Will this change in this future?', 'label': 'fiat_currency_support'}) (input_keys={'text'})
Let's define an evaluator on this small dev set, where the metric ignores the reasoning and checks that the label is exactly correct.
metric = (lambda x, y, trace=None: x.label == y.label)
evaluate = dspy.Evaluate(devset=devset, metric=metric, display_progress=True, display_table=5, num_threads=16)
Now, let's evaluate the finetuned 1B classifier.
evaluate(classify_ft)
Average Metric: 51.00 / 99 (51.5%): 100%|██████████| 100/100 [00:35<00:00, 2.79it/s]
text | example_label | reasoning | pred_label | <lambda> | label | |
---|---|---|---|---|---|---|
0 | Which fiat currencies do you currently support? Will this change i... | fiat_currency_support | The user is inquiring about the current support for fiat currencie... | fiat_currency_support | ✔️ [True] | NaN |
1 | I didn't receive my money earlier and it says the transaction is s... | pending_cash_withdrawal | The user is inquiring about a specific issue, which they did not r... | pending_cash_withdrawal | ✔️ [True] | NaN |
2 | what currencies do you accept? | fiat_currency_support | The user is inquiring about the currencies that are accepted, whic... | fiat_currency_support | ✔️ [True] | NaN |
3 | Where can I find your exchange rates? | exchange_rate | The user is inquiring about where to find exchange rates, which re... | exchange_rate | ✔️ [True] | NaN |
4 | why hasnt my card come in yet? | card_arrival | The user is inquiring about the status of their card, which sugges... | card_arrival | ✔️ [True] | NaN |
51.0
Not bad, given that we started with no labels of the task. Even though we have no labels, you can use various strategies to boost the quality of the bootstrapped training data.
To try that next, let's free our GPU memory by killing the finetuned LM.
classify_ft.get_lm().kill()
Bootstrapped finetuning against a metric¶
If you have labels, you can generally boost this by a large margin. To do so, you can pass a metric
to BootstrapFinetune, which it will use for filtering the trajectories over your program before it builds the finetuning data.
optimizer = dspy.BootstrapFinetune(num_threads=16, metric=metric)
classify_ft = optimizer.compile(student_classify, teacher=teacher_classify, trainset=raw_data[:500])
Let's now launch and evaluate this.
classify_ft.get_lm().launch()
evaluate(classify_ft)
Average Metric: 85.00 / 98 (86.7%): 100%|██████████| 100/100 [00:46<00:00, 2.14it/s]
text | example_label | reasoning | pred_label | <lambda> | label | |
---|---|---|---|---|---|---|
0 | Which fiat currencies do you currently support? Will this change i... | fiat_currency_support | The user is inquiring about the fiat currencies currently supporte... | fiat_currency_support | ✔️ [True] | NaN |
1 | I didn't receive my money earlier and it says the transaction is s... | pending_cash_withdrawal | The user is inquiring about an unexpected fee on their account, wh... | extra_charge_on_statement | NaN | |
2 | what currencies do you accept? | fiat_currency_support | The user is inquiring about the types of currencies that are accep... | fiat_currency_support | ✔️ [True] | NaN |
3 | Where can I find your exchange rates? | exchange_rate | The user is inquiring about where to find exchange rates, which re... | exchange_rate | ✔️ [True] | NaN |
4 | why hasnt my card come in yet? | card_arrival | The user is inquiring about the status of their card delivery, whi... | card_arrival | ✔️ [True] | NaN |
85.0
That's quite a bit better, given just 500 labels. In fact, it seems to be a lot stronger than the teacher LM gets out of the box!
evaluate(teacher_classify)
Average Metric: 55.00 / 100 (55.0%): 100%|██████████| 100/100 [00:11<00:00, 8.88it/s]
2025/01/08 12:38:35 INFO dspy.evaluate.evaluate: Average Metric: 55 / 100 (55.0%)
text | example_label | reasoning | pred_label | <lambda> | |
---|---|---|---|---|---|
0 | Which fiat currencies do you currently support? Will this change i... | fiat_currency_support | The user is inquiring about the fiat currencies supported by the s... | fiat_currency_support | ✔️ [True] |
1 | I didn't receive my money earlier and it says the transaction is s... | pending_cash_withdrawal | The user is experiencing an issue with a transaction that is still... | pending_transfer | |
2 | what currencies do you accept? | fiat_currency_support | The question is asking about the types of currencies accepted, whi... | fiat_currency_support | ✔️ [True] |
3 | Where can I find your exchange rates? | exchange_rate | The user is inquiring about where to find exchange rates, which re... | exchange_rate | ✔️ [True] |
4 | why hasnt my card come in yet? | card_arrival | The user is inquiring about the status of their card delivery, whi... | card_delivery_estimate |
55.0
And thanks to bootstrapping, the model learns to apply our modules to get the right label, in this case, reasoning explicitly:
classify_ft(text="why hasnt my card come in yet?")
dspy.inspect_history()
[2025-01-08T12:39:42.143798] System message: Your input fields are: 1. `text` (str) Your output fields are: 1. `reasoning` (str) 2. `label` (Literal[activate_my_card, age_limit, apple_pay_or_google_pay, atm_support, automatic_top_up, balance_not_updated_after_bank_transfer, balance_not_updated_after_cheque_or_cash_deposit, beneficiary_not_allowed, cancel_transfer, card_about_to_expire, card_acceptance, card_arrival, card_delivery_estimate, card_linking, card_not_working, card_payment_fee_charged, card_payment_not_recognised, card_payment_wrong_exchange_rate, card_swallowed, cash_withdrawal_charge, cash_withdrawal_not_recognised, change_pin, compromised_card, contactless_not_working, country_support, declined_card_payment, declined_cash_withdrawal, declined_transfer, direct_debit_payment_not_recognised, disposable_card_limits, edit_personal_details, exchange_charge, exchange_rate, exchange_via_app, extra_charge_on_statement, failed_transfer, fiat_currency_support, get_disposable_virtual_card, get_physical_card, getting_spare_card, getting_virtual_card, lost_or_stolen_card, lost_or_stolen_phone, order_physical_card, passcode_forgotten, pending_card_payment, pending_cash_withdrawal, pending_top_up, pending_transfer, pin_blocked, receiving_money, Refund_not_showing_up, request_refund, reverted_card_payment?, supported_cards_and_currencies, terminate_account, top_up_by_bank_transfer_charge, top_up_by_card_charge, top_up_by_cash_or_cheque, top_up_failed, top_up_limits, top_up_reverted, topping_up_by_card, transaction_charged_twice, transfer_fee_charged, transfer_into_account, transfer_not_received_by_recipient, transfer_timing, unable_to_verify_identity, verify_my_identity, verify_source_of_funds, verify_top_up, virtual_card_not_working, visa_or_mastercard, why_verify_identity, wrong_amount_of_cash_received, wrong_exchange_rate_for_cash_withdrawal]) All interactions will be structured in the following way, with the appropriate values filled in. [[ ## text ## ]] {text} [[ ## reasoning ## ]] {reasoning} [[ ## label ## ]] {label} # note: the value you produce must be one of: activate_my_card; age_limit; apple_pay_or_google_pay; atm_support; automatic_top_up; balance_not_updated_after_bank_transfer; balance_not_updated_after_cheque_or_cash_deposit; beneficiary_not_allowed; cancel_transfer; card_about_to_expire; card_acceptance; card_arrival; card_delivery_estimate; card_linking; card_not_working; card_payment_fee_charged; card_payment_not_recognised; card_payment_wrong_exchange_rate; card_swallowed; cash_withdrawal_charge; cash_withdrawal_not_recognised; change_pin; compromised_card; contactless_not_working; country_support; declined_card_payment; declined_cash_withdrawal; declined_transfer; direct_debit_payment_not_recognised; disposable_card_limits; edit_personal_details; exchange_charge; exchange_rate; exchange_via_app; extra_charge_on_statement; failed_transfer; fiat_currency_support; get_disposable_virtual_card; get_physical_card; getting_spare_card; getting_virtual_card; lost_or_stolen_card; lost_or_stolen_phone; order_physical_card; passcode_forgotten; pending_card_payment; pending_cash_withdrawal; pending_top_up; pending_transfer; pin_blocked; receiving_money; Refund_not_showing_up; request_refund; reverted_card_payment?; supported_cards_and_currencies; terminate_account; top_up_by_bank_transfer_charge; top_up_by_card_charge; top_up_by_cash_or_cheque; top_up_failed; top_up_limits; top_up_reverted; topping_up_by_card; transaction_charged_twice; transfer_fee_charged; transfer_into_account; transfer_not_received_by_recipient; transfer_timing; unable_to_verify_identity; verify_my_identity; verify_source_of_funds; verify_top_up; virtual_card_not_working; visa_or_mastercard; why_verify_identity; wrong_amount_of_cash_received; wrong_exchange_rate_for_cash_withdrawal [[ ## completed ## ]] In adhering to this structure, your objective is: Given the fields `text`, produce the fields `label`. User message: [[ ## text ## ]] why hasnt my card come in yet? Respond with the corresponding output fields, starting with the field `[[ ## reasoning ## ]]`, then `[[ ## label ## ]]` (must be formatted as a valid Python Literal[activate_my_card, age_limit, apple_pay_or_google_pay, atm_support, automatic_top_up, balance_not_updated_after_bank_transfer, balance_not_updated_after_cheque_or_cash_deposit, beneficiary_not_allowed, cancel_transfer, card_about_to_expire, card_acceptance, card_arrival, card_delivery_estimate, card_linking, card_not_working, card_payment_fee_charged, card_payment_not_recognised, card_payment_wrong_exchange_rate, card_swallowed, cash_withdrawal_charge, cash_withdrawal_not_recognised, change_pin, compromised_card, contactless_not_working, country_support, declined_card_payment, declined_cash_withdrawal, declined_transfer, direct_debit_payment_not_recognised, disposable_card_limits, edit_personal_details, exchange_charge, exchange_rate, exchange_via_app, extra_charge_on_statement, failed_transfer, fiat_currency_support, get_disposable_virtual_card, get_physical_card, getting_spare_card, getting_virtual_card, lost_or_stolen_card, lost_or_stolen_phone, order_physical_card, passcode_forgotten, pending_card_payment, pending_cash_withdrawal, pending_top_up, pending_transfer, pin_blocked, receiving_money, Refund_not_showing_up, request_refund, reverted_card_payment?, supported_cards_and_currencies, terminate_account, top_up_by_bank_transfer_charge, top_up_by_card_charge, top_up_by_cash_or_cheque, top_up_failed, top_up_limits, top_up_reverted, topping_up_by_card, transaction_charged_twice, transfer_fee_charged, transfer_into_account, transfer_not_received_by_recipient, transfer_timing, unable_to_verify_identity, verify_my_identity, verify_source_of_funds, verify_top_up, virtual_card_not_working, visa_or_mastercard, why_verify_identity, wrong_amount_of_cash_received, wrong_exchange_rate_for_cash_withdrawal]), and then ending with the marker for `[[ ## completed ## ]]`. Response: [[ ## reasoning ## ]] The user is inquiring about the status of their card delivery, which suggests they are concerned about when they will receive their card. This aligns with the topic of card arrival and delivery estimates. [[ ## label ## ]] card_arrival [[ ## completed ## ]]