Skip to content

Commit b778865

Browse files
committed
fix(keycardai-mcp): support x-forwarded-port header
1 parent 5d729ae commit b778865

6 files changed

Lines changed: 205 additions & 6 deletions

File tree

packages/mcp/src/keycardai/mcp/server/handlers/metadata.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
from keycardai.oauth.types.oauth import GrantType, TokenEndpointAuthMethod
1212

13+
from ..shared.starlette import get_base_url
14+
1315

1416
class InferredProtectedResourceMetadata(ProtectedResourceMetadata):
1517
"""Extended ProtectedResourceMetadata that allows resource to be inferred from request."""
@@ -56,6 +58,7 @@ def _strip_zone_id_from_path(zone_id: str, path: str) -> str:
5658
return path[len(zone_id):]
5759
return path
5860

61+
5962
def _create_resource_url(base_url: str | AnyHttpUrl, path: str) -> AnyHttpUrl:
6063
base_url_str = str(base_url).rstrip("/")
6164
if path and not path.startswith("/"):
@@ -80,13 +83,17 @@ def wrapper(request: Request) -> Response:
8083
# Create a copy of the metadata to avoid mutating the original
8184
request_metadata = metadata.model_copy(deep=True)
8285
path = _remove_well_known_prefix(request.url.path)
86+
87+
# Get proxy-aware base URL for correct scheme handling
88+
base_url = get_base_url(request)
89+
8390
if enable_multi_zone:
8491
zone_id = _get_zone_id_from_path(path)
8592
if zone_id:
8693
request_metadata.authorization_servers = [ _create_zone_scoped_authorization_server_url(zone_id, request_metadata.authorization_servers[0]) ]
8794

88-
request_metadata.resource = _create_resource_url(request.base_url, path)
89-
request_metadata.jwks_uri = _create_jwks_uri(str(request.base_url))
95+
request_metadata.resource = _create_resource_url(base_url, path)
96+
request_metadata.jwks_uri = _create_jwks_uri(base_url)
9097
request_metadata.client_id = str(request_metadata.resource)
9198
request_metadata.client_name = "MCP Server"
9299
request_metadata.token_endpoint_auth_method = TokenEndpointAuthMethod.PRIVATE_KEY_JWT
@@ -96,7 +103,7 @@ def wrapper(request: Request) -> Response:
96103
mcp_version = request.headers.get("mcp-protocol-version")
97104
# TODO: what is the reason for this?
98105
if mcp_version == "2025-03-26":
99-
json["authorization_servers"] = [ request.base_url ]
106+
json["authorization_servers"] = [ base_url ]
100107
return Response(content=request_metadata.model_dump_json(exclude_none=True), status_code=200)
101108
return wrapper
102109

@@ -115,7 +122,8 @@ def wrapper(request: Request) -> Response:
115122
resp = client.get(f"{actual_issuer}/.well-known/oauth-authorization-server")
116123
resp.raise_for_status()
117124
authorization_server_metadata = resp.json()
118-
authorization_server_metadata["authorization_endpoint"] = f"{request.base_url}{authorization_server_metadata['authorization_endpoint']}"
125+
base_url = get_base_url(request)
126+
authorization_server_metadata["authorization_endpoint"] = f"{base_url}{authorization_server_metadata['authorization_endpoint']}"
119127
return Response(content=json.dumps(authorization_server_metadata), status_code=200)
120128
except httpx.HTTPStatusError as e:
121129
# Return the same status code as the upstream server

packages/mcp/src/keycardai/mcp/server/middleware/bearer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77
from starlette.types import ASGIApp
88

99
from ..auth.verifier import TokenVerifier
10+
from ..shared.starlette import get_base_url
1011

1112

1213
def _get_oauth_protected_resource_url(request: Request) -> str:
1314
path = request.url.path.lstrip("/").rstrip("/")
14-
base_url = str(request.base_url).rstrip("/")
15+
base_url = get_base_url(request)
1516
return str(AnyHttpUrl(f"{base_url}/.well-known/oauth-protected-resource/{path}"))
1617

1718
def _get_bearer_token(request: Request) -> str | None:
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
"""Shared utilities for Starlette/FastAPI applications."""
2+
3+
from pydantic import AnyHttpUrl
4+
from starlette.requests import Request
5+
6+
"""Supported protocols for the base URL."""
7+
SUPPORTED_PROTOCOLS = ["http", "https"]
8+
9+
10+
def get_base_url(request: Request) -> str:
11+
"""Get the correct base URL considering proxy headers like X-Forwarded-Proto."""
12+
request_base_url = AnyHttpUrl(str(request.base_url))
13+
proto = request.headers.get("x-forwarded-proto") or request_base_url.scheme
14+
if proto not in SUPPORTED_PROTOCOLS:
15+
proto = "https"
16+
17+
if request_base_url.port not in [443, 80]:
18+
base_url = f"{proto}://{request_base_url.host}:{request_base_url.port}"
19+
else:
20+
base_url = f"{proto}://{request_base_url.host}"
21+
22+
return base_url

packages/mcp/tests/keycardai/mcp/server/handlers/test_metadata.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
import pytest
77
from pydantic import AnyHttpUrl
8+
from starlette.datastructures import URL
9+
from starlette.requests import Request
810

911
from keycardai.mcp.server.handlers.metadata import (
1012
_create_resource_url,
@@ -14,6 +16,7 @@
1416
_remove_well_known_prefix,
1517
_strip_zone_id_from_path,
1618
)
19+
from keycardai.mcp.server.shared.starlette import get_base_url
1720

1821

1922
class TestIsAuthorizationServerZoneScoped:
@@ -399,5 +402,86 @@ def test_case_sensitivity(self):
399402
assert result == "zone123/api/v1"
400403

401404

405+
class TestGetBaseUrl:
406+
"""Test get_base_url function."""
407+
408+
def _create_mock_request(self, base_url: str, headers: dict[str, str] | None = None) -> Request:
409+
"""Create a mock request with specified base URL and headers."""
410+
if headers is None:
411+
headers = {}
412+
413+
parsed_url = URL(base_url)
414+
# Create a minimal ASGI scope for testing
415+
scope = {
416+
"type": "http",
417+
"method": "GET",
418+
"scheme": parsed_url.scheme,
419+
"server": (parsed_url.hostname, parsed_url.port or (443 if parsed_url.scheme == "https" else 80)),
420+
"path": parsed_url.path or "/",
421+
"query_string": b"",
422+
"headers": [(k.lower().encode(), v.encode()) for k, v in headers.items()],
423+
}
424+
return Request(scope)
425+
426+
def test_no_proxy_headers(self):
427+
"""Test with no proxy headers - should return original base URL."""
428+
request = self._create_mock_request("http://example.com")
429+
result = get_base_url(request)
430+
assert result == "http://example.com"
431+
432+
def test_with_x_forwarded_proto_https(self):
433+
"""Test with X-Forwarded-Proto header indicating HTTPS."""
434+
headers = {"x-forwarded-proto": "https"}
435+
request = self._create_mock_request("http://example.com", headers)
436+
result = get_base_url(request)
437+
assert result == "https://example.com"
438+
439+
def test_with_x_forwarded_proto_http(self):
440+
"""Test with X-Forwarded-Proto header indicating HTTP."""
441+
headers = {"x-forwarded-proto": "http"}
442+
request = self._create_mock_request("https://example.com", headers)
443+
result = get_base_url(request)
444+
assert result == "http://example.com"
445+
446+
def test_with_port_number(self):
447+
"""Test with port number in base URL."""
448+
headers = {"x-forwarded-proto": "https"}
449+
request = self._create_mock_request("http://example.com:8080", headers)
450+
result = get_base_url(request)
451+
assert result == "https://example.com:8080"
452+
453+
def test_with_path_in_base_url(self):
454+
"""Test with path in base URL - path should be ignored for base URL."""
455+
headers = {"x-forwarded-proto": "https"}
456+
request = self._create_mock_request("http://example.com/api/v1", headers)
457+
result = get_base_url(request)
458+
assert result == "https://example.com"
459+
460+
def test_case_insensitive_header(self):
461+
"""Test that header matching is case insensitive (Starlette handles this)."""
462+
headers = {"X-Forwarded-Proto": "https"}
463+
request = self._create_mock_request("http://example.com", headers)
464+
result = get_base_url(request)
465+
assert result == "https://example.com"
466+
467+
def test_trailing_slash_handling(self):
468+
"""Test that trailing slashes are properly handled."""
469+
headers = {"x-forwarded-proto": "https"}
470+
request = self._create_mock_request("http://example.com/", headers)
471+
result = get_base_url(request)
472+
assert result == "https://example.com"
473+
474+
def test_aws_app_runner_scenario(self):
475+
"""Test the specific AWS App Runner scenario from the issue."""
476+
headers = {
477+
"host": "ppxrhd2bw4.us-east-1.awsapprunner.com",
478+
"x-forwarded-proto": "https",
479+
"x-forwarded-for": "92.238.31.228"
480+
}
481+
request = self._create_mock_request("http://ppxrhd2bw4.us-east-1.awsapprunner.com", headers)
482+
result = get_base_url(request)
483+
assert result == "https://ppxrhd2bw4.us-east-1.awsapprunner.com"
484+
485+
402486
if __name__ == "__main__":
403487
pytest.main([__file__])

packages/mcp/tests/keycardai/mcp/server/middleware/test_bearer.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,3 +246,51 @@ def test_case_sensitivity_of_auth_scheme(self):
246246

247247
result = _get_bearer_token(request)
248248
assert result == expected_token, f"Failed for auth scheme: {auth_scheme}"
249+
250+
251+
class TestGetBaseUrlMiddleware:
252+
"""Tests for get_base_url function in middleware."""
253+
254+
def _create_mock_request(self, base_url: str, path: str = "/", headers: dict[str, str] | None = None) -> Request:
255+
"""Create a mock request with specified base URL, path, and headers."""
256+
if headers is None:
257+
headers = {}
258+
259+
# Create a minimal ASGI scope for testing
260+
scope = {
261+
"type": "http",
262+
"method": "GET",
263+
"scheme": URL(base_url).scheme,
264+
"server": (URL(base_url).hostname, URL(base_url).port or (443 if URL(base_url).scheme == "https" else 80)),
265+
"path": path,
266+
"query_string": b"",
267+
"headers": [(k.lower().encode(), v.encode()) for k, v in headers.items()],
268+
}
269+
return Request(scope)
270+
271+
def test_proxy_aware_url_in_oauth_resource_url(self):
272+
"""Test that _get_oauth_protected_resource_url uses proxy-aware base URL."""
273+
headers = {"x-forwarded-proto": "https"}
274+
request = self._create_mock_request("http://example.com", "/api/resource", headers)
275+
276+
result = _get_oauth_protected_resource_url(request)
277+
assert result == "https://example.com/.well-known/oauth-protected-resource/api/resource"
278+
279+
def test_no_proxy_headers_in_oauth_resource_url(self):
280+
"""Test _get_oauth_protected_resource_url without proxy headers."""
281+
request = self._create_mock_request("http://example.com", "/api/resource")
282+
283+
result = _get_oauth_protected_resource_url(request)
284+
assert result == "http://example.com/.well-known/oauth-protected-resource/api/resource"
285+
286+
def test_aws_app_runner_scenario_in_middleware(self):
287+
"""Test the AWS App Runner scenario in middleware context."""
288+
headers = {
289+
"host": "ppxrhd2bw4.us-east-1.awsapprunner.com",
290+
"x-forwarded-proto": "https",
291+
"x-forwarded-for": "92.238.31.228"
292+
}
293+
request = self._create_mock_request("http://ppxrhd2bw4.us-east-1.awsapprunner.com", "/zone123/api", headers)
294+
295+
result = _get_oauth_protected_resource_url(request)
296+
assert result == "https://ppxrhd2bw4.us-east-1.awsapprunner.com/.well-known/oauth-protected-resource/zone123/api"

uv.lock

Lines changed: 37 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)