|
| 1 | +# Copyright (C) 2024 Intel Corporation |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +import json |
| 5 | +import os |
| 6 | +from os import environ |
| 7 | +from typing import Any, Dict, List, Optional, Union |
| 8 | + |
| 9 | +import intel_extension_for_pytorch as ipex |
| 10 | +import numpy as np |
| 11 | +import torch |
| 12 | +from mosec import Server, Worker |
| 13 | +from mosec.mixin import TypedMsgPackMixin |
| 14 | +from msgspec import Struct |
| 15 | +from sentence_transformers import CrossEncoder |
| 16 | +from torch.utils.data import DataLoader |
| 17 | +from tqdm.autonotebook import tqdm, trange |
| 18 | + |
| 19 | +DEFAULT_MODEL = "/root/bge-reranker-large" |
| 20 | + |
| 21 | + |
| 22 | +class MyCrossEncoder(CrossEncoder): |
| 23 | + def __init__( |
| 24 | + self, |
| 25 | + model_name: str, |
| 26 | + num_labels: int = None, |
| 27 | + max_length: int = None, |
| 28 | + device: str = None, |
| 29 | + tokenizer_args: Dict = None, |
| 30 | + automodel_args: Dict = None, |
| 31 | + trust_remote_code: bool = False, |
| 32 | + revision: Optional[str] = None, |
| 33 | + local_files_only: bool = False, |
| 34 | + default_activation_function=None, |
| 35 | + classifier_dropout: float = None, |
| 36 | + ) -> None: |
| 37 | + super().__init__( |
| 38 | + model_name, |
| 39 | + num_labels, |
| 40 | + max_length, |
| 41 | + device, |
| 42 | + tokenizer_args, |
| 43 | + automodel_args, |
| 44 | + trust_remote_code, |
| 45 | + revision, |
| 46 | + local_files_only, |
| 47 | + default_activation_function, |
| 48 | + classifier_dropout, |
| 49 | + ) |
| 50 | + # jit trace model |
| 51 | + self.model = ipex.optimize(self.model, dtype=torch.float32) |
| 52 | + vocab_size = self.model.config.vocab_size |
| 53 | + batch_size = 16 |
| 54 | + seq_length = 512 |
| 55 | + d = torch.randint(vocab_size, size=[batch_size, seq_length]) |
| 56 | + # t = torch.randint(0, 1, size=[batch_size, seq_length]) |
| 57 | + m = torch.randint(1, 2, size=[batch_size, seq_length]) |
| 58 | + self.model = torch.jit.trace(self.model, [d, m], check_trace=False, strict=False) |
| 59 | + self.model = torch.jit.freeze(self.model) |
| 60 | + |
| 61 | + def predict( |
| 62 | + self, |
| 63 | + sentences: List[List[str]], |
| 64 | + batch_size: int = 32, |
| 65 | + show_progress_bar: bool = None, |
| 66 | + num_workers: int = 0, |
| 67 | + activation_fct=None, |
| 68 | + apply_softmax=False, |
| 69 | + convert_to_numpy: bool = True, |
| 70 | + convert_to_tensor: bool = False, |
| 71 | + ) -> Union[List[float], np.ndarray, torch.Tensor]: |
| 72 | + input_was_string = False |
| 73 | + if isinstance(sentences[0], str): # Cast an individual sentence to a list with length 1 |
| 74 | + sentences = [sentences] |
| 75 | + input_was_string = True |
| 76 | + |
| 77 | + inp_dataloader = DataLoader( |
| 78 | + sentences, |
| 79 | + batch_size=batch_size, |
| 80 | + collate_fn=self.smart_batching_collate_text_only, |
| 81 | + num_workers=num_workers, |
| 82 | + shuffle=False, |
| 83 | + ) |
| 84 | + |
| 85 | + iterator = inp_dataloader |
| 86 | + if show_progress_bar: |
| 87 | + iterator = tqdm(inp_dataloader, desc="Batches") |
| 88 | + |
| 89 | + if activation_fct is None: |
| 90 | + activation_fct = self.default_activation_function |
| 91 | + |
| 92 | + pred_scores = [] |
| 93 | + self.model.eval() |
| 94 | + self.model.to(self._target_device) |
| 95 | + with torch.no_grad(): |
| 96 | + for features in iterator: |
| 97 | + model_predictions = self.model(**features) |
| 98 | + logits = activation_fct(model_predictions["logits"]) |
| 99 | + |
| 100 | + if apply_softmax and len(logits[0]) > 1: |
| 101 | + logits = torch.nn.functional.softmax(logits, dim=1) |
| 102 | + pred_scores.extend(logits) |
| 103 | + |
| 104 | + if self.config.num_labels == 1: |
| 105 | + pred_scores = [score[0] for score in pred_scores] |
| 106 | + |
| 107 | + if convert_to_tensor: |
| 108 | + pred_scores = torch.stack(pred_scores) |
| 109 | + elif convert_to_numpy: |
| 110 | + pred_scores = np.asarray([score.cpu().detach().numpy() for score in pred_scores]) |
| 111 | + |
| 112 | + if input_was_string: |
| 113 | + pred_scores = pred_scores[0] |
| 114 | + |
| 115 | + return pred_scores |
| 116 | + |
| 117 | + |
| 118 | +class Request(Struct, kw_only=True): |
| 119 | + query: str |
| 120 | + docs: List[str] |
| 121 | + |
| 122 | + |
| 123 | +class Response(Struct, kw_only=True): |
| 124 | + scores: List[float] |
| 125 | + |
| 126 | + |
| 127 | +def float_handler(o): |
| 128 | + if isinstance(o, float): |
| 129 | + return format(o, ".10f") |
| 130 | + raise TypeError("Not serializable") |
| 131 | + |
| 132 | + |
| 133 | +class MosecReranker(Worker): |
| 134 | + def __init__(self): |
| 135 | + self.model_name = environ.get("MODEL_NAME", DEFAULT_MODEL) |
| 136 | + self.model = MyCrossEncoder(self.model_name) |
| 137 | + |
| 138 | + def serialize(self, data: Response) -> bytes: |
| 139 | + sorted_list = sorted(data.scores, reverse=True) |
| 140 | + index_sorted = [data.scores.index(i) for i in sorted_list] |
| 141 | + res = [] |
| 142 | + for i, s in zip(index_sorted, sorted_list): |
| 143 | + tmp = {"index": i, "score": "{:.10f}".format(s)} |
| 144 | + res.append(tmp) |
| 145 | + return json.dumps(res, default=float_handler).encode("utf-8") |
| 146 | + |
| 147 | + def forward(self, data: List[Request]) -> List[Response]: |
| 148 | + sentence_pairs = [] |
| 149 | + inputs_lens = [] |
| 150 | + for d in data: |
| 151 | + inputs_lens.append(len(d["texts"])) |
| 152 | + tmp = [[d["query"], doc] for doc in d["texts"]] |
| 153 | + sentence_pairs.extend(tmp) |
| 154 | + |
| 155 | + scores = self.model.predict(sentence_pairs) |
| 156 | + scores = scores.tolist() |
| 157 | + |
| 158 | + resp = [] |
| 159 | + cur_idx = 0 |
| 160 | + for lens in inputs_lens: |
| 161 | + resp.append(Response(scores=scores[cur_idx : cur_idx + lens])) |
| 162 | + cur_idx += lens |
| 163 | + |
| 164 | + return resp |
| 165 | + |
| 166 | + |
| 167 | +if __name__ == "__main__": |
| 168 | + MAX_BATCH_SIZE = int(os.environ.get("MAX_BATCH_SIZE", 128)) |
| 169 | + MAX_WAIT_TIME = int(os.environ.get("MAX_WAIT_TIME", 10)) |
| 170 | + server = Server() |
| 171 | + server.append_worker(MosecReranker, max_batch_size=MAX_BATCH_SIZE, max_wait_time=MAX_WAIT_TIME) |
| 172 | + server.run() |
0 commit comments