diff --git a/pkg/client/client.go b/pkg/client/client.go index a27d1842..3bca50e0 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -301,7 +301,7 @@ func newClientWithTokenSource( ) (C1Client, error) { uclient, err := uhttp.NewClient( ctx, - uhttp.WithTokenSource(tokenSrc), + uhttp.WithTokenSource(tokenSrc, tokenHost), uhttp.WithDebug(v.GetBool("debug")), uhttp.WithRequestSource(cmdName), ) diff --git a/pkg/uhttp/client.go b/pkg/uhttp/client.go index b70feb51..44502ce4 100644 --- a/pkg/uhttp/client.go +++ b/pkg/uhttp/client.go @@ -43,16 +43,19 @@ func WithLogger(log bool, logger *zap.Logger) Option { type tokenSourceOption struct { tokenSource oauth2.TokenSource + tokenHost string } func (t tokenSourceOption) Apply(c *Transport) { c.tokenSource = t.tokenSource + c.tokenHost = t.tokenHost } // WithTokenSource sets a token source option to the transport layer. -func WithTokenSource(tokenSource oauth2.TokenSource) Option { +func WithTokenSource(tokenSource oauth2.TokenSource, tokenHost string) Option { return tokenSourceOption{ tokenSource: tokenSource, + tokenHost: tokenHost, } } diff --git a/pkg/uhttp/transport.go b/pkg/uhttp/transport.go index 7b95a048..bf39320d 100644 --- a/pkg/uhttp/transport.go +++ b/pkg/uhttp/transport.go @@ -34,6 +34,7 @@ func NewTransport(ctx context.Context, options ...Option) (*Transport, error) { type Transport struct { userAgent string tokenSource oauth2.TokenSource + tokenHost string requestSource string tlsClientConfig *tls.Config roundTripper http.RoundTripper @@ -178,12 +179,17 @@ func (uat *debugTripper) RoundTrip(req *http.Request) (*http.Response, error) { type tokenSourceTripper struct { next http.RoundTripper tokenSource oauth2.TokenSource + tokenHost string } func (uts *tokenSourceTripper) RoundTrip(req *http.Request) (*http.Response, error) { if uts.tokenSource == nil { return uts.next.RoundTrip(req) } + if !strings.EqualFold(req.URL.Host, uts.tokenHost) { + req.Header.Del("Authorization") + return uts.next.RoundTrip(req) + } token, err := uts.tokenSource.Token() if err != nil { return nil, err @@ -216,7 +222,7 @@ func (t *Transport) make(ctx context.Context) (http.RoundTripper, error) { t.userAgent = fmt.Sprintf("%s cone", t.userAgent) rv = &debugTripper{next: rv, debug: t.debug} rv = &userAgentTripper{next: rv, userAgent: t.userAgent} - rv = &tokenSourceTripper{next: rv, tokenSource: t.tokenSource} + rv = &tokenSourceTripper{next: rv, tokenSource: t.tokenSource, tokenHost: t.tokenHost} rv = &requestSourceTripper{next: rv, requestSource: t.requestSource} return rv, nil } diff --git a/pkg/uhttp/transport_test.go b/pkg/uhttp/transport_test.go new file mode 100644 index 00000000..0e4b6cce --- /dev/null +++ b/pkg/uhttp/transport_test.go @@ -0,0 +1,154 @@ +package uhttp + +import ( + "context" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "golang.org/x/oauth2" +) + +func TestTokenSourceTripperKeepsAuthOnTrustedHostRedirect(t *testing.T) { + const bearerToken = "test-token" + + var startAuth string + var redirectedAuth string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/start": + startAuth = r.Header.Get("Authorization") + http.Redirect(w, r, "/next", http.StatusFound) + case "/next": + redirectedAuth = r.Header.Get("Authorization") + w.WriteHeader(http.StatusNoContent) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + serverURL, err := url.Parse(server.URL) + if err != nil { + t.Fatal(err) + } + client, err := NewClient( + context.Background(), + WithTokenSource(staticBearerTokenSource(bearerToken), serverURL.Host), + ) + if err != nil { + t.Fatal(err) + } + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, server.URL+"/start", nil) + if err != nil { + t.Fatal(err) + } + resp, err := client.Do(req) + if err != nil { + t.Fatal(err) + } + defer func() { _ = resp.Body.Close() }() + + expectedAuth := "Bearer " + bearerToken + if startAuth != expectedAuth { + t.Fatalf("start Authorization header = %q, want %q", startAuth, expectedAuth) + } + if redirectedAuth != expectedAuth { + t.Fatalf("redirected Authorization header = %q, want %q", redirectedAuth, expectedAuth) + } +} + +func TestTokenSourceTripperSkipsAuthOnCrossHostRedirect(t *testing.T) { + const bearerToken = "test-token" + + var trustedAuth string + var redirectedAuth string + redirectedServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + redirectedAuth = r.Header.Get("Authorization") + w.WriteHeader(http.StatusNoContent) + })) + defer redirectedServer.Close() + + trustedServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + trustedAuth = r.Header.Get("Authorization") + http.Redirect(w, r, redirectedServer.URL, http.StatusFound) + })) + defer trustedServer.Close() + + trustedURL, err := url.Parse(trustedServer.URL) + if err != nil { + t.Fatal(err) + } + client, err := NewClient( + context.Background(), + WithTokenSource(staticBearerTokenSource(bearerToken), trustedURL.Host), + ) + if err != nil { + t.Fatal(err) + } + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, trustedServer.URL, nil) + if err != nil { + t.Fatal(err) + } + resp, err := client.Do(req) + if err != nil { + t.Fatal(err) + } + defer func() { _ = resp.Body.Close() }() + + expectedAuth := "Bearer " + bearerToken + if trustedAuth != expectedAuth { + t.Fatalf("trusted Authorization header = %q, want %q", trustedAuth, expectedAuth) + } + if redirectedAuth != "" { + t.Fatalf("cross-host redirected Authorization header = %q, want empty", redirectedAuth) + } +} + +func TestTokenSourceTripperRemovesAuthOnUntrustedHost(t *testing.T) { + var auth string + tripper := &tokenSourceTripper{ + next: roundTripFunc(func(req *http.Request) (*http.Response, error) { + auth = req.Header.Get("Authorization") + return &http.Response{ + StatusCode: http.StatusNoContent, + Body: http.NoBody, + Header: make(http.Header), + Request: req, + }, nil + }), + tokenSource: staticBearerTokenSource("fresh-token"), + tokenHost: "trusted.example", + } + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "https://other.example/path", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Authorization", "Bearer copied-token") + + resp, err := tripper.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + defer func() { _ = resp.Body.Close() }() + + if auth != "" { + t.Fatalf("untrusted host Authorization header = %q, want empty", auth) + } +} + +func staticBearerTokenSource(accessToken string) oauth2.TokenSource { + return oauth2.StaticTokenSource(&oauth2.Token{ + AccessToken: accessToken, + }) +} + +type roundTripFunc func(req *http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +}