Skip to content
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ nltk = "^3.9.1"
[tool.poetry.group.dev.dependencies]
ruff = "^0.7.1"
pytest = "^8.3.3"
pytest-mock = "^3.14.0"
mypy = "^1.13.0"
black = "^24.10.0"
isort = "^5.13.2"
Expand Down
64 changes: 57 additions & 7 deletions src/agent/profiles/cross_database.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Literal
from typing import Annotated, Any, Literal, TypedDict

from langchain_core.embeddings import Embeddings
from langchain_core.language_models.chat_models import BaseChatModel
Expand All @@ -15,8 +15,11 @@
create_uniprot_rewriter_w_reactome
from agent.tasks.cross_database.summarize_reactome_uniprot import \
create_reactome_uniprot_summarizer
from agent.tasks.flow_reasoner import create_flow_reasoner
from retrievers.reactome.rag import create_reactome_rag
from retrievers.uniprot.rag import create_uniprot_rag
from tools.reactome_topology import ReactomeTopologyTool
import re


class CrossDatabaseState(BaseState):
Expand All @@ -28,6 +31,8 @@ class CrossDatabaseState(BaseState):
uniprot_answer: str # LLM-generated answer from UniProt
uniprot_completeness: str # LLM-assessed completeness of the UniProt answer

flow_context: str # Topological flow data fetched by identify_flow() for mechanistic queries


class CrossDatabaseGraphBuilder(BaseGraphBuilder):
def __init__(
Expand All @@ -47,6 +52,8 @@ def __init__(
self.summarize_final_answer = create_reactome_uniprot_summarizer(
llm, streaming=True
)
self.flow_reasoner = create_flow_reasoner(llm)
self.topology_tool = ReactomeTopologyTool()

# Create graph
state_graph = StateGraph(CrossDatabaseState)
Expand All @@ -62,6 +69,8 @@ def __init__(
state_graph.add_node("rewrite_uniprot_answer", self.rewrite_uniprot_answer)
state_graph.add_node("assess_completeness", self.assess_completeness)
state_graph.add_node("decide_next_steps", self.decide_next_steps)
state_graph.add_node("identify_flow", self.identify_flow)
state_graph.add_node("verify_mechanism", self.verify_mechanism)
state_graph.add_node("generate_final_response", self.generate_final_response)
state_graph.add_node("postprocess", self.postprocess)
# Set up edges
Expand All @@ -84,8 +93,11 @@ def __init__(
"perform_web_search": "generate_final_response",
"rewrite_reactome_query": "rewrite_reactome_query",
"rewrite_uniprot_query": "rewrite_uniprot_query",
"identify_flow": "identify_flow",
},
)
state_graph.add_edge("identify_flow", "verify_mechanism")
state_graph.add_edge("verify_mechanism", "generate_final_response")
state_graph.add_edge("rewrite_reactome_query", "rewrite_reactome_answer")
state_graph.add_edge("rewrite_uniprot_query", "rewrite_uniprot_answer")
state_graph.add_edge("rewrite_reactome_answer", "generate_final_response")
Expand Down Expand Up @@ -203,14 +215,23 @@ async def assess_completeness(
uniprot_completeness=uniprot_completeness.binary_score,
)

async def decide_next_steps(self, state: CrossDatabaseState) -> Literal[
"generate_final_response",
"perform_web_search",
"rewrite_reactome_query",
"rewrite_uniprot_query",
]:
async def decide_next_steps(self, state: CrossDatabaseState) -> Literal["identify_flow", "generate_final_response", "perform_web_search", "rewrite_reactome_query", "rewrite_uniprot_query"]:
"""Decide the next step based on the research results and context."""
user_query = state.get("rephrased_input", "").lower()
reactome_answer = state.get("reactome_answer", "")
uniprot_answer = state.get("uniprot_answer", "")

# Tightened keyword matching for mechanistic flow detection
flow_pattern = r"\b(after|consequence|downstream|flow|precede|trigger|following|upstream|mechanism)\b"
is_mechanistic = bool(re.search(flow_pattern, user_query))

if is_mechanistic and reactome_answer and "error" not in reactome_answer.lower():
return "identify_flow"

reactome_complete = state["reactome_completeness"] != "No"
uniprot_complete = state["uniprot_completeness"] != "No"


if reactome_complete and uniprot_complete:
return "generate_final_response"
elif not reactome_complete and uniprot_complete:
Expand All @@ -220,6 +241,35 @@ async def decide_next_steps(self, state: CrossDatabaseState) -> Literal[
else:
return "perform_web_search"

async def identify_flow(self, state: CrossDatabaseState, config: RunnableConfig) -> dict[str, Any]:
# Extract Reactome Stable IDs (e.g., R-HSA-123456) from the reactome_answer
id_pattern = re.compile(r"R-[A-Z]{3}-\d+")
reactome_ans = state.get("reactome_answer", "")
st_ids = list(set(id_pattern.findall(reactome_ans)))

flow_context: str = ""
for st_id in st_ids[:5]: # Limit to top 5 IDs for token safety
context = self.topology_tool.get_flow_context(st_id)
if context:
flow_context += f"\n---\n{context}"

return {"flow_context": flow_context}

async def verify_mechanism(self, state: CrossDatabaseState, config: RunnableConfig) -> dict[str, Any]:
flow_ctx = state.get("flow_context")
if not flow_ctx:
return {}

verified_answer: str = await self.flow_reasoner.ainvoke(
{
"input": state["rephrased_input"],
"initial_answer": state["reactome_answer"],
"flow_context": flow_ctx,
},
config
)
return {"reactome_answer": verified_answer}

async def generate_final_response(
self, state: CrossDatabaseState, config: RunnableConfig
) -> CrossDatabaseState:
Expand Down
41 changes: 41 additions & 0 deletions src/agent/tasks/flow_reasoner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import Runnable

flow_reasoning_message = """
You are a senior curator and mechanistic reasoner for the Reactome Pathway Knowledgebase.
Your task is to take a biological explanation and the corresponding topological data (next/previous steps, participants) from the Reactome Graph to verify and enrich the answer.

Context provided:
1. Initial Explanation: The draft answer generated by the RAG system.
2. Topological Data: Raw data from the Reactome Graph about the reactions and pathways mentioned.

Objective:
- Verify the sequence of events: Does the Graph data confirm that Reaction A leads to Reaction B as described in the initial answer?
- Identify missing mechanistic links: If there's a gap between two steps in the explanation, use the topological data to fill it (e.g., mention a missing intermediate metabolite or catalyst).
- Correct errors: If the initial answer claimed a protein is an input but the graph says it's a catalyst, correct it.

Output Requirements:
- Provide a refined, highly accurate mechanistic description of the pathway/process.
- Highlight the "flow" of information or matter (e.g., "First X happens, which triggers Y, resulting in Z").
- Maintain all citations from the original context and add new ones if new IDs from the topology are used.
- If the topological data contradicts the initial findings, prioritize the topological data as the 'ground truth' of the Reactome Graph.

Strict Rule: Focus ONLY on the biological mechanism and flow. Do not add generic filler.
"""

flow_reasoning_prompt = ChatPromptTemplate.from_messages(
[
("system", flow_reasoning_message),
(
"human",
"Initial Explanation: {initial_answer} \n\n Topological Data: \n {flow_context} \n\n User Question: {input}",
),
]
)

def create_flow_reasoner(llm: BaseChatModel) -> Runnable:
return (flow_reasoning_prompt | llm | StrOutputParser()).with_config(
run_name="flow_reasoning"
)
27 changes: 19 additions & 8 deletions src/retrievers/csv_chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
from typing import Annotated, Any, Coroutine, TypedDict

import chromadb.config
from langchain.chains.query_constructor.schema import AttributeInfo
from util.langchain_compat import AttributeInfo
from langchain.retrievers import EnsembleRetriever, MultiQueryRetriever
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import EmbeddingsFilter
from langchain.retrievers.merger_retriever import MergerRetriever
from langchain.retrievers.self_query.base import SelfQueryRetriever
from langchain_chroma.vectorstores import Chroma
Expand Down Expand Up @@ -71,15 +73,25 @@ def create_bm25_chroma_ensemble_retriever(
*,
descriptions_info: dict[str, str],
field_info: dict[str, list[AttributeInfo]],
) -> MergerRetriever:
return HybridRetriever.from_subdirectory(
) -> ContextualCompressionRetriever:
"""Create a HybridRetriever wrapped with EmbeddingsFilter-based
Contextual Compression to filter out low-relevance documents before
they are passed to the LLM, improving answer quality and reducing
hallucinations caused by noisy retrieval."""
base_retriever = HybridRetriever.from_subdirectory(
llm,
embedding,
embeddings_directory,
descriptions_info=descriptions_info,
field_info=field_info,
include_original=True,
)
embeddings_filter = EmbeddingsFilter(
embeddings=embedding, similarity_threshold=0.76
)
return ContextualCompressionRetriever(
base_compressor=embeddings_filter, base_retriever=base_retriever
)


class RetrieverDict(TypedDict):
Expand Down Expand Up @@ -177,7 +189,7 @@ def retrieve_documents(self, queries: list[str], run_manager) -> list[Document]:
)
},
)
doc_lists.append(bm25_docs + vector_docs)
doc_lists.extend([bm25_docs, vector_docs])
subdirectory_docs.extend(self.weighted_reciprocal_rank(doc_lists))
return subdirectory_docs

Expand Down Expand Up @@ -214,9 +226,8 @@ async def aretrieve_documents(
subdirectory_docs: list[Document] = []
for subdir_results in subdirectory_results.values():
results_iter = iter(await asyncio.gather(*subdir_results))
doc_lists: list[list[Document]] = [
bm25_results + vector_results
for bm25_results, vector_results in zip(results_iter, results_iter)
]
doc_lists: list[list[Document]] = []
for bm25_results, vector_results in zip(results_iter, results_iter):
doc_lists.extend([bm25_results, vector_results])
subdirectory_docs.extend(self.weighted_reciprocal_rank(doc_lists))
return subdirectory_docs
2 changes: 1 addition & 1 deletion src/retrievers/reactome/metadata_info.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from langchain.chains.query_constructor.base import AttributeInfo
from util.langchain_compat import AttributeInfo

pathway_id_description = "A Reactome Identifier unique to each pathway. A pathway name may appear multiple times in the dataset\
This ID allows for the specific identification and exploration of each pathway's details within the Reactome Database."
Expand Down
2 changes: 1 addition & 1 deletion src/retrievers/uniprot/metadata_info.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from langchain.chains.query_constructor.base import AttributeInfo
from util.langchain_compat import AttributeInfo

uniprot_descriptions_info = {
"uniprot_data": "Contains detailed protein information about gene names, protein names, subcellular localizations, family classifications, biological pathway associations, domains, motifs, disease associations, and functional descriptions. ",
Expand Down
79 changes: 79 additions & 0 deletions src/tools/reactome_topology.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import requests
from typing import Any, Optional
from util.logging import logging

class ReactomeTopologyTool:
"""
A tool to query the Reactome Content Service for topological information
about pathways and reactions (e.g., inputs, outputs, preceding/subsequent events).
"""

BASE_URL = "https://reactome.org/ContentService/data"

def __init__(self):
self.session = requests.Session()

def query_id(self, st_id: str) -> dict[str, Any] | None:
"""Query the Content Service for a single ID."""
url = f"{self.BASE_URL}/query/{st_id}"
try:
response = self.session.get(url, timeout=10)
response.raise_for_status()
return response.json()
except Exception as e:
logging.debug(f"Error querying {st_id}: {e}")
return None

def get_flow_context(self, st_id: str, max_depth: int = 2) -> str:
"""
Get a human-readable summary of the topological flow for an event,
traversing multiple hops (upstream and hierarchical).
"""
visited = set()

def _traverse(target_id: str, depth: int) -> str:
if depth > max_depth or target_id in visited:
return ""

visited.add(target_id)
data = self.query_id(target_id)
if not data:
return ""

display_name = data.get("displayName", target_id)
cls_name = data.get("className", "Event")
indent = " " * (depth - 1)

lines = [f"{indent}- {cls_name}: {display_name} ({target_id})"]

# Reactions: Inputs/Outputs/Catalysts
if depth == 1:
inputs = [i.get("displayName") for i in data.get("input", [])]
outputs = [o.get("displayName") for o in data.get("output", [])]
catalysts = [c.get("physicalEntity", {}).get("displayName") for c in data.get("catalystActivity", []) if c.get("physicalEntity")]

if inputs: lines.append(f"{indent} Inputs: {', '.join(inputs)}")
if outputs: lines.append(f"{indent} Outputs: {', '.join(outputs)}")
if catalysts: lines.append(f"{indent} Catalysts: {', '.join(catalysts)}")

# Causal connection: Preceding Events
preceding = data.get("precedingEvent", [])
if preceding:
lines.append(f"{indent} Preceding ({len(preceding)}):")
for p in preceding[:3]: # Cap per level to avoid overflow
st_id_p = p.get("stId")
if st_id_p:
lines.append(_traverse(st_id_p, depth + 1))

# Hierarchical connection: Sub-events (for Pathways)
sub_events = data.get("hasEvent", [])
if sub_events:
lines.append(f"{indent} Sub-events ({len(sub_events)}):")
for s in sub_events[:3]: # Cap per level
st_id_s = s.get("stId")
if st_id_s:
lines.append(_traverse(st_id_s, depth + 1))

return "\n".join(filter(None, lines))

return _traverse(st_id, 1) or "No topological data available."
17 changes: 17 additions & 0 deletions src/util/langchain_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""LangChain compatibility utility to handle environment-specific import variations."""

try:
from langchain.chains.query_constructor.base import AttributeInfo
except ImportError:
try:
from langchain.chains.query_constructor.schema import AttributeInfo
except ImportError:
try:
from langchain_classic.chains.query_constructor.schema import AttributeInfo
except ImportError:
# Fallback for environments where these imports are totally unavailable
class AttributeInfo:
def __init__(self, name: str, description: str, type: str):
self.name = name
self.description = description
self.type = type
8 changes: 8 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import sys
from pathlib import Path

# Add src to python path so tests can import from it
root_dir = Path(__file__).parent.parent.absolute()
src_path = str(root_dir / "src")
if src_path not in sys.path:
sys.path.insert(0, src_path)
Loading