diff --git a/src/conversational_chain/chain.py b/src/conversational_chain/chain.py index 3cb7e8c..5b59510 100644 --- a/src/conversational_chain/chain.py +++ b/src/conversational_chain/chain.py @@ -1,22 +1,20 @@ from langchain.chains.combine_documents import create_stuff_documents_chain -from langchain.chains.history_aware_retriever import \ - create_history_aware_retriever from langchain.chains.retrieval import create_retrieval_chain from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.output_parsers import StrOutputParser from langchain_core.retrievers import BaseRetriever from langchain_core.runnables import Runnable from system_prompt.reactome_prompt import contextualize_q_prompt, qa_prompt -def create_rag_chain(llm: BaseChatModel, retriever: BaseRetriever) -> Runnable: - # Create the history-aware retriever - history_aware_retriever: Runnable = create_history_aware_retriever( - llm=llm, - retriever=retriever, - prompt=contextualize_q_prompt, +def create_rephrase_chain(llm: BaseChatModel) -> Runnable: + return (contextualize_q_prompt | llm | StrOutputParser()).with_config( + run_name="rephrase_question" ) + +def create_rag_chain(llm: BaseChatModel, retriever: BaseRetriever) -> Runnable: # Create the documents chain question_answer_chain: Runnable = create_stuff_documents_chain( llm=llm.model_copy(update={"streaming": True}), @@ -25,7 +23,7 @@ def create_rag_chain(llm: BaseChatModel, retriever: BaseRetriever) -> Runnable: # Create the retrieval chain rag_chain: Runnable = create_retrieval_chain( - retriever=history_aware_retriever, + retriever=retriever, combine_docs_chain=question_answer_chain, ) diff --git a/src/conversational_chain/graph.py b/src/conversational_chain/graph.py index 38ac7da..a6e937c 100644 --- a/src/conversational_chain/graph.py +++ b/src/conversational_chain/graph.py @@ -1,6 +1,6 @@ import asyncio import os -from typing import Annotated, Any, Sequence, TypedDict +from typing import Annotated, Any, TypedDict from langchain_core.callbacks.base import Callbacks from langchain_core.documents import Document @@ -16,7 +16,7 @@ from psycopg import AsyncConnection from psycopg_pool import AsyncConnectionPool -from conversational_chain.chain import create_rag_chain +from conversational_chain.chain import create_rag_chain, create_rephrase_chain from external_search.state import WebSearchResult from external_search.workflow import create_search_workflow from util.logging import logging @@ -37,8 +37,9 @@ class AdditionalContent(TypedDict): class ChatState(TypedDict): - input: str - chat_history: Annotated[Sequence[BaseMessage], add_messages] + input: str # User input text + query: str # LLM-generated query from user input + chat_history: Annotated[list[BaseMessage], add_messages] context: list[Document] answer: str # primary LLM response that is streamed to the user additional_content: ( @@ -50,15 +51,18 @@ class RAGGraphWithMemory: def __init__(self, retriever: BaseRetriever, llm: BaseChatModel) -> None: # Set up runnables self.rag_chain: Runnable = create_rag_chain(llm, retriever) + self.rephrase_chain: Runnable = create_rephrase_chain(llm) self.search_workflow: CompiledStateGraph = create_search_workflow(llm) # Create graph state_graph: StateGraph = StateGraph(ChatState) # Set up nodes + state_graph.add_node("preprocess", self.preprocess) state_graph.add_node("model", self.call_model) state_graph.add_node("postprocess", self.postprocess) # Set up edges - state_graph.set_entry_point("model") + state_graph.set_entry_point("preprocess") + state_graph.add_edge("preprocess", "model") state_graph.add_edge("model", "postprocess") state_graph.set_finish_point("postprocess") @@ -95,10 +99,16 @@ async def close_pool(self) -> None: if self.pool: await self.pool.close() + async def preprocess( + self, state: ChatState, config: RunnableConfig + ) -> dict[str, str]: + query: str = await self.rephrase_chain.ainvoke(state, config) + return {"query": query} + async def call_model( self, state: ChatState, config: RunnableConfig ) -> dict[str, Any]: - result = await self.rag_chain.ainvoke(state, config) + result: dict[str, Any] = await self.rag_chain.ainvoke(state, config) return { "chat_history": [ HumanMessage(state["input"]), @@ -114,7 +124,7 @@ async def postprocess( search_results: list[WebSearchResult] = [] if config["configurable"]["enable_postprocess"]: result: dict[str, Any] = await self.search_workflow.ainvoke( - {"question": state["input"], "generation": state["answer"]}, + {"question": state["query"], "generation": state["answer"]}, config=RunnableConfig(callbacks=config["callbacks"]), ) search_results = result["search_results"] diff --git a/src/system_prompt/reactome_prompt.py b/src/system_prompt/reactome_prompt.py index b3f3908..acb6642 100644 --- a/src/system_prompt/reactome_prompt.py +++ b/src/system_prompt/reactome_prompt.py @@ -44,6 +44,6 @@ [ ("system", qa_system_prompt), MessagesPlaceholder(variable_name="chat_history"), - ("user", "Context:\n{context}\n\nQuestion: {input}"), + ("user", "Context:\n{context}\n\nQuestion: {query}"), ] ) diff --git a/src/util/chainlit_helpers.py b/src/util/chainlit_helpers.py index 48fe2a3..afd07d3 100644 --- a/src/util/chainlit_helpers.py +++ b/src/util/chainlit_helpers.py @@ -80,11 +80,11 @@ async def message_rate_limited(config: Config | None) -> bool: f"You are allowed a maximum of {rate_limit.max_messages} messages every {rate_limit.interval}." ) else: - quota_message = "Public messages quota reached. " + quota_message = "Our servers are currently overloaded.\n" login_uri: str | None = os.getenv("CHAINLIT_URI_LOGIN", "") if login_uri: quota_message += ( - f"[Log in]({login_uri}) to continue chatting with fewer limits." + f"Please [log in]({login_uri}) to continue chatting and enjoy features like saved chat history and fewer limits." ) else: quota_message += "Please try again later."