Skip to content

Commit f6d42a6

Browse files
feat: add middleware support for Flight calls (#196)
1 parent 56e2b13 commit f6d42a6

5 files changed

Lines changed: 120 additions & 4 deletions

File tree

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22

33
## 0.18.0 [unreleased]
44

5+
### Features
6+
7+
1. [#196](https://github.com/InfluxCommunity/influxdb3-python/pull/196): Support passing middleware functions to the Flight client.
8+
59
### Bug Fixes
610

711
1. [#194](https://github.com/InfluxCommunity/influxdb3-python/pull/194): Fix `InfluxDBClient3.write_file()` and `InfluxDBClient3.write_dataframe()` fail with batching mode.

Examples/query_with_middleware.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from pyarrow import flight
2+
3+
from config import Config
4+
from influxdb_client_3 import InfluxDBClient3, flight_client_options
5+
6+
7+
# This middleware will add an additional attribute `some-attribute` to the header
8+
class ModifyHeaderClientMiddleware(flight.ClientMiddleware):
9+
def sending_headers(self):
10+
return {
11+
"some-attribute": "some-value",
12+
}
13+
14+
def received_headers(self, headers):
15+
pass
16+
17+
18+
class ModifyHeaderClientMiddlewareFactory(flight.ClientMiddlewareFactory):
19+
def start_call(self, info):
20+
return ModifyHeaderClientMiddleware()
21+
22+
23+
config = Config()
24+
middleware = [ModifyHeaderClientMiddlewareFactory()]
25+
client = InfluxDBClient3(
26+
host=config.host,
27+
token=config.token,
28+
database=config.database,
29+
flight_client_options=flight_client_options(middleware=middleware)
30+
)
31+
32+
df = client.query(query="select * from cpu11 limit 10", mode="pandas")
33+
print(len(df))

influxdb_client_3/query/query_api.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class QueryApiOptions(object):
2020
flight_client_options (dict): base set of flight client options passed to internal pyarrow.flight.FlightClient
2121
timeout(float): timeout in seconds to wait for a response
2222
disable_grpc_compression (bool): disable gRPC compression for query responses
23+
middleware (list): list of middleware functions to be applied to Flight calls
2324
"""
2425
_DEFAULT_TIMEOUT = 300.0
2526
tls_root_certs: bytes = None
@@ -28,13 +29,15 @@ class QueryApiOptions(object):
2829
flight_client_options: dict = None
2930
timeout: float = None
3031
disable_grpc_compression: bool = False
32+
middleware: list = None
3133

3234
def __init__(self, root_certs_path: str,
3335
verify: bool,
3436
proxy: str,
3537
flight_client_options: dict,
3638
timeout: float = _DEFAULT_TIMEOUT,
37-
disable_grpc_compression: bool = False):
39+
disable_grpc_compression: bool = False,
40+
middleware: list = None):
3841
"""
3942
Initialize a set of QueryApiOptions
4043
@@ -45,6 +48,7 @@ def __init__(self, root_certs_path: str,
4548
to be passed to internal pyarrow.flight.FlightClient.
4649
:param timeout: timeout in seconds to wait for a response.
4750
:param disable_grpc_compression: disable gRPC compression for query responses.
51+
:param middleware: list of middleware functions to be applied to Flight calls.
4852
"""
4953
if root_certs_path:
5054
self.tls_root_certs = self._read_certs(root_certs_path)
@@ -53,6 +57,7 @@ def __init__(self, root_certs_path: str,
5357
self.flight_client_options = flight_client_options
5458
self.timeout = timeout
5559
self.disable_grpc_compression = disable_grpc_compression
60+
self.middleware = middleware
5661

5762
def _read_certs(self, path: str) -> bytes:
5863
with open(path, "rb") as certs_file:
@@ -81,6 +86,7 @@ class QueryApiOptionsBuilder(object):
8186
_flight_client_options: dict = None
8287
_timeout: float = None
8388
_disable_grpc_compression: bool = False
89+
_middleware: list = None
8490

8591
def root_certs(self, path: str):
8692
self._root_certs_path = path
@@ -107,6 +113,10 @@ def disable_grpc_compression(self, disable: bool):
107113
self._disable_grpc_compression = disable
108114
return self
109115

116+
def middleware(self, middleware: list):
117+
self._middleware = middleware
118+
return self
119+
110120
def build(self) -> QueryApiOptions:
111121
"""Build a QueryApiOptions object with previously set values"""
112122
return QueryApiOptions(
@@ -116,6 +126,7 @@ def build(self) -> QueryApiOptions:
116126
flight_client_options=self._flight_client_options,
117127
timeout=self._timeout,
118128
disable_grpc_compression=self._disable_grpc_compression,
129+
middleware=self._middleware
119130
)
120131

121132

@@ -181,6 +192,8 @@ def __init__(self,
181192
self._flight_client_options["generic_options"].append(
182193
("grpc.compression_enabled_algorithms_bitset", 1)
183194
)
195+
if options.middleware:
196+
self._flight_client_options["middleware"] = options.middleware
184197
if self._proxy:
185198
self._flight_client_options["generic_options"].append(("grpc.http_proxy", self._proxy))
186199
self._flight_client = FlightClient(connection_string, **self._flight_client_options)

tests/test_query.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
Ticket
1313
)
1414

15-
from influxdb_client_3 import InfluxDBClient3
15+
from influxdb_client_3 import InfluxDBClient3, flight_client_options
1616
from influxdb_client_3.query.query_api import QueryApiOptionsBuilder, QueryApi
1717
from influxdb_client_3.version import USER_AGENT
1818
from tests.util import asyncio_run
@@ -25,7 +25,8 @@
2525
HeaderCheckServerMiddlewareFactory,
2626
NoopAuthHandler,
2727
get_req_headers,
28-
set_req_headers
28+
set_req_headers, ModifyHeaderClientMiddlewareFactory,
29+
HeaderCheckServerMiddlewareFactory1
2930
)
3031

3132

@@ -175,11 +176,13 @@ def test_query_client_with_options(self):
175176
cert_chain = 'mTLS_explicit_chain'
176177
self.create_cert_file(cert_file)
177178
test_flight_client_options = {'private_key': private_key, 'cert_chain': cert_chain}
179+
middleware = [ModifyHeaderClientMiddlewareFactory()]
178180
options = QueryApiOptionsBuilder()\
179181
.proxy(proxy_name) \
180182
.root_certs(cert_file) \
181183
.tls_verify(False) \
182184
.flight_client_options(test_flight_client_options) \
185+
.middleware(middleware) \
183186
.build()
184187

185188
client = QueryApi(connection,
@@ -195,6 +198,7 @@ def test_query_client_with_options(self):
195198
assert client._flight_client_options['private_key'] == private_key
196199
assert client._flight_client_options['cert_chain'] == cert_chain
197200
assert client._proxy == proxy_name
201+
assert client._flight_client_options['middleware'] == middleware
198202
fc_opts = client._flight_client_options
199203
assert dict(fc_opts['generic_options'])['grpc.secondary_user_agent'].startswith('influxdb3-python/')
200204
assert dict(fc_opts['generic_options'])['grpc.http_proxy'] == proxy_name
@@ -311,6 +315,41 @@ def test_prepare_query(self):
311315
assert _req_headers['authorization'] == [f"Bearer {token}"]
312316
set_req_headers({})
313317

318+
def test_query_with_middleware_success(self):
319+
with HeaderCheckFlightServer(
320+
auth_handler=NoopAuthHandler(),
321+
middleware={"check": HeaderCheckServerMiddlewareFactory1()}) as server:
322+
323+
middleware = [ModifyHeaderClientMiddlewareFactory()]
324+
client = InfluxDBClient3(
325+
host=f'http://localhost:{server.port}',
326+
org='test_org',
327+
databse='test_db',
328+
token='TEST_TOKEN',
329+
flight_client_options=flight_client_options(middleware=middleware)
330+
)
331+
332+
df = client.query(query='SELECT * FROM test', mode="pandas")
333+
self.assertIsNotNone(df)
334+
335+
def test_query_with_missing_middleware(self):
336+
with HeaderCheckFlightServer(
337+
auth_handler=NoopAuthHandler(),
338+
middleware={"check": HeaderCheckServerMiddlewareFactory1()}) as server:
339+
340+
client = InfluxDBClient3(
341+
host=f'http://localhost:{server.port}',
342+
org='test_org',
343+
databse='test_db',
344+
token='TEST_TOKEN'
345+
)
346+
347+
try:
348+
client.query(query='SELECT * FROM test', mode="pandas")
349+
self.fail("Should have failed due to missing middleware")
350+
except Exception as e:
351+
assert "Invalid header value from middleware" in str(e)
352+
314353
@asyncio_run
315354
async def test_query_async_pandas(self):
316355
with ConstantFlightServer() as server:

tests/util/mocks.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
from pyarrow import (
66
array,
77
Table,
8-
concat_tables, ArrowException
8+
concat_tables, ArrowException, flight
99
)
10+
from pyarrow._flight import FlightInternalError
1011
from pyarrow.flight import (
1112
FlightServerBase,
1213
RecordBatchStream,
@@ -159,6 +160,32 @@ def number_batches(table):
159160
yield batch, buf
160161

161162

163+
class ModifyHeaderClientMiddleware(flight.ClientMiddleware):
164+
def sending_headers(self):
165+
return {
166+
"header-from-middleware": "some-value",
167+
}
168+
169+
def received_headers(self, headers):
170+
pass
171+
172+
173+
class ModifyHeaderClientMiddlewareFactory(flight.ClientMiddlewareFactory):
174+
def start_call(self, info):
175+
return ModifyHeaderClientMiddleware()
176+
177+
178+
class HeaderCheckServerMiddlewareFactory1(ServerMiddlewareFactory):
179+
"""Factory to create HeaderCheckServerMiddleware and check header values"""
180+
def start_call(self, info, headers):
181+
values = case_insensitive_header_lookup(headers, "header-from-middleware")
182+
if values is None or values[0] != 'some-value':
183+
raise FlightInternalError("Invalid header value from middleware")
184+
global req_headers
185+
req_headers = headers
186+
return HeaderCheckServerMiddleware('')
187+
188+
162189
class ErrorFlightServer(FlightServerBase):
163190
def do_get(self, context, ticket):
164191
raise ArrowException

0 commit comments

Comments
 (0)