Skip to content

Commit 4046191

Browse files
DeepanshuAcgillum
andauthored
Add gRPC metadata option (#16)
Signed-off-by: Deepanshu Agarwal <deepanshu.agarwal1984@gmail.com> Co-authored-by: Chris Gillum <cgillum@gmail.com>
1 parent c2bca71 commit 4046191

5 files changed

Lines changed: 100 additions & 10 deletions

File tree

CHANGELOG.md

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,21 @@
1-
## v0.1.0
1+
## Unreleased
2+
3+
### New
4+
5+
- Add gRPC metadata option ([#16](https://github.com/microsoft/durabletask-python/pull/16)) - contributed by [@DeepanshuA](https://github.com/DeepanshuA)
6+
7+
### Changes
8+
9+
- Removed Python 3.7 support due to EOL ([#14](https://github.com/microsoft/durabletask-python/pull/14)) - contributed by [@berndverst](https://github.com/berndverst)
10+
11+
## v0.1.0b
12+
13+
### New
14+
15+
- Continue-as-new ([#9](https://github.com/microsoft/durabletask-python/pull/9))
16+
- Support for Python 3.7+ ([#10](https://github.com/microsoft/durabletask-python/pull/10)) - contributed by [@DeepanshuA](https://github.com/DeepanshuA)
17+
18+
## v0.1.0a
219

320
Initial release, which includes the following features:
421

durabletask/client.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from dataclasses import dataclass
77
from datetime import datetime
88
from enum import Enum
9-
from typing import Any, TypeVar, Union
9+
from typing import Any, List, Tuple, TypeVar, Union
1010

1111
import grpc
1212
from google.protobuf import wrappers_pb2
@@ -93,9 +93,10 @@ class TaskHubGrpcClient:
9393

9494
def __init__(self, *,
9595
host_address: Union[str, None] = None,
96-
log_handler=None,
96+
metadata: Union[List[Tuple[str, str]], None] = None,
97+
log_handler = None,
9798
log_formatter: Union[logging.Formatter, None] = None):
98-
channel = shared.get_grpc_channel(host_address)
99+
channel = shared.get_grpc_channel(host_address, metadata)
99100
self._stub = stubs.TaskHubSidecarServiceStub(channel)
100101
self._logger = shared.get_logger("client", log_handler, log_formatter)
101102

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
from collections import namedtuple
5+
from typing import List, Tuple
6+
7+
import grpc
8+
9+
10+
class _ClientCallDetails(
11+
namedtuple(
12+
'_ClientCallDetails',
13+
['method', 'timeout', 'metadata', 'credentials', 'wait_for_ready', 'compression']),
14+
grpc.ClientCallDetails):
15+
"""This is an implementation of the ClientCallDetails interface needed for interceptors.
16+
This class takes six named values and inherits the ClientCallDetails from grpc package.
17+
This class encloses the values that describe a RPC to be invoked.
18+
"""
19+
pass
20+
21+
22+
class DefaultClientInterceptorImpl (
23+
grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamClientInterceptor,
24+
grpc.StreamUnaryClientInterceptor, grpc.StreamStreamClientInterceptor):
25+
"""The class implements a UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor,
26+
StreamUnaryClientInterceptor and StreamStreamClientInterceptor from grpc to add an
27+
interceptor to add additional headers to all calls as needed."""
28+
29+
def __init__(self, metadata: List[Tuple[str, str]]):
30+
super().__init__()
31+
self._metadata = metadata
32+
33+
def _intercept_call(
34+
self, client_call_details: _ClientCallDetails) -> grpc.ClientCallDetails:
35+
"""Internal intercept_call implementation which adds metadata to grpc metadata in the RPC
36+
call details."""
37+
if self._metadata is None:
38+
return client_call_details
39+
40+
if client_call_details.metadata is not None:
41+
metadata = list(client_call_details.metadata)
42+
else:
43+
metadata = []
44+
45+
metadata.extend(self._metadata)
46+
client_call_details = _ClientCallDetails(
47+
client_call_details.method, client_call_details.timeout, metadata,
48+
client_call_details.credentials, client_call_details.wait_for_ready, client_call_details.compression)
49+
50+
return client_call_details
51+
52+
def intercept_unary_unary(self, continuation, client_call_details, request):
53+
new_client_call_details = self._intercept_call(client_call_details)
54+
return continuation(new_client_call_details, request)
55+
56+
def intercept_unary_stream(self, continuation, client_call_details, request):
57+
new_client_call_details = self._intercept_call(client_call_details)
58+
return continuation(new_client_call_details, request)
59+
60+
def intercept_stream_unary(self, continuation, client_call_details, request):
61+
new_client_call_details = self._intercept_call(client_call_details)
62+
return continuation(new_client_call_details, request)
63+
64+
def intercept_stream_stream(self, continuation, client_call_details, request):
65+
new_client_call_details = self._intercept_call(client_call_details)
66+
return continuation(new_client_call_details, request)

durabletask/internal/shared.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55
import json
66
import logging
77
from types import SimpleNamespace
8-
from typing import Any, Dict, Union
8+
from typing import Any, Dict, List, Tuple, Union
99

1010
import grpc
1111

12+
from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl
13+
1214
# Field name used to indicate that an object was automatically serialized
1315
# and should be deserialized as a SimpleNamespace
1416
AUTO_SERIALIZED = "__durabletask_autoobject__"
@@ -18,13 +20,15 @@ def get_default_host_address() -> str:
1820
return "localhost:4001"
1921

2022

21-
def get_grpc_channel(host_address: Union[str, None]) -> grpc.Channel:
23+
def get_grpc_channel(host_address: Union[str, None], metadata: Union[List[Tuple[str, str]], None]) -> grpc.Channel:
2224
if host_address is None:
2325
host_address = get_default_host_address()
2426
channel = grpc.insecure_channel(host_address)
27+
if metadata is not None and len(metadata) > 0:
28+
interceptors = [DefaultClientInterceptorImpl(metadata)]
29+
channel = grpc.intercept_channel(channel, *interceptors)
2530
return channel
2631

27-
2832
def get_logger(
2933
name_suffix: str,
3034
log_handler: Union[logging.Handler, None] = None,

durabletask/worker.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from datetime import datetime, timedelta
77
from threading import Event, Thread
88
from types import GeneratorType
9-
from typing import Any, Dict, Generator, List, Sequence, TypeVar, Union
9+
from typing import Any, Dict, Generator, List, Sequence, Tuple, TypeVar, Union
1010

1111
import grpc
1212
from google.protobuf import empty_pb2
@@ -85,10 +85,12 @@ class TaskHubGrpcWorker:
8585

8686
def __init__(self, *,
8787
host_address: Union[str, None] = None,
88-
log_handler=None,
88+
metadata: Union[List[Tuple[str, str]], None] = None,
89+
log_handler = None,
8990
log_formatter: Union[logging.Formatter, None] = None):
9091
self._registry = _Registry()
9192
self._host_address = host_address if host_address else shared.get_default_host_address()
93+
self._metadata = metadata
9294
self._logger = shared.get_logger("worker", log_handler, log_formatter)
9395
self._shutdown = Event()
9496
self._response_stream = None
@@ -114,7 +116,7 @@ def add_activity(self, fn: task.Activity) -> str:
114116

115117
def start(self):
116118
"""Starts the worker on a background thread and begins listening for work items."""
117-
channel = shared.get_grpc_channel(self._host_address)
119+
channel = shared.get_grpc_channel(self._host_address, self._metadata)
118120
stub = stubs.TaskHubSidecarServiceStub(channel)
119121

120122
if self._is_running:

0 commit comments

Comments
 (0)