Skip to content

dspy.Refine

dspy.Refine(module: Module, N: int, reward_fn: Callable[[dict, Prediction], float], threshold: float, fail_count: Optional[int] = None)

Bases: Module

Refines a module by running it up to N times with different temperatures and returns the best prediction.

This module runs the provided module multiple times with varying temperature settings and selects either the first prediction that exceeds the specified threshold or the one with the highest reward. If no prediction meets the threshold, it automatically generates feedback to improve future predictions.

Parameters:

Name Type Description Default
module Module

The module to refine.

required
N int

The number of times to run the module. must

required
reward_fn Callable

The reward function.

required
threshold float

The threshold for the reward function.

required
fail_count Optional[int]

The number of times the module can fail before raising an error

None
Example
import dspy

dspy.settings.configure(lm=dspy.LM("openai/gpt-4o-mini"))

# Define a QA module with chain of thought
qa = dspy.ChainOfThought("question -> answer")

# Define a reward function that checks for one-word answers
def one_word_answer(args, pred):
    return 1.0 if len(pred.answer.split()) == 1 else 0.0

# Create a refined module that tries up to 3 times
best_of_3 = dspy.Refine(module=qa, N=3, reward_fn=one_word_answer, threshold=1.0)

# Use the refined module
result = best_of_3(question="What is the capital of Belgium?").answer
# Returns: Brussels
Source code in dspy/predict/refine.py
def __init__(
    self,
    module: Module,
    N: int,
    reward_fn: Callable[[dict, Prediction], float],
    threshold: float,
    fail_count: Optional[int] = None,
):
    """
    Refines a module by running it up to N times with different temperatures and returns the best prediction.

    This module runs the provided module multiple times with varying temperature settings and selects
    either the first prediction that exceeds the specified threshold or the one with the highest reward.
    If no prediction meets the threshold, it automatically generates feedback to improve future predictions.


    Args:
        module (Module): The module to refine.
        N (int): The number of times to run the module. must
        reward_fn (Callable): The reward function.
        threshold (float): The threshold for the reward function.
        fail_count (Optional[int], optional): The number of times the module can fail before raising an error

    Example:
        ```python
        import dspy

        dspy.settings.configure(lm=dspy.LM("openai/gpt-4o-mini"))

        # Define a QA module with chain of thought
        qa = dspy.ChainOfThought("question -> answer")

        # Define a reward function that checks for one-word answers
        def one_word_answer(args, pred):
            return 1.0 if len(pred.answer.split()) == 1 else 0.0

        # Create a refined module that tries up to 3 times
        best_of_3 = dspy.Refine(module=qa, N=3, reward_fn=one_word_answer, threshold=1.0)

        # Use the refined module
        result = best_of_3(question="What is the capital of Belgium?").answer
        # Returns: Brussels
        ```
    """
    self.module = module
    self.reward_fn = lambda *args: reward_fn(*args)  # to prevent this from becoming a parameter
    self.threshold = threshold
    self.N = N
    self.fail_count = fail_count or N  # default to N if fail_count is not provided
    self.module_code = inspect.getsource(module.__class__)
    try:
        self.reward_fn_code = inspect.getsource(reward_fn)
    except TypeError:
        self.reward_fn_code = inspect.getsource(reward_fn.__class__)

Functions

forward(**kwargs)

Source code in dspy/predict/refine.py
def forward(self, **kwargs):
    lm = self.module.get_lm() or dspy.settings.lm
    temps = [lm.kwargs["temperature"]] + [0.5 + i * (0.5 / self.N) for i in range(self.N)]
    temps = list(dict.fromkeys(temps))[: self.N]
    best_pred, best_trace, best_reward = None, None, -float("inf")
    advice = None
    adapter = dspy.settings.adapter or dspy.ChatAdapter()

    for idx, t in enumerate(temps):
        lm_ = lm.copy(temperature=t)
        mod = self.module.deepcopy()
        mod.set_lm(lm_)

        predictor2name = {predictor: name for name, predictor in mod.named_predictors()}
        signature2name = {predictor.signature: name for name, predictor in mod.named_predictors()}
        module_names = [name for name, _ in mod.named_predictors()]

        try:
            with dspy.context(trace=[]):
                if not advice:
                    outputs = mod(**kwargs)
                else:

                    class WrapperAdapter(adapter.__class__):
                        def __call__(self, lm, lm_kwargs, signature, demos, inputs):
                            inputs["hint_"] = advice.get(signature2name[signature], "N/A")
                            signature = signature.append(
                                "hint_", InputField(desc="A hint to the module from an earlier run")
                            )
                            return adapter(lm, lm_kwargs, signature, demos, inputs)

                    with dspy.context(adapter=WrapperAdapter()):
                        outputs = mod(**kwargs)

                trace = dspy.settings.trace.copy()

                # TODO: Remove the hint from the trace, if it's there.

                # NOTE: Not including the trace of reward_fn.
                reward = self.reward_fn(kwargs, outputs)

            if reward > best_reward:
                best_reward, best_pred, best_trace = reward, outputs, trace

            if self.threshold is not None and reward >= self.threshold:
                break

            if idx == self.N - 1:
                break

            modules = dict(program_code=self.module_code, modules_defn=inspect_modules(mod))
            trajectory = [dict(module_name=predictor2name[p], inputs=i, outputs=dict(o)) for p, i, o in trace]
            trajectory = dict(program_inputs=kwargs, program_trajectory=trajectory, program_outputs=dict(outputs))
            reward = dict(reward_code=self.reward_fn_code, target_threshold=self.threshold, reward_value=reward)

            advise_kwargs = dict(**modules, **trajectory, **reward, module_names=module_names)
            # advise_kwargs = {k: ujson.dumps(recursive_mask(v), indent=2) for k, v in advise_kwargs.items()}
            # only dumps if it's a list or dict
            advise_kwargs = {
                k: v if isinstance(v, str) else ujson.dumps(recursive_mask(v), indent=2)
                for k, v in advise_kwargs.items()
            }
            advice = dspy.Predict(OfferFeedback)(**advise_kwargs).advice
            # print(f"Advice for each module: {advice}")

        except Exception as e:
            print(f"Refine: Attempt failed with temperature {t}: {e}")
            if idx > self.fail_count:
                raise e
            self.fail_count -= 1
    if best_trace:
        dspy.settings.trace.extend(best_trace)
    return best_pred