Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion bin/chat-chainlit.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
config: Config | None = Config.from_yaml()

profiles: list[ProfileName] = config.profiles if config else [ProfileName.React_to_Me]
llm_graph = AgentGraph(profiles)
models_config = config.models if config else None
llm_graph = AgentGraph(profiles, models_config=models_config)

POSTGRES_CHAINLIT_DB = os.getenv("POSTGRES_CHAINLIT_DB")
POSTGRES_USER = os.getenv("POSTGRES_USER")
Expand Down
8 changes: 8 additions & 0 deletions config_default.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
# yaml-language-server: $schema=./.config.schema.yaml

models:
llm:
provider: openai
model: gpt-4o-mini
embedding:
provider: openai
model: text-embedding-3-large

profiles:
- React-to-Me

Expand Down
18 changes: 15 additions & 3 deletions src/agent/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from agent.models import get_embedding, get_llm
from agent.profiles import ProfileName, create_profile_graphs
from agent.profiles.base import InputState, OutputState
from util.config_yml.models import EmbeddingConfig, LLMConfig, ModelsConfig
from util.logging import logging

LANGGRAPH_DB_URI = f"postgresql://{os.getenv('POSTGRES_USER')}:{os.getenv('POSTGRES_PASSWORD')}@postgres:5432/{os.getenv('POSTGRES_LANGGRAPH_DB')}?sslmode=disable"
Expand All @@ -28,10 +29,21 @@ class AgentGraph:
def __init__(
self,
profiles: list[ProfileName],
models_config: ModelsConfig | None = None,
) -> None:
# Get base models
llm: BaseChatModel = get_llm("openai", "gpt-4o-mini")
embedding: Embeddings = get_embedding("openai", "text-embedding-3-large")
# Get base models from config (with backward-compatible defaults)
if models_config is None:
models_config = ModelsConfig()

llm_cfg: LLMConfig = models_config.llm
emb_cfg: EmbeddingConfig = models_config.embedding

llm: BaseChatModel = get_llm(
llm_cfg.provider, llm_cfg.model, base_url=llm_cfg.base_url
)
embedding: Embeddings = get_embedding(
emb_cfg.provider, emb_cfg.model, device=emb_cfg.device
)

self.uncompiled_graph: dict[str, StateGraph] = create_profile_graphs(
profiles, llm, embedding
Expand Down
6 changes: 5 additions & 1 deletion src/util/config_yml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
from typing import Self

import yaml
from pydantic import BaseModel, ValidationError
from pydantic import BaseModel, ConfigDict, ValidationError

from agent.profiles import ProfileName
from util.config_yml.features import Feature, Features
from util.config_yml.messages import Message, TriggerEvent
from util.config_yml.models import EmbeddingConfig, LLMConfig, ModelsConfig
from util.config_yml.usage_limits import MessageRate, UsageLimits
from util.config_yml.user_matching import match_user
from util.logging import logging
Expand All @@ -16,8 +17,11 @@


class Config(BaseModel):
model_config = ConfigDict(extra="ignore")

features: Features
messages: dict[str, Message]
models: ModelsConfig = ModelsConfig()
profiles: list[ProfileName]
usage_limits: UsageLimits

Expand Down
20 changes: 20 additions & 0 deletions src/util/config_yml/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from typing import Literal

from pydantic import BaseModel


class LLMConfig(BaseModel):
provider: Literal["openai", "ollama"] | str = "openai"
model: str = "gpt-4o-mini"
base_url: str | None = None


class EmbeddingConfig(BaseModel):
provider: Literal["openai", "huggingfacehub", "huggingfacelocal"] | str = "openai"
model: str = "text-embedding-3-large"
device: str | None = "cpu"


class ModelsConfig(BaseModel):
llm: LLMConfig = LLMConfig()
embedding: EmbeddingConfig = EmbeddingConfig()