Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backend/app/core/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class Provider(str, Enum):
AWS = "aws"
LANGFUSE = "langfuse"
GOOGLE = "google"
SARVAMAI = "sarvamai"


@dataclass
Expand All @@ -32,6 +33,7 @@ class ProviderConfig:
required_fields=["secret_key", "public_key", "host"]
),
Provider.GOOGLE: ProviderConfig(required_fields=["api_key"]),
Provider.SARVAMAI: ProviderConfig(required_fields=["api_key"]),
}


Expand Down
13 changes: 7 additions & 6 deletions backend/app/models/llm/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class TextLLMParams(SQLModel):
description="Reasoning configuration or instructions",
)
temperature: float | None = Field(
default=None,
default=0.1,
ge=0.0,
le=2.0,
)
Expand All @@ -35,17 +35,18 @@ class TextLLMParams(SQLModel):

class STTLLMParams(SQLModel):
model: str
instructions: str
instructions: str | None = None
input_language: str | None = None
output_language: str | None = None
response_format: Literal["text"] | None = Field(
None,
description="Currently supports text type",
)
temperature: float | None = Field(
default=0.2,
default=None,
ge=0.0,
le=2.0,
description="Temperature parameter (not supported by all STT providers)",
)


Expand Down Expand Up @@ -190,7 +191,7 @@ class NativeCompletionConfig(SQLModel):
Supports any LLM provider's native API format.
"""

provider: Literal["openai-native", "google-native"] = Field(
provider: Literal["openai-native", "google-native", "sarvamai-native"] = Field(
...,
description="Native provider type (e.g., openai-native)",
)
Expand All @@ -210,8 +211,8 @@ class KaapiCompletionConfig(SQLModel):
Supports multiple providers: OpenAI, Claude, Gemini, etc.
"""

provider: Literal["openai", "google"] = Field(
..., description="LLM provider (openai)"
provider: Literal["openai", "google", "sarvamai"] = Field(
..., description="LLM provider (openai, google, sarvamai)"
)

type: Literal["text", "stt", "tts"] = Field(
Expand Down
97 changes: 96 additions & 1 deletion backend/app/services/llm/mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,12 +142,98 @@ def map_kaapi_to_google_params(kaapi_params: dict) -> tuple[dict, list[str]]:
return google_params, warnings


def map_kaapi_to_sarvam_params(kaapi_params: dict) -> tuple[dict, list[str]]:
"""Map Kaapi-abstracted parameters to SarvamAI API parameters.

Handles both STTLLMParams and TTSLLMParams.

STTLLMParams: model, instructions, input_language, output_language, response_format, temperature
TTSLLMParams: model, voice, language, response_format

Args:
kaapi_params: Dictionary with standardized Kaapi parameters

Returns:
Tuple of:
- Dictionary of SarvamAI API parameters
- List of warnings for unsupported parameters
"""
sarvam_params = {}
warnings = []

# Model is required for all completion types
model = kaapi_params.get("model")
if not model:
return {}, ["Missing required 'model' parameter"]
sarvam_params["model"] = model

# Determine if STT or TTS based on presence of specific params
voice = kaapi_params.get("voice")
input_language = kaapi_params.get("input_language")

if voice is not None:
# TTS mode - map TTSLLMParams
sarvam_params["speaker"] = voice

language = kaapi_params.get("language")
if not language:
return {}, ["Missing required 'language' parameter for TTS"]
sarvam_params["target_language_code"] = language

response_format = kaapi_params.get("response_format")
if response_format:
# Map audio format to SarvamAI codec
format_mapping = {"mp3": "mp3", "wav": "wav", "ogg": "ogg"}
sarvam_params["output_audio_codec"] = format_mapping.get(
response_format, "wav"
)

elif input_language is not None or kaapi_params.get("output_language") is not None:
# STT mode - map STTLLMParams
output_language = kaapi_params.get("output_language")
transcription_mode = "transcribe"

if input_language == "auto":
sarvam_params["language_code"] = "unknown"
elif input_language:
sarvam_params["language_code"] = input_language

if output_language is None:
output_language = input_language

if output_language == "en-IN" and input_language != output_language:
transcription_mode = "translate"

sarvam_params["mode"] = transcription_mode

# Warn about unsupported STT parameters
instructions = kaapi_params.get("instructions")
if instructions:
warnings.append(
"Parameter 'instructions' is not supported by SarvamAI STT and was ignored"
)

temperature = kaapi_params.get("temperature")
if temperature is not None:
warnings.append(
"Parameter 'temperature' is not supported by SarvamAI STT and was ignored"
)

response_format = kaapi_params.get("response_format")
if response_format:
warnings.append(
"Parameter 'response_format' is not supported by SarvamAI STT and was ignored"
)

return sarvam_params, warnings


def transform_kaapi_config_to_native(
kaapi_config: KaapiCompletionConfig,
) -> tuple[NativeCompletionConfig, list[str]]:
"""Transform Kaapi completion config to native provider config with mapped parameters.

Supports OpenAI and Google AI providers.
Supports OpenAI,Google AI and Sarvam AI providers.

Args:
kaapi_config: KaapiCompletionConfig with abstracted parameters
Expand Down Expand Up @@ -175,4 +261,13 @@ def transform_kaapi_config_to_native(
warnings,
)

if kaapi_config.provider == "sarvamai":
mapped_params, warnings = map_kaapi_to_sarvam_params(kaapi_config.params)
return (
NativeCompletionConfig(
provider="sarvamai-native", params=mapped_params, type=kaapi_config.type
),
warnings,
)

raise ValueError(f"Unsupported provider: {kaapi_config.provider}")
3 changes: 3 additions & 0 deletions backend/app/services/llm/providers/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from app.services.llm.providers.base import BaseProvider
from app.services.llm.providers.oai import OpenAIProvider
from app.services.llm.providers.gai import GoogleAIProvider
from app.services.llm.providers.sai import SarvamAIProvider

logger = logging.getLogger(__name__)

Expand All @@ -16,13 +17,15 @@ class LLMProvider:
# Future constants for native providers:
# CLAUDE_NATIVE = "claude-native"
GOOGLE_NATIVE = "google-native"
SARVAMAI_NATIVE = "sarvamai-native"

_registry: dict[str, type[BaseProvider]] = {
OPENAI_NATIVE: OpenAIProvider,
OPENAI: OpenAIProvider,
# Future native providers:
# CLAUDE_NATIVE: ClaudeProvider,
GOOGLE_NATIVE: GoogleAIProvider,
SARVAMAI_NATIVE: SarvamAIProvider,
}

@classmethod
Expand Down
Loading