|
| 1 | +from unittest.mock import patch |
| 2 | + |
| 3 | +from durabletask.internal.shared import (DefaultClientInterceptorImpl, |
| 4 | + get_default_host_address, |
| 5 | + get_grpc_channel) |
| 6 | + |
| 7 | +HOST_ADDRESS = 'localhost:50051' |
| 8 | +METADATA = [('key1', 'value1'), ('key2', 'value2')] |
| 9 | + |
| 10 | + |
| 11 | +def test_get_grpc_channel_insecure(): |
| 12 | + with patch('grpc.insecure_channel') as mock_channel: |
| 13 | + get_grpc_channel(HOST_ADDRESS, METADATA, False) |
| 14 | + mock_channel.assert_called_once_with(HOST_ADDRESS) |
| 15 | + |
| 16 | + |
| 17 | +def test_get_grpc_channel_secure(): |
| 18 | + with patch('grpc.secure_channel') as mock_channel, patch( |
| 19 | + 'grpc.ssl_channel_credentials') as mock_credentials: |
| 20 | + get_grpc_channel(HOST_ADDRESS, METADATA, True) |
| 21 | + mock_channel.assert_called_once_with(HOST_ADDRESS, mock_credentials.return_value) |
| 22 | + |
| 23 | + |
| 24 | +def test_get_grpc_channel_default_host_address(): |
| 25 | + with patch('grpc.insecure_channel') as mock_channel: |
| 26 | + get_grpc_channel(None, METADATA, False) |
| 27 | + mock_channel.assert_called_once_with(get_default_host_address()) |
| 28 | + |
| 29 | + |
| 30 | +def test_get_grpc_channel_with_metadata(): |
| 31 | + with patch('grpc.insecure_channel') as mock_channel, patch( |
| 32 | + 'grpc.intercept_channel') as mock_intercept_channel: |
| 33 | + get_grpc_channel(HOST_ADDRESS, METADATA, False) |
| 34 | + mock_channel.assert_called_once_with(HOST_ADDRESS) |
| 35 | + mock_intercept_channel.assert_called_once() |
| 36 | + |
| 37 | + # Capture and check the arguments passed to intercept_channel() |
| 38 | + args, kwargs = mock_intercept_channel.call_args |
| 39 | + assert args[0] == mock_channel.return_value |
| 40 | + assert isinstance(args[1], DefaultClientInterceptorImpl) |
| 41 | + assert args[1]._metadata == METADATA |
0 commit comments