Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions internal/api/chat/create_conversation_message_stream_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ func (s *ChatServerV2) CreateConversationMessageStream(
}
}

openaiChatHistory, inappChatHistory, err := s.aiClientV2.ChatCompletionStreamV2(ctx, stream, conversation.ID.Hex(), modelSlug, conversation.OpenaiChatHistoryCompletion, llmProvider)
openaiChatHistory, inappChatHistory, err := s.aiClientV2.ChatCompletionStreamV2(ctx, stream, conversation.ID.Hex(), modelSlug, conversation.OpenaiChatHistoryCompletion, llmProvider, customModel)
if err != nil {
return s.sendStreamError(stream, err)
}
Expand All @@ -347,7 +347,7 @@ func (s *ChatServerV2) CreateConversationMessageStream(
for i, bsonMsg := range conversation.InappChatHistory {
protoMessages[i] = mapper.BSONToChatMessageV2(bsonMsg)
}
title, err := s.aiClientV2.GetConversationTitleV2(ctx, protoMessages, llmProvider, modelSlug)
title, err := s.aiClientV2.GetConversationTitleV2(ctx, protoMessages, llmProvider, modelSlug, customModel)
if err != nil {
s.logger.Error("Failed to get conversation title", "error", err, "conversationID", conversation.ID.Hex())
return
Expand Down
42 changes: 24 additions & 18 deletions internal/api/mapper/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,18 @@ func MapProtoSettingsToModel(settings *userv1.Settings) *models.Settings {
}

customModels[i] = models.CustomModel{
Id: id,
Slug: m.Slug,
Name: m.Name,
BaseUrl: m.BaseUrl,
APIKey: m.ApiKey,
ContextWindow: m.ContextWindow,
MaxOutput: m.MaxOutput,
InputPrice: m.InputPrice,
OutputPrice: m.OutputPrice,
Id: id,
Slug: m.Slug,
Name: m.Name,
BaseUrl: m.BaseUrl,
APIKey: m.ApiKey,
ContextWindow: m.ContextWindow,
MaxOutput: m.MaxOutput,
InputPrice: m.InputPrice,
OutputPrice: m.OutputPrice,
Temperature: m.Temperature,
ParallelToolCalls: m.ParallelToolCalls,
Store: m.Store,
}
}

Expand All @@ -47,15 +50,18 @@ func MapModelSettingsToProto(settings *models.Settings) *userv1.Settings {
customModels := make([]*userv1.CustomModel, len(settings.CustomModels))
for i, m := range settings.CustomModels {
customModels[i] = &userv1.CustomModel{
Id: m.Id.Hex(),
Slug: m.Slug,
Name: m.Name,
BaseUrl: m.BaseUrl,
ApiKey: m.APIKey,
ContextWindow: m.ContextWindow,
MaxOutput: m.MaxOutput,
InputPrice: m.InputPrice,
OutputPrice: m.OutputPrice,
Id: m.Id.Hex(),
Slug: m.Slug,
Name: m.Name,
BaseUrl: m.BaseUrl,
ApiKey: m.APIKey,
ContextWindow: m.ContextWindow,
MaxOutput: m.MaxOutput,
InputPrice: m.InputPrice,
OutputPrice: m.OutputPrice,
Temperature: m.Temperature,
ParallelToolCalls: m.ParallelToolCalls,
Store: m.Store,
}
}

Expand Down
21 changes: 12 additions & 9 deletions internal/models/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,18 @@ package models
import "go.mongodb.org/mongo-driver/v2/bson"

type CustomModel struct {
Id bson.ObjectID `bson:"_id"`
Slug string `bson:"slug"`
Name string `bson:"name"`
BaseUrl string `bson:"base_url"`
APIKey string `bson:"api_key"`
ContextWindow int32 `bson:"context_window"`
MaxOutput int32 `bson:"max_output"`
InputPrice int32 `bson:"input_price"`
OutputPrice int32 `bson:"output_price"`
Id bson.ObjectID `bson:"_id"`
Slug string `bson:"slug"`
Name string `bson:"name"`
BaseUrl string `bson:"base_url"`
APIKey string `bson:"api_key"`
ContextWindow int32 `bson:"context_window"`
MaxOutput int32 `bson:"max_output"`
InputPrice int32 `bson:"input_price"`
OutputPrice int32 `bson:"output_price"`
Temperature float32 `bson:"temperature"`
ParallelToolCalls bool `bson:"parallel_tool_calls"`
Store bool `bson:"store"`
}

type Settings struct {
Expand Down
8 changes: 4 additions & 4 deletions internal/services/toolkit/client/completion_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ import (
// 1. The full chat history sent to the language model (including any tool call results).
// 2. The incremental chat history visible to the user (including tool call results and assistant responses).
// 3. An error, if any occurred during the process.
func (a *AIClientV2) ChatCompletionV2(ctx context.Context, modelSlug string, messages OpenAIChatHistory, llmProvider *models.LLMProviderConfig) (OpenAIChatHistory, AppChatHistory, error) {
openaiChatHistory, inappChatHistory, err := a.ChatCompletionStreamV2(ctx, nil, "", modelSlug, messages, llmProvider)
func (a *AIClientV2) ChatCompletionV2(ctx context.Context, modelSlug string, messages OpenAIChatHistory, llmProvider *models.LLMProviderConfig, customModel *models.CustomModel) (OpenAIChatHistory, AppChatHistory, error) {
openaiChatHistory, inappChatHistory, err := a.ChatCompletionStreamV2(ctx, nil, "", modelSlug, messages, llmProvider, customModel)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -54,7 +54,7 @@ func (a *AIClientV2) ChatCompletionV2(ctx context.Context, modelSlug string, mes
// - If tool calls are required, it handles them and appends the results to the chat history, then continues the loop.
// - If no tool calls are needed, it appends the assistant's response and exits the loop.
// - Finally, it returns the updated chat histories and any error encountered.
func (a *AIClientV2) ChatCompletionStreamV2(ctx context.Context, callbackStream chatv2.ChatService_CreateConversationMessageStreamServer, conversationId string, modelSlug string, messages OpenAIChatHistory, llmProvider *models.LLMProviderConfig) (OpenAIChatHistory, AppChatHistory, error) {
func (a *AIClientV2) ChatCompletionStreamV2(ctx context.Context, callbackStream chatv2.ChatService_CreateConversationMessageStreamServer, conversationId string, modelSlug string, messages OpenAIChatHistory, llmProvider *models.LLMProviderConfig, customModel *models.CustomModel) (OpenAIChatHistory, AppChatHistory, error) {
openaiChatHistory := messages
inappChatHistory := AppChatHistory{}

Expand All @@ -66,7 +66,7 @@ func (a *AIClientV2) ChatCompletionStreamV2(ctx context.Context, callbackStream
}()

oaiClient := a.GetOpenAIClient(llmProvider)
params := getDefaultParamsV2(modelSlug, a.toolCallHandler.Registry, llmProvider.IsCustomModel)
params := getDefaultParamsV2(modelSlug, a.toolCallHandler.Registry, customModel)

for {
params.Messages = openaiChatHistory
Expand Down
2 changes: 1 addition & 1 deletion internal/services/toolkit/client/get_citation_keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ func (a *AIClientV2) GetCitationKeys(ctx context.Context, sentence string, userI
_, resp, err := a.ChatCompletionV2(ctx, "gpt-5.2", OpenAIChatHistory{
openai.SystemMessage("You are a helpful assistant that suggests relevant citation keys."),
openai.UserMessage(message),
}, llmProvider)
}, llmProvider, nil)

if err != nil {
return nil, err
Expand Down
4 changes: 2 additions & 2 deletions internal/services/toolkit/client/get_conversation_title_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
"github.com/samber/lo"
)

func (a *AIClientV2) GetConversationTitleV2(ctx context.Context, inappChatHistory []*chatv2.Message, llmProvider *models.LLMProviderConfig, modelSlug string) (string, error) {
func (a *AIClientV2) GetConversationTitleV2(ctx context.Context, inappChatHistory []*chatv2.Message, llmProvider *models.LLMProviderConfig, modelSlug string, customModel *models.CustomModel) (string, error) {
messages := lo.Map(inappChatHistory, func(message *chatv2.Message, _ int) string {
if _, ok := message.Payload.MessageType.(*chatv2.MessagePayload_Assistant); ok {
return fmt.Sprintf("Assistant: %s", message.Payload.GetAssistant().GetContent())
Expand All @@ -38,7 +38,7 @@ func (a *AIClientV2) GetConversationTitleV2(ctx context.Context, inappChatHistor
_, resp, err := a.ChatCompletionV2(ctx, modelToUse, OpenAIChatHistory{
openai.SystemMessage("You are a helpful assistant that generates a title for a conversation."),
openai.UserMessage(message),
}, llmProvider)
}, llmProvider, customModel)
if err != nil {
return "", err
}
Expand Down
24 changes: 16 additions & 8 deletions internal/services/toolkit/client/utils_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"paperdebugger/internal/libs/cfg"
"paperdebugger/internal/libs/db"
"paperdebugger/internal/libs/logger"
"paperdebugger/internal/models"
"paperdebugger/internal/services"
"paperdebugger/internal/services/toolkit/registry"
filetools "paperdebugger/internal/services/toolkit/tools/files"
Expand Down Expand Up @@ -53,7 +54,7 @@ func appendAssistantTextResponseV2(openaiChatHistory *OpenAIChatHistory, inappCh
})
}

func getDefaultParamsV2(modelSlug string, toolRegistry *registry.ToolRegistryV2, isCustomModel bool) openaiv3.ChatCompletionNewParams {
func getDefaultParamsV2(modelSlug string, toolRegistry *registry.ToolRegistryV2, customModel *models.CustomModel) openaiv3.ChatCompletionNewParams {
var reasoningModels = []string{
"gpt-5",
"gpt-5-mini",
Expand All @@ -67,15 +68,22 @@ func getDefaultParamsV2(modelSlug string, toolRegistry *registry.ToolRegistryV2,
"codex-mini-latest",
}

// Other model providers generally do not support the Store param
if isCustomModel {
return openaiv3.ChatCompletionNewParams{
Model: modelSlug,
Temperature: openaiv3.Float(0.7),
MaxCompletionTokens: openaiv3.Int(4000),
if customModel != nil {
params := openaiv3.ChatCompletionNewParams{
Model: customModel.Slug,
Temperature: openaiv3.Float(float64(customModel.Temperature)),
MaxCompletionTokens: openaiv3.Int(int64(customModel.MaxOutput)),
Tools: toolRegistry.GetTools(),
ParallelToolCalls: openaiv3.Bool(true),
ParallelToolCalls: openaiv3.Bool(customModel.ParallelToolCalls),
}

// Store param should only be included if it is true
// Some providers like Gemini might not support the param at all even if false
if customModel.Store {
params.Store = openaiv3.Bool(customModel.Store)
}

return params
}

for _, model := range reasoningModels {
Expand Down
56 changes: 42 additions & 14 deletions pkg/gen/api/user/v1/user.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions proto/user/v1/user.proto
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ message CustomModel {
int32 max_output = 7;
int32 input_price = 8;
int32 output_price = 9;
float temperature = 10;
bool parallel_tool_calls = 11;
bool store = 12;
}

message Settings {
Expand Down
Loading
Loading