diff --git a/bin/chat-chainlit.py b/bin/chat-chainlit.py index fa4faf6..60e0f95 100644 --- a/bin/chat-chainlit.py +++ b/bin/chat-chainlit.py @@ -20,7 +20,8 @@ config: Config | None = Config.from_yaml() profiles: list[ProfileName] = config.profiles if config else [ProfileName.React_to_Me] -llm_graph = AgentGraph(profiles) +models_config = config.models if config else None +llm_graph = AgentGraph(profiles, models_config=models_config) POSTGRES_CHAINLIT_DB = os.getenv("POSTGRES_CHAINLIT_DB") POSTGRES_USER = os.getenv("POSTGRES_USER") diff --git a/config_default.yml b/config_default.yml index e53055a..e2b8eda 100644 --- a/config_default.yml +++ b/config_default.yml @@ -1,5 +1,13 @@ # yaml-language-server: $schema=./.config.schema.yaml +models: + llm: + provider: openai + model: gpt-4o-mini + embedding: + provider: openai + model: text-embedding-3-large + profiles: - React-to-Me diff --git a/src/agent/graph.py b/src/agent/graph.py index 012df27..b972e4c 100644 --- a/src/agent/graph.py +++ b/src/agent/graph.py @@ -16,6 +16,7 @@ from agent.models import get_embedding, get_llm from agent.profiles import ProfileName, create_profile_graphs from agent.profiles.base import InputState, OutputState +from util.config_yml.models import EmbeddingConfig, LLMConfig, ModelsConfig from util.logging import logging LANGGRAPH_DB_URI = f"postgresql://{os.getenv('POSTGRES_USER')}:{os.getenv('POSTGRES_PASSWORD')}@postgres:5432/{os.getenv('POSTGRES_LANGGRAPH_DB')}?sslmode=disable" @@ -28,10 +29,21 @@ class AgentGraph: def __init__( self, profiles: list[ProfileName], + models_config: ModelsConfig | None = None, ) -> None: - # Get base models - llm: BaseChatModel = get_llm("openai", "gpt-4o-mini") - embedding: Embeddings = get_embedding("openai", "text-embedding-3-large") + # Get base models from config (with backward-compatible defaults) + if models_config is None: + models_config = ModelsConfig() + + llm_cfg: LLMConfig = models_config.llm + emb_cfg: EmbeddingConfig = models_config.embedding + + llm: BaseChatModel = get_llm( + llm_cfg.provider, llm_cfg.model, base_url=llm_cfg.base_url + ) + embedding: Embeddings = get_embedding( + emb_cfg.provider, emb_cfg.model, device=emb_cfg.device + ) self.uncompiled_graph: dict[str, StateGraph] = create_profile_graphs( profiles, llm, embedding diff --git a/src/util/config_yml/__init__.py b/src/util/config_yml/__init__.py index e6d57e9..fa460c2 100644 --- a/src/util/config_yml/__init__.py +++ b/src/util/config_yml/__init__.py @@ -2,11 +2,12 @@ from typing import Self import yaml -from pydantic import BaseModel, ValidationError +from pydantic import BaseModel, ConfigDict, ValidationError from agent.profiles import ProfileName from util.config_yml.features import Feature, Features from util.config_yml.messages import Message, TriggerEvent +from util.config_yml.models import EmbeddingConfig, LLMConfig, ModelsConfig from util.config_yml.usage_limits import MessageRate, UsageLimits from util.config_yml.user_matching import match_user from util.logging import logging @@ -16,8 +17,11 @@ class Config(BaseModel): + model_config = ConfigDict(extra="ignore") + features: Features messages: dict[str, Message] + models: ModelsConfig = ModelsConfig() profiles: list[ProfileName] usage_limits: UsageLimits diff --git a/src/util/config_yml/models.py b/src/util/config_yml/models.py new file mode 100644 index 0000000..8df314a --- /dev/null +++ b/src/util/config_yml/models.py @@ -0,0 +1,20 @@ +from typing import Literal + +from pydantic import BaseModel + + +class LLMConfig(BaseModel): + provider: Literal["openai", "ollama"] | str = "openai" + model: str = "gpt-4o-mini" + base_url: str | None = None + + +class EmbeddingConfig(BaseModel): + provider: Literal["openai", "huggingfacehub", "huggingfacelocal"] | str = "openai" + model: str = "text-embedding-3-large" + device: str | None = "cpu" + + +class ModelsConfig(BaseModel): + llm: LLMConfig = LLMConfig() + embedding: EmbeddingConfig = EmbeddingConfig()