diff --git a/MultimodalQnA/ui/gradio/conversation.py b/MultimodalQnA/ui/gradio/conversation.py
index 678f7872c2..36e0754953 100644
--- a/MultimodalQnA/ui/gradio/conversation.py
+++ b/MultimodalQnA/ui/gradio/conversation.py
@@ -2,11 +2,12 @@
# SPDX-License-Identifier: Apache-2.0
import dataclasses
+
from enum import Enum, auto
-from typing import Dict, List
+from pathlib import Path
+from typing import Dict, List, Any, Literal
-from PIL import Image
-from utils import convert_audio_to_base64, get_b64_frame_from_timestamp
+from utils import convert_audio_to_base64, get_b64_frame_from_timestamp, GRADIO_IMAGE_FORMATS, GRADIO_AUDIO_FORMATS
class SeparatorStyle(Enum):
@@ -21,8 +22,7 @@ class Conversation:
system: str
roles: List[str]
- messages: List[List[str]]
- image_query_files: Dict[int, str]
+ chatbot_history: List[Dict[str, Any]]
offset: int
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
sep: str = "\n"
@@ -42,66 +42,44 @@ def _template_caption(self):
out = f"The caption associated with the image is '{self.caption}'. "
return out
- def get_prompt(self):
- messages = self.messages
- if len(messages) > 1 and messages[1][1] is None:
- # Need to do RAG. If the query is text, prompt is the query only
- if self.audio_query_file:
- ret = [{"role": "user", "content": [{"type": "audio", "audio": self.get_b64_audio_query()}]}]
- elif 0 in self.image_query_files:
- b64_image = get_b64_frame_from_timestamp(self.image_query_files[0], 0)
- ret = [
- {
- "role": "user",
- "content": [
- {"type": "text", "text": messages[0][1]},
- {"type": "image_url", "image_url": {"url": b64_image}},
- ],
- }
- ]
- else:
- ret = messages[0][1]
- else:
- # No need to do RAG. Thus, prompt of chatcompletion format
- conv_dict = []
- if self.sep_style == SeparatorStyle.SINGLE:
- for i, (role, message) in enumerate(messages):
- if message:
- dic = {"role": role}
- content = [{"type": "text", "text": message}]
- # There might be audio
- if self.audio_query_file:
- content.append({"type": "audio", "audio": self.get_b64_audio_query()})
- # There might be a returned item from the first query
- if i == 0 and self.time_of_frame_ms and self.video_file:
- base64_frame = (
- self.base64_frame
- if self.base64_frame
- else get_b64_frame_from_timestamp(self.video_file, self.time_of_frame_ms)
- )
- if base64_frame is None:
- base64_frame = ""
- # Include the original caption for the returned image/video
- if self.caption and content[0]["type"] == "text":
- content[0]["text"] = content[0]["text"] + " " + self._template_caption()
- content.append({"type": "image_url", "image_url": {"url": base64_frame}})
- # There might be a query image
- if i in self.image_query_files:
- content.append(
- {
- "type": "image_url",
- "image_url": {"url": get_b64_frame_from_timestamp(self.image_query_files[i], 0)},
- }
- )
- dic["content"] = content
- conv_dict.append(dic)
- else:
- raise ValueError(f"Invalid style: {self.sep_style}")
- ret = conv_dict
- return ret
-
- def append_message(self, role, message):
- self.messages.append([role, message])
+ def get_prompt(self, is_very_first_query):
+ conv_dict = [{'role': 'user', 'content': []}]
+ caption_flag = True
+ is_image_query = False
+
+ for record in self.chatbot_history:
+ role = record['role']
+ content = record['content']
+
+ if role == 'user':
+ # Check if last entry of conv_dict has role user
+ if conv_dict[-1]['role'] != 'user':
+ conv_dict.append({'role': 'user', 'content': []})
+ elif role == 'assistant':
+ caption_flag = False
+ # Check if last entry of conv_dict has role assistant
+ if conv_dict[-1]['role'] != 'assistant':
+ conv_dict.append({'role': 'assistant', 'content': []})
+
+ # Add content to the last conv_dict record. The single space has only effect on first image-only
+ # query for the similarity search results to get expected response.
+ if isinstance(content, str):
+ if caption_flag:
+ content += " " + self._template_caption()
+ conv_dict[-1]['content'].append({'type': 'text', 'text': content})
+
+ if isinstance(content, dict) and 'path' in content:
+ if Path(content['path']).suffix in GRADIO_IMAGE_FORMATS:
+ is_image_query = True
+ conv_dict[-1]['content'].append({'type': 'image_url', 'image_url': {'url': get_b64_frame_from_timestamp(content['path'], 0)}})
+ if Path(content['path']).suffix in GRADIO_AUDIO_FORMATS:
+ conv_dict[-1]['content'].append({'type': 'audio', 'audio': convert_audio_to_base64(content['path'])})
+
+ # include the image from the assistant's response given the user's is not a image query
+ if not is_image_query and caption_flag and self.image:
+ conv_dict[-1]['content'].append({'type': 'image_url', 'image_url': {'url': get_b64_frame_from_timestamp(self.image, 0)}})
+
+ return conv_dict
def get_b64_image(self):
b64_img = None
@@ -118,68 +96,13 @@ def get_b64_audio_query(self):
return b64_audio
def to_gradio_chatbot(self):
- ret = []
- for i, (role, msg) in enumerate(self.messages[self.offset :]):
- if i % 2 == 0:
- if type(msg) is tuple:
- import base64
- from io import BytesIO
-
- msg, image, image_process_mode = msg
- max_hw, min_hw = max(image.size), min(image.size)
- aspect_ratio = max_hw / min_hw
- max_len, min_len = 800, 400
- shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
- longest_edge = int(shortest_edge * aspect_ratio)
- W, H = image.size
- if H > W:
- H, W = longest_edge, shortest_edge
- else:
- H, W = shortest_edge, longest_edge
- image = image.resize((W, H))
- buffered = BytesIO()
- image.save(buffered, format="JPEG")
- img_b64_str = base64.b64encode(buffered.getvalue()).decode()
- img_str = f'
'
- msg = img_str + msg.replace("", "").strip()
- ret.append([msg, None])
- elif i in self.image_query_files:
- import base64
- from io import BytesIO
-
- image = Image.open(self.image_query_files[i])
- max_hw, min_hw = max(image.size), min(image.size)
- aspect_ratio = max_hw / min_hw
- max_len, min_len = 800, 400
- shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
- longest_edge = int(shortest_edge * aspect_ratio)
- W, H = image.size
- if H > W:
- H, W = longest_edge, shortest_edge
- else:
- H, W = shortest_edge, longest_edge
- image = image.resize((W, H))
- buffered = BytesIO()
- if image.format not in ["JPEG", "JPG"]:
- image = image.convert("RGB")
- image.save(buffered, format="JPEG")
- img_b64_str = base64.b64encode(buffered.getvalue()).decode()
- img_str = f'
'
- msg = img_str + msg.replace("", "").strip()
- ret.append([msg, None])
-
- else:
- ret.append([msg, None])
- else:
- ret[-1][-1] = msg
- return ret
-
+ return self.chatbot_history
+
def copy(self):
return Conversation(
system=self.system,
roles=self.roles,
- messages=[[x, y] for x, y in self.messages],
- image_query_files=self.image_query_files,
+ chatbot_history=self.chatbot_history,
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep,
@@ -192,7 +115,7 @@ def dict(self):
return {
"system": self.system,
"roles": self.roles,
- "messages": self.messages,
+ "chatbot_history": self.chatbot_history,
"offset": self.offset,
"sep": self.sep,
"time_of_frame_ms": self.time_of_frame_ms,
@@ -209,8 +132,7 @@ def dict(self):
multimodalqna_conv = Conversation(
system="",
roles=("user", "assistant"),
- messages=(),
- image_query_files={},
+ chatbot_history=[],
offset=0,
sep_style=SeparatorStyle.SINGLE,
sep="\n",
diff --git a/MultimodalQnA/ui/gradio/multimodalqna_ui_gradio.py b/MultimodalQnA/ui/gradio/multimodalqna_ui_gradio.py
index cb992dd990..1b5b00b5ca 100644
--- a/MultimodalQnA/ui/gradio/multimodalqna_ui_gradio.py
+++ b/MultimodalQnA/ui/gradio/multimodalqna_ui_gradio.py
@@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0
import argparse
+import glob
import os
import shutil
import time
@@ -14,7 +15,16 @@
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from gradio_pdf import PDF
-from utils import build_logger, make_temp_image, server_error_msg, split_video
+from utils import (
+ build_logger,
+ convert_base64_to_audio,
+ GRADIO_AUDIO_FORMATS,
+ GRADIO_IMAGE_FORMATS,
+ make_temp_image,
+ server_error_msg,
+ split_video,
+ TMP_DIR
+)
IMAGE_FORMATS = ['.png', '.gif', '.jpg', '.jpeg']
AUDIO_FORMATS = ['.wav', '.mp3']
@@ -59,79 +69,99 @@ def clear_history(state, request: gr.Request):
if state.pdf and os.path.exists(state.pdf):
os.remove(state.pdf)
state = multimodalqna_conv.copy()
- video = gr.Video(height=512, width=512, elem_id="video", visible=True, label="Media")
- image = gr.Image(height=512, width=512, elem_id="image", visible=False, label="Media")
- pdf = PDF(height=512, elem_id="pdf", interactive=False, visible=False, label="Media")
- return (state, state.to_gradio_chatbot(), {"text": "", "files": []}, None, video, image, pdf) + (disable_btn,) * 1
-
-
-def add_text(state, textbox, audio, request: gr.Request):
- text = textbox["text"]
- logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
- if audio:
- state.audio_query_file = audio
- state.append_message(state.roles[0], "--input placeholder--")
- state.append_message(state.roles[1], None)
- state.skip_next = False
- return (state, state.to_gradio_chatbot(), None, None) + (disable_btn,) * 1
- # If it is a image query
- elif textbox["files"]:
- image_file = textbox["files"][0]
- state.image_query_files[len(state.messages)] = image_file
- state.append_message(state.roles[0], text)
- state.append_message(state.roles[1], None)
- state.skip_next = False
- return (state, state.to_gradio_chatbot(), None, None) + (disable_btn,) * 1
- elif len(text) <= 0:
+ state.chatbot_history = []
+ for file in glob.glob(os.path.join(TMP_DIR, "*.wav")):
+ os.remove(file) # This removes all chatbot assistant's voice response files
+ video = gr.Video(value=None, elem_id="video", visible=True, label="Media")
+ image = gr.Image(value=None, elem_id="image", visible=False, label="Media")
+ pdf = PDF(value=None, elem_id="pdf", interactive=False, visible=False, label="Media")
+ return (state, state.to_gradio_chatbot(), None, video, image, pdf) + (disable_btn,) * 1
+
+
+def add_text(state, multimodal_textbox, request: gr.Request):
+ text = multimodal_textbox["text"]
+ files = multimodal_textbox["files"]
+
+ image_file, audio_file = None, None
+
+ text = text.strip()
+
+ if not text and not files:
state.skip_next = True
- return (state, state.to_gradio_chatbot(), None, None) + (no_change_btn,) * 1
+ return (state, state.to_gradio_chatbot(), None) + (no_change_btn,) * 1
text = text[:2000] # Hard cut-off
- state.append_message(state.roles[0], text)
- state.append_message(state.roles[1], None)
state.skip_next = False
-
- return (state, state.to_gradio_chatbot(), None, None) + (disable_btn,) * 1
+
+ if files:
+ if Path(files[0]).suffix in GRADIO_IMAGE_FORMATS:
+ image_file = files[0]
+ if Path(files[0]).suffix in GRADIO_AUDIO_FORMATS or len(files) > 1:
+ audio_file = files[-1] # Guaranteed that last file would be recorded audio
+
+ # Add to chatbot history
+ if image_file:
+ state.image_query_file = image_file
+ state.chatbot_history.append({
+ "role": state.roles[0],
+ "content": {"path": image_file}
+ })
+ if audio_file:
+ state.audio_query_file = audio_file
+ state.chatbot_history.append({
+ "role": state.roles[0],
+ "content": {"path": audio_file}
+ })
+
+ state.chatbot_history.append({
+ "role": state.roles[0],
+ "content": text
+ })
+
+ logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
+
+ return (
+ state,
+ state.to_gradio_chatbot(),
+ gr.MultimodalTextbox(value=None)
+ ) + (disable_btn,) * 1
-def http_bot(state, request: gr.Request):
+def http_bot(state, audio_response_toggler, request: gr.Request):
global gateway_addr
logger.info(f"http_bot. ip: {request.client.host}")
url = gateway_addr
- is_very_first_query = False
- is_audio_query = state.audio_query_file is not None
+
if state.skip_next:
# This generate call is skipped due to invalid inputs
yield (state, state.to_gradio_chatbot(), None, None, None) + (no_change_btn,) * 1
return
-
- if len(state.messages) == state.offset + 2:
- # First round of conversation
- is_very_first_query = True
- new_state = multimodalqna_conv.copy()
- new_state.append_message(new_state.roles[0], state.messages[-2][1])
- new_state.append_message(new_state.roles[1], None)
- new_state.audio_query_file = state.audio_query_file
- new_state.image_query_files = state.image_query_files
- state = new_state
-
+
+ is_very_first_query = all(True if h['role'] == 'user' else False for h in state.chatbot_history)
+
# Construct prompt
- prompt = state.get_prompt()
+ prompt = state.get_prompt(is_very_first_query)
+
+ modalities = ["text", "audio"] if audio_response_toggler else ["text"]
# Make requests
pload = {
"messages": prompt,
+ "modalities": modalities
}
-
+
+ state.chatbot_history.append({
+ "role": state.roles[1],
+ "content": "▌"
+ })
+
+ yield (state, state.to_gradio_chatbot(), state.split_video, state.image, state.pdf) + (disable_btn,) * 1
+
if logflag:
logger.info(f"==== request ====\n{pload}")
logger.info(f"==== url request ====\n{gateway_addr}")
- state.messages[-1][-1] = "▌"
-
- yield (state, state.to_gradio_chatbot(), state.split_video, state.image, state.pdf) + (disable_btn,) * 1
-
try:
response = requests.post(
url,
@@ -148,9 +178,15 @@ def http_bot(state, request: gr.Request):
choice = response["choices"][-1]
metadata = choice["metadata"]
message = choice["message"]["content"]
+ audio_response = None
+ if audio_response_toggler:
+ if choice["message"]["audio"]:
+ audio_response = choice["message"]["audio"]["data"]
+
if (
is_very_first_query
and not state.video_file
+ and metadata
and "source_video" in metadata
and not state.time_of_frame_ms
and "time_of_frame_ms" in metadata
@@ -183,32 +219,34 @@ def http_bot(state, request: gr.Request):
print(f"pdf {state.video_file} does not exist in UI host!")
output_pdf_path = None
state.pdf = output_pdf_path
-
else:
raise requests.exceptions.RequestException
+
except requests.exceptions.RequestException as e:
- state.messages[-1][-1] = server_error_msg
+ if logflag:
+ logger.info(f"Request Exception occured:\n{str(e)}")
+
+ gr.Error("Request exception occurred. See logs for details.")
+
yield (state, state.to_gradio_chatbot(), None, None, None) + (enable_btn,)
return
-
- state.messages[-1][-1] = message
-
- if is_audio_query:
- state.messages[-2][-1] = metadata.get("audio", "--transcribed audio not available--")
- state.audio_query_file = None
-
+
+ if audio_response:
+ state.chatbot_history[-1]["content"] = {'path': convert_base64_to_audio(audio_response)}
+ else:
+ state.chatbot_history[-1]["content"] = message
+
yield (
state,
state.to_gradio_chatbot(),
gr.Video(state.split_video, visible=state.split_video is not None),
gr.Image(state.image, visible=state.image is not None),
- PDF(state.pdf, visible=state.pdf is not None, interactive=False, starting_page=int(state.time_of_frame_ms)),
+ PDF(state.pdf, visible=state.pdf is not None, interactive=False, starting_page=int(state.time_of_frame_ms) if state.time_of_frame_ms else 0),
) + (enable_btn,) * 1
- logger.info(f"{state.messages[-1][-1]}")
+ logger.info(f"{state.chatbot_history[-1]['content']}")
return
-
def ingest_gen_transcript(filepath, filetype, request: gr.Request):
yield (
gr.Textbox(visible=True, value=f"Please wait while your uploaded {filetype} is ingested into the database...")
@@ -636,30 +674,24 @@ def verify_audio_caption_type(file, request: gr.Request):
with gr.Blocks() as qna:
state = gr.State(multimodalqna_conv.copy())
- with gr.Row():
+ with gr.Row(equal_height=True):
with gr.Column(scale=2):
- video = gr.Video(height=512, width=512, elem_id="video", visible=True, label="Media")
- image = gr.Image(height=512, width=512, elem_id="image", visible=False, label="Media")
- pdf = PDF(height=512, elem_id="pdf", interactive=False, visible=False, label="Media")
+ video = gr.Video(elem_id="video", visible=True, label="Media")
+ image = gr.Image(elem_id="image", visible=False, label="Media")
+ pdf = PDF(elem_id="pdf", interactive=False, visible=False, label="Media")
with gr.Column(scale=9):
- chatbot = gr.Chatbot(elem_id="chatbot", label="MultimodalQnA Chatbot", height=390)
- with gr.Row():
+ chatbot = gr.Chatbot(elem_id="chatbot", label="MultimodalQnA Chatbot", type="messages")
+ with gr.Row(equal_height=True):
with gr.Column(scale=8):
- with gr.Tabs():
- with gr.TabItem("Text & Image Query"):
- textbox = gr.MultimodalTextbox(
- show_label=False, container=True, submit_btn=False, file_types=["image"]
- )
- with gr.TabItem("Audio Query"):
- audio = gr.Audio(
- type="filepath",
- sources=["microphone", "upload"],
- show_label=False,
- container=False,
- )
- with gr.Column(scale=1, min_width=100):
+ multimodal_textbox = gr.MultimodalTextbox(
+ show_label=False,
+ file_types=GRADIO_IMAGE_FORMATS + GRADIO_AUDIO_FORMATS,
+ sources=["microphone", "upload"],
+ placeholder="Text, Image & Audio Query"
+ )
+ with gr.Column(scale=1, min_width=150):
with gr.Row():
- submit_btn = gr.Button(value="Send", variant="primary", interactive=True)
+ audio_response_toggler = gr.Checkbox(label="Audio Responses", container=False)
with gr.Row(elem_id="buttons") as button_row:
clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
@@ -668,19 +700,21 @@ def verify_audio_caption_type(file, request: gr.Request):
[
state,
],
- [state, chatbot, textbox, audio, video, image, pdf, clear_btn],
+ [state, chatbot, multimodal_textbox, video, image, pdf, clear_btn],
)
- submit_btn.click(
+ multimodal_textbox.submit(
add_text,
- [state, textbox, audio],
- [state, chatbot, textbox, audio, clear_btn],
+ [state, multimodal_textbox],
+ [state, chatbot, multimodal_textbox, clear_btn]
).then(
http_bot,
- [
- state,
- ],
- [state, chatbot, video, image, pdf, clear_btn],
+ [state, audio_response_toggler],
+ [state, chatbot, video, image, pdf, clear_btn]
+ ).then(
+ lambda: gr.MultimodalTextbox(interactive=True),
+ None,
+ [multimodal_textbox]
)
with gr.Blocks() as vector_store:
diff --git a/MultimodalQnA/ui/gradio/requirements.txt b/MultimodalQnA/ui/gradio/requirements.txt
index 7c9814d696..80fc5a0dcc 100644
--- a/MultimodalQnA/ui/gradio/requirements.txt
+++ b/MultimodalQnA/ui/gradio/requirements.txt
@@ -1,4 +1,4 @@
-gradio==5.11.0
+gradio==5.17.1
gradio_pdf==0.0.20
moviepy==1.0.3
numpy==1.26.4
diff --git a/MultimodalQnA/ui/gradio/utils.py b/MultimodalQnA/ui/gradio/utils.py
index c22d102a5a..6e4e5ba764 100644
--- a/MultimodalQnA/ui/gradio/utils.py
+++ b/MultimodalQnA/ui/gradio/utils.py
@@ -7,16 +7,21 @@
import os
import shutil
import sys
+import tempfile
from pathlib import Path
import cv2
from moviepy.video.io.VideoFileClip import VideoFileClip
LOGDIR = "."
+TMP_DIR = "/tmp"
server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
+GRADIO_IMAGE_FORMATS = [".jpeg", ".png", ".jpg", ".gif"]
+GRADIO_AUDIO_FORMATS = [".wav", ".mp3",]
+
handler = None
save_log = False
@@ -186,3 +191,15 @@ def convert_audio_to_base64(audio_path):
"""Convert .wav file to base64 string."""
encoded_string = base64.b64encode(open(audio_path, "rb").read())
return encoded_string.decode("utf-8")
+
+def convert_base64_to_audio(b64_str):
+ """Decodes the base64 encoded audio data and returns a saved filepath."""
+
+ audio_data = base64.b64decode(b64_str)
+
+ # Create a temporary file
+ with tempfile.NamedTemporaryFile(dir=TMP_DIR, delete=False, suffix=".wav") as temp_audio:
+ temp_audio.write(audio_data)
+ temp_audio_path = temp_audio.name # Store the path
+
+ return temp_audio_path