Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 49 additions & 6 deletions pkg/github/ui_tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@ import (
"github.com/shurcooL/githubv4"
)

// uiGetMaxPages bounds how many pages each ui_get pagination loop will fetch.
// ui_get backs a synchronous UI picker (dropdowns for labels/assignees/etc. in
// the MCP App issue/PR write surfaces), so responsiveness matters more than
// completeness. At PerPage 100 this caps a call at ~1000 items and a bounded
// number of API round-trips, keeping latency predictable on very large
// repos/orgs. Results past the cap are truncated and surfaced via a "has_more"
// flag, which is acceptable because the picker pairs truncation with typeahead.
const uiGetMaxPages = 10

// UIGet creates a tool to fetch UI data for MCP Apps.
func UIGet(t translations.TranslationHelperFunc) inventory.ServerTool {
st := NewTool(
Expand Down Expand Up @@ -131,7 +140,8 @@ func uiGetLabels(ctx context.Context, deps ToolDependencies, args map[string]any

labels := make([]map[string]any, 0)
var totalCount int
for {
hasMore := false
for page := 1; ; page++ {
if err := client.Query(ctx, &query, vars); err != nil {
return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to list labels", err), nil, nil
}
Expand All @@ -147,12 +157,17 @@ func uiGetLabels(ctx context.Context, deps ToolDependencies, args map[string]any
if !query.Repository.Labels.PageInfo.HasNextPage {
break
}
if page >= uiGetMaxPages {
hasMore = true
break
}
vars["cursor"] = githubv4.NewString(query.Repository.Labels.PageInfo.EndCursor)
}

response := map[string]any{
"labels": labels,
"totalCount": totalCount,
"has_more": hasMore,
}

out, err := json.Marshal(response)
Expand All @@ -176,8 +191,9 @@ func uiGetAssignees(ctx context.Context, deps ToolDependencies, args map[string]

opts := &github.ListOptions{PerPage: 100}
var allAssignees []*github.User
hasMore := false

for {
for page := 1; ; page++ {
assignees, resp, err := client.Issues.ListAssignees(ctx, owner, repo, opts)
if err != nil {
return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to list assignees", resp, err), nil, nil
Expand All @@ -189,6 +205,10 @@ func uiGetAssignees(ctx context.Context, deps ToolDependencies, args map[string]
if resp.NextPage == 0 {
break
}
if page >= uiGetMaxPages {
hasMore = true
break
}
opts.Page = resp.NextPage
}

Expand All @@ -203,6 +223,7 @@ func uiGetAssignees(ctx context.Context, deps ToolDependencies, args map[string]
out, err := json.Marshal(map[string]any{
"assignees": result,
"totalCount": len(result),
"has_more": hasMore,
})
if err != nil {
return utils.NewToolResultErrorFromErr("failed to marshal assignees", err), nil, nil
Expand All @@ -228,7 +249,8 @@ func uiGetMilestones(ctx context.Context, deps ToolDependencies, args map[string
}

var allMilestones []*github.Milestone
for {
hasMore := false
for page := 1; ; page++ {
milestones, resp, err := client.Issues.ListMilestones(ctx, owner, repo, opts)
if err != nil {
return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to list milestones", resp, err), nil, nil
Expand All @@ -240,6 +262,10 @@ func uiGetMilestones(ctx context.Context, deps ToolDependencies, args map[string
if resp.NextPage == 0 {
break
}
if page >= uiGetMaxPages {
hasMore = true
break
}
opts.Page = resp.NextPage
}

Expand All @@ -262,6 +288,7 @@ func uiGetMilestones(ctx context.Context, deps ToolDependencies, args map[string
out, err := json.Marshal(map[string]any{
"milestones": result,
"totalCount": len(result),
"has_more": hasMore,
})
if err != nil {
return utils.NewToolResultErrorFromErr("failed to marshal milestones", err), nil, nil
Expand Down Expand Up @@ -314,7 +341,8 @@ func uiGetBranches(ctx context.Context, deps ToolDependencies, args map[string]a
}

var allBranches []*github.Branch
for {
hasMore := false
for page := 1; ; page++ {
branches, resp, err := client.Repositories.ListBranches(ctx, owner, repo, opts)
if err != nil {
return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to list branches", resp, err), nil, nil
Expand All @@ -326,6 +354,10 @@ func uiGetBranches(ctx context.Context, deps ToolDependencies, args map[string]a
if resp.NextPage == 0 {
break
}
if page >= uiGetMaxPages {
hasMore = true
break
}
opts.Page = resp.NextPage
}

Expand All @@ -337,6 +369,7 @@ func uiGetBranches(ctx context.Context, deps ToolDependencies, args map[string]a
r, err := json.Marshal(map[string]any{
"branches": minimalBranches,
"totalCount": len(minimalBranches),
"has_more": hasMore,
})
if err != nil {
return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil
Expand Down Expand Up @@ -446,7 +479,8 @@ func uiGetReviewers(ctx context.Context, deps ToolDependencies, args map[string]
ListOptions: github.ListOptions{PerPage: 100},
}
var allCollaborators []*github.User
for {
hasMore := false
for page := 1; ; page++ {
collaborators, resp, err := client.Repositories.ListCollaborators(ctx, owner, repo, collaboratorOpts)
if err != nil {
return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to list reviewers", resp, err), nil, nil
Expand All @@ -458,12 +492,16 @@ func uiGetReviewers(ctx context.Context, deps ToolDependencies, args map[string]
if resp.NextPage == 0 {
break
}
if page >= uiGetMaxPages {
hasMore = true
break
}
collaboratorOpts.Page = resp.NextPage
}

teamOpts := &github.ListOptions{PerPage: 100}
var allTeams []*github.Team
for {
for page := 1; ; page++ {
teams, resp, err := client.Repositories.ListTeams(ctx, owner, repo, teamOpts)
if err != nil {
return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to list reviewer teams", resp, err), nil, nil
Expand All @@ -475,6 +513,10 @@ func uiGetReviewers(ctx context.Context, deps ToolDependencies, args map[string]
if resp.NextPage == 0 {
break
}
if page >= uiGetMaxPages {
hasMore = true
break
}
teamOpts.Page = resp.NextPage
}

Expand Down Expand Up @@ -503,6 +545,7 @@ func uiGetReviewers(ctx context.Context, deps ToolDependencies, args map[string]
"users": users,
"teams": teams,
"totalCount": len(users) + len(teams),
"has_more": hasMore,
})
if err != nil {
return utils.NewToolResultErrorFromErr("failed to marshal reviewers", err), nil, nil
Expand Down
142 changes: 142 additions & 0 deletions pkg/github/ui_tools_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ package github
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strconv"
"testing"
"time"

Expand All @@ -17,6 +20,78 @@ import (
"github.com/stretchr/testify/require"
)

// recorderTransport routes HTTP requests through an in-process handler, mirroring
// internal/githubv4mock's own transport. We need it because githubv4mock keys its
// matchers by query string, so it cannot model a multi-page labels query: every
// page issues the identical query and differs only by the $cursor variable. This
// transport lets a single handler answer each page dynamically.
type recorderTransport struct{ handler http.Handler }

func (rt recorderTransport) RoundTrip(req *http.Request) (*http.Response, error) {
rec := httptest.NewRecorder()
rt.handler.ServeHTTP(rec, req)
return rec.Result(), nil
}

// alwaysHasNextPageLabelsClient returns a GraphQL client whose labels query always
// reports another page, advancing the cursor on each call. It exercises uiGetLabels'
// page cap: the loop fetches one label per page until it stops at uiGetMaxPages with
// has_more=true. totalCount is reported as a large server-side count so the test can
// confirm it stays the full repo count even when results are truncated.
func alwaysHasNextPageLabelsClient(t *testing.T) *http.Client {
t.Helper()
var calls int
mux := http.NewServeMux()
mux.HandleFunc("/graphql", func(w http.ResponseWriter, _ *http.Request) {
calls++
resp := map[string]any{
"data": map[string]any{
"repository": map[string]any{
"labels": map[string]any{
"nodes": []any{
map[string]any{
"id": fmt.Sprintf("label-%d", calls),
"name": fmt.Sprintf("label-%d", calls),
"color": "ededed",
"description": "",
},
},
"totalCount": 9999,
"pageInfo": map[string]any{
"hasNextPage": true,
"endCursor": fmt.Sprintf("cursor-%d", calls),
},
},
},
},
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(resp)
})
return &http.Client{Transport: recorderTransport{handler: mux}}
}

// alwaysNextPageHandler returns a REST handler that always advertises another page
// via the Link header, regardless of the page requested. It drives a pagination loop
// purely off the page cap so tests can assert ui_get stops at uiGetMaxPages and sets
// has_more=true. The same body is returned for every page, so the number of items
// collected equals the number of pages fetched.
func alwaysNextPageHandler(t *testing.T, body any) http.HandlerFunc {
t.Helper()
return func(w http.ResponseWriter, r *http.Request) {
page := 1
if p := r.URL.Query().Get("page"); p != "" {
if parsed, err := strconv.Atoi(p); err == nil {
page = parsed
}
}
w.Header().Set("Link", fmt.Sprintf(`<https://api.github.com/next?page=%d>; rel="next"`, page+1))
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(body)
}
}

func Test_UIGet(t *testing.T) {
// Verify tool definition
serverTool := UIGet(translations.NullTranslationHelper)
Expand Down Expand Up @@ -95,6 +170,7 @@ func Test_UIGet(t *testing.T) {
require.NoError(t, json.Unmarshal([]byte(responseText), &response))
assert.Contains(t, response, "assignees")
assert.Contains(t, response, "totalCount")
assert.Equal(t, false, response["has_more"], "results within the page cap should not be truncated")
},
},
{
Expand All @@ -113,6 +189,7 @@ func Test_UIGet(t *testing.T) {
require.NoError(t, json.Unmarshal([]byte(responseText), &response))
assert.Contains(t, response, "branches")
assert.Contains(t, response, "totalCount")
assert.Equal(t, false, response["has_more"], "results within the page cap should not be truncated")
},
},
{
Expand Down Expand Up @@ -228,6 +305,7 @@ func Test_UIGet(t *testing.T) {
require.Len(t, labels, 1)
assert.Equal(t, "bug", labels[0].(map[string]any)["name"])
assert.Equal(t, float64(1), response["totalCount"])
assert.Equal(t, false, response["has_more"], "results within the page cap should not be truncated")
},
},
{
Expand Down Expand Up @@ -300,6 +378,70 @@ func Test_UIGet(t *testing.T) {
assert.Equal(t, "docs", teams[0].(map[string]any)["slug"])
assert.Equal(t, "owner", teams[0].(map[string]any)["org"])
assert.Equal(t, float64(2), response["totalCount"])
assert.Equal(t, false, response["has_more"], "results within the page cap should not be truncated")
},
},
{
name: "branches pagination stops at the page cap",
mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{
"GET /repos/owner/repo/branches": alwaysNextPageHandler(t, []*github.Branch{{Name: github.Ptr("feature")}}),
}),
requestArgs: map[string]any{
"method": "branches",
"owner": "owner",
"repo": "repo",
},
expectError: false,
validateResult: func(t *testing.T, responseText string) {
var response map[string]any
require.NoError(t, json.Unmarshal([]byte(responseText), &response))
branches, ok := response["branches"].([]any)
require.True(t, ok, "branches should be a list")
assert.Len(t, branches, uiGetMaxPages, "loop should stop at the page cap")
assert.Equal(t, float64(uiGetMaxPages), response["totalCount"], "totalCount should be the bounded count")
assert.Equal(t, true, response["has_more"], "truncated results should set has_more")
},
},
{
name: "labels pagination stops at the page cap",
mockedGQLClient: alwaysHasNextPageLabelsClient(t),
requestArgs: map[string]any{
"method": "labels",
"owner": "owner",
"repo": "repo",
},
expectError: false,
validateResult: func(t *testing.T, responseText string) {
var response map[string]any
require.NoError(t, json.Unmarshal([]byte(responseText), &response))
labels, ok := response["labels"].([]any)
require.True(t, ok, "labels should be a list")
assert.Len(t, labels, uiGetMaxPages, "loop should stop at the page cap")
assert.Equal(t, true, response["has_more"], "truncated results should set has_more")
// totalCount stays the server-reported full count, so it can exceed
// the number of labels returned once results are truncated.
assert.Equal(t, float64(9999), response["totalCount"])
},
},
{
name: "reviewers pagination stops at the page cap",
mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{
"GET /repos/owner/repo/collaborators": alwaysNextPageHandler(t, []*github.User{{Login: github.Ptr("octocat")}}),
"GET /repos/owner/repo/teams": mockResponse(t, http.StatusOK, mockReviewerTeams),
}),
requestArgs: map[string]any{
"method": "reviewers",
"owner": "owner",
"repo": "repo",
},
expectError: false,
validateResult: func(t *testing.T, responseText string) {
var response map[string]any
require.NoError(t, json.Unmarshal([]byte(responseText), &response))
users, ok := response["users"].([]any)
require.True(t, ok, "users should be a list")
assert.Len(t, users, uiGetMaxPages, "collaborators loop should stop at the page cap")
assert.Equal(t, true, response["has_more"], "truncating either loop should set has_more")
},
},
{
Expand Down
Loading