From 9a4e442592e363126526cb56c01983bd77587118 Mon Sep 17 00:00:00 2001 From: bhavyakeerthi3 Date: Mon, 9 Mar 2026 12:01:03 +0530 Subject: [PATCH] feat: add EmbeddingsFilter contextual compression to HybridRetriever (closes #132) --- src/retrievers/csv_chroma.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/src/retrievers/csv_chroma.py b/src/retrievers/csv_chroma.py index a792c93..e15a5a7 100644 --- a/src/retrievers/csv_chroma.py +++ b/src/retrievers/csv_chroma.py @@ -5,6 +5,8 @@ import chromadb.config from langchain.chains.query_constructor.schema import AttributeInfo from langchain.retrievers import EnsembleRetriever, MultiQueryRetriever +from langchain.retrievers.contextual_compression import ContextualCompressionRetriever +from langchain.retrievers.document_compressors import EmbeddingsFilter from langchain.retrievers.merger_retriever import MergerRetriever from langchain.retrievers.self_query.base import SelfQueryRetriever from langchain_chroma.vectorstores import Chroma @@ -71,8 +73,12 @@ def create_bm25_chroma_ensemble_retriever( *, descriptions_info: dict[str, str], field_info: dict[str, list[AttributeInfo]], -) -> MergerRetriever: - return HybridRetriever.from_subdirectory( +) -> ContextualCompressionRetriever: + """Create a HybridRetriever wrapped with EmbeddingsFilter-based + Contextual Compression to filter out low-relevance documents before + they are passed to the LLM, improving answer quality and reducing + hallucinations caused by noisy retrieval.""" + base_retriever = HybridRetriever.from_subdirectory( llm, embedding, embeddings_directory, @@ -80,6 +86,12 @@ def create_bm25_chroma_ensemble_retriever( field_info=field_info, include_original=True, ) + embeddings_filter = EmbeddingsFilter( + embeddings=embedding, similarity_threshold=0.76 + ) + return ContextualCompressionRetriever( + base_compressor=embeddings_filter, base_retriever=base_retriever + ) class RetrieverDict(TypedDict): @@ -177,7 +189,7 @@ def retrieve_documents(self, queries: list[str], run_manager) -> list[Document]: ) }, ) - doc_lists.append(bm25_docs + vector_docs) + doc_lists.extend([bm25_docs, vector_docs]) subdirectory_docs.extend(self.weighted_reciprocal_rank(doc_lists)) return subdirectory_docs @@ -214,9 +226,8 @@ async def aretrieve_documents( subdirectory_docs: list[Document] = [] for subdir_results in subdirectory_results.values(): results_iter = iter(await asyncio.gather(*subdir_results)) - doc_lists: list[list[Document]] = [ - bm25_results + vector_results - for bm25_results, vector_results in zip(results_iter, results_iter) - ] + doc_lists: list[list[Document]] = [] + for bm25_results, vector_results in zip(results_iter, results_iter): + doc_lists.extend([bm25_results, vector_results]) subdirectory_docs.extend(self.weighted_reciprocal_rank(doc_lists)) return subdirectory_docs