diff --git a/mcp/auth_error.go b/mcp/auth_error.go new file mode 100644 index 0000000..bf758f7 --- /dev/null +++ b/mcp/auth_error.go @@ -0,0 +1,19 @@ +package mcp + +import "fmt" + +// AuthError wraps a 401 from a remote MCP server with structured metadata +// for fc-safari to translate into an authorization_required tool result. +type AuthError struct { + StatusCode int + WWWAuthenticate string + ResourceMetadataURL string + Underlying error +} + +func (e *AuthError) Error() string { + return fmt.Sprintf("mcp auth error: status=%d www-authenticate=%q resource-metadata-url=%q", + e.StatusCode, e.WWWAuthenticate, e.ResourceMetadataURL) +} + +func (e *AuthError) Unwrap() error { return e.Underlying } diff --git a/mcp/transport.go b/mcp/transport.go index 3bc7e46..ed98190 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -64,6 +64,15 @@ type headerTransport struct { base http.RoundTripper } +// newHeaderTransport creates a headerTransport with the given static headers +// and http.DefaultTransport as the base. Used by tests and NewSSETransport. +func newHeaderTransport(headers map[string]string) *headerTransport { + return &headerTransport{ + headers: headers, + base: http.DefaultTransport, + } +} + func (t *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) { for k, v := range t.headers { req.Header.Set(k, v) @@ -111,5 +120,32 @@ func (t *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) { "status", resp.StatusCode, ) + // Surface 401 as a structured AuthError so fc-safari can route it to the + // per-user auth flow instead of treating it as an opaque failure string. + if resp.StatusCode == 401 { + www := resp.Header.Get("WWW-Authenticate") + return resp, &AuthError{ + StatusCode: resp.StatusCode, + WWWAuthenticate: www, + ResourceMetadataURL: parseResourceMetadataURL(www), + } + } + return resp, nil } + +// parseResourceMetadataURL extracts the `resource_metadata="..."` param +// from a Bearer challenge per RFC 6750 ยง3. +func parseResourceMetadataURL(wwwAuth string) string { + const key = `resource_metadata="` + i := strings.Index(wwwAuth, key) + if i < 0 { + return "" + } + rest := wwwAuth[i+len(key):] + j := strings.IndexByte(rest, '"') + if j < 0 { + return "" + } + return rest[:j] +} diff --git a/mcp/transport_auth_test.go b/mcp/transport_auth_test.go new file mode 100644 index 0000000..70a7ed7 --- /dev/null +++ b/mcp/transport_auth_test.go @@ -0,0 +1,38 @@ +package mcp + +import ( + "errors" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestRoundTripWrapsAuthError(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("WWW-Authenticate", `Bearer error="invalid_token", resource_metadata="https://ex/.well-known/oauth-protected-resource"`) + w.WriteHeader(401) + w.Write([]byte("unauth")) //nolint:errcheck + })) + defer ts.Close() + + transport := newHeaderTransport(map[string]string{}) + req, _ := http.NewRequest("GET", ts.URL, nil) + _, err := transport.RoundTrip(req) + if err == nil { + t.Fatal("expected error for 401") + } + var ae *AuthError + if !errors.As(err, &ae) { + t.Fatalf("expected *AuthError, got %T: %v", err, err) + } + if ae.StatusCode != 401 { + t.Fatalf("status: got %d", ae.StatusCode) + } + if !strings.Contains(ae.WWWAuthenticate, "invalid_token") { + t.Fatalf("WWWAuthenticate not captured: %q", ae.WWWAuthenticate) + } + if !strings.Contains(ae.ResourceMetadataURL, "/.well-known/") { + t.Fatalf("ResourceMetadataURL not captured: %q", ae.ResourceMetadataURL) + } +} diff --git a/protocol/messages.go b/protocol/messages.go index 696bb75..f5be7db 100644 --- a/protocol/messages.go +++ b/protocol/messages.go @@ -399,6 +399,18 @@ type MCPResultPayload struct { Success bool `json:"success"` Result json.RawMessage `json:"result,omitempty"` Error string `json:"error,omitempty"` + + // P1: structured auth-error metadata (additive; old senders omit). + ErrorType string `json:"error_type,omitempty"` // "auth_error" | "" + AuthMetadata *AuthErrorMeta `json:"auth_metadata,omitempty"` // non-nil iff ErrorType=="auth_error" +} + +// AuthErrorMeta carries the 401 challenge details so fc-safari can initiate +// the per-user OAuth flow without re-parsing opaque error strings. +type AuthErrorMeta struct { + StatusCode int `json:"status_code,omitempty"` + WWWAuthenticate string `json:"www_authenticate,omitempty"` + ResourceMetadataURL string `json:"resource_metadata_url,omitempty"` } // KnowledgeFile is a single file entry in a stage_knowledge_files request. diff --git a/ws/handler.go b/ws/handler.go index 7d8a06f..a5b2d67 100644 --- a/ws/handler.go +++ b/ws/handler.go @@ -3,12 +3,14 @@ package ws import ( "context" "encoding/json" + "errors" "fmt" "log/slog" "sync" "time" "github.com/flashcatcloud/flashduty-runner/environment" + "github.com/flashcatcloud/flashduty-runner/mcp" "github.com/flashcatcloud/flashduty-runner/protocol" ) @@ -311,6 +313,22 @@ func (h *Handler) sendMCPResult(callID string, success bool, result any, errMsg }) } +// buildMCPResultPayload constructs a failure payload, enriching it with +// structured AuthError metadata when the error chain contains *mcp.AuthError. +func buildMCPResultPayload(callID string, err error) protocol.MCPResultPayload { + p := protocol.MCPResultPayload{CallID: callID, Success: false, Error: err.Error()} + var ae *mcp.AuthError + if errors.As(err, &ae) { + p.ErrorType = "auth_error" + p.AuthMetadata = &protocol.AuthErrorMeta{ + StatusCode: ae.StatusCode, + WWWAuthenticate: ae.WWWAuthenticate, + ResourceMetadataURL: ae.ResourceMetadataURL, + } + } + return p +} + func (h *Handler) sendPayload(msgType protocol.MessageType, payload any) { if h.client == nil { slog.Error("client not set, cannot send message", "type", msgType) @@ -364,7 +382,8 @@ func (h *Handler) handleMCPCall(ctx context.Context, msg *protocol.Message) erro }, logger) if err != nil { logger.Error("mcp call failed", "error", err) - h.sendMCPResult(payload.CallID, false, nil, err.Error()) + p := buildMCPResultPayload(payload.CallID, err) + h.sendPayload(protocol.MessageTypeMCPResult, p) } else { logger.Info("mcp call completed") h.sendMCPResult(payload.CallID, true, result, "")