def compile(
self, student: Program, trainset: List[Example], teacher: Optional[Union[Program, List[Program]]] = None
) -> Program:
# TODO: Print statements can be converted to logger.info if we ensure
# that the default DSPy logger logs info level messages in notebook
# environments.
logger.info("Preparing the student and teacher programs...")
set_missing_predictor_lms(student)
logger.info("Bootstrapping data...")
trace_data = []
teachers = teacher if isinstance(teacher, list) else [teacher]
teachers = [prepare_teacher(student, t) for t in teachers]
for t in teachers:
set_missing_predictor_lms(t)
trace_data += bootstrap_trace_data(
program=t, dataset=trainset, metric=self.metric, num_threads=self.num_threads
)
logger.info("Preparing the train data...")
key_to_data = {}
for pred_ind, pred in enumerate(student.predictors()):
data_pred_ind = None if self.multitask else pred_ind
training_key = (pred.lm, data_pred_ind)
if training_key not in key_to_data:
train_data, data_format = self._prepare_finetune_data(
trace_data=trace_data, lm=pred.lm, pred_ind=data_pred_ind
)
logger.info(f"Using {len(train_data)} data points for fine-tuning the model: {pred.lm.model}")
finetune_kwargs = dict(
lm=pred.lm,
train_data=train_data,
train_data_format=data_format,
train_kwargs=self.train_kwargs[pred.lm],
)
key_to_data[training_key] = finetune_kwargs
logger.info("Starting LM fine-tuning...")
# TODO(feature): We could run batches of fine-tuning jobs in sequence
# to avoid exceeding the number of threads.
if len(key_to_data) > self.num_threads:
raise ValueError(
"BootstrapFinetune requires `num_threads` to be bigger than or equal to the number of fine-tuning "
f"jobs. There are {len(key_to_data)} fine-tuning jobs to start, but the number of threads is: "
f"{self.num_threads}! If the `multitask` flag is set to False, the number of fine-tuning jobs will "
"be equal to the number of predictors in the student program. If the `multitask` flag is set to True, "
"the number of fine-tuning jobs will be equal to: 1 if there is only a context LM, or the number of "
"unique LMs attached to the predictors in the student program. In any case, the number of fine-tuning "
"jobs will be less than or equal to the number of predictors."
)
logger.info(f"{len(key_to_data)} fine-tuning job(s) to start")
key_to_lm = self.finetune_lms(key_to_data)
logger.info("Updating the student program with the fine-tuned LMs...")
for pred_ind, pred in enumerate(student.predictors()):
data_pred_ind = None if self.multitask else pred_ind
training_key = (pred.lm, data_pred_ind)
pred.lm = key_to_lm[training_key]
# TODO: What should the correct behavior be here? Should
# BootstrapFinetune modify the prompt demos according to the
# train data?
pred.demos = [] if self.exclude_demos else pred.demos
logger.info("BootstrapFinetune has finished compiling the student program")
student._compiled = True
return student