-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathflight_client.rs
More file actions
182 lines (149 loc) · 5.19 KB
/
flight_client.rs
File metadata and controls
182 lines (149 loc) · 5.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
use std::{sync::Arc, time::Duration};
use arrow_array::{ArrayRef, Datum, RecordBatch, StringArray};
use arrow_cast::{cast_with_options, CastOptions};
use arrow_flight::{sql::client::FlightSqlServiceClient, FlightInfo};
use arrow_schema::{ArrowError, Schema};
use futures::TryStreamExt;
use napi_derive::napi;
use snafu::ResultExt;
use tonic::transport::{Channel, ClientTlsConfig, Endpoint};
use tracing_log::log::{debug, info};
use crate::error::{ArrowSnafu, FlightSnafu, Result};
/// A ':' separated key value pair
#[derive(Debug, Clone)]
#[napi(object)]
pub struct KeyValue {
pub key: String,
pub value: String,
}
#[derive(Debug)]
#[napi(object)]
pub struct ClientOptions {
/// Additional headers.
///
/// Values should be key value pairs separated by ':'
pub headers: Vec<KeyValue>,
/// Username
pub username: Option<String>,
/// Password
pub password: Option<String>,
/// Auth token.
pub token: Option<String>,
/// Use TLS.
pub tls: bool,
/// Server host.
pub host: String,
/// Server port.
pub port: Option<u16>,
/// Connection timeout in seconds
pub connect_timeout: Option<u32>,
/// Request timeout in seconds
pub timeout: Option<u32>,
/// Keep alive interval in seconds
pub keep_alive_interval: Option<u32>,
/// Keep alive timeout in seconds
pub keep_alive_timeout: Option<u32>,
}
pub(crate) async fn execute_flight(
client: &mut FlightSqlServiceClient<Channel>,
info: FlightInfo,
) -> Result<Vec<RecordBatch>> {
let schema = Arc::new(Schema::try_from(info.clone()).context(ArrowSnafu {
message: "creating schema from flight info",
})?);
let mut batches = Vec::with_capacity(info.endpoint.len() + 1);
batches.push(RecordBatch::new_empty(schema));
debug!("decoded schema");
for endpoint in info.endpoint {
let Some(ticket) = &endpoint.ticket else {
panic!("did not get ticket");
};
let flight_data = client.do_get(ticket.clone()).await.context(ArrowSnafu {
message: "do_get_request",
})?;
let mut flight_data: Vec<_> = flight_data
.try_collect()
.await
.context(FlightSnafu {
message: "collect data stream",
})
.expect("collect data stream");
batches.append(&mut flight_data);
}
debug!("received data");
Ok(batches)
}
#[allow(unused)]
fn construct_record_batch_from_params(
params: &[(String, String)],
parameter_schema: &Schema,
) -> Result<RecordBatch, ArrowError> {
let mut items = Vec::<(&String, ArrayRef)>::new();
for (name, value) in params {
let field = parameter_schema.field_with_name(name)?;
let value_as_array = StringArray::new_scalar(value);
let casted = cast_with_options(
value_as_array.get().0,
field.data_type(),
&CastOptions::default(),
)?;
items.push((name, casted))
}
RecordBatch::try_from_iter(items)
}
#[allow(unused)]
fn setup_logging() {
tracing_log::LogTracer::init().expect("tracing log init");
tracing_subscriber::fmt::init();
}
pub(crate) async fn setup_client(
args: ClientOptions,
) -> Result<FlightSqlServiceClient<Channel>, ArrowError> {
let port = args.port.unwrap_or(if args.tls { 443 } else { 80 });
let protocol = if args.tls { "https" } else { "http" };
let mut endpoint = Endpoint::new(format!("{}://{}:{}", protocol, args.host, port))
.map_err(|err| ArrowError::ExternalError(Box::new(err)))?
.connect_timeout(Duration::from_secs(args.connect_timeout.unwrap_or(20) as u64))
.timeout(Duration::from_secs(args.timeout.unwrap_or(20) as u64))
.tcp_nodelay(true) // Disable Nagle's Algorithm since we don't want packets to wait
.tcp_keepalive(Option::Some(Duration::from_secs(3600)))
.http2_keep_alive_interval(Duration::from_secs(args.keep_alive_interval.unwrap_or(300) as u64))
.keep_alive_timeout(Duration::from_secs(args.keep_alive_timeout.unwrap_or(20) as u64))
.keep_alive_while_idle(true);
if args.tls {
let tls_config = ClientTlsConfig::new();
endpoint = endpoint
.tls_config(tls_config)
.map_err(|err| ArrowError::ExternalError(Box::new(err)))?;
}
let channel = endpoint
.connect()
.await
.map_err(|err| ArrowError::ExternalError(Box::new(err)))?;
let mut client = FlightSqlServiceClient::new(channel);
info!("connected");
for kv in args.headers {
client.set_header(kv.key, kv.value);
}
if let Some(token) = args.token {
client.set_token(token);
info!("token set");
}
match (args.username, args.password) {
(None, None) => {}
(Some(username), Some(password)) => {
client
.handshake(&username, &password)
.await
.expect("handshake");
info!("performed handshake");
}
(Some(_), None) => {
panic!("when username is set, you also need to set a password")
}
(None, Some(_)) => {
panic!("when password is set, you also need to set a username")
}
}
Ok(client)
}