Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pkg/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ func newClientWithTokenSource(
) (C1Client, error) {
uclient, err := uhttp.NewClient(
ctx,
uhttp.WithTokenSource(tokenSrc),
uhttp.WithTokenSource(tokenSrc, tokenHost),
uhttp.WithDebug(v.GetBool("debug")),
uhttp.WithRequestSource(cmdName),
)
Expand Down
5 changes: 4 additions & 1 deletion pkg/uhttp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,19 @@ func WithLogger(log bool, logger *zap.Logger) Option {

type tokenSourceOption struct {
tokenSource oauth2.TokenSource
tokenHost string
}

func (t tokenSourceOption) Apply(c *Transport) {
c.tokenSource = t.tokenSource
c.tokenHost = t.tokenHost
}

// WithTokenSource sets a token source option to the transport layer.
func WithTokenSource(tokenSource oauth2.TokenSource) Option {
func WithTokenSource(tokenSource oauth2.TokenSource, tokenHost string) Option {
return tokenSourceOption{
tokenSource: tokenSource,
tokenHost: tokenHost,
}
}

Expand Down
8 changes: 7 additions & 1 deletion pkg/uhttp/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ func NewTransport(ctx context.Context, options ...Option) (*Transport, error) {
type Transport struct {
userAgent string
tokenSource oauth2.TokenSource
tokenHost string
requestSource string
tlsClientConfig *tls.Config
roundTripper http.RoundTripper
Expand Down Expand Up @@ -178,12 +179,17 @@ func (uat *debugTripper) RoundTrip(req *http.Request) (*http.Response, error) {
type tokenSourceTripper struct {
next http.RoundTripper
tokenSource oauth2.TokenSource
tokenHost string
}

func (uts *tokenSourceTripper) RoundTrip(req *http.Request) (*http.Response, error) {
if uts.tokenSource == nil {
return uts.next.RoundTrip(req)
}
if !strings.EqualFold(req.URL.Host, uts.tokenHost) {
req.Header.Del("Authorization")
return uts.next.RoundTrip(req)
}
token, err := uts.tokenSource.Token()
if err != nil {
return nil, err
Expand Down Expand Up @@ -216,7 +222,7 @@ func (t *Transport) make(ctx context.Context) (http.RoundTripper, error) {
t.userAgent = fmt.Sprintf("%s cone", t.userAgent)
rv = &debugTripper{next: rv, debug: t.debug}
rv = &userAgentTripper{next: rv, userAgent: t.userAgent}
rv = &tokenSourceTripper{next: rv, tokenSource: t.tokenSource}
rv = &tokenSourceTripper{next: rv, tokenSource: t.tokenSource, tokenHost: t.tokenHost}
rv = &requestSourceTripper{next: rv, requestSource: t.requestSource}
return rv, nil
}
Expand Down
154 changes: 154 additions & 0 deletions pkg/uhttp/transport_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
package uhttp

import (
"context"
"net/http"
"net/http/httptest"
"net/url"
"testing"

"golang.org/x/oauth2"
)

func TestTokenSourceTripperKeepsAuthOnTrustedHostRedirect(t *testing.T) {
const bearerToken = "test-token"

var startAuth string
var redirectedAuth string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/start":
startAuth = r.Header.Get("Authorization")
http.Redirect(w, r, "/next", http.StatusFound)
case "/next":
redirectedAuth = r.Header.Get("Authorization")
w.WriteHeader(http.StatusNoContent)
default:
http.NotFound(w, r)
}
}))
defer server.Close()

serverURL, err := url.Parse(server.URL)
if err != nil {
t.Fatal(err)
}
client, err := NewClient(
context.Background(),
WithTokenSource(staticBearerTokenSource(bearerToken), serverURL.Host),
)
if err != nil {
t.Fatal(err)
}

req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, server.URL+"/start", nil)
if err != nil {
t.Fatal(err)
}
resp, err := client.Do(req)
if err != nil {
t.Fatal(err)
}
defer func() { _ = resp.Body.Close() }()

expectedAuth := "Bearer " + bearerToken
if startAuth != expectedAuth {
t.Fatalf("start Authorization header = %q, want %q", startAuth, expectedAuth)
}
if redirectedAuth != expectedAuth {
t.Fatalf("redirected Authorization header = %q, want %q", redirectedAuth, expectedAuth)
}
}

func TestTokenSourceTripperSkipsAuthOnCrossHostRedirect(t *testing.T) {
const bearerToken = "test-token"

var trustedAuth string
var redirectedAuth string
redirectedServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
redirectedAuth = r.Header.Get("Authorization")
w.WriteHeader(http.StatusNoContent)
}))
defer redirectedServer.Close()

trustedServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
trustedAuth = r.Header.Get("Authorization")
http.Redirect(w, r, redirectedServer.URL, http.StatusFound)
}))
defer trustedServer.Close()

trustedURL, err := url.Parse(trustedServer.URL)
if err != nil {
t.Fatal(err)
}
client, err := NewClient(
context.Background(),
WithTokenSource(staticBearerTokenSource(bearerToken), trustedURL.Host),
)
if err != nil {
t.Fatal(err)
}

req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, trustedServer.URL, nil)
if err != nil {
t.Fatal(err)
}
resp, err := client.Do(req)
if err != nil {
t.Fatal(err)
}
defer func() { _ = resp.Body.Close() }()

expectedAuth := "Bearer " + bearerToken
if trustedAuth != expectedAuth {
t.Fatalf("trusted Authorization header = %q, want %q", trustedAuth, expectedAuth)
}
if redirectedAuth != "" {
t.Fatalf("cross-host redirected Authorization header = %q, want empty", redirectedAuth)
}
}

func TestTokenSourceTripperRemovesAuthOnUntrustedHost(t *testing.T) {
var auth string
tripper := &tokenSourceTripper{
next: roundTripFunc(func(req *http.Request) (*http.Response, error) {
auth = req.Header.Get("Authorization")
return &http.Response{
StatusCode: http.StatusNoContent,
Body: http.NoBody,
Header: make(http.Header),
Request: req,
}, nil
}),
tokenSource: staticBearerTokenSource("fresh-token"),
tokenHost: "trusted.example",
}

req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "https://other.example/path", nil)
if err != nil {
t.Fatal(err)
}
req.Header.Set("Authorization", "Bearer copied-token")

resp, err := tripper.RoundTrip(req)
if err != nil {
t.Fatal(err)
}
defer func() { _ = resp.Body.Close() }()

if auth != "" {
t.Fatalf("untrusted host Authorization header = %q, want empty", auth)
}
}

func staticBearerTokenSource(accessToken string) oauth2.TokenSource {
return oauth2.StaticTokenSource(&oauth2.Token{
AccessToken: accessToken,
})
}

type roundTripFunc func(req *http.Request) (*http.Response, error)

func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}