Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
d2ba10d
local changes
heliamoh Nov 29, 2024
7a882a0
Updated the Evaluator, Added a new function to RAGChain to return inp…
heliamoh Dec 23, 2024
36770a9
created a new function for RAGChainWithMemory that returns source doc…
heliamoh Dec 24, 2024
7789560
pulled main branch changes
heliamoh Dec 24, 2024
6d9fb56
Created a wrapper around the Tavily API, modified the ndoes and workf…
heliamoh Jan 13, 2025
ec9af0a
created a wrapper arounf Tavily API
heliamoh Jan 13, 2025
ecbac92
updated final result formatting of the wrappers
heliamoh Jan 14, 2025
4d7cf2f
modified query-rewriter and completeness-evaluator
heliamoh Jan 14, 2025
bc7d411
created a new query-handler to organize and re-rank external search (…
heliamoh Jan 14, 2025
ceafa7b
updated the node and workflow logic & added a new node to format and …
heliamoh Jan 14, 2025
cf9f27e
Update nodes.py
heliamoh Jan 14, 2025
dea0d30
Update nodes.py
heliamoh Jan 14, 2025
6f42eed
added instructions to make the llm reposne more engaging and clear
heliamoh Jan 15, 2025
c771b1c
Merge branch 'main' into completeness-websearch
GFJHogue Jan 16, 2025
7cf292e
integrating external search flow (WIP)
GFJHogue Jan 16, 2025
aabed26
integrating external search flow (WIP2)
GFJHogue Jan 17, 2025
1ffacd9
external search flow added to main graph; feature switch added to con…
GFJHogue Jan 17, 2025
fa45de9
search results JSX element + code format + tavily rate limit
GFJHogue Jan 20, 2025
4833c7c
fix import
GFJHogue Jan 20, 2025
e0b1d5c
don't pass config to subgraph - searches are one-off and shouldn't ge…
GFJHogue Jan 21, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion .config.schema.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,19 @@
$schema: "https://json-schema.org/draft/2020-12/schema"
type: object
properties:
features:
type: object
properties:
postprocessing:
type: object
properties:
enabled:
type: boolean
user_group:
type: string
enum: ["all", "logged_in"]
required: ["enabled"]
required: ["postprocessing"]
messages:
type: object
additionalProperties:
Expand Down Expand Up @@ -38,4 +51,4 @@ properties:
- required: ["event"]
- required: ["after_messages"]
required: ["message", "trigger"]
required: ["messages"]
required: ["features", "messages"]
18 changes: 15 additions & 3 deletions bin/chat-chainlit.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from conversational_chain.graph import RAGGraphWithMemory
from retreival_chain import create_retrieval_chain
from util.chainlit_helpers import static_messages
from util.chainlit_helpers import is_feature_enabled, static_messages
from util.config_yml import Config, TriggerEvent
from util.embedding_environment import EmbeddingEnvironment
from util.logging import logging
Expand Down Expand Up @@ -87,11 +87,23 @@ async def main(message: cl.Message) -> None:
stream_final_answer=True,
force_stream_final_answer=True, # we're not using prefix tokens
)
enable_postprocess: bool = is_feature_enabled(config, "postprocessing")
result: dict[str, Any] = await llm_graph.ainvoke(
message.content,
callbacks=[cb],
thread_id=thread_id,
enable_postprocess=enable_postprocess,
)
if len(result["additional_text"]) > 0:
await cl.Message(content=result["additional_text"]).send()
if (
enable_postprocess
and cb.final_stream
and len(result["additional_content"]["search_results"]) > 0
):
sent_message: cl.Message = cb.final_stream
search_results_element = cl.CustomElement(
name="SearchResults",
props={"results": result["additional_content"]["search_results"]},
)
sent_message.elements = [search_results_element] # type: ignore
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Chainlit's own type annotations on this make it seemingly impossible to satisfy mypy here.

await sent_message.update()
await static_messages(config, after_messages=message_count)
2 changes: 1 addition & 1 deletion bin/chat-fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ async def verify_captcha_middleware(request: Request, call_next):
response = await call_next(request)
return response

host = request.headers.get('referer')
host = request.headers.get("referer")
if host and host.startswith("http:"):
url = request.url.replace(scheme="https")
return RedirectResponse(url=str(url))
Expand Down
5 changes: 5 additions & 0 deletions config_default.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# yaml-language-server: $schema=./.config.schema.yaml

features:
postprocessing: # external web search feature
enabled: true
user_group: all

messages:

welcome:
Expand Down
3 changes: 3 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ services:
- CHAINLIT_AUTH_SECRET=${CHAINLIT_AUTH_SECRET}
- CHAINLIT_URI=${CHAINLIT_URI}
- CHAINLIT_URL=${CHAINLIT_URL}
- TAVILY_API_KEY=${TAVILY_API_KEY}
ports:
- "8000:8000"
depends_on:
Expand All @@ -40,6 +41,8 @@ services:
- CLOUDFLARE_SECRET_KEY=${CLOUDFLARE_SECRET_KEY}
- CLOUDFLARE_SITE_KEY=${CLOUDFLARE_SITE_KEY}
- CHAINLIT_URI=${CHAINLIT_URI_NO_LOGIN}
- CHAINLIT_URL=${CHAINLIT_URL}
- TAVILY_API_KEY=${TAVILY_API_KEY}
ports:
- "8001:8000"
depends_on:
Expand Down
347 changes: 343 additions & 4 deletions poetry.lock

Large diffs are not rendered by default.

51 changes: 51 additions & 0 deletions public/elements/SearchResults.jsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
const getDomainFromUrl = (url) => {
const { hostname } = new URL(url);
return hostname;
};

const SearchResults = () => {
return (
<div>
<div class="prose lg:prose-xl">
<p class="leading-7 [&amp;:not(:first-child)]:mt-4 whitespace-pre-wrap break-words">
Here are some external resources you may find helpful:
</p>
</div>
<div className="flex flex-col gap-2 p-4 pt-0">
{props.results.map((result) => (
<a
key={result.id}
href={result.url}
className="flex flex-col items-start gap-2 rounded-lg border p-3 text-left text-sm transition-all hover:bg-accent"
>
<div className="flex w-full flex-col gap-1">
<div className="flex items-center">
<div className="flex items-center gap-2">
<div className="font-semibold">
{result.title}
</div>
</div>
<div className="ml-auto text-xs text-muted-foreground">
{getDomainFromUrl(result.url)}
</div>
</div>
<div
className="text-xs text-muted-foreground"
style={{ // line-clamp-2 class not working for some reason
display: '-webkit-box',
WebkitBoxOrient: 'vertical',
WebkitLineClamp: 2,
overflow: 'hidden',
}}
>
{result.content.substring(0, 300)}
</div>
</div>
</a>
))}
</div>
</div>
)
};

export default SearchResults;
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ rank-bm25 = "^0.2.2"
psycopg = {extras = ["binary"], version = "^3.2.3"}
pydantic = "^2.10.5"
pyyaml = "^6.0.2"
tavily-python = "^0.5.0"

[tool.poetry.group.dev.dependencies]
ruff = "^0.7.1"
Expand All @@ -54,6 +55,8 @@ isort = "^5.13.2"
pandas-stubs = "^2.2.3.241009"
types-requests = "^2.32.0.20241016"
types-pyyaml = "^6.0.12.20241230"
datasets = "^3.2.0"
ragas = "^0.2.11"

[[tool.poetry.source]]
name = "PyPI"
Expand Down
57 changes: 38 additions & 19 deletions src/conversational_chain/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from psycopg_pool import AsyncConnectionPool

from conversational_chain.chain import create_rag_chain
from external_search.state import WebSearchResult
from external_search.workflow import create_search_workflow
from util.logging import logging

LANGGRAPH_DB_URI = f"postgresql://{os.getenv('POSTGRES_USER')}:{os.getenv('POSTGRES_PASSWORD')}@postgres:5432/{os.getenv('POSTGRES_LANGGRAPH_DB')}?sslmode=disable"
Expand All @@ -30,21 +32,25 @@
logging.warning("POSTGRES_LANGGRAPH_DB undefined; falling back to MemorySaver.")


class ChatResponse(TypedDict):
chat_history: Annotated[Sequence[BaseMessage], add_messages]
context: list[Document]
answer: str # primary LLM response that is streamed to the user
class AdditionalContent(TypedDict):
search_results: list[WebSearchResult]


class ChatState(ChatResponse):
class ChatState(TypedDict):
input: str
additional_text: str # additional text to send after graph completes
chat_history: Annotated[Sequence[BaseMessage], add_messages]
context: list[Document]
answer: str # primary LLM response that is streamed to the user
additional_content: (
AdditionalContent # additional content to send after graph completes
)


class RAGGraphWithMemory:
def __init__(self, retriever: BaseRetriever, llm: BaseChatModel) -> None:
# Set up runnables
self.rag_chain: Runnable = create_rag_chain(llm, retriever)
self.search_workflow: CompiledStateGraph = create_search_workflow(llm)

# Create graph
state_graph: StateGraph = StateGraph(ChatState)
Expand Down Expand Up @@ -91,35 +97,48 @@ async def close_pool(self) -> None:

async def call_model(
self, state: ChatState, config: RunnableConfig
) -> ChatResponse:
response = await self.rag_chain.ainvoke(state, config)
) -> dict[str, Any]:
result = await self.rag_chain.ainvoke(state, config)
return {
"chat_history": [
HumanMessage(state["input"]),
AIMessage(response["answer"]),
AIMessage(result["answer"]),
],
"context": response["context"],
"answer": response["answer"],
"context": result["context"],
"answer": result["answer"],
}

async def postprocess(
self, state: ChatResponse, config: RunnableConfig
) -> dict[str, str]:
# TODO: add completeness checking flow here
self, state: ChatState, config: RunnableConfig
) -> dict[str, dict[str, list[WebSearchResult]]]:
search_results: list[WebSearchResult] = []
if config["configurable"]["enable_postprocess"]:
result: dict[str, Any] = await self.search_workflow.ainvoke(
{"question": state["input"], "generation": state["answer"]},
)
search_results = result["search_results"]
return {
"additional_text": "",
"additional_content": {"search_results": search_results},
}

async def ainvoke(
self, user_input: str, callbacks: Callbacks, thread_id: str
self,
user_input: str,
*,
callbacks: Callbacks,
thread_id: str,
enable_postprocess: bool = True,
) -> dict[str, Any]:
if self.graph is None:
self.graph = await self.initialize()
response: dict[str, Any] = await self.graph.ainvoke(
result: dict[str, Any] = await self.graph.ainvoke(
{"input": user_input},
config=RunnableConfig(
callbacks=callbacks,
configurable={"thread_id": thread_id},
configurable={
"thread_id": thread_id,
"enable_postprocess": enable_postprocess,
},
),
)
return response
return result
Loading