Creating Custom RM Client
DSPy provides support for various retrieval modules out of the box like ColBERTv2
, AzureCognitiveSearch
, Lancedb
, Pinecone
, Weaviate
, etc. Unlike Language Model (LM) modules, creating a custom RM module is much more simple and flexible.
As of now, DSPy offers 2 ways to create a custom RM: the Pythonic way and the DSPythonic way. We'll take a look at both, understand why both are performing the same behavior, and how you can implement each!
I/O of RM Client
Before understanding the implementation, let's understand the idea and I/O within RM modules.
The input to an RM module is either 1) a single query or 2) a list of queries.
The output is the top-k
passages per query retrieved from a retrieval model, vector store, or search client.
Conventionally, we simply call the RM module object through the __call__
method, inputting the query/queries as argument(s) of the call with the corresponding output returned as a list of strings.
We'll see how this I/O is essentially same in both methods of implementation but differs in their delivery.
The Pythonic Way
To account for our RM I/O, we create a class that conducts the retrieval logic, which we implement in the __init__
and __call__
methods:
from typing import List, Union
class PythonicRMClient:
def __init__(self):
pass
def __call__(self, query: Union[str, List[str]], k:int) -> Union[List[str], List[List[str]]]:
pass
Info
Don't worry about the extensive type-hinting above. typing
is a package that provides type-definitions for function inputs and outputs.
Union
covers all possible types of the argument/output. So:
* Union[str, List[str]]
: Assigned to query
to work with a single query string or a list of queries strings.
* Union[List[str], List[List[str]]]
: Assigned to the output of __call__
to work with a single query string as a list or a list of multiple query string lists.
So let's start by implementing PythonicRMClient
for a local retrieval model hosted on a API with endpoint being /
. We'll start by implementing the __init__
method, which simply initializes the class attributes, url
and port
, and attaches the port to the url if present.
Now it's time to write the retrieval logic in __call__
method:
def __call__(self, query:str, k:int) -> List[str]:
params = {"query": query, "k": k}
response = requests.get(self.url, params=params)
response = response.json()["retrieved_passages"] # List of top-k passages
return response
This serves to represent our API request call to retrieve our list of top-k passages which we return as the response. Let's bring it all together and see how our RM class looks like:
from typing import List
class PythonicRMClient:
def __init__(self, url: str, port:int = None):
self.url = f`{url}:{port}` if port else url
def __call__(self, query:str, k:int) -> List[str]:
# Only accept single query input, feel free to modify it to support
params = {"query": query, "k": k}
response = requests.get(self.url, params=params)
response = response.json()["retrieved_passages"] # List of top k passages
return response
That's all!! This is the most basic way to implement a RM model and mirrors DSP-v1-hosted RM models like ColBERTv2
and AzureCognitiveSearch
.
Now let's take a look at how we streamline this process in DSPy!
The DSPythonic Way
The DSPythonic way mirrors the Pythonic way in maintaining the same input but now returning an object of dspy.Prediction
class, the standard output format for any DSPy module as we've seen in previous docs. Additionally, this class would now inherit the dspy.Retrieve
class to maintain state management within the DSPy library.
So let's implement __init__
and forward
method where our class's __call__
is calling the forward
method as is=:
import dspy
from typing import List, Union, Optional
class DSPythonicRMClient(dspy.Retrieve):
def __init__(self, k:int):
pass
def forward(self, query: Union[str, List[str]], k:Optional[str]) -> dspy.Prediction:
pass
Unlike PythonicRMClient
, we initialize k
as part of the initialization call and the forward
method will take query/queries as arguments and the k
number of retrieved passages as an optional argument. k
is used within the inherited dspy.Retrieve
initialization when we call super().__init__()
.
We'll be implementing DSPythonicRMClient
for the same local retrieval model API we used above. We'll start by implementing the __init__
method, which mirrors the PythonicRMClient
.
def __init__(self, url: str, port:int = None, k:int = 3):
super().__init__(k=k)
self.url = f`{url}:{port}` if port else url
We'll now implement the forward
method, returning the output as a dspy.Prediction
object under the passage
attribute which is standard among all the RM modules. The call will default to the defined self.k
argument unless overridden in this call.
def forward(self, query:str, k:Optional[int]) -> dspy.Prediction:
params = {"query": query, "k": k if k else self.k}
response = requests.get(self.url, params=params)
response = response.json()["retrieved_passages"] # List of top k passages
return dspy.Prediction(
passages=response
)
Let's bring it all together and see how our RM class looks like:
import dspy
from typing import List, Union, Optional
class DSPythonicRMClient(dspy.Retrieve):
def __init__(self, url: str, port:int = None, k:int = 3):
super().__init__(k=k)
self.url = f`{url}:{port}` if port else url
def forward(self, query_or_queries:str, k:Optional[int]) -> dspy.Prediction:
params = {"query": query_or_queries, "k": k if k else self.k}
response = requests.get(self.url, params=params)
response = response.json()["retrieved_passages"] # List of top k passages
return dspy.Prediction(
passages=response
)
That's all!! This is the way to implement a custom RM model client within DSPy and how more recently-supported RM models like QdrantRM
, WeaviateRM
, LancedbRM
, etc. are implemented in DSPy.
Let's take a look at how we use these retrievers.
Using Custom RM Models
DSPy offers two ways of using custom RM clients: Direct Method and using dspy.Retrieve
.
Direct Method
The most straightforward way to use your custom RM is by directly using its object within the DSPy Pipeline.
Let's take a look at the following pseudocode of a DSPy Pipeline as an example:
class DSPyPipeline(dspy.Module):
def __init__(self):
super().__init__()
url = "http://0.0.0.0"
port = 3000
self.pythonic_rm = PythonicRMClient(url=url, port=port)
self.dspythonic_rm = DSPythonicRMClient(url=url, port=port, k=3)
...
def forward(self, *args):
...
passages_from_pythonic_rm = self.pythonic_rm(query)
passages_from_dspythonic_rm = self.dspythonic_rm(query).passages
...
This ensures you retrieve a list of passages from your RM client and can interact with the results within your forward pass in whichever way needed for your pipeline's purpose!
Using dspy.Retrieve
This way is more experimental in essence, allowing you to maintain the same pipeline and experiment with different RMs. How? By configuring it!
import dspy
lm = ...
url = "http://0.0.0.0"
port = 3000
# pythonic_rm = PythonicRMClient(url=url, port=port)
dspythonic_rm = DSPythonicRMClient(url=url, port=port, k=3)
dspy.settings.configure(lm=lm, rm=dspythonic_rm)
Now, in the pipeline, you just need to use dspy.Retrieve
which will use this rm
client to get the top-k passage for a given query!
class DSPyPipeline(dspy.Module):
def __init__(self):
super().__init__()
url = "http://0.0.0.0"
port = 3000
self.rm = dspy.Retrieve(k=3)
...
def forward(self, *args):
...
passages = self.rm(query)
...
Now if you'd like to use a different RM, you can just update the rm
parameter via dspy.settings.configure
.
How dspy.Retrieve
uses rm
When we call dspy.Retrieve
the __call__
method will execute the forward
method as is. In forward
, the top-k passages are received by the dsp.retrieveEnsemble
method in search.py.
If an rm
is not initialized in dsp.settings
, this would raise an error.