Skip to content

Commit 0457ad4

Browse files
committed
Add support for multiple upstream servers
1 parent be9c75b commit 0457ad4

5 files changed

Lines changed: 183 additions & 64 deletions

File tree

example-encrypted-dns.toml

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,14 @@ listen_addrs = [
2727
]
2828

2929

30-
## Upstream DNS server and port
30+
## Upstream DNS server(s) and port(s)
31+
## The server tries each address in order and fails over on timeout/error.
3132

32-
upstream_addr = "9.9.9.9:53"
33+
# Single upstream:
34+
# upstream_addrs = ["9.9.9.9:53"]
35+
36+
# Multiple upstreams with failover:
37+
upstream_addrs = ["9.9.9.9:53", "149.112.112.112:53"]
3338

3439

3540
## File name to save the state to

src/config.rs

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,34 @@ pub struct FilteringConfig {
5959
pub ignore_unqualified_hostnames: Option<bool>,
6060
}
6161

62+
fn deserialize_upstream_addrs<'de, D>(deserializer: D) -> Result<Vec<SocketAddr>, D::Error>
63+
where
64+
D: serde::Deserializer<'de>,
65+
{
66+
use serde::Deserialize;
67+
68+
#[derive(Deserialize)]
69+
#[serde(untagged)]
70+
enum SingleOrVec {
71+
Single(SocketAddr),
72+
Vec(Vec<SocketAddr>),
73+
}
74+
75+
match SingleOrVec::deserialize(deserializer)? {
76+
SingleOrVec::Single(addr) => Ok(vec![addr]),
77+
SingleOrVec::Vec(addrs) => Ok(addrs),
78+
}
79+
}
80+
6281
#[derive(Serialize, Deserialize, Debug, Clone)]
6382
pub struct Config {
6483
pub listen_addrs: Vec<ListenAddrConfig>,
6584
pub external_addr: Option<IpAddr>,
66-
pub upstream_addr: SocketAddr,
85+
#[serde(
86+
alias = "upstream_addr",
87+
deserialize_with = "deserialize_upstream_addrs"
88+
)]
89+
pub upstream_addrs: Vec<SocketAddr>,
6790
pub state_file: PathBuf,
6891
pub udp_timeout: u32,
6992
pub tcp_timeout: u32,
@@ -91,17 +114,20 @@ pub struct Config {
91114
}
92115

93116
impl Config {
94-
pub fn from_string(toml: &str) -> Result<Config, Error> {
95-
let config: Config = match toml::from_str(toml) {
117+
pub fn from_string(toml_str: &str) -> Result<Config, Error> {
118+
let config: Config = match toml::from_str(toml_str) {
96119
Ok(config) => config,
97120
Err(e) => bail!("Parse error in the configuration file: {}", e),
98121
};
122+
if config.upstream_addrs.is_empty() {
123+
bail!("At least one upstream address must be specified");
124+
}
99125
Ok(config)
100126
}
101127

102128
pub fn from_path(path: impl AsRef<Path>) -> Result<Config, Error> {
103-
let toml = fs::read_to_string(path)?;
104-
Config::from_string(&toml)
129+
let toml_str = fs::read_to_string(path)?;
130+
Config::from_string(&toml_str)
105131
}
106132
}
107133

src/globals.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ pub struct Globals {
2727
pub provider_kp: SignKeyPair,
2828
pub listen_addrs: Vec<SocketAddr>,
2929
pub external_addr: Option<SocketAddr>,
30-
pub upstream_addr: SocketAddr,
30+
pub upstream_addrs: Vec<SocketAddr>,
3131
pub tls_upstream_addr: Option<SocketAddr>,
3232
pub udp_timeout: Duration,
3333
pub tcp_timeout: Duration,

src/main.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ async fn encrypt_and_respond_to_query(
153153
(Some(_), None) => {
154154
warn!("Shared key provided without nonce");
155155
bail!("Internal error: shared key without nonce");
156-
},
156+
}
157157
(Some(shared_key), Some(nonce)) => dnscrypt::encrypt(
158158
maybe_truncate_response(&client_ctx, packet, response, original_packet_size)?,
159159
shared_key,
@@ -309,11 +309,11 @@ async fn tcp_acceptor(globals: Arc<Globals>, tcp_listener: TcpListener) -> Resul
309309
// Rate limit repeated errors to avoid spinning
310310
let is_resource_error = matches!(
311311
e.kind(),
312-
std::io::ErrorKind::ConnectionRefused |
313-
std::io::ErrorKind::ConnectionReset |
314-
std::io::ErrorKind::ConnectionAborted |
315-
std::io::ErrorKind::AddrInUse |
316-
std::io::ErrorKind::AddrNotAvailable
312+
std::io::ErrorKind::ConnectionRefused
313+
| std::io::ErrorKind::ConnectionReset
314+
| std::io::ErrorKind::ConnectionAborted
315+
| std::io::ErrorKind::AddrInUse
316+
| std::io::ErrorKind::AddrNotAvailable
317317
);
318318

319319
if is_resource_error {
@@ -845,7 +845,7 @@ fn main() -> Result<(), Error> {
845845
provider_name,
846846
provider_kp,
847847
listen_addrs,
848-
upstream_addr: config.upstream_addr,
848+
upstream_addrs: config.upstream_addrs,
849849
tls_upstream_addr: config.tls.upstream_addr,
850850
external_addr,
851851
tcp_timeout: Duration::from_secs(u64::from(config.tcp_timeout)),

src/resolver.rs

Lines changed: 137 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
use std::cmp;
22
use std::hash::Hasher;
33
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
4+
use std::time::Duration;
45

56
use byteorder::{BigEndian, ByteOrder};
6-
use rand::{random, Rng, rng};
7+
use rand::{random, rng, Rng};
78
use siphasher::sip128::Hasher128;
89
use tokio::io::{AsyncReadExt, AsyncWriteExt};
910
use tokio::net::{TcpSocket, UdpSocket};
@@ -14,16 +15,17 @@ use crate::errors::*;
1415
use crate::globals::*;
1516
use crate::ClientCtx;
1617

17-
pub async fn resolve_udp(
18-
globals: &Globals,
19-
packet: &mut Vec<u8>,
18+
async fn resolve_udp_single(
19+
upstream_addr: SocketAddr,
20+
external_addr: Option<SocketAddr>,
21+
packet: &[u8],
2022
packet_qname: &[u8],
2123
tid: u16,
22-
has_cached_response: bool,
24+
timeout: Duration,
2325
) -> Result<Vec<u8>, Error> {
24-
let ext_socket = match globals.external_addr {
26+
let ext_socket = match external_addr {
2527
Some(x) => UdpSocket::bind(x).await?,
26-
None => match globals.upstream_addr {
28+
None => match upstream_addr {
2729
SocketAddr::V4(_) => {
2830
UdpSocket::bind(&SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)))
2931
.await?
@@ -39,49 +41,80 @@ pub async fn resolve_udp(
3941
}
4042
},
4143
};
42-
ext_socket.connect(globals.upstream_addr).await?;
44+
ext_socket.connect(upstream_addr).await?;
45+
ext_socket.send(packet).await?;
46+
let mut response = vec![0u8; DNS_MAX_PACKET_SIZE];
47+
let fut = tokio::time::timeout(timeout, ext_socket.recv_from(&mut response[..]));
48+
match fut.await {
49+
Ok(Ok((response_len, response_addr))) => {
50+
response.truncate(response_len);
51+
if response_addr == upstream_addr
52+
&& response_len >= DNS_HEADER_SIZE
53+
&& dns::tid(&response) == tid
54+
&& packet_qname.eq_ignore_ascii_case(dns::qname(&response)?.as_slice())
55+
{
56+
return Ok(response);
57+
}
58+
bail!("Invalid response from upstream");
59+
}
60+
Ok(Err(e)) => bail!("UDP receive error: {}", e),
61+
Err(_) => bail!("UDP timeout"),
62+
}
63+
}
64+
65+
pub async fn resolve_udp(
66+
globals: &Globals,
67+
packet: &mut Vec<u8>,
68+
packet_qname: &[u8],
69+
tid: u16,
70+
has_cached_response: bool,
71+
) -> Result<Vec<u8>, Error> {
4372
dns::set_edns_max_payload_size(packet, DNS_MAX_PACKET_SIZE as u16)?;
44-
let mut response;
4573
let timeout = if has_cached_response {
4674
globals.udp_timeout / 2
4775
} else {
4876
globals.udp_timeout
4977
};
50-
loop {
51-
ext_socket.send(packet).await?;
52-
response = vec![0u8; DNS_MAX_PACKET_SIZE];
53-
dns::set_rcode_servfail(&mut response);
54-
let fut = tokio::time::timeout(timeout, ext_socket.recv_from(&mut response[..]));
55-
match fut.await {
56-
Ok(Ok((response_len, response_addr))) => {
57-
response.truncate(response_len);
58-
if response_addr == globals.upstream_addr
59-
&& response_len >= DNS_HEADER_SIZE
60-
&& dns::tid(&response) == tid
61-
&& packet_qname.eq_ignore_ascii_case(dns::qname(&response)?.as_slice())
62-
{
63-
break;
64-
}
65-
}
66-
_ => {
67-
if has_cached_response {
68-
trace!("Timeout, but cached response is present");
69-
break;
70-
}
71-
trace!("Timeout, no cached response");
78+
79+
let mut last_error = None;
80+
for upstream_addr in &globals.upstream_addrs {
81+
match resolve_udp_single(
82+
*upstream_addr,
83+
globals.external_addr,
84+
packet,
85+
packet_qname,
86+
tid,
87+
timeout,
88+
)
89+
.await
90+
{
91+
Ok(response) => return Ok(response),
92+
Err(e) => {
93+
trace!("Upstream {} failed: {}", upstream_addr, e);
94+
last_error = Some(e);
7295
}
7396
}
7497
}
75-
Ok(response)
98+
99+
if has_cached_response {
100+
trace!("All upstreams failed, but cached response is present");
101+
let mut response = vec![0u8; DNS_MAX_PACKET_SIZE];
102+
dns::set_rcode_servfail(&mut response);
103+
return Ok(response);
104+
}
105+
106+
Err(last_error.unwrap_or_else(|| anyhow!("No upstream servers configured")))
76107
}
77108

78-
pub async fn resolve_tcp(
79-
globals: &Globals,
80-
packet: &mut [u8],
109+
async fn resolve_tcp_single(
110+
upstream_addr: SocketAddr,
111+
external_addr: Option<SocketAddr>,
112+
packet: &[u8],
81113
packet_qname: &[u8],
82114
tid: u16,
115+
timeout: Duration,
83116
) -> Result<Vec<u8>, Error> {
84-
let socket = match globals.external_addr {
117+
let socket = match external_addr {
85118
Some(x @ SocketAddr::V4(_)) => {
86119
let socket = TcpSocket::new_v4()?;
87120
socket.set_reuseaddr(true).ok();
@@ -94,26 +127,53 @@ pub async fn resolve_tcp(
94127
socket.bind(x)?;
95128
socket
96129
}
97-
None => match globals.upstream_addr {
130+
None => match upstream_addr {
98131
SocketAddr::V4(_) => TcpSocket::new_v4()?,
99132
SocketAddr::V6(_) => TcpSocket::new_v6()?,
100133
},
101134
};
102-
let mut ext_socket = socket.connect(globals.upstream_addr).await?;
135+
136+
let connect_fut = tokio::time::timeout(timeout, socket.connect(upstream_addr));
137+
let mut ext_socket = match connect_fut.await {
138+
Ok(Ok(s)) => s,
139+
Ok(Err(e)) => bail!("TCP connect error: {}", e),
140+
Err(_) => bail!("TCP connect timeout"),
141+
};
142+
103143
ext_socket.set_nodelay(true)?;
104144
let mut binlen = [0u8, 0];
105145
BigEndian::write_u16(&mut binlen[..], packet.len() as u16);
106-
ext_socket.write_all(&binlen).await?;
107-
ext_socket.write_all(packet).await?;
108-
ext_socket.flush().await?;
109-
ext_socket.read_exact(&mut binlen).await?;
110-
let response_len = BigEndian::read_u16(&binlen) as usize;
111-
ensure!(
112-
(DNS_HEADER_SIZE..=DNS_MAX_PACKET_SIZE).contains(&response_len),
113-
"Unexpected response size"
114-
);
115-
let mut response = vec![0u8; response_len];
116-
ext_socket.read_exact(&mut response).await?;
146+
147+
let write_fut = async {
148+
ext_socket.write_all(&binlen).await?;
149+
ext_socket.write_all(packet).await?;
150+
ext_socket.flush().await?;
151+
Ok::<_, std::io::Error>(())
152+
};
153+
tokio::time::timeout(timeout, write_fut)
154+
.await
155+
.map_err(|_| anyhow!("TCP write timeout"))?
156+
.map_err(|e| anyhow!("TCP write error: {}", e))?;
157+
158+
let read_fut = async {
159+
ext_socket.read_exact(&mut binlen).await?;
160+
let response_len = BigEndian::read_u16(&binlen) as usize;
161+
if !(DNS_HEADER_SIZE..=DNS_MAX_PACKET_SIZE).contains(&response_len) {
162+
return Err(std::io::Error::new(
163+
std::io::ErrorKind::InvalidData,
164+
"Unexpected response size",
165+
));
166+
}
167+
let mut response = vec![0u8; response_len];
168+
ext_socket.read_exact(&mut response).await?;
169+
Ok::<_, std::io::Error>(response)
170+
};
171+
172+
let response = tokio::time::timeout(timeout, read_fut)
173+
.await
174+
.map_err(|_| anyhow!("TCP read timeout"))?
175+
.map_err(|e| anyhow!("TCP read error: {}", e))?;
176+
117177
ensure!(dns::tid(&response) == tid, "Unexpected transaction ID");
118178
ensure!(
119179
packet_qname.eq_ignore_ascii_case(dns::qname(&response)?.as_slice()),
@@ -122,6 +182,34 @@ pub async fn resolve_tcp(
122182
Ok(response)
123183
}
124184

185+
pub async fn resolve_tcp(
186+
globals: &Globals,
187+
packet: &mut [u8],
188+
packet_qname: &[u8],
189+
tid: u16,
190+
) -> Result<Vec<u8>, Error> {
191+
let mut last_error = None;
192+
for upstream_addr in &globals.upstream_addrs {
193+
match resolve_tcp_single(
194+
*upstream_addr,
195+
globals.external_addr,
196+
packet,
197+
packet_qname,
198+
tid,
199+
globals.tcp_timeout,
200+
)
201+
.await
202+
{
203+
Ok(response) => return Ok(response),
204+
Err(e) => {
205+
trace!("Upstream {} TCP failed: {}", upstream_addr, e);
206+
last_error = Some(e);
207+
}
208+
}
209+
}
210+
Err(last_error.unwrap_or_else(|| anyhow!("No upstream servers configured")))
211+
}
212+
125213
pub async fn resolve(
126214
globals: &Globals,
127215
packet: &mut Vec<u8>,

0 commit comments

Comments
 (0)