From a142f5daa9d4825783c5f61b24527f4178c88048 Mon Sep 17 00:00:00 2001 From: Paul Querna Date: Fri, 26 Jun 2026 05:49:17 +0000 Subject: [PATCH] fix: scope bearer auth to API host Only attach the OAuth bearer token when the outgoing request host matches the configured API host, and clear any copied Authorization header before untrusted redirect hops. Add uhttp coverage for same-host redirects, cross-host redirects, and copied Authorization headers on untrusted hosts. Co-authored-by: c1-squire-dev[bot] --- pkg/client/client.go | 2 +- pkg/uhttp/client.go | 5 +- pkg/uhttp/transport.go | 8 +- pkg/uhttp/transport_test.go | 154 ++++++++++++++++++++++++++++++++++++ 4 files changed, 166 insertions(+), 3 deletions(-) create mode 100644 pkg/uhttp/transport_test.go diff --git a/pkg/client/client.go b/pkg/client/client.go index a27d1842..3bca50e0 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -301,7 +301,7 @@ func newClientWithTokenSource( ) (C1Client, error) { uclient, err := uhttp.NewClient( ctx, - uhttp.WithTokenSource(tokenSrc), + uhttp.WithTokenSource(tokenSrc, tokenHost), uhttp.WithDebug(v.GetBool("debug")), uhttp.WithRequestSource(cmdName), ) diff --git a/pkg/uhttp/client.go b/pkg/uhttp/client.go index b70feb51..44502ce4 100644 --- a/pkg/uhttp/client.go +++ b/pkg/uhttp/client.go @@ -43,16 +43,19 @@ func WithLogger(log bool, logger *zap.Logger) Option { type tokenSourceOption struct { tokenSource oauth2.TokenSource + tokenHost string } func (t tokenSourceOption) Apply(c *Transport) { c.tokenSource = t.tokenSource + c.tokenHost = t.tokenHost } // WithTokenSource sets a token source option to the transport layer. -func WithTokenSource(tokenSource oauth2.TokenSource) Option { +func WithTokenSource(tokenSource oauth2.TokenSource, tokenHost string) Option { return tokenSourceOption{ tokenSource: tokenSource, + tokenHost: tokenHost, } } diff --git a/pkg/uhttp/transport.go b/pkg/uhttp/transport.go index 7b95a048..bf39320d 100644 --- a/pkg/uhttp/transport.go +++ b/pkg/uhttp/transport.go @@ -34,6 +34,7 @@ func NewTransport(ctx context.Context, options ...Option) (*Transport, error) { type Transport struct { userAgent string tokenSource oauth2.TokenSource + tokenHost string requestSource string tlsClientConfig *tls.Config roundTripper http.RoundTripper @@ -178,12 +179,17 @@ func (uat *debugTripper) RoundTrip(req *http.Request) (*http.Response, error) { type tokenSourceTripper struct { next http.RoundTripper tokenSource oauth2.TokenSource + tokenHost string } func (uts *tokenSourceTripper) RoundTrip(req *http.Request) (*http.Response, error) { if uts.tokenSource == nil { return uts.next.RoundTrip(req) } + if !strings.EqualFold(req.URL.Host, uts.tokenHost) { + req.Header.Del("Authorization") + return uts.next.RoundTrip(req) + } token, err := uts.tokenSource.Token() if err != nil { return nil, err @@ -216,7 +222,7 @@ func (t *Transport) make(ctx context.Context) (http.RoundTripper, error) { t.userAgent = fmt.Sprintf("%s cone", t.userAgent) rv = &debugTripper{next: rv, debug: t.debug} rv = &userAgentTripper{next: rv, userAgent: t.userAgent} - rv = &tokenSourceTripper{next: rv, tokenSource: t.tokenSource} + rv = &tokenSourceTripper{next: rv, tokenSource: t.tokenSource, tokenHost: t.tokenHost} rv = &requestSourceTripper{next: rv, requestSource: t.requestSource} return rv, nil } diff --git a/pkg/uhttp/transport_test.go b/pkg/uhttp/transport_test.go new file mode 100644 index 00000000..0e4b6cce --- /dev/null +++ b/pkg/uhttp/transport_test.go @@ -0,0 +1,154 @@ +package uhttp + +import ( + "context" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "golang.org/x/oauth2" +) + +func TestTokenSourceTripperKeepsAuthOnTrustedHostRedirect(t *testing.T) { + const bearerToken = "test-token" + + var startAuth string + var redirectedAuth string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/start": + startAuth = r.Header.Get("Authorization") + http.Redirect(w, r, "/next", http.StatusFound) + case "/next": + redirectedAuth = r.Header.Get("Authorization") + w.WriteHeader(http.StatusNoContent) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + serverURL, err := url.Parse(server.URL) + if err != nil { + t.Fatal(err) + } + client, err := NewClient( + context.Background(), + WithTokenSource(staticBearerTokenSource(bearerToken), serverURL.Host), + ) + if err != nil { + t.Fatal(err) + } + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, server.URL+"/start", nil) + if err != nil { + t.Fatal(err) + } + resp, err := client.Do(req) + if err != nil { + t.Fatal(err) + } + defer func() { _ = resp.Body.Close() }() + + expectedAuth := "Bearer " + bearerToken + if startAuth != expectedAuth { + t.Fatalf("start Authorization header = %q, want %q", startAuth, expectedAuth) + } + if redirectedAuth != expectedAuth { + t.Fatalf("redirected Authorization header = %q, want %q", redirectedAuth, expectedAuth) + } +} + +func TestTokenSourceTripperSkipsAuthOnCrossHostRedirect(t *testing.T) { + const bearerToken = "test-token" + + var trustedAuth string + var redirectedAuth string + redirectedServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + redirectedAuth = r.Header.Get("Authorization") + w.WriteHeader(http.StatusNoContent) + })) + defer redirectedServer.Close() + + trustedServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + trustedAuth = r.Header.Get("Authorization") + http.Redirect(w, r, redirectedServer.URL, http.StatusFound) + })) + defer trustedServer.Close() + + trustedURL, err := url.Parse(trustedServer.URL) + if err != nil { + t.Fatal(err) + } + client, err := NewClient( + context.Background(), + WithTokenSource(staticBearerTokenSource(bearerToken), trustedURL.Host), + ) + if err != nil { + t.Fatal(err) + } + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, trustedServer.URL, nil) + if err != nil { + t.Fatal(err) + } + resp, err := client.Do(req) + if err != nil { + t.Fatal(err) + } + defer func() { _ = resp.Body.Close() }() + + expectedAuth := "Bearer " + bearerToken + if trustedAuth != expectedAuth { + t.Fatalf("trusted Authorization header = %q, want %q", trustedAuth, expectedAuth) + } + if redirectedAuth != "" { + t.Fatalf("cross-host redirected Authorization header = %q, want empty", redirectedAuth) + } +} + +func TestTokenSourceTripperRemovesAuthOnUntrustedHost(t *testing.T) { + var auth string + tripper := &tokenSourceTripper{ + next: roundTripFunc(func(req *http.Request) (*http.Response, error) { + auth = req.Header.Get("Authorization") + return &http.Response{ + StatusCode: http.StatusNoContent, + Body: http.NoBody, + Header: make(http.Header), + Request: req, + }, nil + }), + tokenSource: staticBearerTokenSource("fresh-token"), + tokenHost: "trusted.example", + } + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "https://other.example/path", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Authorization", "Bearer copied-token") + + resp, err := tripper.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + defer func() { _ = resp.Body.Close() }() + + if auth != "" { + t.Fatalf("untrusted host Authorization header = %q, want empty", auth) + } +} + +func staticBearerTokenSource(accessToken string) oauth2.TokenSource { + return oauth2.StaticTokenSource(&oauth2.Token{ + AccessToken: accessToken, + }) +} + +type roundTripFunc func(req *http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +}