Skip to content

Commit f143ea0

Browse files
cursoragentlovasoa
andcommitted
feat: Add SSRP support for MSSQL named instances
Co-authored-by: contact <contact@ophir.dev>
1 parent 6ed3daa commit f143ea0

7 files changed

Lines changed: 182 additions & 8 deletions

File tree

sqlx-core/src/mssql/connection/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ use std::sync::Arc;
1414
mod establish;
1515
mod executor;
1616
mod prepare;
17+
mod ssrp;
1718
mod stream;
1819
mod tls_prelogin_stream_wrapper;
1920

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
use crate::error::Error;
2+
use sqlx_rt::{timeout, UdpSocket};
3+
use std::collections::HashMap;
4+
use std::time::Duration;
5+
6+
const SSRP_PORT: u16 = 1434;
7+
const CLNT_UCAST_INST: u8 = 0x04;
8+
const SVR_RESP: u8 = 0x05;
9+
const SSRP_TIMEOUT: Duration = Duration::from_secs(1);
10+
11+
pub(crate) async fn resolve_instance_port(server: &str, instance: &str) -> Result<u16, Error> {
12+
let mut request = Vec::with_capacity(1 + instance.len() + 1);
13+
request.push(CLNT_UCAST_INST);
14+
request.extend_from_slice(instance.as_bytes());
15+
request.push(0);
16+
17+
let socket = UdpSocket::bind("0.0.0.0:0").await.map_err(|e| {
18+
err_protocol!("failed to bind UDP socket for SSRP: {}", e)
19+
})?;
20+
21+
socket
22+
.send_to(&request, (server, SSRP_PORT))
23+
.await
24+
.map_err(|e| {
25+
err_protocol!("failed to send SSRP request to {}:{}: {}", server, SSRP_PORT, e)
26+
})?;
27+
28+
let mut buffer = [0u8; 1024];
29+
let bytes_read = timeout(SSRP_TIMEOUT, socket.recv(&mut buffer))
30+
.await
31+
.map_err(|_| {
32+
err_protocol!(
33+
"SSRP request to {} for instance {} timed out after {:?}",
34+
server,
35+
instance,
36+
SSRP_TIMEOUT
37+
)
38+
})?
39+
.map_err(|e| {
40+
err_protocol!(
41+
"failed to receive SSRP response from {} for instance {}: {}",
42+
server,
43+
instance,
44+
e
45+
)
46+
})?;
47+
48+
if bytes_read < 3 {
49+
return Err(err_protocol!(
50+
"SSRP response too short: {} bytes",
51+
bytes_read
52+
));
53+
}
54+
55+
if buffer[0] != SVR_RESP {
56+
return Err(err_protocol!(
57+
"invalid SSRP response type: expected 0x05, got 0x{:02x}",
58+
buffer[0]
59+
));
60+
}
61+
62+
let response_size = u16::from_le_bytes([buffer[1], buffer[2]]) as usize;
63+
if response_size + 3 > bytes_read {
64+
return Err(err_protocol!(
65+
"SSRP response size mismatch: expected {} bytes, got {}",
66+
response_size + 3,
67+
bytes_read
68+
));
69+
}
70+
71+
let response_data = String::from_utf8(buffer[3..(3 + response_size)].to_vec())
72+
.map_err(|e| err_protocol!("SSRP response is not valid UTF-8: {}", e))?;
73+
74+
parse_ssrp_response(&response_data, instance)
75+
}
76+
77+
fn parse_ssrp_response(data: &str, instance_name: &str) -> Result<u16, Error> {
78+
let instances: Vec<&str> = data.split(";;").collect();
79+
80+
for instance_data in instances {
81+
if instance_data.is_empty() {
82+
continue;
83+
}
84+
85+
let tokens: Vec<&str> = instance_data.split(';').collect();
86+
let mut properties: HashMap<&str, &str> = HashMap::new();
87+
88+
let mut i = 0;
89+
while i + 1 < tokens.len() {
90+
let key = tokens[i];
91+
let value = tokens[i + 1];
92+
properties.insert(key, value);
93+
i += 2;
94+
}
95+
96+
if let Some(name) = properties.get("InstanceName") {
97+
if name.eq_ignore_ascii_case(instance_name) {
98+
if let Some(tcp_port_str) = properties.get("tcp") {
99+
return tcp_port_str.parse::<u16>().map_err(|e| {
100+
err_protocol!(
101+
"invalid TCP port '{}' in SSRP response: {}",
102+
tcp_port_str,
103+
e
104+
)
105+
});
106+
} else {
107+
return Err(err_protocol!(
108+
"instance '{}' found but no TCP port available",
109+
instance_name
110+
));
111+
}
112+
}
113+
}
114+
}
115+
116+
Err(err_protocol!(
117+
"instance '{}' not found in SSRP response",
118+
instance_name
119+
))
120+
}
121+
122+
#[cfg(test)]
123+
mod tests {
124+
use super::*;
125+
126+
#[test]
127+
fn test_parse_ssrp_response_single_instance() {
128+
let data = "ServerName;MYSERVER;InstanceName;SQLEXPRESS;IsClustered;No;Version;15.0.2000.5;tcp;1433;;";
129+
let port = parse_ssrp_response(data, "SQLEXPRESS").unwrap();
130+
assert_eq!(port, 1433);
131+
}
132+
133+
#[test]
134+
fn test_parse_ssrp_response_multiple_instances() {
135+
let data = "ServerName;SRV1;InstanceName;INST1;IsClustered;No;Version;15.0.2000.5;tcp;1433;;ServerName;SRV1;InstanceName;INST2;IsClustered;No;Version;16.0.1000.6;tcp;1434;np;\\\\SRV1\\pipe\\MSSQL$INST2\\sql\\query;;";
136+
let port = parse_ssrp_response(data, "INST2").unwrap();
137+
assert_eq!(port, 1434);
138+
}
139+
140+
#[test]
141+
fn test_parse_ssrp_response_case_insensitive() {
142+
let data = "ServerName;MYSERVER;InstanceName;SQLExpress;IsClustered;No;Version;15.0.2000.5;tcp;1433;;";
143+
let port = parse_ssrp_response(data, "sqlexpress").unwrap();
144+
assert_eq!(port, 1433);
145+
}
146+
147+
#[test]
148+
fn test_parse_ssrp_response_instance_not_found() {
149+
let data = "ServerName;MYSERVER;InstanceName;SQLEXPRESS;IsClustered;No;Version;15.0.2000.5;tcp;1433;;";
150+
let result = parse_ssrp_response(data, "NOTFOUND");
151+
assert!(result.is_err());
152+
}
153+
154+
#[test]
155+
fn test_parse_ssrp_response_no_tcp_port() {
156+
let data = "ServerName;MYSERVER;InstanceName;SQLEXPRESS;IsClustered;No;Version;15.0.2000.5;;";
157+
let result = parse_ssrp_response(data, "SQLEXPRESS");
158+
assert!(result.is_err());
159+
}
160+
}

sqlx-core/src/mssql/connection/stream.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,13 @@ pub(crate) struct MssqlStream {
5151

5252
impl MssqlStream {
5353
pub(super) async fn connect(options: &MssqlConnectOptions) -> Result<Self, Error> {
54-
let tcp_stream = TcpStream::connect((&*options.host, options.port)).await?;
54+
let port = if let Some(ref instance) = options.instance {
55+
super::ssrp::resolve_instance_port(&options.host, instance).await?
56+
} else {
57+
options.port
58+
};
59+
60+
let tcp_stream = TcpStream::connect((&*options.host, port)).await?;
5561
let wrapped_stream = TlsPreloginWrapper::new(tcp_stream);
5662
let inner = BufStream::new(MaybeTlsStream::Raw(wrapped_stream));
5763

sqlx-core/src/mssql/options/mod.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,14 @@ mod parse;
1010
///
1111
/// Connection strings should be in the form:
1212
/// ```text
13-
/// mssql://[username[:password]@]host/database[?instance=instance_name&packet_size=packet_size&client_program_version=client_program_version&client_pid=client_pid&hostname=hostname&app_name=app_name&server_name=server_name&client_interface_name=client_interface_name&language=language]
13+
/// mssql://[username[:password]@]host[:port]/database[?param1=value1&param2=value2...]
1414
/// ```
15+
///
16+
/// When connecting to a named instance, use the `instance` parameter:
17+
/// ```text
18+
/// mssql://user:pass@localhost/mydb?instance=SQLEXPRESS
19+
/// ```
20+
/// The port will be automatically discovered using the SQL Server Resolution Protocol (SSRP).
1521
#[derive(Debug, Clone)]
1622
pub struct MssqlConnectOptions {
1723
pub(crate) host: String,

sqlx-core/src/mssql/options/parse.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ impl FromStr for MssqlConnectOptions {
2323
/// - `database`: The name of the database to connect to.
2424
///
2525
/// Supported query parameters:
26-
/// - `instance`: SQL Server named instance.
26+
/// - `instance`: SQL Server named instance. When specified, the port is automatically discovered using the SQL Server Resolution Protocol (SSRP).
2727
/// - `encrypt`: Controls connection encryption:
2828
/// - `strict`: Requires encryption and validates the server certificate.
2929
/// - `mandatory` or `true` or `yes`: Requires encryption but doesn't validate the server certificate.
@@ -41,9 +41,10 @@ impl FromStr for MssqlConnectOptions {
4141
/// - `client_interface_name`: Name of the client interface, sent to the server for logging purposes.
4242
/// - `language`: Sets the language for server messages. Affects date formats and system messages.
4343
///
44-
/// Example:
44+
/// Examples:
4545
/// ```text
4646
/// mssql://user:pass@localhost:1433/mydb?encrypt=strict&app_name=MyApp&packet_size=4096
47+
/// mssql://user:pass@localhost/mydb?instance=SQLEXPRESS
4748
/// ```
4849
fn from_str(s: &str) -> Result<Self, Self::Err> {
4950
let url: Url = s.parse().map_err(Error::config)?;

sqlx-rt/src/rt_async_std.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
pub use async_std::{
22
self, fs, future::timeout, io::prelude::ReadExt as AsyncReadExt,
33
io::prelude::WriteExt as AsyncWriteExt, io::Read as AsyncRead, io::Write as AsyncWrite,
4-
net::TcpStream, sync::Mutex as AsyncMutex, task::sleep, task::spawn, task::spawn_blocking,
5-
task::yield_now,
4+
net::TcpStream, net::UdpSocket, sync::Mutex as AsyncMutex, task::sleep, task::spawn,
5+
task::spawn_blocking, task::yield_now,
66
};
77

88
#[cfg(unix)]

sqlx-rt/src/rt_tokio.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
pub use tokio::{
22
self, fs, io::AsyncRead, io::AsyncReadExt, io::AsyncWrite, io::AsyncWriteExt, io::ReadBuf,
3-
net::TcpStream, runtime::Handle, sync::Mutex as AsyncMutex, task::spawn, task::yield_now,
4-
time::sleep, time::timeout,
3+
net::TcpStream, net::UdpSocket, runtime::Handle, sync::Mutex as AsyncMutex, task::spawn,
4+
task::yield_now, time::sleep, time::timeout,
55
};
66

77
#[cfg(unix)]

0 commit comments

Comments
 (0)