Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 119 additions & 73 deletions include/fastmcpp/client/client.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,11 +285,13 @@ class Client
response = invoke_request();
}

const auto& response_body = unwrap_rpc_result(response);

// Optional server-side progress events
if (options.progress_handler && response.contains("progress") &&
response["progress"].is_array())
if (options.progress_handler && response_body.contains("progress") &&
response_body["progress"].is_array())
{
for (const auto& p : response["progress"])
for (const auto& p : response_body["progress"])
{
float value = p.value("progress", 0.0f);
std::optional<float> total = std::nullopt;
Expand All @@ -301,9 +303,9 @@ class Client
}

// Notification forwarding (sampling/elicitation/roots) if provided by server
if (response.contains("notifications") && response["notifications"].is_array())
if (response_body.contains("notifications") && response_body["notifications"].is_array())
{
for (const auto& n : response["notifications"])
for (const auto& n : response_body["notifications"])
{
if (!n.contains("method"))
continue;
Expand Down Expand Up @@ -385,16 +387,18 @@ class Client
TaskStatus get_task_status(const std::string& task_id)
{
fastmcpp::Json response = call("tasks/get", {{"taskId", task_id}});
const auto& body = unwrap_rpc_result(response);
TaskStatus status;
from_json(response, status);
from_json(body, status);
return status;
}

/// Retrieve raw task result via MCP 'tasks/result' (tool/prompt/resource specific).
/// Callers are responsible for parsing into appropriate result type.
fastmcpp::Json get_task_result_raw(const std::string& task_id)
{
return call("tasks/result", {{"taskId", task_id}});
fastmcpp::Json response = call("tasks/result", {{"taskId", task_id}});
return unwrap_rpc_result(response);
}

/// List tasks via MCP 'tasks/list'. Returns raw JSON as provided by server.
Expand All @@ -406,16 +410,18 @@ class Client
params["cursor"] = *cursor;
if (limit > 0)
params["limit"] = limit;
return call("tasks/list", params);
fastmcpp::Json response = call("tasks/list", params);
return unwrap_rpc_result(response);
}

/// Cancel a background task via MCP 'tasks/cancel'. Returns final task status.
/// @throws fastmcpp::Error if task does not exist or server returns error
TaskStatus cancel_task(const std::string& task_id)
{
fastmcpp::Json response = call("tasks/cancel", {{"taskId", task_id}});
const auto& body = unwrap_rpc_result(response);
TaskStatus status;
from_json(response, status);
from_json(body, status);
return status;
}

Expand Down Expand Up @@ -705,9 +711,10 @@ class Client
void poll_notifications()
{
auto response = call("notifications/poll", fastmcpp::Json::object());
if (!response.contains("notifications") || !response["notifications"].is_array())
const auto& body = unwrap_rpc_result(response);
if (!body.contains("notifications") || !body["notifications"].is_array())
return;
for (const auto& n : response["notifications"])
for (const auto& n : body["notifications"])
{
if (!n.contains("method"))
continue;
Expand Down Expand Up @@ -916,36 +923,66 @@ class Client
return value;
}

const fastmcpp::Json& unwrap_rpc_result(const fastmcpp::Json& response)
{
if (!response.is_object())
return response;

if (response.contains("error"))
{
if (response["error"].is_object())
{
const auto& error = response["error"];
std::string message = error.value("message", "Unknown JSON-RPC error");
if (error.contains("code") && error["code"].is_number_integer())
{
throw fastmcpp::Error("JSON-RPC error (" +
std::to_string(error["code"].get<int>()) + "): " +
message);
}
throw fastmcpp::Error("JSON-RPC error: " + message);
}
throw fastmcpp::Error("JSON-RPC error: " + response["error"].dump());
}

if (response.contains("result"))
return response["result"];

return response;
}

ListToolsResult parse_list_tools_result(const fastmcpp::Json& response)
{
const auto& body = unwrap_rpc_result(response);
ListToolsResult result;
if (response.contains("tools"))
for (const auto& t : response["tools"])
if (body.contains("tools"))
for (const auto& t : body["tools"])
result.tools.push_back(t.get<ToolInfo>());
if (response.contains("nextCursor"))
result.nextCursor = response["nextCursor"].get<std::string>();
if (response.contains("_meta"))
result._meta = response["_meta"];
if (body.contains("nextCursor"))
result.nextCursor = body["nextCursor"].get<std::string>();
if (body.contains("_meta"))
result._meta = body["_meta"];
return result;
}

CallToolResult parse_call_tool_result(const fastmcpp::Json& response,
const std::string& tool_name)
{
const auto& body = unwrap_rpc_result(response);

CallToolResult result;
result.isError = response.value("isError", false);
result.isError = body.value("isError", false);

if (!response.contains("content"))
if (!body.contains("content"))
throw fastmcpp::ValidationError("tools/call response missing content");

if (response.contains("content"))
for (const auto& c : response["content"])
if (body.contains("content"))
for (const auto& c : body["content"])
result.content.push_back(parse_content_block(c));

if (response.contains("structuredContent"))
if (body.contains("structuredContent"))
{
result.structuredContent = response["structuredContent"];
result.structuredContent = body["structuredContent"];
// Try to provide a convenient data view similar to Python
auto structured = *result.structuredContent;
auto it = tool_output_schemas_.find(tool_name);
Expand Down Expand Up @@ -1011,31 +1048,33 @@ class Client
}
}

if (response.contains("_meta"))
result.meta = response["_meta"];
if (body.contains("_meta"))
result.meta = body["_meta"];

return result;
}

ListResourcesResult parse_list_resources_result(const fastmcpp::Json& response)
{
const auto& body = unwrap_rpc_result(response);
ListResourcesResult result;
if (response.contains("resources"))
for (const auto& r : response["resources"])
if (body.contains("resources"))
for (const auto& r : body["resources"])
result.resources.push_back(r.get<ResourceInfo>());
if (response.contains("nextCursor"))
result.nextCursor = response["nextCursor"].get<std::string>();
if (response.contains("_meta"))
result._meta = response["_meta"];
if (body.contains("nextCursor"))
result.nextCursor = body["nextCursor"].get<std::string>();
if (body.contains("_meta"))
result._meta = body["_meta"];
return result;
}

ListResourceTemplatesResult parse_list_resource_templates_result(const fastmcpp::Json& response)
{
const auto& body = unwrap_rpc_result(response);
ListResourceTemplatesResult result;
if (response.contains("resourceTemplates"))
if (body.contains("resourceTemplates"))
{
for (const auto& r : response["resourceTemplates"])
for (const auto& r : body["resourceTemplates"])
{
ResourceTemplate rt;
rt.uriTemplate = r.at("uriTemplate").get<std::string>();
Expand Down Expand Up @@ -1071,45 +1110,48 @@ class Client
result.resourceTemplates.push_back(rt);
}
}
if (response.contains("nextCursor"))
result.nextCursor = response["nextCursor"].get<std::string>();
if (response.contains("_meta"))
result._meta = response["_meta"];
if (body.contains("nextCursor"))
result.nextCursor = body["nextCursor"].get<std::string>();
if (body.contains("_meta"))
result._meta = body["_meta"];
return result;
}

ReadResourceResult parse_read_resource_result(const fastmcpp::Json& response)
{
const auto& body = unwrap_rpc_result(response);
ReadResourceResult result;
if (response.contains("contents"))
for (const auto& c : response["contents"])
if (body.contains("contents"))
for (const auto& c : body["contents"])
result.contents.push_back(parse_resource_content(c));
if (response.contains("_meta"))
result._meta = response["_meta"];
if (body.contains("_meta"))
result._meta = body["_meta"];
return result;
}

ListPromptsResult parse_list_prompts_result(const fastmcpp::Json& response)
{
const auto& body = unwrap_rpc_result(response);
ListPromptsResult result;
if (response.contains("prompts"))
for (const auto& p : response["prompts"])
if (body.contains("prompts"))
for (const auto& p : body["prompts"])
result.prompts.push_back(p.get<PromptInfo>());
if (response.contains("nextCursor"))
result.nextCursor = response["nextCursor"].get<std::string>();
if (response.contains("_meta"))
result._meta = response["_meta"];
if (body.contains("nextCursor"))
result.nextCursor = body["nextCursor"].get<std::string>();
if (body.contains("_meta"))
result._meta = body["_meta"];
return result;
}

GetPromptResult parse_get_prompt_result(const fastmcpp::Json& response)
{
const auto& body = unwrap_rpc_result(response);
GetPromptResult result;
if (response.contains("description"))
result.description = response["description"].get<std::string>();
if (response.contains("messages"))
if (body.contains("description"))
result.description = body["description"].get<std::string>();
if (body.contains("messages"))
{
for (const auto& m : response["messages"])
for (const auto& m : body["messages"])
{
PromptMessage msg;
std::string role = m.at("role").get<std::string>();
Expand All @@ -1136,37 +1178,39 @@ class Client
result.messages.push_back(msg);
}
}
if (response.contains("_meta"))
result._meta = response["_meta"];
if (body.contains("_meta"))
result._meta = body["_meta"];
return result;
}

CompleteResult parse_complete_result(const fastmcpp::Json& response)
{
const auto& body = unwrap_rpc_result(response);
CompleteResult result;
if (response.contains("completion"))
if (body.contains("completion"))
{
const auto& c = response["completion"];
const auto& c = body["completion"];
if (c.contains("values"))
for (const auto& v : c["values"])
result.completion.values.push_back(v.get<std::string>());
if (c.contains("total"))
result.completion.total = c["total"].get<int>();
result.completion.hasMore = c.value("hasMore", false);
}
if (response.contains("_meta"))
result._meta = response["_meta"];
if (body.contains("_meta"))
result._meta = body["_meta"];
return result;
}

InitializeResult parse_initialize_result(const fastmcpp::Json& response)
{
const auto& body = unwrap_rpc_result(response);
InitializeResult result;
result.protocolVersion = response.value("protocolVersion", "2024-11-05");
result.protocolVersion = body.value("protocolVersion", "2024-11-05");

if (response.contains("capabilities"))
if (body.contains("capabilities"))
{
const auto& caps = response["capabilities"];
const auto& caps = body["capabilities"];
if (caps.contains("experimental"))
result.capabilities.experimental = caps["experimental"];
if (caps.contains("logging"))
Expand All @@ -1185,17 +1229,17 @@ class Client
result.capabilities.extensions = caps["extensions"];
}

if (response.contains("serverInfo"))
if (body.contains("serverInfo"))
{
result.serverInfo.name = response["serverInfo"].value("name", "unknown");
result.serverInfo.version = response["serverInfo"].value("version", "unknown");
result.serverInfo.name = body["serverInfo"].value("name", "unknown");
result.serverInfo.version = body["serverInfo"].value("version", "unknown");
}

if (response.contains("instructions"))
result.instructions = response["instructions"].get<std::string>();
if (body.contains("instructions"))
result.instructions = body["instructions"].get<std::string>();

if (response.contains("_meta"))
result._meta = response["_meta"];
if (body.contains("_meta"))
result._meta = body["_meta"];

return result;
}
Expand Down Expand Up @@ -1588,18 +1632,19 @@ inline std::shared_ptr<ResourceTask> Client::read_resource_task(const std::strin
payload["_meta"] = *propagated_meta;

auto response = call("resources/read", payload);
const auto& body = unwrap_rpc_result(response);

if (response.contains("_meta") && response["_meta"].contains("modelcontextprotocol.io/task"))
if (body.contains("_meta") && body["_meta"].contains("modelcontextprotocol.io/task"))
{
const auto& task_obj = response["_meta"]["modelcontextprotocol.io/task"];
const auto& task_obj = body["_meta"]["modelcontextprotocol.io/task"];
if (task_obj.contains("taskId"))
{
std::string task_id = task_obj["taskId"].get<std::string>();
return std::make_shared<ResourceTask>(this, std::move(task_id), uri, std::nullopt);
}
}

ReadResourceResult result = parse_read_resource_result(response);
ReadResourceResult result = parse_read_resource_result(body);
return std::make_shared<ResourceTask>(this, std::string{}, uri, std::move(result.contents));
}

Expand All @@ -1625,18 +1670,19 @@ Client::get_prompt_task(const std::string& name, const fastmcpp::Json& arguments
payload["_meta"] = *propagated_meta;

auto response = call("prompts/get", payload);
const auto& body = unwrap_rpc_result(response);

if (response.contains("_meta") && response["_meta"].contains("modelcontextprotocol.io/task"))
if (body.contains("_meta") && body["_meta"].contains("modelcontextprotocol.io/task"))
{
const auto& task_obj = response["_meta"]["modelcontextprotocol.io/task"];
const auto& task_obj = body["_meta"]["modelcontextprotocol.io/task"];
if (task_obj.contains("taskId"))
{
std::string task_id = task_obj["taskId"].get<std::string>();
return std::make_shared<PromptTask>(this, std::move(task_id), name, std::nullopt);
}
}

GetPromptResult result = parse_get_prompt_result(response);
GetPromptResult result = parse_get_prompt_result(body);
return std::make_shared<PromptTask>(this, std::string{}, name, std::move(result));
}

Expand Down
Loading
Loading