Skip to content

Commit b8cabea

Browse files
committed
Added new class for entity metadata
1 parent 493c073 commit b8cabea

2 files changed

Lines changed: 60 additions & 12 deletions

File tree

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from typing import Optional
2+
3+
4+
class EntityMetadata(object):
5+
6+
def __init__(self, metadata: Optional[dict] = None):
7+
if metadata is None:
8+
self.metadata = {}
9+
else:
10+
self.metadata = metadata
11+
12+
def __getitem__(self, item):
13+
for _type, _dict in self.metadata.items():
14+
_val = _dict.get(item, None)
15+
if _val:
16+
return _val
17+
raise KeyError(item)
18+
19+
def __setitem__(self, key, value):
20+
# Assumes not multiple entity types
21+
self.metadata[key] = value
22+
23+
def entity_types(self):
24+
return list(self.metadata.keys())
25+
26+
def entity_type(self, etype):
27+
return self.metadata[etype]
28+
29+
def items(self):
30+
return self.metadata.items()
31+
32+
def get_entity_type_claim(self, entity_type, claim):
33+
return self.metadata[entity_type][claim]
34+
35+
def __contains__(self, item):
36+
return item in self.metadata

src/idpyoidc/client/service_context.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55
import hashlib
66
import logging
77
from typing import Callable
8+
from typing import List
89
from typing import Optional
910
from typing import Union
1011

11-
from cryptojwt.jwk.rsa import RSAKey
1212
from cryptojwt.jwk.rsa import import_private_rsa_key_from_file
13+
from cryptojwt.jwk.rsa import RSAKey
1314
from cryptojwt.key_bundle import KeyBundle
1415
from cryptojwt.key_jar import KeyJar
1516
from cryptojwt.utils import as_bytes
@@ -22,12 +23,12 @@
2223
from idpyoidc.client.claims.oidc import Claims as OIDC_Specs
2324
from idpyoidc.client.configure import Configuration
2425
from idpyoidc.util import rndstr
25-
26-
from ..impexp import ImpExp
2726
from .claims.transform import preferred_to_registered
2827
from .claims.transform import supported_to_preferred
2928
from .configure import get_configuration
3029
from .current import Current
30+
from .entity_metadata import EntityMetadata
31+
from ..impexp import ImpExp
3132

3233
logger = logging.getLogger(__name__)
3334

@@ -99,6 +100,7 @@ class ServiceContext(ImpExp):
99100
"httpc_params": None,
100101
"iss_hash": None,
101102
"issuer": None,
103+
"server_metadata": EntityMetadata,
102104
"keyjar": KeyJar,
103105
"claims": Claims,
104106
"provider_info": None,
@@ -116,14 +118,14 @@ class ServiceContext(ImpExp):
116118
init_args = ["upstream_get"]
117119

118120
def __init__(
119-
self,
120-
upstream_get: Optional[Callable] = None,
121-
base_url: Optional[str] = "",
122-
keyjar: Optional[KeyJar] = None,
123-
config: Optional[Union[dict, Configuration]] = None,
124-
cstate: Optional[Current] = None,
125-
client_type: Optional[str] = "oauth2",
126-
**kwargs,
121+
self,
122+
upstream_get: Optional[Callable] = None,
123+
base_url: Optional[str] = "",
124+
keyjar: Optional[KeyJar] = None,
125+
config: Optional[Union[dict, Configuration]] = None,
126+
cstate: Optional[Current] = None,
127+
client_type: Optional[str] = "oauth2",
128+
**kwargs,
127129
):
128130
ImpExp.__init__(self)
129131
config = get_configuration(config)
@@ -150,6 +152,7 @@ def __init__(
150152
self.allow = config.conf.get("allow", {})
151153
self.base_url = base_url or config.conf.get("base_url", self.entity_id)
152154
self.provider_info = config.conf.get("provider_info", {})
155+
self.server_metadata = config.conf.get("server_metadata", EntityMetadata())
153156

154157
# Below so my IDE won't complain
155158
self.args = {}
@@ -212,7 +215,7 @@ def filename_from_webname(self, webname):
212215
if not webname.startswith(self.base_url):
213216
raise ValueError("Webname doesn't match base_url")
214217

215-
_name = webname[len(self.base_url) :]
218+
_name = webname[len(self.base_url):]
216219
if _name.startswith("/"):
217220
return _name[1:]
218221

@@ -415,3 +418,12 @@ def map_preferred_to_registered(self, registration_response: Optional[dict] = No
415418
)
416419

417420
return self.claims.use
421+
422+
def get_metadata_claim(self, claim, entity_type: Optional[List[str]] = ""):
423+
if entity_type:
424+
for _type in entity_type:
425+
if _type in self.server_metadata:
426+
return self.server_metadata[_type][claim]
427+
return KeyError(f"{claim} not in {entity_type} metadata")
428+
else:
429+
return self.provider_info[claim]

0 commit comments

Comments
 (0)