diff --git a/FinanceAgent/README.md b/FinanceAgent/README.md index 36cb8b3a33..e37d3df0e2 100644 --- a/FinanceAgent/README.md +++ b/FinanceAgent/README.md @@ -1,16 +1,20 @@ # Finance Agent -## 1. Overview +## 1. Overview ## 2. Getting started -### 2.1 Dowload repos + +### 2.1 Download repos + ```bash mkdir /path/to/your/workspace/ export WORKDIR=/path/to/your/workspace/ genaicomps genaiexamples ``` + ### 2.2 Set up env vars + ```bash export HF_CACHE_DIR=/path/to/your/model/cache/ export HF_TOKEN= @@ -18,16 +22,20 @@ export HF_TOKEN= ``` ### 2.3 Build docker images + Build docker images for dataprep, agent, agent-ui. + ```bash # use docker image build ``` If deploy on Gaudi, also need to build vllm image. + ```bash cd $WORKDIR git clone https://github.com/HabanaAI/vllm-fork.git # get the latest release tag of vllm gaudi +cd vllm-fork VLLM_VER=$(git describe --tags "$(git rev-list --tags --max-count=1)") echo "Check out vLLM tag ${VLLM_VER}" git checkout ${VLLM_VER} @@ -35,8 +43,11 @@ docker build --no-cache -f Dockerfile.hpu -t opea/vllm-gaudi:latest --shm-size=1 ``` ## 3. Deploy with docker compose + ### 3.1 Launch vllm endpoint + Below is the command to launch a vllm endpoint on Gaudi that serves `meta-llama/Llama-3.3-70B-Instruct` model on 4 Gaudi cards. + ```bash export vllm_port=8086 export vllm_volume=$HF_CACHE_DIR @@ -44,14 +55,17 @@ export max_length=16384 export model="meta-llama/Llama-3.3-70B-Instruct" docker run -d --runtime=habana --rm --name "vllm-gaudi-server" -e HABANA_VISIBLE_DEVICES=all -p $vllm_port:8000 -v $vllm_volume:/data -e HF_TOKEN=$HF_TOKEN -e HUGGING_FACE_HUB_TOKEN=$HF_TOKEN -e HF_HOME=/data -e OMPI_MCA_btl_vader_single_copy_mechanism=none -e PT_HPU_ENABLE_LAZY_COLLECTIVES=true -e http_proxy=$http_proxy -e https_proxy=$https_proxy -e no_proxy=$no_proxy -e VLLM_SKIP_WARMUP=true --cap-add=sys_nice --ipc=host opea/vllm-gaudi:comps --model ${model} --max-seq-len-to-capture $max_length --tensor-parallel-size 4 ``` + ### 3.2 Prepare knowledge base + The commands below will upload some example files into the knowledge base. You can also upload files through UI. First, launch the redis databases and the dataprep microservice. + ```bash -docker compose -f $WORKDIR/GenAIExamples/FinanceAgent/docker_compose/intel/hpu/gaudi/dataprep.yaml up -d +docker compose -f $WORKDIR/GenAIExamples/FinanceAgent/docker_compose/intel/hpu/gaudi/dataprep_compose.yaml up -d ``` ### 3.3 Launch the multi-agent system -### 3.4 Validate agents +### 3.4 Validate agents diff --git a/FinanceAgent/docker_compose/intel/hpu/gaudi/compose.yaml b/FinanceAgent/docker_compose/intel/hpu/gaudi/compose.yaml index ba243dbbec..9a06b965ac 100644 --- a/FinanceAgent/docker_compose/intel/hpu/gaudi/compose.yaml +++ b/FinanceAgent/docker_compose/intel/hpu/gaudi/compose.yaml @@ -95,4 +95,28 @@ services: https_proxy: ${https_proxy} WORKER_FINQA_AGENT_URL: $WORKER_FINQA_AGENT_URL WORKER_RESEARCH_AGENT_URL: $WORKER_RESEARCH_AGENT_URL + #WORKER_SUM_AGENT_URL: $WORKER_SUM_AGENT_URL + DOCSUM_ENDPOINT: $DOCSUM_ENDPOINT + REDIS_URL_VECTOR: $REDIS_URL_VECTOR + REDIS_URL_KV: $REDIS_URL_KV + TEI_EMBEDDING_ENDPOINT: $TEI_EMBEDDING_ENDPOINT port: 9090 + + docsum-vllm-gaudi: + image: opea/llm-docsum:latest + container_name: docsum-vllm-gaudi + ports: + - ${DOCSUM_PORT:-9000}:9000 + ipc: host + environment: + no_proxy: ${no_proxy} + http_proxy: ${http_proxy} + https_proxy: ${https_proxy} + LLM_ENDPOINT: ${LLM_ENDPOINT} + LLM_MODEL_ID: ${LLM_MODEL_ID} + HF_TOKEN: ${HF_TOKEN} + LOGFLAG: ${LOGFLAG:-False} + MAX_INPUT_TOKENS: ${MAX_INPUT_TOKENS} + MAX_TOTAL_TOKENS: ${MAX_TOTAL_TOKENS} + DocSum_COMPONENT_NAME: ${DocSum_COMPONENT_NAME:-OpeaDocSumvLLM} + restart: unless-stopped diff --git a/FinanceAgent/docker_compose/intel/hpu/gaudi/dataprep_compose.yaml b/FinanceAgent/docker_compose/intel/hpu/gaudi/dataprep_compose.yaml index a1f78bebab..5e4333c7d2 100644 --- a/FinanceAgent/docker_compose/intel/hpu/gaudi/dataprep_compose.yaml +++ b/FinanceAgent/docker_compose/intel/hpu/gaudi/dataprep_compose.yaml @@ -22,7 +22,7 @@ services: interval: 10s timeout: 6s retries: 48 - + redis-vector-db: image: redis/redis-stack:7.2.0-v9 container_name: redis-vector-db @@ -54,7 +54,7 @@ services: timeout: 10s retries: 3 start_period: 10s - + dataprep-redis-finance: image: ${REGISTRY:-opea}/dataprep:${TAG:-latest} container_name: dataprep-redis-server-finance @@ -80,4 +80,3 @@ services: HUGGINGFACEHUB_API_TOKEN: ${HF_TOKEN} HF_TOKEN: ${HF_TOKEN} LOGFLAG: true - diff --git a/FinanceAgent/docker_compose/intel/hpu/gaudi/launch_agents.sh b/FinanceAgent/docker_compose/intel/hpu/gaudi/launch_agents.sh index cc2d979f21..9a97cfd631 100644 --- a/FinanceAgent/docker_compose/intel/hpu/gaudi/launch_agents.sh +++ b/FinanceAgent/docker_compose/intel/hpu/gaudi/launch_agents.sh @@ -1,9 +1,13 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + export ip_address=$(hostname -I | awk '{print $1}') export HUGGINGFACEHUB_API_TOKEN=${HF_TOKEN} export TOOLSET_PATH=$WORKDIR/GenAIExamples/FinanceAgent/tools/ echo "TOOLSET_PATH=${TOOLSET_PATH}" export PROMPT_PATH=$WORKDIR/GenAIExamples/FinanceAgent/prompts/ +echo "PROMPT_PATH=${PROMPT_PATH}" export recursion_limit_worker=12 export recursion_limit_supervisor=10 @@ -16,10 +20,16 @@ export MAX_TOKENS=4096 export WORKER_FINQA_AGENT_URL="http://${ip_address}:9095/v1/chat/completions" export WORKER_RESEARCH_AGENT_URL="http://${ip_address}:9096/v1/chat/completions" +export EMBEDDING_MODEL_ID="BAAI/bge-base-en-v1.5" export TEI_EMBEDDING_ENDPOINT="http://${ip_address}:10221" export REDIS_URL_VECTOR="redis://${ip_address}:6379" export REDIS_URL_KV="redis://${ip_address}:6380" +export MAX_INPUT_TOKENS=2048 +export MAX_TOTAL_TOKENS=4096 +export DocSum_COMPONENT_NAME="OpeaDocSumvLLM" +export DOCSUM_ENDPOINT="http://${ip_address}:9000/v1/docsum" + docker compose -f compose.yaml up -d diff --git a/FinanceAgent/prompts/finqa_prompt.py b/FinanceAgent/prompts/finqa_prompt.py index 877aa342ed..9dda6dc22c 100644 --- a/FinanceAgent/prompts/finqa_prompt.py +++ b/FinanceAgent/prompts/finqa_prompt.py @@ -1,3 +1,6 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + REACT_AGENT_LLAMA_PROMPT = """\ You are a helpful assistant engaged in multi-turn conversations with Financial analysts. You have access to the following two tools: diff --git a/FinanceAgent/prompts/supervisor_prompt.py b/FinanceAgent/prompts/supervisor_prompt.py index cf5883da80..404ba1a481 100644 --- a/FinanceAgent/prompts/supervisor_prompt.py +++ b/FinanceAgent/prompts/supervisor_prompt.py @@ -1,3 +1,6 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + REACT_AGENT_LLAMA_PROMPT = """\ You are a helpful assistant engaged in multi-turn conversations with users. You have the following worker agents working for you. You can call them as calling tools. diff --git a/FinanceAgent/tests/test.py b/FinanceAgent/tests/test.py index a8c73bbd88..5162a58f67 100644 --- a/FinanceAgent/tests/test.py +++ b/FinanceAgent/tests/test.py @@ -57,6 +57,7 @@ def test_chat_completion_multi_turn(args): add_message_and_run(url, user_message, thread_id, stream=args.stream) print("===============End of second turn==================") + def test_supervisor_agent_single_turn(args): url = f"http://{args.ip_addr}:{args.ext_port}/v1/chat/completions" query_list = [ @@ -67,9 +68,7 @@ def test_supervisor_agent_single_turn(args): for query in query_list: thread_id = f"{uuid.uuid4()}" add_message_and_run(url, query, thread_id, stream=args.stream) - print("="*50) - - + print("=" * 50) if __name__ == "__main__": diff --git a/FinanceAgent/tests/test_compose_on_gaudi.sh b/FinanceAgent/tests/test_compose_on_gaudi.sh index d7d0826a1c..e232008887 100644 --- a/FinanceAgent/tests/test_compose_on_gaudi.sh +++ b/FinanceAgent/tests/test_compose_on_gaudi.sh @@ -150,7 +150,7 @@ function stop_dataprep() { echo "Stopping databases" cid=$(docker ps -aq --filter "name=dataprep-redis-server*" --filter "name=redis-*" --filter "name=tei-embedding-*") if [[ ! -z "$cid" ]]; then docker stop $cid && docker rm $cid && sleep 1s; fi - + } function start_agents() { diff --git a/FinanceAgent/tests/test_redis_finance.py b/FinanceAgent/tests/test_redis_finance.py index f98360d69a..b62d4deb2d 100644 --- a/FinanceAgent/tests/test_redis_finance.py +++ b/FinanceAgent/tests/test_redis_finance.py @@ -67,4 +67,4 @@ def test_get(url): url = f"http://localhost:{port}/v1/dataprep/get" test_get(url) else: - raise ValueError("Invalid test_option value. Please choose from ingest, get, delete.") \ No newline at end of file + raise ValueError("Invalid test_option value. Please choose from ingest, get, delete.") diff --git a/FinanceAgent/tools/finqa_agent_tools.yaml b/FinanceAgent/tools/finqa_agent_tools.yaml index bcc5cfd1c0..c118e222a2 100644 --- a/FinanceAgent/tools/finqa_agent_tools.yaml +++ b/FinanceAgent/tools/finqa_agent_tools.yaml @@ -1,3 +1,6 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + search_knowledge_base: description: Search knowledge base of SEC filings. callable_api: finqa_tools.py:get_context_bm25_llm @@ -14,4 +17,4 @@ search_knowledge_base: quarter: type: str description: the quarter of interest, can only specify one quarter. can be 'Q1', 'Q2', 'Q3', 'Q4'. can be an empty string. - return_output: retrieved_data \ No newline at end of file + return_output: retrieved_data diff --git a/FinanceAgent/tools/finqa_tools.py b/FinanceAgent/tools/finqa_tools.py index 55ea015d87..57a0ff95d0 100644 --- a/FinanceAgent/tools/finqa_tools.py +++ b/FinanceAgent/tools/finqa_tools.py @@ -1,24 +1,28 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + from tools.utils import * -def get_context_bm25_llm(query, company, year, quarter = ""): + +def get_context_bm25_llm(query, company, year, quarter=""): k = 5 - + company_list = get_company_list() company = get_company_name_in_kb(company, company_list) if "Cannot find" in company or "Database is empty" in company: return company - + print(f"Company: {company}") # chunks - index_name=f"chunks_{company}" + index_name = f"chunks_{company}" vector_store = get_vectorstore(index_name) chunks_bm25 = bm25_search_broad(query, company, year, quarter, k=k, doc_type="chunks") chunks_sim = similarity_search(vector_store, k, query, company, year, quarter) chunks = chunks_bm25 + chunks_sim - + # tables try: - index_name=f"tables_{company}" + index_name = f"tables_{company}" vector_store_table = get_vectorstore(index_name) # get tables matching metadata tables_bm25 = bm25_search_broad(query, company, year, quarter, k=k, doc_type="tables") @@ -28,7 +32,7 @@ def get_context_bm25_llm(query, company, year, quarter = ""): tables = [] # get unique results - context = get_unique_docs(chunks+tables) + context = get_unique_docs(chunks + tables) print("Context:\n", context[:500]) if context: @@ -50,7 +54,7 @@ def search_full_doc(query, company): company = get_company_name_in_kb(company, company_list) if "Cannot find" in company or "Database is empty" in company: return company - + # search most similar doc title index_name = f"titles_{company}" vector_store = get_vectorstore_titles(index_name) @@ -60,8 +64,8 @@ def search_full_doc(query, company): doc = docs[0] doc_title = doc.page_content print(f"Most similar doc title: {doc_title}") - - kvstore= RedisKVStore(redis_uri=REDIS_URL_KV) + + kvstore = RedisKVStore(redis_uri=REDIS_URL_KV) doc = kvstore.get(doc_title, f"full_doc_{company}") content = doc["full_doc"] doc_length = doc["doc_length"] @@ -80,17 +84,17 @@ def search_full_doc(query, company): # year="2024" # quarter="Q4" - company="Costco" - year="2025" - quarter="Q2" + company = "Costco" + year = "2025" + quarter = "Q2" - collection_name=f"chunks_{company}" + collection_name = f"chunks_{company}" search_metadata = ("company", company) - + resp = get_context_bm25_llm("revenue", company, year, quarter) print("***Response:\n", resp) - print("="*50) + print("=" * 50) print("testing retrieve full doc") query = f"{company} {year} {quarter} earning call" - search_full_doc(query, company) \ No newline at end of file + search_full_doc(query, company) diff --git a/FinanceAgent/tools/redis_kv.py b/FinanceAgent/tools/redis_kv.py index 9f88d02a6b..3ded5bffdb 100644 --- a/FinanceAgent/tools/redis_kv.py +++ b/FinanceAgent/tools/redis_kv.py @@ -143,4 +143,4 @@ def from_host_and_port( port (int): Redis port """ url = f"redis://{host}:{port}".format(host=host, port=port) - return cls(redis_uri=url) \ No newline at end of file + return cls(redis_uri=url) diff --git a/FinanceAgent/tools/research_agent_tools.yaml b/FinanceAgent/tools/research_agent_tools.yaml index e69de29bb2..4057dc0163 100644 --- a/FinanceAgent/tools/research_agent_tools.yaml +++ b/FinanceAgent/tools/research_agent_tools.yaml @@ -0,0 +1,2 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 diff --git a/FinanceAgent/tools/sum_agent_tools.py b/FinanceAgent/tools/sum_agent_tools.py new file mode 100644 index 0000000000..6c502c13cf --- /dev/null +++ b/FinanceAgent/tools/sum_agent_tools.py @@ -0,0 +1,113 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import json + +import requests +from tools.redis_kv import RedisKVStore +from tools.utils import * + + +def get_summary_else_doc(query, company): + company = company.upper() + + # decide if company is in company list + company_list = get_company_list() + print(f"company_list {company_list}") + company = get_company_name_in_kb(company, company_list) + if "Cannot find" in company or "Database is empty" in company: + print(f"Company not found in knowledge base: {company}") + return company + print(f"Company {company}") + + # search most similar doc title + index_name = f"titles_{company}" + vector_store = get_vectorstore_titles(index_name) + k = 1 + docs = vector_store.similarity_search(query, k=k) + if docs: + doc = docs[0] + doc_title = doc.page_content + print(f"Most similar doc title: {doc_title}") + + kvstore = RedisKVStore(redis_uri=REDIS_URL_KV) + doc = kvstore.get(doc_title, f"full_doc_{company}") + doc_length = doc["doc_length"] + print(f"Doc length: {doc_length}") + if "summary" not in doc: + content = doc["full_doc"] + print(f"is_summary: {False}") + is_summary = False + return doc_title, content, is_summary + content = doc["summary"] + is_summary = True + print(f"Summary already exists in KV store: {is_summary}") + return doc_title, content, is_summary + + +def save_doc_summary(summary, doc_title, company): + """Adds a summary to the existing document in the key-value store. + + Args: + kvstore: The key-value store instance. + summary: The summary to be added. + doc_title: The title of the document. + company: The company associated with the document. + """ + kvstore = RedisKVStore(redis_uri=REDIS_URL_KV) + doc_dict = kvstore.get(doc_title, f"full_doc_{company}") + + # Add the summary to the dictionary + doc_dict["summary"] = summary + + # Save the updated value back to the store + kvstore.put(doc_title, doc_dict, collection=f"full_doc_{company}") + + +def summarize(doc_name, company): + ip_address = os.environ.get("ip_address") + # docsum_url = f"http://{ip_address}:9000/v1/docsum" + docsum_url = os.environ.get("DOCSUM_ENDPOINT") + print(f"Docsum Endpoint URL: {docsum_url}") + + doc_title, sum, is_summary = get_summary_else_doc(doc_name, company) + print(f"Summary or full doc from get_summary_else_doc: {sum[:100]} \n -------\n") + if not is_summary: + data = { + "messages": sum, + "max_tokens": 512, + "language": "en", + "stream": False, + "summary_type": "auto", + "chunk_size": 2000, + } + + headers = {"Content-Type": "application/json"} + try: + print("Computing Summary with OPEA DocSum...") + resp = requests.post(url=docsum_url, data=json.dumps(data), headers=headers) + ret = resp.text + resp.raise_for_status() # Raise an exception for unsuccessful HTTP status codes + except requests.exceptions.RequestException as e: + ret = f"An error occurred:{e}" + # save summary into db + print("Saving Summary into KV Store...") + save_doc_summary(ret, doc_title, company) + return ret + else: + return sum + + +if __name__ == "__main__": + company = "Gap" + year = "2024" + quarter = "Q4" + + # company="Costco" + # year="2025" + # quarter="Q2" + + print("testing summarize") + doc_name = f"{company} {year} {quarter} earning call" + summarize(doc_name, company) + print("=" * 50) diff --git a/FinanceAgent/tools/supervisor_agent_tools.yaml b/FinanceAgent/tools/supervisor_agent_tools.yaml index 9c9789674e..0b451557e3 100644 --- a/FinanceAgent/tools/supervisor_agent_tools.yaml +++ b/FinanceAgent/tools/supervisor_agent_tools.yaml @@ -1,3 +1,6 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + finqa_agent: description: answer financial questions about a company. callable_api: supervisor_tools.py:finqa_agent @@ -7,14 +10,17 @@ finqa_agent: description: should include company name and time, for example, which business unit had the highest growth for Microsoft in 2024. return_output: retrieved_data -# summarization_agent: -# description: generate high-lelvel summary of a financial document. -# callable_api: supervisor_tools.py:summarize_doc -# args_schema: -# doc_title: -# type: str -# description: the description of the document, should include company name and time, for example, Apple 2023 Q4 earnings call. -# return_output: summary +summarization_tool: + description: Searches KV store for summary, if it doesn't exist pulls full document and summarize it + callable_api: sum_agent_tools.py:summarize + args_schema: + doc_name: + type: str + description: Descriptive name of the document + company: + type: str + description: Name of the company document belongs to + return_output: summary # research_agent: # description: generate research report on a specified company with fundamentals analysis, sentiment analysis and risk analysis. diff --git a/FinanceAgent/tools/supervisor_tools.py b/FinanceAgent/tools/supervisor_tools.py index 68b9215142..5bddc16bcd 100644 --- a/FinanceAgent/tools/supervisor_tools.py +++ b/FinanceAgent/tools/supervisor_tools.py @@ -1,7 +1,12 @@ -import requests +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + import os -def finqa_agent(query:str): +import requests + + +def finqa_agent(query: str): url = os.environ.get("WORKER_FINQA_AGENT_URL") print(url) proxies = {"http": ""} @@ -11,10 +16,8 @@ def finqa_agent(query:str): response = requests.post(url, json=payload, proxies=proxies) return response.json()["text"] -def summarize_doc(doc_title): - pass -def research_agent(company:str): +def research_agent(company: str): url = os.environ.get("WORKER_RESEARCH_AGENT_URL") print(url) proxies = {"http": ""} @@ -22,4 +25,4 @@ def research_agent(company:str): "messages": company, } response = requests.post(url, json=payload, proxies=proxies) - return response.json()["text"] \ No newline at end of file + return response.json()["text"] diff --git a/FinanceAgent/tools/utils.py b/FinanceAgent/tools/utils.py index fb0679f6e7..ba3820ca54 100644 --- a/FinanceAgent/tools/utils.py +++ b/FinanceAgent/tools/utils.py @@ -1,14 +1,18 @@ -from langchain_community.retrievers import BM25Retriever -from langchain_redis import RedisConfig, RedisVectorStore -from langchain_core.documents import Document -from tools.redis_kv import RedisKVStore +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + import os -from openai import OpenAI + from langchain_community.embeddings import HuggingFaceBgeEmbeddings +from langchain_community.retrievers import BM25Retriever +from langchain_core.documents import Document from langchain_huggingface import HuggingFaceEndpointEmbeddings +from langchain_redis import RedisConfig, RedisVectorStore +from openai import OpenAI +from tools.redis_kv import RedisKVStore # Embedding model -EMBED_MODEL = os.getenv("EMBED_MODEL", "BAAI/bge-base-en-v1.5") +EMBED_MODEL = os.getenv("EMBED_MODEL", "BAAI/bge-base-en-v1.5") TEI_EMBEDDING_ENDPOINT = os.getenv("TEI_EMBEDDING_ENDPOINT", "") # Redis URL @@ -16,19 +20,19 @@ REDIS_URL_KV = os.getenv("REDIS_URL_KV", "redis://localhost:6380/") # LLM config -LLM_MODEL=os.getenv("model", "meta-llama/Llama-3.3-70B-Instruct") -LLM_ENDPOINT=os.getenv("llm_endpoint_url", "http://localhost:8086") +LLM_MODEL = os.getenv("model", "meta-llama/Llama-3.3-70B-Instruct") +LLM_ENDPOINT = os.getenv("llm_endpoint_url", "http://localhost:8086") print(f"LLM endpoint: {LLM_ENDPOINT}") MAX_TOKENS = 1024 TEMPERATURE = 0.2 -COMPANY_NAME_PROMPT="""\ +COMPANY_NAME_PROMPT = """\ Here is the list of company names in the knowledge base: {company_list} This is the company of interest: {company} -Determine if the company of interest is the same as any of the companies in the knowledge base. +Determine if the company of interest is the same as any of the companies in the knowledge base. If yes, map the company of interest to the company name in the knowledge base. Output the company name in {{}}. Example: {{3M}}. If none of the companies in the knowledge base match the company of interest, output "NONE". """ @@ -42,6 +46,7 @@ Now take a deep breath and think step by step to answer the question. Wrap your final answer in {{}}. Example: {{The company has a revenue of $100 million.}} """ + def get_embedder(): if TEI_EMBEDDING_ENDPOINT: # create embeddings using TEI endpoint service @@ -54,10 +59,9 @@ def get_embedder(): embedder = HuggingFaceBgeEmbeddings(model_name=EMBED_MODEL) return embedder + def generate_answer(prompt): - """ - Use vllm endpoint to generate the answer - """ + """Use vllm endpoint to generate the answer.""" # send request to vllm endpoint client = OpenAI( base_url=f"{LLM_ENDPOINT}/v1", @@ -70,12 +74,8 @@ def generate_answer(prompt): } completion = client.chat.completions.create( - model=LLM_MODEL, - messages=[ - {"role": "user", "content": prompt} - ], - **params - ) + model=LLM_MODEL, messages=[{"role": "user", "content": prompt}], **params + ) # get response response = completion.choices[0].message.content @@ -99,15 +99,16 @@ def get_company_list(): return company_list else: return [] - + + def get_company_name_in_kb(company, company_list): if not company_list: return "Database is empty." - + company = company.upper() if company in company_list: return company - + prompt = COMPANY_NAME_PROMPT.format(company_list=company_list, company=company) response = generate_answer(prompt) if "NONE" in response.upper(): @@ -119,6 +120,7 @@ def get_company_name_in_kb(company, company_list): else: return "Failed to parse LLM response." + def get_docs_matching_metadata(metadata, collection_name): """ metadata: ("company_year", "3M_2023") @@ -126,8 +128,8 @@ def get_docs_matching_metadata(metadata, collection_name): """ key = metadata[0] value = metadata[1] - kvstore= RedisKVStore(redis_uri=REDIS_URL_KV) - collection = kvstore.get_all(collection_name) # collection is a dict + kvstore = RedisKVStore(redis_uri=REDIS_URL_KV) + collection = kvstore.get_all(collection_name) # collection is a dict matching_docs = [] for idx in collection: @@ -138,7 +140,8 @@ def get_docs_matching_metadata(metadata, collection_name): matching_docs.append(doc) print(f"Number of docs found with search_metadata {metadata}: {len(matching_docs)}") return matching_docs - + + def convert_docs(docs): # docs: list of dicts converted_docs_content = [] @@ -146,24 +149,17 @@ def convert_docs(docs): for doc in docs: content = doc["content"] # convert content to Document object - metadata = {"type":"content",**doc["metadata"]} - converted_content = Document( - id = doc["metadata"]["doc_id"], - page_content=content, - metadata=metadata - ) - + metadata = {"type": "content", **doc["metadata"]} + converted_content = Document(id=doc["metadata"]["doc_id"], page_content=content, metadata=metadata) + # convert summary to Document object - metadata = {"type":"summary", "content":content, **doc["metadata"]} - converted_summary = Document( - id = doc["metadata"]["doc_id"], - page_content=doc["summary"], - metadata=metadata - ) + metadata = {"type": "summary", "content": content, **doc["metadata"]} + converted_summary = Document(id=doc["metadata"]["doc_id"], page_content=doc["summary"], metadata=metadata) converted_docs_content.append(converted_content) converted_docs_summary.append(converted_summary) return converted_docs_content, converted_docs_summary + def bm25_search(query, metadata, company, doc_type="chunks", k=10): collection_name = f"{doc_type}_{company}" print(f"Collection name: {collection_name}") @@ -189,21 +185,21 @@ def bm25_search(query, metadata, company, doc_type="chunks", k=10): def bm25_search_broad(query, company, year, quarter, k=10, doc_type="chunks"): # search with company filter, but query is query_company_quarter - metadata = ("company",f"{company}") + metadata = ("company", f"{company}") query1 = f"{query} {year} {quarter}" docs1 = bm25_search(query1, metadata, company, k=k, doc_type=doc_type) # search with metadata filters - metadata = ("company_year_quarter",f"{company}_{year}_{quarter}") + metadata = ("company_year_quarter", f"{company}_{year}_{quarter}") print(f"BM25: Searching for docs with metadata: {metadata}") docs = bm25_search(query, metadata, company, k=k, doc_type=doc_type) if not docs: print("BM25: No docs found with company, year and quarter filter, only search with company and year filter") - metadata = ("company_year",f"{company}_{year}") + metadata = ("company_year", f"{company}_{year}") docs = bm25_search(query, metadata, company, k=k, doc_type=doc_type) if not docs: print("BM25: No docs found with company and year filter, only search with company filter") - metadata = ("company",f"{company}") + metadata = ("company", f"{company}") docs = bm25_search(query, metadata, company, k=k, doc_type=doc_type) docs = docs + docs1 @@ -216,11 +212,13 @@ def bm25_search_broad(query, company, year, quarter, k=10, doc_type="chunks"): def set_filter(metadata_filter): # metadata_filter: tuple of (key, value) from redisvl.query.filter import Text + key = metadata_filter[0] value = metadata_filter[1] filter_condition = Text(key) == value return filter_condition + def similarity_search(vector_store, k, query, company, year, quarter=None): query1 = f"{query} {year} {quarter}" filter_condition = set_filter(("company", company)) @@ -229,17 +227,17 @@ def similarity_search(vector_store, k, query, company, year, quarter=None): filter_condition = set_filter(("company_year_quarter", f"{company}_{year}_{quarter}")) docs = vector_store.similarity_search(query, k=k, filter=filter_condition) - - if not docs: # if no relevant document found, relax the filter - print("No relevant document found with company, year and quarter filter, only search with comany and year") + + if not docs: # if no relevant document found, relax the filter + print("No relevant document found with company, year and quarter filter, only search with company and year") filter_condition = set_filter(("company_year", f"{company}_{year}")) docs = vector_store.similarity_search(query, k=k, filter=filter_condition) - - if not docs: # if no relevant document found, relax the filter - print("No relevant document found with company_year filter, only serach with company.....") + + if not docs: # if no relevant document found, relax the filter + print("No relevant document found with company_year filter, only search with company.....") filter_condition = set_filter(("company", company)) docs = vector_store.similarity_search(query, k=k, filter=filter_condition) - + print(f"Similarity search: Found {len(docs)} docs with filter and query: {query}") docs = docs + docs1 @@ -248,6 +246,7 @@ def similarity_search(vector_store, k, query, company, year, quarter=None): else: return docs + def get_index_name(doc_type: str, metadata: dict): company = metadata["company"] if doc_type == "chunks": @@ -262,6 +261,7 @@ def get_index_name(doc_type: str, metadata: dict): raise ValueError("doc_type should be either chunks, tables, titles, or full_doc.") return index_name + def get_content(doc): # doc can be converted doc # of saved doc in vector store @@ -273,20 +273,20 @@ def get_content(doc): content = doc.page_content else: print("Dense retriever doc...") - + doc_id = doc.metadata["doc_id"] # doc_summary=doc.page_content kvstore = RedisKVStore(redis_uri=REDIS_URL_KV) collection_name = get_index_name(doc.metadata["doc_type"], doc.metadata) result = kvstore.get(doc_id, collection_name) content = result["content"] - + # print(f"***Doc Metadata:\n{doc.metadata}") # print(f"***Content: {content[:100]}...") return content - + def get_unique_docs(docs): results = [] context = "" @@ -340,5 +340,3 @@ def get_vectorstore_titles(index_name): embedder = get_embedder() vector_store = RedisVectorStore(embedder, config=config) return vector_store - -