diff --git a/MultimodalQnA/ui/gradio/conversation.py b/MultimodalQnA/ui/gradio/conversation.py index 678f7872c2..36e0754953 100644 --- a/MultimodalQnA/ui/gradio/conversation.py +++ b/MultimodalQnA/ui/gradio/conversation.py @@ -2,11 +2,12 @@ # SPDX-License-Identifier: Apache-2.0 import dataclasses + from enum import Enum, auto -from typing import Dict, List +from pathlib import Path +from typing import Dict, List, Any, Literal -from PIL import Image -from utils import convert_audio_to_base64, get_b64_frame_from_timestamp +from utils import convert_audio_to_base64, get_b64_frame_from_timestamp, GRADIO_IMAGE_FORMATS, GRADIO_AUDIO_FORMATS class SeparatorStyle(Enum): @@ -21,8 +22,7 @@ class Conversation: system: str roles: List[str] - messages: List[List[str]] - image_query_files: Dict[int, str] + chatbot_history: List[Dict[str, Any]] offset: int sep_style: SeparatorStyle = SeparatorStyle.SINGLE sep: str = "\n" @@ -42,66 +42,44 @@ def _template_caption(self): out = f"The caption associated with the image is '{self.caption}'. " return out - def get_prompt(self): - messages = self.messages - if len(messages) > 1 and messages[1][1] is None: - # Need to do RAG. If the query is text, prompt is the query only - if self.audio_query_file: - ret = [{"role": "user", "content": [{"type": "audio", "audio": self.get_b64_audio_query()}]}] - elif 0 in self.image_query_files: - b64_image = get_b64_frame_from_timestamp(self.image_query_files[0], 0) - ret = [ - { - "role": "user", - "content": [ - {"type": "text", "text": messages[0][1]}, - {"type": "image_url", "image_url": {"url": b64_image}}, - ], - } - ] - else: - ret = messages[0][1] - else: - # No need to do RAG. Thus, prompt of chatcompletion format - conv_dict = [] - if self.sep_style == SeparatorStyle.SINGLE: - for i, (role, message) in enumerate(messages): - if message: - dic = {"role": role} - content = [{"type": "text", "text": message}] - # There might be audio - if self.audio_query_file: - content.append({"type": "audio", "audio": self.get_b64_audio_query()}) - # There might be a returned item from the first query - if i == 0 and self.time_of_frame_ms and self.video_file: - base64_frame = ( - self.base64_frame - if self.base64_frame - else get_b64_frame_from_timestamp(self.video_file, self.time_of_frame_ms) - ) - if base64_frame is None: - base64_frame = "" - # Include the original caption for the returned image/video - if self.caption and content[0]["type"] == "text": - content[0]["text"] = content[0]["text"] + " " + self._template_caption() - content.append({"type": "image_url", "image_url": {"url": base64_frame}}) - # There might be a query image - if i in self.image_query_files: - content.append( - { - "type": "image_url", - "image_url": {"url": get_b64_frame_from_timestamp(self.image_query_files[i], 0)}, - } - ) - dic["content"] = content - conv_dict.append(dic) - else: - raise ValueError(f"Invalid style: {self.sep_style}") - ret = conv_dict - return ret - - def append_message(self, role, message): - self.messages.append([role, message]) + def get_prompt(self, is_very_first_query): + conv_dict = [{'role': 'user', 'content': []}] + caption_flag = True + is_image_query = False + + for record in self.chatbot_history: + role = record['role'] + content = record['content'] + + if role == 'user': + # Check if last entry of conv_dict has role user + if conv_dict[-1]['role'] != 'user': + conv_dict.append({'role': 'user', 'content': []}) + elif role == 'assistant': + caption_flag = False + # Check if last entry of conv_dict has role assistant + if conv_dict[-1]['role'] != 'assistant': + conv_dict.append({'role': 'assistant', 'content': []}) + + # Add content to the last conv_dict record. The single space has only effect on first image-only + # query for the similarity search results to get expected response. + if isinstance(content, str): + if caption_flag: + content += " " + self._template_caption() + conv_dict[-1]['content'].append({'type': 'text', 'text': content}) + + if isinstance(content, dict) and 'path' in content: + if Path(content['path']).suffix in GRADIO_IMAGE_FORMATS: + is_image_query = True + conv_dict[-1]['content'].append({'type': 'image_url', 'image_url': {'url': get_b64_frame_from_timestamp(content['path'], 0)}}) + if Path(content['path']).suffix in GRADIO_AUDIO_FORMATS: + conv_dict[-1]['content'].append({'type': 'audio', 'audio': convert_audio_to_base64(content['path'])}) + + # include the image from the assistant's response given the user's is not a image query + if not is_image_query and caption_flag and self.image: + conv_dict[-1]['content'].append({'type': 'image_url', 'image_url': {'url': get_b64_frame_from_timestamp(self.image, 0)}}) + + return conv_dict def get_b64_image(self): b64_img = None @@ -118,68 +96,13 @@ def get_b64_audio_query(self): return b64_audio def to_gradio_chatbot(self): - ret = [] - for i, (role, msg) in enumerate(self.messages[self.offset :]): - if i % 2 == 0: - if type(msg) is tuple: - import base64 - from io import BytesIO - - msg, image, image_process_mode = msg - max_hw, min_hw = max(image.size), min(image.size) - aspect_ratio = max_hw / min_hw - max_len, min_len = 800, 400 - shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) - longest_edge = int(shortest_edge * aspect_ratio) - W, H = image.size - if H > W: - H, W = longest_edge, shortest_edge - else: - H, W = shortest_edge, longest_edge - image = image.resize((W, H)) - buffered = BytesIO() - image.save(buffered, format="JPEG") - img_b64_str = base64.b64encode(buffered.getvalue()).decode() - img_str = f'user upload image' - msg = img_str + msg.replace("", "").strip() - ret.append([msg, None]) - elif i in self.image_query_files: - import base64 - from io import BytesIO - - image = Image.open(self.image_query_files[i]) - max_hw, min_hw = max(image.size), min(image.size) - aspect_ratio = max_hw / min_hw - max_len, min_len = 800, 400 - shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) - longest_edge = int(shortest_edge * aspect_ratio) - W, H = image.size - if H > W: - H, W = longest_edge, shortest_edge - else: - H, W = shortest_edge, longest_edge - image = image.resize((W, H)) - buffered = BytesIO() - if image.format not in ["JPEG", "JPG"]: - image = image.convert("RGB") - image.save(buffered, format="JPEG") - img_b64_str = base64.b64encode(buffered.getvalue()).decode() - img_str = f'user upload image' - msg = img_str + msg.replace("", "").strip() - ret.append([msg, None]) - - else: - ret.append([msg, None]) - else: - ret[-1][-1] = msg - return ret - + return self.chatbot_history + def copy(self): return Conversation( system=self.system, roles=self.roles, - messages=[[x, y] for x, y in self.messages], - image_query_files=self.image_query_files, + chatbot_history=self.chatbot_history, offset=self.offset, sep_style=self.sep_style, sep=self.sep, @@ -192,7 +115,7 @@ def dict(self): return { "system": self.system, "roles": self.roles, - "messages": self.messages, + "chatbot_history": self.chatbot_history, "offset": self.offset, "sep": self.sep, "time_of_frame_ms": self.time_of_frame_ms, @@ -209,8 +132,7 @@ def dict(self): multimodalqna_conv = Conversation( system="", roles=("user", "assistant"), - messages=(), - image_query_files={}, + chatbot_history=[], offset=0, sep_style=SeparatorStyle.SINGLE, sep="\n", diff --git a/MultimodalQnA/ui/gradio/multimodalqna_ui_gradio.py b/MultimodalQnA/ui/gradio/multimodalqna_ui_gradio.py index cb992dd990..1b5b00b5ca 100644 --- a/MultimodalQnA/ui/gradio/multimodalqna_ui_gradio.py +++ b/MultimodalQnA/ui/gradio/multimodalqna_ui_gradio.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import argparse +import glob import os import shutil import time @@ -14,7 +15,16 @@ from fastapi import FastAPI from fastapi.staticfiles import StaticFiles from gradio_pdf import PDF -from utils import build_logger, make_temp_image, server_error_msg, split_video +from utils import ( + build_logger, + convert_base64_to_audio, + GRADIO_AUDIO_FORMATS, + GRADIO_IMAGE_FORMATS, + make_temp_image, + server_error_msg, + split_video, + TMP_DIR +) IMAGE_FORMATS = ['.png', '.gif', '.jpg', '.jpeg'] AUDIO_FORMATS = ['.wav', '.mp3'] @@ -59,79 +69,99 @@ def clear_history(state, request: gr.Request): if state.pdf and os.path.exists(state.pdf): os.remove(state.pdf) state = multimodalqna_conv.copy() - video = gr.Video(height=512, width=512, elem_id="video", visible=True, label="Media") - image = gr.Image(height=512, width=512, elem_id="image", visible=False, label="Media") - pdf = PDF(height=512, elem_id="pdf", interactive=False, visible=False, label="Media") - return (state, state.to_gradio_chatbot(), {"text": "", "files": []}, None, video, image, pdf) + (disable_btn,) * 1 - - -def add_text(state, textbox, audio, request: gr.Request): - text = textbox["text"] - logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}") - if audio: - state.audio_query_file = audio - state.append_message(state.roles[0], "--input placeholder--") - state.append_message(state.roles[1], None) - state.skip_next = False - return (state, state.to_gradio_chatbot(), None, None) + (disable_btn,) * 1 - # If it is a image query - elif textbox["files"]: - image_file = textbox["files"][0] - state.image_query_files[len(state.messages)] = image_file - state.append_message(state.roles[0], text) - state.append_message(state.roles[1], None) - state.skip_next = False - return (state, state.to_gradio_chatbot(), None, None) + (disable_btn,) * 1 - elif len(text) <= 0: + state.chatbot_history = [] + for file in glob.glob(os.path.join(TMP_DIR, "*.wav")): + os.remove(file) # This removes all chatbot assistant's voice response files + video = gr.Video(value=None, elem_id="video", visible=True, label="Media") + image = gr.Image(value=None, elem_id="image", visible=False, label="Media") + pdf = PDF(value=None, elem_id="pdf", interactive=False, visible=False, label="Media") + return (state, state.to_gradio_chatbot(), None, video, image, pdf) + (disable_btn,) * 1 + + +def add_text(state, multimodal_textbox, request: gr.Request): + text = multimodal_textbox["text"] + files = multimodal_textbox["files"] + + image_file, audio_file = None, None + + text = text.strip() + + if not text and not files: state.skip_next = True - return (state, state.to_gradio_chatbot(), None, None) + (no_change_btn,) * 1 + return (state, state.to_gradio_chatbot(), None) + (no_change_btn,) * 1 text = text[:2000] # Hard cut-off - state.append_message(state.roles[0], text) - state.append_message(state.roles[1], None) state.skip_next = False - - return (state, state.to_gradio_chatbot(), None, None) + (disable_btn,) * 1 + + if files: + if Path(files[0]).suffix in GRADIO_IMAGE_FORMATS: + image_file = files[0] + if Path(files[0]).suffix in GRADIO_AUDIO_FORMATS or len(files) > 1: + audio_file = files[-1] # Guaranteed that last file would be recorded audio + + # Add to chatbot history + if image_file: + state.image_query_file = image_file + state.chatbot_history.append({ + "role": state.roles[0], + "content": {"path": image_file} + }) + if audio_file: + state.audio_query_file = audio_file + state.chatbot_history.append({ + "role": state.roles[0], + "content": {"path": audio_file} + }) + + state.chatbot_history.append({ + "role": state.roles[0], + "content": text + }) + + logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}") + + return ( + state, + state.to_gradio_chatbot(), + gr.MultimodalTextbox(value=None) + ) + (disable_btn,) * 1 -def http_bot(state, request: gr.Request): +def http_bot(state, audio_response_toggler, request: gr.Request): global gateway_addr logger.info(f"http_bot. ip: {request.client.host}") url = gateway_addr - is_very_first_query = False - is_audio_query = state.audio_query_file is not None + if state.skip_next: # This generate call is skipped due to invalid inputs yield (state, state.to_gradio_chatbot(), None, None, None) + (no_change_btn,) * 1 return - - if len(state.messages) == state.offset + 2: - # First round of conversation - is_very_first_query = True - new_state = multimodalqna_conv.copy() - new_state.append_message(new_state.roles[0], state.messages[-2][1]) - new_state.append_message(new_state.roles[1], None) - new_state.audio_query_file = state.audio_query_file - new_state.image_query_files = state.image_query_files - state = new_state - + + is_very_first_query = all(True if h['role'] == 'user' else False for h in state.chatbot_history) + # Construct prompt - prompt = state.get_prompt() + prompt = state.get_prompt(is_very_first_query) + + modalities = ["text", "audio"] if audio_response_toggler else ["text"] # Make requests pload = { "messages": prompt, + "modalities": modalities } - + + state.chatbot_history.append({ + "role": state.roles[1], + "content": "▌" + }) + + yield (state, state.to_gradio_chatbot(), state.split_video, state.image, state.pdf) + (disable_btn,) * 1 + if logflag: logger.info(f"==== request ====\n{pload}") logger.info(f"==== url request ====\n{gateway_addr}") - state.messages[-1][-1] = "▌" - - yield (state, state.to_gradio_chatbot(), state.split_video, state.image, state.pdf) + (disable_btn,) * 1 - try: response = requests.post( url, @@ -148,9 +178,15 @@ def http_bot(state, request: gr.Request): choice = response["choices"][-1] metadata = choice["metadata"] message = choice["message"]["content"] + audio_response = None + if audio_response_toggler: + if choice["message"]["audio"]: + audio_response = choice["message"]["audio"]["data"] + if ( is_very_first_query and not state.video_file + and metadata and "source_video" in metadata and not state.time_of_frame_ms and "time_of_frame_ms" in metadata @@ -183,32 +219,34 @@ def http_bot(state, request: gr.Request): print(f"pdf {state.video_file} does not exist in UI host!") output_pdf_path = None state.pdf = output_pdf_path - else: raise requests.exceptions.RequestException + except requests.exceptions.RequestException as e: - state.messages[-1][-1] = server_error_msg + if logflag: + logger.info(f"Request Exception occured:\n{str(e)}") + + gr.Error("Request exception occurred. See logs for details.") + yield (state, state.to_gradio_chatbot(), None, None, None) + (enable_btn,) return - - state.messages[-1][-1] = message - - if is_audio_query: - state.messages[-2][-1] = metadata.get("audio", "--transcribed audio not available--") - state.audio_query_file = None - + + if audio_response: + state.chatbot_history[-1]["content"] = {'path': convert_base64_to_audio(audio_response)} + else: + state.chatbot_history[-1]["content"] = message + yield ( state, state.to_gradio_chatbot(), gr.Video(state.split_video, visible=state.split_video is not None), gr.Image(state.image, visible=state.image is not None), - PDF(state.pdf, visible=state.pdf is not None, interactive=False, starting_page=int(state.time_of_frame_ms)), + PDF(state.pdf, visible=state.pdf is not None, interactive=False, starting_page=int(state.time_of_frame_ms) if state.time_of_frame_ms else 0), ) + (enable_btn,) * 1 - logger.info(f"{state.messages[-1][-1]}") + logger.info(f"{state.chatbot_history[-1]['content']}") return - def ingest_gen_transcript(filepath, filetype, request: gr.Request): yield ( gr.Textbox(visible=True, value=f"Please wait while your uploaded {filetype} is ingested into the database...") @@ -636,30 +674,24 @@ def verify_audio_caption_type(file, request: gr.Request): with gr.Blocks() as qna: state = gr.State(multimodalqna_conv.copy()) - with gr.Row(): + with gr.Row(equal_height=True): with gr.Column(scale=2): - video = gr.Video(height=512, width=512, elem_id="video", visible=True, label="Media") - image = gr.Image(height=512, width=512, elem_id="image", visible=False, label="Media") - pdf = PDF(height=512, elem_id="pdf", interactive=False, visible=False, label="Media") + video = gr.Video(elem_id="video", visible=True, label="Media") + image = gr.Image(elem_id="image", visible=False, label="Media") + pdf = PDF(elem_id="pdf", interactive=False, visible=False, label="Media") with gr.Column(scale=9): - chatbot = gr.Chatbot(elem_id="chatbot", label="MultimodalQnA Chatbot", height=390) - with gr.Row(): + chatbot = gr.Chatbot(elem_id="chatbot", label="MultimodalQnA Chatbot", type="messages") + with gr.Row(equal_height=True): with gr.Column(scale=8): - with gr.Tabs(): - with gr.TabItem("Text & Image Query"): - textbox = gr.MultimodalTextbox( - show_label=False, container=True, submit_btn=False, file_types=["image"] - ) - with gr.TabItem("Audio Query"): - audio = gr.Audio( - type="filepath", - sources=["microphone", "upload"], - show_label=False, - container=False, - ) - with gr.Column(scale=1, min_width=100): + multimodal_textbox = gr.MultimodalTextbox( + show_label=False, + file_types=GRADIO_IMAGE_FORMATS + GRADIO_AUDIO_FORMATS, + sources=["microphone", "upload"], + placeholder="Text, Image & Audio Query" + ) + with gr.Column(scale=1, min_width=150): with gr.Row(): - submit_btn = gr.Button(value="Send", variant="primary", interactive=True) + audio_response_toggler = gr.Checkbox(label="Audio Responses", container=False) with gr.Row(elem_id="buttons") as button_row: clear_btn = gr.Button(value="🗑️ Clear", interactive=False) @@ -668,19 +700,21 @@ def verify_audio_caption_type(file, request: gr.Request): [ state, ], - [state, chatbot, textbox, audio, video, image, pdf, clear_btn], + [state, chatbot, multimodal_textbox, video, image, pdf, clear_btn], ) - submit_btn.click( + multimodal_textbox.submit( add_text, - [state, textbox, audio], - [state, chatbot, textbox, audio, clear_btn], + [state, multimodal_textbox], + [state, chatbot, multimodal_textbox, clear_btn] ).then( http_bot, - [ - state, - ], - [state, chatbot, video, image, pdf, clear_btn], + [state, audio_response_toggler], + [state, chatbot, video, image, pdf, clear_btn] + ).then( + lambda: gr.MultimodalTextbox(interactive=True), + None, + [multimodal_textbox] ) with gr.Blocks() as vector_store: diff --git a/MultimodalQnA/ui/gradio/requirements.txt b/MultimodalQnA/ui/gradio/requirements.txt index 7c9814d696..80fc5a0dcc 100644 --- a/MultimodalQnA/ui/gradio/requirements.txt +++ b/MultimodalQnA/ui/gradio/requirements.txt @@ -1,4 +1,4 @@ -gradio==5.11.0 +gradio==5.17.1 gradio_pdf==0.0.20 moviepy==1.0.3 numpy==1.26.4 diff --git a/MultimodalQnA/ui/gradio/utils.py b/MultimodalQnA/ui/gradio/utils.py index c22d102a5a..6e4e5ba764 100644 --- a/MultimodalQnA/ui/gradio/utils.py +++ b/MultimodalQnA/ui/gradio/utils.py @@ -7,16 +7,21 @@ import os import shutil import sys +import tempfile from pathlib import Path import cv2 from moviepy.video.io.VideoFileClip import VideoFileClip LOGDIR = "." +TMP_DIR = "/tmp" server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." +GRADIO_IMAGE_FORMATS = [".jpeg", ".png", ".jpg", ".gif"] +GRADIO_AUDIO_FORMATS = [".wav", ".mp3",] + handler = None save_log = False @@ -186,3 +191,15 @@ def convert_audio_to_base64(audio_path): """Convert .wav file to base64 string.""" encoded_string = base64.b64encode(open(audio_path, "rb").read()) return encoded_string.decode("utf-8") + +def convert_base64_to_audio(b64_str): + """Decodes the base64 encoded audio data and returns a saved filepath.""" + + audio_data = base64.b64decode(b64_str) + + # Create a temporary file + with tempfile.NamedTemporaryFile(dir=TMP_DIR, delete=False, suffix=".wav") as temp_audio: + temp_audio.write(audio_data) + temp_audio_path = temp_audio.name # Store the path + + return temp_audio_path