Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,13 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

.files/
.ruff_cache/
reactome_github/
reactomegit/
get-pip.py
.DS_Store

.chainlit/translations/*
!.chainlit/translations/en-US.json
csv_files/
Expand Down
2 changes: 1 addition & 1 deletion bin/chat-fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ async def landing_page():
<div class="button-container">
<a class="button" href="$CHAINLIT_URL/chat/guest/" target="_blank">Guest Access</a>
<a class="button" href="$CHAINLIT_URL/chat/personal/" target="_blank">Log In</a>
<a class="button feedback-button" href="https://docs.google.com/forms/d/e/1FAIpQLSeWajgdJGV2gETj2bo-_jqU54Ryy6d7acJkvMo-KkflYUmfTg/viewform" target="_blank">Feedback</a>
<a class="button feedback-button" href="https://forms.gle/Rvzb8EA73yZs7wd38" target="_blank">Feedback</a>
</div>

<p class="left-justified">Choose <strong>Guest Access</strong> to try the chatbot out. <strong>Log In</strong> will give an increased query allowance and securely stores your chat history so you can save and continue conversations.</p>
Expand Down
15 changes: 8 additions & 7 deletions src/conversational_chain/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ class AdditionalContent(TypedDict):


class ChatState(TypedDict):
input: str # User input text
query: str # LLM-generated query from user input
user_input: str # User input text
rephrased_input: str # LLM-generated query from user input
chat_history: Annotated[list[BaseMessage], add_messages]
context: list[Document]
answer: str # primary LLM response that is streamed to the user
Expand Down Expand Up @@ -103,21 +103,22 @@ async def preprocess(
self, state: ChatState, config: RunnableConfig
) -> dict[str, str]:
query: str = await self.rephrase_chain.ainvoke(state, config)
return {"query": query}
return {"rephrased_input": query}

async def call_model(
self, state: ChatState, config: RunnableConfig
) -> dict[str, Any]:
result: dict[str, Any] = await self.rag_chain.ainvoke(
{
"input": state["query"],
"input": state["rephrased_input"],
"user_input": state["user_input"],
"chat_history": state["chat_history"],
},
config,
)
return {
"chat_history": [
HumanMessage(state["input"]),
HumanMessage(state["user_input"]),
AIMessage(result["answer"]),
],
"context": result["context"],
Expand All @@ -130,7 +131,7 @@ async def postprocess(
search_results: list[WebSearchResult] = []
if config["configurable"]["enable_postprocess"]:
result: dict[str, Any] = await self.search_workflow.ainvoke(
{"question": state["query"], "generation": state["answer"]},
{"question": state["rephrased_input"], "generation": state["answer"]},
config=RunnableConfig(callbacks=config["callbacks"]),
)
search_results = result["search_results"]
Expand All @@ -149,7 +150,7 @@ async def ainvoke(
if self.graph is None:
self.graph = await self.initialize()
result: dict[str, Any] = await self.graph.ainvoke(
{"input": user_input},
{"user_input": user_input},
config=RunnableConfig(
callbacks=callbacks,
configurable={
Expand Down
55 changes: 40 additions & 15 deletions src/evaluation/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from langchain_community.retrievers import BM25Retriever
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from ragas import evaluate
from ragas.metrics import answer_relevancy, context_utilization, faithfulness
from ragas.metrics import (answer_relevancy, context_recall,
context_utilization, faithfulness)

from conversational_chain.chain import create_rag_chain
from reactome.metadata_info import descriptions_info, field_info
Expand All @@ -34,6 +35,12 @@ def parse_arguments():
default="gpt-4o-mini",
help="Language model to use for evaluation",
)
parser.add_argument(
"--rag_type",
choices=["basic", "advanced"],
required=True,
help="Type of RAG system to use for evaluation",
)
return parser.parse_args()


Expand All @@ -50,8 +57,8 @@ def load_dataset(testset_path):
raise ValueError(f"Error reading the Excel file: {e}")


def initialize_rag_chain(embeddings_directory, model_name):
"""Initialize the RAG chain system."""
def initialize_rag_chain_with_memory(embeddings_directory, model_name, rag_type):
"""Initialize the RAGChainWithMemory system."""
llm = ChatOpenAI(temperature=0.0, verbose=True, model=model_name)
retriever_list = []

Expand All @@ -60,7 +67,7 @@ def initialize_rag_chain(embeddings_directory, model_name):
)
data = loader.load()
bm25_retriever = BM25Retriever.from_documents(data)
bm25_retriever.k = 15
bm25_retriever.k = 7

# Set up vectorstore SelfQuery retriever
embedding = OpenAIEmbeddings(model="text-embedding-3-large")
Expand All @@ -69,17 +76,22 @@ def initialize_rag_chain(embeddings_directory, model_name):
embedding_function=embedding,
)

vectordb_retriever = vectordb.as_retriever(search_kwargs={"k": 7})

selfq_retriever = SelfQueryRetriever.from_llm(
llm=llm,
vectorstore=vectordb,
document_contents=descriptions_info["summations"],
metadata_field_info=field_info["summations"],
search_kwargs={"k": 15},
search_kwargs={"k": 7},
)
rrf_retriever = EnsembleRetriever(
retrievers=[bm25_retriever, selfq_retriever], weights=[0.2, 0.8]
)
retriever_list.append(rrf_retriever)
if rag_type == "basic":
retriever_list.append(vectordb_retriever)
elif rag_type == "advanced":
retriever_list.append(rrf_retriever)

reactome_retriever = MergerRetriever(retrievers=retriever_list)

Expand All @@ -91,11 +103,15 @@ def initialize_rag_chain(embeddings_directory, model_name):


def process_testset(
testset_path, qa_system, embeddings_directory, response_dir, eval_dir, model_name
testset_path,
qa_system,
embeddings_directory,
response_dir,
eval_dir,
model_name,
rag_type,
):
"""Process a single testset file."""
args = parse_arguments()

testset = load_dataset(testset_path)
questions = [item["question"] for item in testset]
ground_truths = [item["ground_truth"] for item in testset]
Expand All @@ -108,6 +124,11 @@ def process_testset(
answers.append(response["answer"])
contexts.append([context.page_content for context in response["context"]])

rag_response_dir = os.path.join(response_dir, rag_type)
rag_eval_dir = os.path.join(eval_dir, rag_type)
os.makedirs(rag_response_dir, exist_ok=True)
os.makedirs(rag_eval_dir, exist_ok=True)

# Save responses to an Excel file
data = {
"question": questions,
Expand All @@ -117,8 +138,8 @@ def process_testset(
}
df_ans = pd.DataFrame(data)
response_filename = os.path.join(
response_dir,
f"{os.path.splitext(os.path.basename(testset_path))[0]}_{args.model}_responses.xlsx",
rag_response_dir,
f"{os.path.splitext(os.path.basename(testset_path))[0]}_{model_name}_responses_{rag_type}.xlsx",
)
df_ans.to_excel(response_filename, index=False)
print(f"Responses saved to {response_filename}")
Expand All @@ -128,13 +149,13 @@ def process_testset(
result = evaluate(
llm=ChatOpenAI(temperature=0.0, verbose=True, model="gpt-4o"),
dataset=dataset,
metrics=[answer_relevancy, context_utilization, faithfulness],
metrics=[answer_relevancy, context_utilization, faithfulness, context_recall],
)

# Save evaluation results to an Excel file
evaluation_filename = os.path.join(
eval_dir,
f"{os.path.splitext(os.path.basename(testset_path))[0]}_{args.model}_evaluation.xlsx",
rag_eval_dir,
f"{os.path.splitext(os.path.basename(testset_path))[0]}_{model_name}_evaluation_{rag_type}.xlsx",
)
df_eval = result.to_pandas()
df_eval.to_excel(evaluation_filename, index=False)
Expand All @@ -144,14 +165,17 @@ def process_testset(
def main():
args = parse_arguments()
model_name = args.model
rag_type = args.rag_type
response_dir = os.path.join(args.testset_dir, "response")
eval_dir = os.path.join(args.testset_dir, "evals")
os.makedirs(response_dir, exist_ok=True)
os.makedirs(eval_dir, exist_ok=True)

# Initialize RAG Chain
embeddings_directory = "/Users/hmohammadi/Desktop/react_to_me_github/reactome_chatbot/embeddings/openai/text-embedding-3-large/reactome/Release90/summations"
qa_system = initialize_rag_chain(embeddings_directory, model_name)
qa_system = initialize_rag_chain_with_memory(
embeddings_directory, model_name, rag_type
)

# Iterate over all .xlsx files in the directory
for filename in os.listdir(args.testset_dir):
Expand All @@ -166,6 +190,7 @@ def main():
response_dir,
eval_dir,
model_name,
rag_type,
)


Expand Down
4 changes: 2 additions & 2 deletions src/retreival_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def create_retrieval_chain(
loader = CSVLoader(file_path=reactome_csvs_dir / csv_file_name)
data = loader.load()
bm25_retriever = BM25Retriever.from_documents(data)
bm25_retriever.k = 15
bm25_retriever.k = 10

# set up vectorstore SelfQuery retriever
embedding = embedding_callable()
Expand All @@ -123,7 +123,7 @@ def create_retrieval_chain(
vectorstore=vectordb,
document_contents=descriptions_info[subdirectory],
metadata_field_info=field_info[subdirectory],
search_kwargs={"k": 15},
search_kwargs={"k": 10},
)
rrf_retriever = EnsembleRetriever(
retrievers=[bm25_retriever, selfq_retriever], weights=[0.2, 0.8]
Expand Down
17 changes: 10 additions & 7 deletions src/system_prompt/reactome_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,22 @@
# Contextualize question prompt
contextualize_q_system_prompt = """
You are an expert in question formulation with deep expertise in molecular biology and experience as a Reactome curator. Your task is to analyze the conversation history and the user’s latest query to fully understand their intent and what they seek to learn.
Reformulate the user’s question into a standalone version that retains its full meaning without requiring prior context. The reformulated question should be:**
Clear, concise, and precise
Optimized for both vector search (semantic meaning) and case-sensitive keyword search
Faithful to the user’s intent and scientific accuracy
If the user’s question is already self-contained and well-formed, return it as is.
If the user's question is not in English, reformulate the question and translate it to English, ensuring the meaning and intent are preserved.
Reformulate the user’s question into a standalone version that retains its full meaning without requiring prior context. The reformulated question should be:
- Clear, concise, and precise
- Optimized for both vector search (semantic meaning) and case-sensitive keyword search
- Faithful to the user’s intent and scientific accuracy

the returned question should always be in English.
If the user’s question is already in English, self-contained and well-formed, return it as is.
Do NOT answer the question or provide explanations.
"""

contextualize_q_prompt = ChatPromptTemplate.from_messages(
[
("system", contextualize_q_system_prompt),
MessagesPlaceholder(variable_name="chat_history"),
("human", "{input}"),
("human", "{user_input}"),
]
)

Expand Down Expand Up @@ -46,6 +49,6 @@
[
("system", qa_system_prompt),
MessagesPlaceholder(variable_name="chat_history"),
("user", "Context:\n{context}\n\nQuestion: {input}"),
("user", "Context:\n{context}\n\nQuestion: {user_input}"),
]
)
4 changes: 1 addition & 3 deletions src/util/chainlit_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,7 @@ async def message_rate_limited(config: Config | None) -> bool:
quota_message = "Our servers are currently overloaded.\n"
login_uri: str | None = os.getenv("CHAINLIT_URI_LOGIN", "")
if login_uri:
quota_message += (
f"Please [log in]({login_uri}) to continue chatting and enjoy features like saved chat history and fewer limits."
)
quota_message += f"Please [log in]({login_uri}) to continue chatting and enjoy features like saved chat history and fewer limits."
else:
quota_message += "Please try again later."
await send_messages([quota_message])
Expand Down