Skip to content

Commit a58ca4a

Browse files
authored
Add a new reranking based on mosec. (#210)
Signed-off-by: Jincheng Miao <jincheng.miao@intel.com>
1 parent 2e3c032 commit a58ca4a

9 files changed

Lines changed: 430 additions & 0 deletions

File tree

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# build reranking Mosec endpoint docker image
2+
3+
```
4+
docker build --build-arg http_proxy=$http_proxy --build-arg https_proxy=$https_proxy -t reranking-langchain-mosec:latest -f comps/reranks/langchain-mosec/mosec-docker/Dockerfile .
5+
```
6+
7+
# build reranking microservice docker image
8+
9+
```
10+
docker build --build-arg http_proxy=$http_proxy --build-arg https_proxy=$https_proxy -t opea/reranking-langchain-mosec:latest -f comps/reranks/langchain-mosec/docker/Dockerfile .
11+
```
12+
13+
# launch Mosec endpoint docker container
14+
15+
```
16+
docker run -d --name="reranking-langchain-mosec-endpoint" -p 6001:8000 reranking-langchain-mosec:latest
17+
```
18+
19+
# launch embedding microservice docker container
20+
21+
```
22+
export MOSEC_RERANKING_ENDPOINT=http://127.0.0.1:6001
23+
docker run -d --name="reranking-langchain-mosec-server" -e http_proxy=$http_proxy -e https_proxy=$https_proxy -p 6000:8000 --ipc=host -e MOSEC_RERANKING_ENDPOINT=$MOSEC_RERANKING_ENDPOINT opea/reranking-langchain-mosec:latest
24+
```
25+
26+
# run client test
27+
28+
```
29+
curl http://localhost:6000/v1/reranking \
30+
-X POST \
31+
-d '{"initial_query":"What is Deep Learning?", "retrieved_docs": [{"text":"Deep Learning is not..."}, {"text":"Deep learning is..."}]}' \
32+
-H 'Content-Type: application/json'
33+
```
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Copyright (C) 2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
2+
# Copyright (C) 2024 Intel Corporation
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
FROM langchain/langchain:latest
6+
7+
RUN apt-get update -y && apt-get install -y --no-install-recommends --fix-missing \
8+
libgl1-mesa-glx \
9+
libjemalloc-dev \
10+
vim
11+
12+
RUN useradd -m -s /bin/bash user && \
13+
mkdir -p /home/user && \
14+
chown -R user /home/user/
15+
16+
USER user
17+
18+
COPY comps /home/user/comps
19+
20+
RUN pip install --no-cache-dir --upgrade pip && \
21+
pip install --no-cache-dir -r /home/user/comps/reranks/langchain-mosec/requirements.txt
22+
23+
ENV PYTHONPATH=$PYTHONPATH:/home/user
24+
25+
WORKDIR /home/user/comps/reranks/langchain-mosec
26+
27+
ENTRYPOINT ["python", "reranking_mosec_xeon.py"]
28+
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Copyright (C) 2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
version: "3.8"
5+
6+
services:
7+
reranking:
8+
image: opea/reranking-langchain-mosec:latest
9+
container_name: reranking-langchain-mosec-server
10+
ports:
11+
- "6000:8000"
12+
ipc: host
13+
environment:
14+
http_proxy: ${http_proxy}
15+
https_proxy: ${https_proxy}
16+
MOSEC_RERANKING_ENDPOINT: ${MOSEC_RERANKING_ENDPOINT}
17+
LANGCHAIN_API_KEY: ${LANGCHAIN_API_KEY}
18+
restart: unless-stopped
19+
20+
networks:
21+
default:
22+
driver: bridge
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright (C) 2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
From ubuntu:22.04
5+
ARG DEBIAN_FRONTEND=noninteractive
6+
7+
ENV GLIBC_TUNABLES glibc.cpu.x86_shstk=permissive
8+
9+
COPY comps /root/comps
10+
11+
RUN apt update && apt install -y python3 python3-pip
12+
RUN pip3 install torch==2.2.2 torchvision --trusted-host download.pytorch.org --index-url https://download.pytorch.org/whl/cpu
13+
RUN pip3 install intel-extension-for-pytorch==2.2.0
14+
RUN pip3 install transformers sentence-transformers
15+
RUN pip3 install llmspec mosec
16+
17+
RUN cd /root/ && export HF_ENDPOINT=https://hf-mirror.com && huggingface-cli download --resume-download BAAI/bge-reranker-large --local-dir /root/bge-reranker-large
18+
19+
ENV EMB_MODEL="/root/bge-reranker-large/"
20+
21+
WORKDIR /root/comps/reranks/langchain-mosec/mosec-docker
22+
23+
CMD ["python3", "server-ipex.py"]
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
# Copyright (C) 2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import json
5+
import os
6+
from os import environ
7+
from typing import Any, Dict, List, Optional, Union
8+
9+
import intel_extension_for_pytorch as ipex
10+
import numpy as np
11+
import torch
12+
from mosec import Server, Worker
13+
from mosec.mixin import TypedMsgPackMixin
14+
from msgspec import Struct
15+
from sentence_transformers import CrossEncoder
16+
from torch.utils.data import DataLoader
17+
from tqdm.autonotebook import tqdm, trange
18+
19+
DEFAULT_MODEL = "/root/bge-reranker-large"
20+
21+
22+
class MyCrossEncoder(CrossEncoder):
23+
def __init__(
24+
self,
25+
model_name: str,
26+
num_labels: int = None,
27+
max_length: int = None,
28+
device: str = None,
29+
tokenizer_args: Dict = None,
30+
automodel_args: Dict = None,
31+
trust_remote_code: bool = False,
32+
revision: Optional[str] = None,
33+
local_files_only: bool = False,
34+
default_activation_function=None,
35+
classifier_dropout: float = None,
36+
) -> None:
37+
super().__init__(
38+
model_name,
39+
num_labels,
40+
max_length,
41+
device,
42+
tokenizer_args,
43+
automodel_args,
44+
trust_remote_code,
45+
revision,
46+
local_files_only,
47+
default_activation_function,
48+
classifier_dropout,
49+
)
50+
# jit trace model
51+
self.model = ipex.optimize(self.model, dtype=torch.float32)
52+
vocab_size = self.model.config.vocab_size
53+
batch_size = 16
54+
seq_length = 512
55+
d = torch.randint(vocab_size, size=[batch_size, seq_length])
56+
# t = torch.randint(0, 1, size=[batch_size, seq_length])
57+
m = torch.randint(1, 2, size=[batch_size, seq_length])
58+
self.model = torch.jit.trace(self.model, [d, m], check_trace=False, strict=False)
59+
self.model = torch.jit.freeze(self.model)
60+
61+
def predict(
62+
self,
63+
sentences: List[List[str]],
64+
batch_size: int = 32,
65+
show_progress_bar: bool = None,
66+
num_workers: int = 0,
67+
activation_fct=None,
68+
apply_softmax=False,
69+
convert_to_numpy: bool = True,
70+
convert_to_tensor: bool = False,
71+
) -> Union[List[float], np.ndarray, torch.Tensor]:
72+
input_was_string = False
73+
if isinstance(sentences[0], str): # Cast an individual sentence to a list with length 1
74+
sentences = [sentences]
75+
input_was_string = True
76+
77+
inp_dataloader = DataLoader(
78+
sentences,
79+
batch_size=batch_size,
80+
collate_fn=self.smart_batching_collate_text_only,
81+
num_workers=num_workers,
82+
shuffle=False,
83+
)
84+
85+
iterator = inp_dataloader
86+
if show_progress_bar:
87+
iterator = tqdm(inp_dataloader, desc="Batches")
88+
89+
if activation_fct is None:
90+
activation_fct = self.default_activation_function
91+
92+
pred_scores = []
93+
self.model.eval()
94+
self.model.to(self._target_device)
95+
with torch.no_grad():
96+
for features in iterator:
97+
model_predictions = self.model(**features)
98+
logits = activation_fct(model_predictions["logits"])
99+
100+
if apply_softmax and len(logits[0]) > 1:
101+
logits = torch.nn.functional.softmax(logits, dim=1)
102+
pred_scores.extend(logits)
103+
104+
if self.config.num_labels == 1:
105+
pred_scores = [score[0] for score in pred_scores]
106+
107+
if convert_to_tensor:
108+
pred_scores = torch.stack(pred_scores)
109+
elif convert_to_numpy:
110+
pred_scores = np.asarray([score.cpu().detach().numpy() for score in pred_scores])
111+
112+
if input_was_string:
113+
pred_scores = pred_scores[0]
114+
115+
return pred_scores
116+
117+
118+
class Request(Struct, kw_only=True):
119+
query: str
120+
docs: List[str]
121+
122+
123+
class Response(Struct, kw_only=True):
124+
scores: List[float]
125+
126+
127+
def float_handler(o):
128+
if isinstance(o, float):
129+
return format(o, ".10f")
130+
raise TypeError("Not serializable")
131+
132+
133+
class MosecReranker(Worker):
134+
def __init__(self):
135+
self.model_name = environ.get("MODEL_NAME", DEFAULT_MODEL)
136+
self.model = MyCrossEncoder(self.model_name)
137+
138+
def serialize(self, data: Response) -> bytes:
139+
sorted_list = sorted(data.scores, reverse=True)
140+
index_sorted = [data.scores.index(i) for i in sorted_list]
141+
res = []
142+
for i, s in zip(index_sorted, sorted_list):
143+
tmp = {"index": i, "score": "{:.10f}".format(s)}
144+
res.append(tmp)
145+
return json.dumps(res, default=float_handler).encode("utf-8")
146+
147+
def forward(self, data: List[Request]) -> List[Response]:
148+
sentence_pairs = []
149+
inputs_lens = []
150+
for d in data:
151+
inputs_lens.append(len(d["texts"]))
152+
tmp = [[d["query"], doc] for doc in d["texts"]]
153+
sentence_pairs.extend(tmp)
154+
155+
scores = self.model.predict(sentence_pairs)
156+
scores = scores.tolist()
157+
158+
resp = []
159+
cur_idx = 0
160+
for lens in inputs_lens:
161+
resp.append(Response(scores=scores[cur_idx : cur_idx + lens]))
162+
cur_idx += lens
163+
164+
return resp
165+
166+
167+
if __name__ == "__main__":
168+
MAX_BATCH_SIZE = int(os.environ.get("MAX_BATCH_SIZE", 128))
169+
MAX_WAIT_TIME = int(os.environ.get("MAX_WAIT_TIME", 10))
170+
server = Server()
171+
server.append_worker(MosecReranker, max_batch_size=MAX_BATCH_SIZE, max_wait_time=MAX_WAIT_TIME)
172+
server.run()
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
docarray[full]
2+
fastapi
3+
langchain
4+
langchain_community
5+
openai
6+
opentelemetry-api
7+
opentelemetry-exporter-otlp
8+
opentelemetry-sdk
9+
shortuuid
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright (C) 2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# Copyright 2024 MOSEC Authors
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
18+
import json
19+
import os
20+
import re
21+
import time
22+
23+
import requests
24+
from langchain_core.prompts import ChatPromptTemplate
25+
from langsmith import traceable
26+
27+
from comps import (
28+
LLMParamsDoc,
29+
SearchedDoc,
30+
ServiceType,
31+
opea_microservices,
32+
register_microservice,
33+
register_statistics,
34+
statistics_dict,
35+
)
36+
37+
38+
@register_microservice(
39+
name="opea_service@reranking_mosec_xeon",
40+
service_type=ServiceType.RERANK,
41+
endpoint="/v1/reranking",
42+
host="0.0.0.0",
43+
port=8000,
44+
input_datatype=SearchedDoc,
45+
output_datatype=LLMParamsDoc,
46+
)
47+
@traceable(run_type="llm")
48+
@register_statistics(names=["opea_service@reranking_mosec_xeon"])
49+
def reranking(input: SearchedDoc) -> LLMParamsDoc:
50+
print("reranking input: ", input)
51+
start = time.time()
52+
docs = [doc.text for doc in input.retrieved_docs]
53+
url = mosec_reranking_endpoint + "/inference"
54+
data = {"query": input.initial_query, "texts": docs}
55+
headers = {"Content-Type": "application/json"}
56+
response = requests.post(url, data=json.dumps(data), headers=headers)
57+
response_data = response.json()
58+
best_response = max(response_data, key=lambda response: response["score"])
59+
doc = input.retrieved_docs[best_response["index"]]
60+
if doc.text and len(re.findall("[\u4E00-\u9FFF]", doc.text)) / len(doc.text) >= 0.3:
61+
# chinese context
62+
template = "仅基于以下背景回答问题:\n{context}\n问题: {question}"
63+
else:
64+
template = """Answer the question based only on the following context:
65+
{context}
66+
Question: {question}
67+
"""
68+
prompt = ChatPromptTemplate.from_template(template)
69+
final_prompt = prompt.format(context=doc.text, question=input.initial_query)
70+
statistics_dict["opea_service@reranking_mosec_xeon"].append_latency(time.time() - start, None)
71+
return LLMParamsDoc(query=final_prompt.strip())
72+
73+
74+
if __name__ == "__main__":
75+
mosec_reranking_endpoint = os.getenv("MOSEC_RERANKING_ENDPOINT", "http://localhost:8080")
76+
opea_microservices["opea_service@reranking_mosec_xeon"].start()

0 commit comments

Comments
 (0)