Skip to content

Commit 9526566

Browse files
committed
fix(keycardai-oauth): update test to expect OAuthProtocolError for structured error bodies
1 parent 4184da7 commit 9526566

2 files changed

Lines changed: 29 additions & 15 deletions

File tree

packages/mcp/tests/keycardai/mcp/server/auth/test_provider.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ async def test_function_called_without_context_value(self, auth_provider_config,
122122
def test_function(access_ctx: AccessContext, ctx: Context, user_id: str) -> dict:
123123
if access_ctx.has_error():
124124
error = access_ctx.get_error()
125-
return {"error": error["error"]}
125+
return {"error": error["message"]}
126126
return {"success": True}
127127

128128
# Call without providing ctx parameter - should cause TypeError due to missing required argument
@@ -141,7 +141,7 @@ async def test_function_called_with_context_as_none(self, auth_provider_config,
141141
def test_function(access_ctx: AccessContext, ctx: Context, user_id: str) -> dict:
142142
if access_ctx.has_error():
143143
error = access_ctx.get_error()
144-
return {"error": error["error"]}
144+
return {"error": error["message"]}
145145
return {"success": True}
146146

147147
# Call with ctx=None - should cause error
@@ -161,7 +161,7 @@ async def test_function_called_with_context_via_positional_args(self, auth_provi
161161
def test_function(access_ctx: AccessContext, ctx: Context, user_id: str) -> dict:
162162
if access_ctx.has_error():
163163
error = access_ctx.get_error()
164-
return {"error": error["error"]}
164+
return {"error": error["message"]}
165165
return {"success": True, "user_id": user_id}
166166

167167
mock_context = self.create_mock_context_with_auth()
@@ -185,7 +185,7 @@ async def test_function_called_with_context_via_kwargs(self, auth_provider_confi
185185
def test_function(access_ctx: AccessContext, ctx: Context, user_id: str) -> dict:
186186
if access_ctx.has_error():
187187
error = access_ctx.get_error()
188-
return {"error": error["error"]}
188+
return {"error": error["message"]}
189189
return {"success": True, "user_id": user_id}
190190

191191
mock_context = self.create_mock_context_with_auth()
@@ -211,7 +211,7 @@ def test_function(access_ctx: AccessContext, ctx: Context, user_id: str) -> dict
211211
assert isinstance(access_ctx, AccessContext)
212212
if access_ctx.has_error():
213213
error = access_ctx.get_error()
214-
return {"error": error["error"]}
214+
return {"error": error["message"]}
215215
return {"success": True, "has_access_ctx": True}
216216

217217
mock_context = self.create_mock_context_with_auth()
@@ -235,7 +235,7 @@ async def test_function_called_with_access_context_via_positional_args(self, aut
235235
def test_function(access_ctx: AccessContext, ctx: Context, user_id: str) -> dict:
236236
if access_ctx.has_error():
237237
error = access_ctx.get_error()
238-
return {"error": error["error"]}
238+
return {"error": error["message"]}
239239
return {"success": True, "user_id": user_id}
240240

241241
mock_context = self.create_mock_context_with_auth()
@@ -260,7 +260,7 @@ async def test_function_called_with_access_context_via_kwargs(self, auth_provide
260260
def test_function(access_ctx: AccessContext, ctx: Context, user_id: str) -> dict:
261261
if access_ctx.has_error():
262262
error = access_ctx.get_error()
263-
return {"error": error["error"]}
263+
return {"error": error["message"]}
264264
return {"success": True, "user_id": user_id}
265265

266266
mock_context = self.create_mock_context_with_auth()
@@ -306,7 +306,7 @@ async def test_context_extraction_from_fastmcp_context(self, auth_provider_confi
306306
def test_function(access_ctx: AccessContext, ctx: Context, user_id: str) -> dict:
307307
if access_ctx.has_error():
308308
error = access_ctx.get_error()
309-
return {"error": error["error"]}
309+
return {"error": error["message"]}
310310
return {"success": True}
311311

312312
# Create mock Context with request_context
@@ -330,7 +330,7 @@ async def test_context_extraction_from_request_context_directly(self, auth_provi
330330
def test_function(access_ctx: AccessContext, ctx: RequestContext, user_id: str) -> dict:
331331
if access_ctx.has_error():
332332
error = access_ctx.get_error()
333-
return {"error": error["error"]}
333+
return {"error": error["message"]}
334334
return {"success": True}
335335

336336
mock_request_context = self.create_mock_request_context_with_auth()
@@ -351,7 +351,7 @@ async def test_missing_auth_info_in_context(self, auth_provider_config, mock_cli
351351
def test_function(access_ctx: AccessContext, ctx: Context, user_id: str) -> dict:
352352
if access_ctx.has_error():
353353
error = access_ctx.get_error()
354-
return {"error": error["error"]}
354+
return {"error": error["message"]}
355355
return {"success": True}
356356

357357
# Create mock Context without auth info
@@ -435,7 +435,7 @@ async def test_parameter_order_with_positional_args(self, auth_provider_config,
435435
def test_function(access_ctx: AccessContext, ctx: Context, user_id: str, extra_param: str = "default") -> dict:
436436
if access_ctx.has_error():
437437
error = access_ctx.get_error()
438-
return {"error": error["error"]}
438+
return {"error": error["message"]}
439439
return {
440440
"success": True,
441441
"user_id": user_id,
@@ -466,7 +466,7 @@ async def test_mixed_args_and_kwargs(self, auth_provider_config, mock_client_fac
466466
def test_function(access_ctx: AccessContext, ctx: Context, user_id: str, extra_param: str = "default") -> dict:
467467
if access_ctx.has_error():
468468
error = access_ctx.get_error()
469-
return {"error": error["error"]}
469+
return {"error": error["message"]}
470470
return {
471471
"success": True,
472472
"user_id": user_id,
@@ -574,7 +574,7 @@ async def test_multiple_resources_token_exchange(self, auth_provider_config, moc
574574
def test_function(access_ctx: AccessContext, ctx: Context, user_id: str) -> dict:
575575
if access_ctx.has_error():
576576
error = access_ctx.get_error()
577-
return {"error": error["error"]}
577+
return {"error": error["message"]}
578578

579579
# Try to access both resources
580580
try:

packages/oauth/tests/keycardai/oauth/operations/test_token_exchange.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,14 +91,28 @@ def test_parse_token_exchange_http_response_success(self):
9191
assert " ".join(result.scope) == "read write" # scope is parsed as a list
9292

9393
def test_parse_token_exchange_http_response_http_error(self):
94-
"""Test parsing HTTP error response."""
94+
"""Test parsing HTTP error response with structured OAuth error body."""
9595
http_response = HttpResponse(
9696
status=400,
9797
headers={"Content-Type": "application/json"},
9898
body=b'{"error": "invalid_request", "error_description": "Invalid subject_token"}'
9999
)
100100

101-
with pytest.raises(OAuthHttpError, match="HTTP 400"):
101+
with pytest.raises(OAuthProtocolError, match="invalid_request") as exc_info:
102+
parse_token_exchange_http_response(http_response)
103+
104+
assert exc_info.value.error == "invalid_request"
105+
assert exc_info.value.error_description == "Invalid subject_token"
106+
107+
def test_parse_token_exchange_http_response_http_error_non_json(self):
108+
"""Test parsing HTTP error response with non-JSON body."""
109+
http_response = HttpResponse(
110+
status=500,
111+
headers={"Content-Type": "text/plain"},
112+
body=b"Internal Server Error"
113+
)
114+
115+
with pytest.raises(OAuthHttpError, match="HTTP 500"):
102116
parse_token_exchange_http_response(http_response)
103117

104118
def test_parse_token_exchange_http_response_invalid_json(self):

0 commit comments

Comments
 (0)