Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions src/conversational_chain/chain.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,20 @@
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains.history_aware_retriever import \
create_history_aware_retriever
from langchain.chains.retrieval import create_retrieval_chain
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.output_parsers import StrOutputParser
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import Runnable

from system_prompt.reactome_prompt import contextualize_q_prompt, qa_prompt


def create_rag_chain(llm: BaseChatModel, retriever: BaseRetriever) -> Runnable:
# Create the history-aware retriever
history_aware_retriever: Runnable = create_history_aware_retriever(
llm=llm,
retriever=retriever,
prompt=contextualize_q_prompt,
def create_rephrase_chain(llm: BaseChatModel) -> Runnable:
return (contextualize_q_prompt | llm | StrOutputParser()).with_config(
run_name="rephrase_question"
)


def create_rag_chain(llm: BaseChatModel, retriever: BaseRetriever) -> Runnable:
# Create the documents chain
question_answer_chain: Runnable = create_stuff_documents_chain(
llm=llm.model_copy(update={"streaming": True}),
Expand All @@ -25,7 +23,7 @@ def create_rag_chain(llm: BaseChatModel, retriever: BaseRetriever) -> Runnable:

# Create the retrieval chain
rag_chain: Runnable = create_retrieval_chain(
retriever=history_aware_retriever,
retriever=retriever,
combine_docs_chain=question_answer_chain,
)

Expand Down
24 changes: 17 additions & 7 deletions src/conversational_chain/graph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import os
from typing import Annotated, Any, Sequence, TypedDict
from typing import Annotated, Any, TypedDict

from langchain_core.callbacks.base import Callbacks
from langchain_core.documents import Document
Expand All @@ -16,7 +16,7 @@
from psycopg import AsyncConnection
from psycopg_pool import AsyncConnectionPool

from conversational_chain.chain import create_rag_chain
from conversational_chain.chain import create_rag_chain, create_rephrase_chain
from external_search.state import WebSearchResult
from external_search.workflow import create_search_workflow
from util.logging import logging
Expand All @@ -37,8 +37,9 @@ class AdditionalContent(TypedDict):


class ChatState(TypedDict):
input: str
chat_history: Annotated[Sequence[BaseMessage], add_messages]
input: str # User input text
query: str # LLM-generated query from user input
chat_history: Annotated[list[BaseMessage], add_messages]
context: list[Document]
answer: str # primary LLM response that is streamed to the user
additional_content: (
Expand All @@ -50,15 +51,18 @@ class RAGGraphWithMemory:
def __init__(self, retriever: BaseRetriever, llm: BaseChatModel) -> None:
# Set up runnables
self.rag_chain: Runnable = create_rag_chain(llm, retriever)
self.rephrase_chain: Runnable = create_rephrase_chain(llm)
self.search_workflow: CompiledStateGraph = create_search_workflow(llm)

# Create graph
state_graph: StateGraph = StateGraph(ChatState)
# Set up nodes
state_graph.add_node("preprocess", self.preprocess)
state_graph.add_node("model", self.call_model)
state_graph.add_node("postprocess", self.postprocess)
# Set up edges
state_graph.set_entry_point("model")
state_graph.set_entry_point("preprocess")
state_graph.add_edge("preprocess", "model")
state_graph.add_edge("model", "postprocess")
state_graph.set_finish_point("postprocess")

Expand Down Expand Up @@ -95,10 +99,16 @@ async def close_pool(self) -> None:
if self.pool:
await self.pool.close()

async def preprocess(
self, state: ChatState, config: RunnableConfig
) -> dict[str, str]:
query: str = await self.rephrase_chain.ainvoke(state, config)
return {"query": query}

async def call_model(
self, state: ChatState, config: RunnableConfig
) -> dict[str, Any]:
result = await self.rag_chain.ainvoke(state, config)
result: dict[str, Any] = await self.rag_chain.ainvoke(state, config)
return {
"chat_history": [
HumanMessage(state["input"]),
Expand All @@ -114,7 +124,7 @@ async def postprocess(
search_results: list[WebSearchResult] = []
if config["configurable"]["enable_postprocess"]:
result: dict[str, Any] = await self.search_workflow.ainvoke(
{"question": state["input"], "generation": state["answer"]},
{"question": state["query"], "generation": state["answer"]},
config=RunnableConfig(callbacks=config["callbacks"]),
)
search_results = result["search_results"]
Expand Down
2 changes: 1 addition & 1 deletion src/system_prompt/reactome_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,6 @@
[
("system", qa_system_prompt),
MessagesPlaceholder(variable_name="chat_history"),
("user", "Context:\n{context}\n\nQuestion: {input}"),
("user", "Context:\n{context}\n\nQuestion: {query}"),
]
)
4 changes: 2 additions & 2 deletions src/util/chainlit_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,11 @@ async def message_rate_limited(config: Config | None) -> bool:
f"You are allowed a maximum of {rate_limit.max_messages} messages every {rate_limit.interval}."
)
else:
quota_message = "Public messages quota reached. "
quota_message = "Our servers are currently overloaded.\n"
login_uri: str | None = os.getenv("CHAINLIT_URI_LOGIN", "")
if login_uri:
quota_message += (
f"[Log in]({login_uri}) to continue chatting with fewer limits."
f"Please [log in]({login_uri}) to continue chatting and enjoy features like saved chat history and fewer limits."
)
else:
quota_message += "Please try again later."
Expand Down