dspy.MultiChainComparison
This page is outdated and may not be fully accurate in DSPy 2.5
Constructor
The constructor initializes the MultiChainComparison
class and sets up its attributes. It inherits from the Predict
class and adds specific functionality for multiple chain comparisons.
The class incorporates multiple student attempt reasonings and concludes with the selected best reasoning path out of the available attempts.
from .predict import Predict
from ..primitives.program import Module
import dsp
class MultiChainComparison(Module):
def __init__(self, signature, M=3, temperature=0.7, **config):
super().__init__()
self.M = M
signature = Predict(signature).signature
*keys, last_key = signature.kwargs.keys()
extended_kwargs = {key: signature.kwargs[key] for key in keys}
for idx in range(M):
candidate_type = dsp.Type(prefix=f"Student Attempt #{idx+1}:", desc="${reasoning attempt}")
extended_kwargs.update({f'reasoning_attempt_{idx+1}': candidate_type})
rationale_type = dsp.Type(prefix="Accurate Reasoning: Thank you everyone. Let's now holistically", desc="${corrected reasoning}")
extended_kwargs.update({'rationale': rationale_type, last_key: signature.kwargs[last_key]})
signature = dsp.Template(signature.instructions, **extended_kwargs)
self.predict = Predict(signature, temperature=temperature, **config)
self.last_key = last_key
Parameters:
- signature
(Any): Signature of predictive model.
- M
(int, optional): Number of student reasoning attempts. Defaults to 3
.
- temperature
(float, optional): Temperature parameter for prediction. Defaults to 0.7
.
- **config
(dict): Additional configuration parameters for model.
Method
forward(self, completions, **kwargs)
This method aggregates all the student reasoning attempts and calls the predict method with extended signatures to get the best reasoning.
Parameters:
- completions
: List of completion objects which include student reasoning attempts.
- **kwargs
: Additional keyword arguments.
Returns:
- The result of the predict
method for the best reasoning.
Examples
class BasicQA(dspy.Signature):
"""Answer questions with short factoid answers."""
question = dspy.InputField()
answer = dspy.OutputField(desc="often between 1 and 5 words")
# Example completions generated by a model for reference
completions = [
dspy.Prediction(rationale="I recall that during clear days, the sky often appears this color.", answer="blue"),
dspy.Prediction(rationale="Based on common knowledge, I believe the sky is typically seen as this color.", answer="green"),
dspy.Prediction(rationale="From images and depictions in media, the sky is frequently represented with this hue.", answer="blue"),
]
# Pass signature to MultiChainComparison module
compare_answers = dspy.MultiChainComparison(BasicQA)
# Call the MultiChainComparison on the completions
question = 'What is the color of the sky?'
final_pred = compare_answers(completions, question=question)
print(f"Question: {question}")
print(f"Final Predicted Answer (after comparison): {final_pred.answer}")
print(f"Final Rationale: {final_pred.rationale}")