Skip to content

Combine tools #349

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
243 changes: 243 additions & 0 deletions pkg/github/notifications.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
package github

import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strconv"
"time"

"github.com/github/github-mcp-server/pkg/translations"
"github.com/google/go-github/v69/github"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
)

// getNotifications creates a tool to list notifications for the current user.
func GetNotifications(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
return mcp.NewTool("get_notifications",
mcp.WithDescription(t("TOOL_GET_NOTIFICATIONS_DESCRIPTION", "Get notifications for the authenticated GitHub user")),
mcp.WithBoolean("all",
mcp.Description("If true, show notifications marked as read. Default: false"),
),
mcp.WithBoolean("participating",
mcp.Description("If true, only shows notifications in which the user is directly participating or mentioned. Default: false"),
),
mcp.WithString("since",
mcp.Description("Only show notifications updated after the given time (ISO 8601 format)"),
),
mcp.WithString("before",
mcp.Description("Only show notifications updated before the given time (ISO 8601 format)"),
),
mcp.WithNumber("per_page",
mcp.Description("Results per page (max 100). Default: 30"),
),
mcp.WithNumber("page",
mcp.Description("Page number of the results to fetch. Default: 1"),
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
client, err := getClient(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
}

// Extract optional parameters with defaults
all, err := OptionalBoolParamWithDefault(request, "all", false)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}

participating, err := OptionalBoolParamWithDefault(request, "participating", false)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}

since, err := OptionalStringParamWithDefault(request, "since", "")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}

before, err := OptionalStringParam(request, "before")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}

perPage, err := OptionalIntParamWithDefault(request, "per_page", 30)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}

page, err := OptionalIntParamWithDefault(request, "page", 1)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}

// Build options
opts := &github.NotificationListOptions{
All: all,
Participating: participating,
ListOptions: github.ListOptions{
Page: page,
PerPage: perPage,
},
}

// Parse time parameters if provided
if since != "" {
sinceTime, err := time.Parse(time.RFC3339, since)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("invalid since time format, should be RFC3339/ISO8601: %v", err)), nil
}
opts.Since = sinceTime
}

if before != "" {
beforeTime, err := time.Parse(time.RFC3339, before)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("invalid before time format, should be RFC3339/ISO8601: %v", err)), nil
}
opts.Before = beforeTime
}

// Call GitHub API
notifications, resp, err := client.Activity.ListNotifications(ctx, opts)
if err != nil {
return nil, fmt.Errorf("failed to get notifications: %w", err)
}
defer func() { _ = resp.Body.Close() }()

if resp.StatusCode != http.StatusOK {
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
return mcp.NewToolResultError(fmt.Sprintf("failed to get notifications: %s", string(body))), nil
}

// Marshal response to JSON
r, err := json.Marshal(notifications)
if err != nil {
return nil, fmt.Errorf("failed to marshal response: %w", err)
}

return mcp.NewToolResultText(string(r)), nil
}
}

// ManageNotifications creates a tool to manage notifications (mark as read, mark all as read, or mark as done).
func ManageNotifications(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
return mcp.NewTool("manage_notifications",
mcp.WithDescription(t("TOOL_MANAGE_NOTIFICATIONS_DESCRIPTION", "Manage notifications (mark as read, mark all as read, or mark as done)")),
mcp.WithString("action",
mcp.Required(),
mcp.Description("The action to perform: 'mark_read', 'mark_all_read', or 'mark_done'"),
),
mcp.WithString("threadID",
mcp.Description("The ID of the notification thread (required for 'mark_read' and 'mark_done')"),
),
mcp.WithString("lastReadAt",
mcp.Description("Describes the last point that notifications were checked (optional, for 'mark_all_read'). Default: Now"),
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
client, err := getClient(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
}

action, err := requiredParam[string](request, "action")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}

switch action {
case "mark_read":
Copy link
Preview

Copilot AI Apr 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 'mark_read' action does not call any API to mark a notification as read; it directly returns a success message. Please add the appropriate API call to actually mark the notification as read.

Copilot uses AI. Check for mistakes.

threadID, err := requiredParam[string](request, "threadID")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}

resp, err := client.Activity.MarkThreadRead(ctx, threadID)
if err != nil {
return nil, fmt.Errorf("failed to mark notification as read: %w", err)
}
defer func() { _ = resp.Body.Close() }()

if resp.StatusCode != http.StatusResetContent && resp.StatusCode != http.StatusOK {
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
return mcp.NewToolResultError(fmt.Sprintf("failed to mark notification as read: %s", string(body))), nil
}

return mcp.NewToolResultText("Notification marked as read"), nil

case "mark_done":
threadIDStr, err := requiredParam[string](request, "threadID")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}

threadID, err := strconv.ParseInt(threadIDStr, 10, 64)
if err != nil {
return mcp.NewToolResultError("Invalid threadID: must be a numeric value"), nil
}

resp, err := client.Activity.MarkThreadDone(ctx, threadID)
if err != nil {
return nil, fmt.Errorf("failed to mark notification as done: %w", err)
}
defer func() { _ = resp.Body.Close() }()

if resp.StatusCode != http.StatusResetContent && resp.StatusCode != http.StatusOK {
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
return mcp.NewToolResultError(fmt.Sprintf("failed to mark all notifications as read: %s", string(body))), nil
}

return mcp.NewToolResultText("Notification marked as done"), nil

case "mark_all_read":
lastReadAt, err := OptionalStringParam(request, "lastReadAt")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}

var markReadOptions github.Timestamp
if lastReadAt != "" {
lastReadTime, err := time.Parse(time.RFC3339, lastReadAt)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("invalid lastReadAt time format, should be RFC3339/ISO8601: %v", err)), nil
}
markReadOptions = github.Timestamp{
Time: lastReadTime,
}
}

resp, err := client.Activity.MarkNotificationsRead(ctx, markReadOptions)
if err != nil {
return nil, fmt.Errorf("failed to mark all notifications as read: %w", err)
}
defer func() { _ = resp.Body.Close() }()

if resp.StatusCode != http.StatusResetContent && resp.StatusCode != http.StatusOK {
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
return mcp.NewToolResultError(fmt.Sprintf("failed to mark all notifications as read: %s", string(body))), nil
}

return mcp.NewToolResultText("All notifications marked as read"), nil

default:
return mcp.NewToolResultError("Invalid action: must be 'mark_read', 'mark_all_read', or 'mark_done'"), nil
}
}
}
41 changes: 41 additions & 0 deletions pkg/github/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,47 @@ func OptionalIntParamWithDefault(r mcp.CallToolRequest, p string, d int) (int, e
return v, nil
}

// OptionalBoolParamWithDefault is a helper function that can be used to fetch a requested parameter from the request
// similar to optionalParam, but it also takes a default value.
func OptionalBoolParamWithDefault(r mcp.CallToolRequest, p string, d bool) (bool, error) {
v, err := OptionalParam[bool](r, p)
if err != nil {
return false, err
}
if !v {
return d, nil
}
return v, nil
}

// OptionalStringParam is a helper function that can be used to fetch a requested parameter from the request.
// It does the following checks:
// 1. Checks if the parameter is present in the request, if not, it returns its zero-value
// 2. If it is present, it checks if the parameter is of the expected type and returns it
func OptionalStringParam(r mcp.CallToolRequest, p string) (string, error) {
v, err := OptionalParam[string](r, p)
if err != nil {
return "", err
}
if v == "" {
return "", nil
}
return v, nil
}

// OptionalStringParamWithDefault is a helper function that can be used to fetch a requested parameter from the request
// similar to optionalParam, but it also takes a default value.
func OptionalStringParamWithDefault(r mcp.CallToolRequest, p string, d string) (string, error) {
v, err := OptionalParam[string](r, p)
if err != nil {
return "", err
}
if v == "" {
return d, nil
}
return v, nil
}

// OptionalStringArrayParam is a helper function that can be used to fetch a requested parameter from the request.
// It does the following checks:
// 1. Checks if the parameter is present in the request, if not, it returns its zero-value
Expand Down
10 changes: 10 additions & 0 deletions pkg/github/tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,15 @@ func InitToolsets(passedToolsets []string, readOnly bool, getClient GetClientFn,
toolsets.NewServerTool(GetSecretScanningAlert(getClient, t)),
toolsets.NewServerTool(ListSecretScanningAlerts(getClient, t)),
)

notifications := toolsets.NewToolset("notifications", "GitHub Notifications related tools").
AddReadTools(
toolsets.NewServerTool(GetNotifications(getClient, t)),
).
AddWriteTools(
toolsets.NewServerTool(ManageNotifications(getClient, t)),
)

// Keep experiments alive so the system doesn't error out when it's always enabled
experiments := toolsets.NewToolset("experiments", "Experimental features that are not considered stable yet")

Expand All @@ -88,6 +97,7 @@ func InitToolsets(passedToolsets []string, readOnly bool, getClient GetClientFn,
tsg.AddToolset(pullRequests)
tsg.AddToolset(codeSecurity)
tsg.AddToolset(secretProtection)
tsg.AddToolset(notifications)
tsg.AddToolset(experiments)
// Enable the requested features

Expand Down