@@ -10,129 +10,103 @@ import (
1010 "net/http/httptest"
1111 "net/url"
1212 "strings"
13+ "sync"
1314 "testing"
1415 "time"
1516)
1617
18+ type proxyOpts struct {
19+ useTLS bool
20+ username string
21+ password string
22+ observe func (* http.Request )
23+ }
24+
1725// startProxy starts an HTTP or HTTPS CONNECT proxy on a random port.
18- // It returns the proxy URL and a channel that receives the protocol observed by
19- // the proxy handler for each CONNECT request .
20- func startProxy (t * testing.T , useTLS bool ) ( proxyURL * url.URL , obsCh <- chan string ) {
26+ // If opts.observe is set, it is called for each CONNECT request.
27+ // If opts.username is set, Proxy-Authorization is required .
28+ func startProxy (t * testing.T , opts proxyOpts ) * url.URL {
2129 t .Helper ()
2230
23- ch := make (chan string , 10 )
24-
2531 srv := httptest .NewUnstartedServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
26- select {
27- case ch <- r .Proto :
28- default :
32+ if opts .observe != nil {
33+ opts .observe (r )
2934 }
3035
3136 if r .Method != http .MethodConnect {
3237 http .Error (w , "expected CONNECT" , http .StatusMethodNotAllowed )
3338 return
3439 }
3540
36- destConn , err := net .DialTimeout ("tcp" , r .Host , 10 * time .Second )
37- if err != nil {
38- http .Error (w , err .Error (), http .StatusBadGateway )
39- return
40- }
41- defer destConn .Close ()
42-
43- hijacker , ok := w .(http.Hijacker )
44- if ! ok {
45- http .Error (w , "hijacking not supported" , http .StatusInternalServerError )
46- return
41+ if opts .username != "" {
42+ wantAuth := "Basic " + base64 .StdEncoding .EncodeToString ([]byte (opts .username + ":" + opts .password ))
43+ if r .Header .Get ("Proxy-Authorization" ) != wantAuth {
44+ http .Error (w , "proxy auth required" , http .StatusProxyAuthRequired )
45+ return
46+ }
4747 }
4848
49- w .WriteHeader (http .StatusOK )
50- clientConn , bufrw , err := hijacker .Hijack ()
51- if err != nil {
52- return
53- }
54- defer clientConn .Close ()
55-
56- done := make (chan struct {}, 2 )
57- // Read from bufrw (not clientConn) so any bytes already buffered
58- // by the server's bufio.Reader are forwarded to the destination.
59- go func () { io .Copy (destConn , bufrw ); done <- struct {}{} }()
60- go func () { io .Copy (clientConn , destConn ); done <- struct {}{} }()
61- <- done
62- // Close both sides so the remaining goroutine unblocks.
63- clientConn .Close ()
64- destConn .Close ()
65- <- done
49+ serveTunnel (w , r )
6650 }))
6751
68- if useTLS {
52+ if opts . useTLS {
6953 srv .StartTLS ()
7054 } else {
7155 srv .Start ()
7256 }
7357 t .Cleanup (srv .Close )
7458
7559 pURL , _ := url .Parse (srv .URL )
76- return pURL , ch
60+ if opts .username != "" {
61+ pURL .User = url .UserPassword (opts .username , opts .password )
62+ }
63+ return pURL
7764}
7865
79- // startProxyWithAuth is like startProxy but requires
80- // Proxy-Authorization with the given username and password.
81- func startProxyWithAuth (t * testing.T , useTLS bool , wantUser , wantPass string ) (proxyURL * url.URL ) {
82- t .Helper ()
83-
84- srv := httptest .NewUnstartedServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
85- if r .Method != http .MethodConnect {
86- http .Error (w , "expected CONNECT" , http .StatusMethodNotAllowed )
87- return
88- }
89-
90- authHeader := r .Header .Get ("Proxy-Authorization" )
91- wantAuth := "Basic " + base64 .StdEncoding .EncodeToString ([]byte (wantUser + ":" + wantPass ))
92- if authHeader != wantAuth {
93- http .Error (w , "proxy auth required" , http .StatusProxyAuthRequired )
94- return
95- }
96-
97- destConn , err := net .DialTimeout ("tcp" , r .Host , 10 * time .Second )
98- if err != nil {
99- http .Error (w , err .Error (), http .StatusBadGateway )
100- return
101- }
102- defer destConn .Close ()
66+ // serveTunnel implements the CONNECT tunnel: dials the target, hijacks the
67+ // client connection, and copies bytes bidirectionally.
68+ func serveTunnel (w http.ResponseWriter , r * http.Request ) {
69+ destConn , err := net .DialTimeout ("tcp" , r .Host , 10 * time .Second )
70+ if err != nil {
71+ http .Error (w , err .Error (), http .StatusBadGateway )
72+ return
73+ }
74+ defer destConn .Close ()
10375
104- hijacker , ok := w .(http.Hijacker )
105- if ! ok {
106- http .Error (w , "hijacking not supported" , http .StatusInternalServerError )
107- return
108- }
76+ hijacker , ok := w .(http.Hijacker )
77+ if ! ok {
78+ http .Error (w , "hijacking not supported" , http .StatusInternalServerError )
79+ return
80+ }
10981
110- w .WriteHeader (http .StatusOK )
111- clientConn , bufrw , err := hijacker .Hijack ()
112- if err != nil {
113- return
114- }
115- defer clientConn .Close ()
82+ w .WriteHeader (http .StatusOK )
83+ clientConn , bufrw , err := hijacker .Hijack ()
84+ if err != nil {
85+ return
86+ }
11687
117- done := make (chan struct {}, 2 )
118- go func () { io .Copy (destConn , bufrw ); done <- struct {}{} }()
119- go func () { io .Copy (clientConn , destConn ); done <- struct {}{} }()
120- <- done
88+ var wg sync.WaitGroup
89+ var once sync.Once
90+ closeBoth := func () {
12191 clientConn .Close ()
12292 destConn .Close ()
123- <- done
124- }))
125-
126- if useTLS {
127- srv .StartTLS ()
128- } else {
129- srv .Start ()
13093 }
131- t . Cleanup ( srv . Close )
94+ defer once . Do ( closeBoth )
13295
133- pURL , _ := url .Parse (srv .URL )
134- pURL .User = url .UserPassword (wantUser , wantPass )
135- return pURL
96+ wg .Add (2 )
97+ // Read from bufrw (not clientConn) so any bytes already buffered
98+ // by the server's bufio.Reader are forwarded to the destination.
99+ go func () {
100+ defer wg .Done ()
101+ io .Copy (destConn , bufrw )
102+ once .Do (closeBoth )
103+ }()
104+ go func () {
105+ defer wg .Done ()
106+ io .Copy (clientConn , destConn )
107+ once .Do (closeBoth )
108+ }()
109+ wg .Wait ()
136110}
137111
138112// newTestTransport creates a base transport suitable for proxy tests.
@@ -157,7 +131,19 @@ func startTargetServer(t *testing.T) *httptest.Server {
157131
158132func TestWithProxyTransport_HTTPProxy (t * testing.T ) {
159133 target := startTargetServer (t )
160- proxyURL , obsCh := startProxy (t , false )
134+
135+ var mu sync.Mutex
136+ var used bool
137+ var proto string
138+
139+ proxyURL := startProxy (t , proxyOpts {
140+ observe : func (r * http.Request ) {
141+ mu .Lock ()
142+ defer mu .Unlock ()
143+ used = true
144+ proto = r .Proto
145+ },
146+ })
161147
162148 transport := withProxyTransport (newTestTransport (), proxyURL , "" )
163149 t .Cleanup (transport .CloseIdleConnections )
@@ -180,19 +166,32 @@ func TestWithProxyTransport_HTTPProxy(t *testing.T) {
180166 t .Errorf ("expected body 'ok', got %q" , got )
181167 }
182168
183- select {
184- case proto := <- obsCh :
185- if proto != "HTTP/1.1" {
186- t .Errorf ("expected proxy to see HTTP/1.1 CONNECT, got %s" , proto )
187- }
188- case <- time .After (2 * time .Second ):
169+ mu .Lock ()
170+ defer mu .Unlock ()
171+ if ! used {
189172 t .Fatal ("proxy handler was never invoked" )
190173 }
174+ if proto != "HTTP/1.1" {
175+ t .Errorf ("expected proxy to see HTTP/1.1 CONNECT, got %s" , proto )
176+ }
191177}
192178
193179func TestWithProxyTransport_HTTPSProxy (t * testing.T ) {
194180 target := startTargetServer (t )
195- proxyURL , obsCh := startProxy (t , true )
181+
182+ var mu sync.Mutex
183+ var used bool
184+ var proto string
185+
186+ proxyURL := startProxy (t , proxyOpts {
187+ useTLS : true ,
188+ observe : func (r * http.Request ) {
189+ mu .Lock ()
190+ defer mu .Unlock ()
191+ used = true
192+ proto = r .Proto
193+ },
194+ })
196195
197196 transport := withProxyTransport (newTestTransport (), proxyURL , "" )
198197 t .Cleanup (transport .CloseIdleConnections )
@@ -215,21 +214,21 @@ func TestWithProxyTransport_HTTPSProxy(t *testing.T) {
215214 t .Errorf ("expected body 'ok', got %q" , got )
216215 }
217216
218- select {
219- case proto := <- obsCh :
220- if proto != "HTTP/1.1" {
221- t .Errorf ("expected proxy to see HTTP/1.1 CONNECT, got %s" , proto )
222- }
223- case <- time .After (2 * time .Second ):
217+ mu .Lock ()
218+ defer mu .Unlock ()
219+ if ! used {
224220 t .Fatal ("proxy handler was never invoked" )
225221 }
222+ if proto != "HTTP/1.1" {
223+ t .Errorf ("expected proxy to see HTTP/1.1 CONNECT, got %s" , proto )
224+ }
226225}
227226
228227func TestWithProxyTransport_ProxyAuth (t * testing.T ) {
229228 target := startTargetServer (t )
230229
231230 t .Run ("http proxy with auth" , func (t * testing.T ) {
232- proxyURL := startProxyWithAuth (t , false , "user" , "pass" )
231+ proxyURL := startProxy (t , proxyOpts { username : "user" , password : "pass" } )
233232 transport := withProxyTransport (newTestTransport (), proxyURL , "" )
234233 t .Cleanup (transport .CloseIdleConnections )
235234 client := & http.Client {Transport : transport , Timeout : 10 * time .Second }
@@ -249,7 +248,7 @@ func TestWithProxyTransport_ProxyAuth(t *testing.T) {
249248 })
250249
251250 t .Run ("https proxy with auth" , func (t * testing.T ) {
252- proxyURL := startProxyWithAuth (t , true , "user" , "s3cret" )
251+ proxyURL := startProxy (t , proxyOpts { useTLS : true , username : "user" , password : "s3cret" } )
253252 transport := withProxyTransport (newTestTransport (), proxyURL , "" )
254253 t .Cleanup (transport .CloseIdleConnections )
255254 client := & http.Client {Transport : transport , Timeout : 10 * time .Second }
@@ -273,7 +272,7 @@ func TestWithProxyTransport_HTTPSProxy_HTTP2ToOrigin(t *testing.T) {
273272 // Verify that when tunneling through an HTTPS proxy, the connection to
274273 // the origin target still negotiates HTTP/2 (not downgraded to HTTP/1.1).
275274 target := startTargetServer (t )
276- proxyURL , _ := startProxy (t , true )
275+ proxyURL := startProxy (t , proxyOpts { useTLS : true } )
277276
278277 transport := withProxyTransport (newTestTransport (), proxyURL , "" )
279278 t .Cleanup (transport .CloseIdleConnections )
@@ -322,7 +321,7 @@ func TestWithProxyTransport_HandshakeFailureClosesConn(t *testing.T) {
322321 close (connClosed )
323322 }()
324323
325- proxyURL , _ := startProxy (t , true )
324+ proxyURL := startProxy (t , proxyOpts { useTLS : true } )
326325 transport := withProxyTransport (newTestTransport (), proxyURL , "" )
327326 t .Cleanup (transport .CloseIdleConnections )
328327 client := & http.Client {Transport : transport , Timeout : 5 * time .Second }
0 commit comments