1use std::cmp::Ordering;
76use std::fmt;
77use std::time::{Duration, SystemTime};
78
79use futures::TryFutureExt;
80use mz_ore::cast::CastFrom;
81use mz_ore::netio::{Listener, Stream};
82use mz_ore::retry::Retry;
83use regex::Regex;
84use tokio::io::{AsyncReadExt, AsyncWriteExt};
85use tracing::{info, warn};
86
87#[derive(Debug)]
89pub(crate) enum CreateSocketsError {
90 Bind {
91 address: String,
92 error: std::io::Error,
93 },
94 EpochMismatch {
95 peer_index: usize,
96 peer_epoch: Epoch,
97 my_epoch: Epoch,
98 },
99 Reconnect {
100 peer_index: usize,
101 },
102}
103
104impl CreateSocketsError {
105 pub fn is_fatal(&self) -> bool {
107 matches!(self, Self::Bind { .. })
108 }
109}
110
111impl fmt::Display for CreateSocketsError {
112 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
113 match self {
114 Self::Bind { address, error } => write!(f, "failed to bind at {address}: {error}"),
115 Self::EpochMismatch {
116 peer_index,
117 peer_epoch,
118 my_epoch,
119 } => write!(
120 f,
121 "peer {peer_index} has greater epoch: {peer_epoch} > {my_epoch}"
122 ),
123 Self::Reconnect { peer_index } => {
124 write!(f, "observed second instance of peer {peer_index}")
125 }
126 }
127 }
128}
129
130impl std::error::Error for CreateSocketsError {}
131
132#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
140pub(crate) struct Epoch {
141 time: u64,
142 nonce: u64,
143}
144
145impl Epoch {
146 fn mint() -> Self {
147 let time = SystemTime::UNIX_EPOCH
148 .elapsed()
149 .expect("current time is after 1970")
150 .as_millis()
151 .try_into()
152 .expect("fits");
153 let nonce = rand::random();
154 Self { time, nonce }
155 }
156
157 async fn read(s: &mut Stream) -> std::io::Result<Self> {
158 let time = s.read_u64().await?;
159 let nonce = s.read_u64().await?;
160 Ok(Self { time, nonce })
161 }
162
163 async fn write(&self, s: &mut Stream) -> std::io::Result<()> {
164 s.write_u64(self.time).await?;
165 s.write_u64(self.nonce).await?;
166 Ok(())
167 }
168}
169
170impl fmt::Display for Epoch {
171 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
172 write!(f, "({}, {})", self.time, self.nonce)
173 }
174}
175
176pub(crate) async fn create_sockets(
181 my_index: usize,
182 addresses: &[String],
183) -> Result<Vec<Option<Stream>>, CreateSocketsError> {
184 let my_address = &addresses[my_index];
185
186 let port_re = Regex::new(r"(?<proto>\w+:)?(?<host>.*):(?<port>\d{1,5})$").unwrap();
189 let listen_address = match port_re.captures(my_address) {
190 Some(cap) => match cap.name("proto") {
191 Some(proto) => format!("{}0.0.0.0:{}", proto.as_str(), &cap["port"]),
192 None => format!("0.0.0.0:{}", &cap["port"]),
193 },
194 None => my_address.to_string(),
195 };
196
197 let listener = Retry::default()
198 .initial_backoff(Duration::from_secs(1))
199 .clamp_backoff(Duration::from_secs(1))
200 .max_tries(10)
201 .retry_async(|_| {
202 Listener::bind(&listen_address)
203 .inspect_err(|error| warn!(%listen_address, "failed to listen: {error}"))
204 })
205 .await
206 .map_err(|error| CreateSocketsError::Bind {
207 address: listen_address,
208 error,
209 })?;
210
211 let (my_epoch, sockets_lower) = match my_index {
212 0 => {
213 let epoch = Epoch::mint();
214 info!(my_index, "minted epoch: {epoch}");
215 (epoch, Vec::new())
216 }
217 _ => connect_lower(my_index, addresses).await?,
218 };
219
220 let n_peers = addresses.len();
221 let sockets_higher = accept_higher(my_index, my_epoch, n_peers, &listener).await?;
222
223 let connections_lower = sockets_lower.into_iter().map(Some);
224 let connections_higher = sockets_higher.into_iter().map(Some);
225 let connections = connections_lower
226 .chain([None])
227 .chain(connections_higher)
228 .collect();
229
230 Ok(connections)
231}
232
233async fn connect_lower(
238 my_index: usize,
239 addresses: &[String],
240) -> Result<(Epoch, Vec<Stream>), CreateSocketsError> {
241 assert!(my_index > 0);
242 assert!(my_index <= addresses.len());
243
244 async fn handshake(
245 my_index: usize,
246 my_epoch: Option<Epoch>,
247 address: &str,
248 ) -> anyhow::Result<(Epoch, Stream)> {
249 let mut s = Stream::connect(address).await?;
250 if let Stream::Tcp(tcp) = &s {
251 tcp.set_nodelay(true)?;
252 }
253
254 s.write_u64(u64::cast_from(my_index)).await?;
255 let peer_epoch = Epoch::read(&mut s).await?;
256 let my_epoch = my_epoch.unwrap_or(peer_epoch);
257 my_epoch.write(&mut s).await?;
258
259 Ok((peer_epoch, s))
260 }
261
262 let mut my_epoch = None;
263 let mut sockets = Vec::new();
264
265 while sockets.len() < my_index {
266 let index = sockets.len();
267 let address = &addresses[index];
268
269 info!(my_index, "connecting to peer {index} at address: {address}");
270
271 let (peer_epoch, sock) = Retry::default()
272 .initial_backoff(Duration::from_secs(1))
273 .clamp_backoff(Duration::from_secs(1))
274 .retry_async(|_| {
275 handshake(my_index, my_epoch, address).inspect_err(|error| {
276 info!(my_index, "error connecting to peer {index}: {error}")
277 })
278 })
279 .await
280 .expect("retries forever");
281
282 if let Some(epoch) = my_epoch {
283 match peer_epoch.cmp(&epoch) {
284 Ordering::Less => {
285 info!(
286 my_index,
287 "refusing connection to peer {index} with smaller epoch: \
288 {peer_epoch} < {epoch}",
289 );
290 continue;
291 }
292 Ordering::Greater => {
293 return Err(CreateSocketsError::EpochMismatch {
294 peer_index: index,
295 peer_epoch,
296 my_epoch: epoch,
297 });
298 }
299 Ordering::Equal => info!(my_index, "connected to peer {index}"),
300 }
301 } else {
302 info!(my_index, "received epoch from peer {index}: {peer_epoch}");
303 my_epoch = Some(peer_epoch);
304 }
305
306 sockets.push(sock);
307 }
308
309 let my_epoch = my_epoch.expect("must exist");
310 Ok((my_epoch, sockets))
311}
312
313async fn accept_higher(
318 my_index: usize,
319 my_epoch: Epoch,
320 n_peers: usize,
321 listener: &Listener,
322) -> Result<Vec<Stream>, CreateSocketsError> {
323 assert!(my_index < n_peers);
324
325 async fn accept(listener: &Listener) -> anyhow::Result<(usize, Stream)> {
326 let (mut s, _) = listener.accept().await?;
327 if let Stream::Tcp(tcp) = &s {
328 tcp.set_nodelay(true)?;
329 }
330
331 let peer_index = s.read_u64().await?;
332 let peer_index = usize::cast_from(peer_index);
333 Ok((peer_index, s))
334 }
335
336 async fn exchange_epochs(my_epoch: Epoch, s: &mut Stream) -> anyhow::Result<Epoch> {
337 my_epoch.write(s).await?;
338 let peer_epoch = Epoch::read(s).await?;
339 Ok(peer_epoch)
340 }
341
342 let offset = my_index + 1;
343 let mut sockets: Vec<_> = (offset..n_peers).map(|_| None).collect();
344
345 while sockets.iter().any(|s| s.is_none()) {
346 info!(my_index, "accepting connection from peer");
347
348 let (index, mut sock) = match accept(listener).await {
349 Ok(result) => result,
350 Err(error) => {
351 info!(my_index, "error accepting connection: {error}");
352 continue;
353 }
354 };
355
356 if sockets[index - offset].is_some() {
357 return Err(CreateSocketsError::Reconnect { peer_index: index });
358 }
359
360 let peer_epoch = match exchange_epochs(my_epoch, &mut sock).await {
361 Ok(result) => result,
362 Err(error) => {
363 info!(my_index, "error exchanging epochs: {error}");
364 continue;
365 }
366 };
367
368 match peer_epoch.cmp(&my_epoch) {
369 Ordering::Less => {
370 info!(
371 my_index,
372 "refusing connection from peer {index} with smaller epoch: \
373 {peer_epoch} < {my_epoch}",
374 );
375 continue;
376 }
377 Ordering::Greater => {
378 return Err(CreateSocketsError::EpochMismatch {
379 peer_index: index,
380 peer_epoch,
381 my_epoch,
382 });
383 }
384 Ordering::Equal => info!(my_index, "connected to peer {index}"),
385 }
386
387 sockets[index - offset] = Some(sock);
388 }
389
390 Ok(sockets.into_iter().map(|s| s.unwrap()).collect())
391}
392
393#[cfg(test)]
394mod turmoil_tests {
395 use rand::rngs::SmallRng;
396 use rand::{Rng, SeedableRng};
397 use tokio::sync::{mpsc, watch};
398 use tokio::time::timeout;
399
400 use super::*;
401
402 #[test] #[cfg_attr(miri, ignore)] fn test_create_sockets() {
411 const NUM_PROCESSES: usize = 10;
412 const NUM_CRASHES: usize = 3;
413
414 configure_tracing_for_turmoil();
415
416 let seed = std::env::var("SEED")
417 .ok()
418 .and_then(|x| x.parse().ok())
419 .unwrap_or_else(rand::random);
420
421 info!("initializing rng with seed {seed}");
422 let mut rng = SmallRng::seed_from_u64(seed);
423
424 let mut sim = turmoil::Builder::new()
425 .enable_random_order()
426 .build_with_rng(Box::new(rng.clone()));
427
428 let processes: Vec<_> = (0..NUM_PROCESSES).map(|i| format!("process-{i}")).collect();
429 let addresses: Vec<_> = processes
430 .iter()
431 .map(|n| format!("turmoil:{n}:7777"))
432 .collect();
433
434 let (ready_tx, mut ready_rx) = mpsc::unbounded_channel();
436
437 let (stable_tx, stable_rx) = watch::channel(false);
441
442 for (index, name) in processes.iter().enumerate() {
443 let addresses = addresses.clone();
444 let ready_tx = ready_tx.clone();
445 let stable_rx = stable_rx.clone();
446 sim.host(&name[..], move || {
447 let addresses = addresses.clone();
448 let ready_tx = ready_tx.clone();
449 let mut stable_rx = stable_rx.clone();
450 async move {
451 'protocol: loop {
452 let mut sockets = match create_sockets(index, &addresses).await {
453 Ok(sockets) => sockets,
454 Err(error) if error.is_fatal() => Err(error)?,
455 Err(error) => {
456 info!("creating sockets failed: {error}; retrying protocol");
457 continue 'protocol;
458 }
459 };
460
461 let _ = stable_rx.wait_for(|stable| *stable).await;
472
473 info!("sockets created; checking connections");
474 for sock in sockets.iter_mut().filter_map(|s| s.as_mut()) {
475 if let Err(error) = sock.write_u8(111).await {
476 info!("error pinging socket: {error}; retrying protocol");
477 continue 'protocol;
478 }
479 }
480 for sock in sockets.iter_mut().filter_map(|s| s.as_mut()) {
481 info!("waiting for ping from {sock:?}");
482 match timeout(Duration::from_secs(1), sock.read_u8()).await {
483 Ok(Ok(ping)) => assert_eq!(ping, 111),
484 Ok(Err(error)) => {
485 info!("error waiting for ping: {error}; retrying protocol");
486 continue 'protocol;
487 }
488 Err(_) => {
489 info!("timed out waiting for ping; retrying protocol");
490 continue 'protocol;
491 }
492 }
493 }
494
495 let _ = ready_tx.send(index);
496
497 std::mem::forget(sockets);
498 return Ok(());
499 }
500 }
501 });
502 }
503
504 for _ in 0..NUM_CRASHES {
506 let steps = rng.gen_range(1..100);
507 for _ in 0..steps {
508 sim.step().unwrap();
509 }
510
511 let i = rng.gen_range(0..NUM_PROCESSES);
512 info!("bouncing process {i}");
513 sim.bounce(format!("process-{i}"));
514 }
515
516 stable_tx.send(true).unwrap();
517
518 let mut num_ready = 0;
520 loop {
521 while let Ok(index) = ready_rx.try_recv() {
522 info!("process {index} is ready");
523 num_ready += 1;
524 }
525 if num_ready == NUM_PROCESSES {
526 break;
527 }
528
529 sim.step().unwrap();
530 if sim.elapsed() > Duration::from_secs(120) {
531 panic!("simulation not finished after 120s");
532 }
533 }
534 }
535
536 #[test] #[ignore = "runs forever"]
539 fn fuzz_create_sockets() {
540 loop {
541 test_create_sockets();
542 }
543 }
544
545 fn configure_tracing_for_turmoil() {
549 use std::sync::Once;
550 use tracing::level_filters::LevelFilter;
551 use tracing_subscriber::fmt::time::FormatTime;
552
553 #[derive(Clone)]
554 struct SimElapsedTime;
555
556 impl FormatTime for SimElapsedTime {
557 fn format_time(
558 &self,
559 w: &mut tracing_subscriber::fmt::format::Writer<'_>,
560 ) -> std::fmt::Result {
561 tracing_subscriber::fmt::time().format_time(w)?;
562 if let Some(sim_elapsed) = turmoil::sim_elapsed() {
563 write!(w, " [{:?}]", sim_elapsed)?;
564 }
565 Ok(())
566 }
567 }
568
569 static INIT_TRACING: Once = Once::new();
570 INIT_TRACING.call_once(|| {
571 let env_filter = tracing_subscriber::EnvFilter::builder()
572 .with_default_directive(LevelFilter::INFO.into())
573 .from_env_lossy();
574 let subscriber = tracing_subscriber::fmt()
575 .with_test_writer()
576 .with_env_filter(env_filter)
577 .with_timer(SimElapsedTime)
578 .finish();
579
580 tracing::subscriber::set_global_default(subscriber).unwrap();
581 });
582 }
583}