Skip to content

dspy.retrievers.Embeddings

dspy.retrievers.Embeddings(corpus: List[str], embedder, k: int = 5, callbacks: Optional[List[Any]] = None, cache: bool = False, brute_force_threshold: int = 20000, normalize: bool = True)

Source code in dspy/retrievers/embeddings.py
def __init__(
    self,
    corpus: List[str],
    embedder,
    k: int = 5,
    callbacks: Optional[List[Any]] = None,
    cache: bool = False,
    brute_force_threshold: int = 20_000,
    normalize: bool = True
):
    assert cache is False, "Caching is not supported for embeddings-based retrievers"

    self.embedder = embedder
    self.k = k
    self.corpus = corpus
    self.normalize = normalize

    self.corpus_embeddings = self.embedder(self.corpus)
    self.corpus_embeddings = self._normalize(self.corpus_embeddings) if self.normalize else self.corpus_embeddings

    self.index = self._build_faiss() if len(corpus) >= brute_force_threshold else None
    self.search_fn = Unbatchify(self._batch_forward)

Functions

__call__(query: str)

Source code in dspy/retrievers/embeddings.py
def __call__(self, query: str):
    return self.forward(query)

forward(query: str)

Source code in dspy/retrievers/embeddings.py
def forward(self, query: str):
    import dspy
    return dspy.Prediction(passages=self.search_fn(query))