1use std::any::Any;
76use std::cmp::Ordering;
77use std::fmt;
78use std::sync::Arc;
79use std::time::{Duration, SystemTime};
80
81use anyhow::{Context, bail};
82use futures::TryFutureExt;
83use mz_ore::cast::CastFrom;
84use mz_ore::netio::{Listener, Stream};
85use mz_ore::retry::Retry;
86use regex::Regex;
87use timely::communication::allocator::zero_copy::allocator::TcpBuilder;
88use timely::communication::allocator::zero_copy::bytes_slab::BytesRefill;
89use timely::communication::allocator::zero_copy::initialize::initialize_networking_from_sockets;
90use timely::communication::allocator::{GenericBuilder, PeerBuilder};
91use tokio::io::{AsyncReadExt, AsyncWriteExt};
92use tracing::{info, warn};
93
94pub async fn initialize_networking<P>(
96 workers: usize,
97 process: usize,
98 addresses: Vec<String>,
99 refill: BytesRefill,
100 builder_fn: impl Fn(TcpBuilder<P::Peer>) -> GenericBuilder,
101) -> Result<(Vec<GenericBuilder>, Box<dyn Any + Send>), anyhow::Error>
102where
103 P: PeerBuilder,
104{
105 info!(
106 process,
107 ?addresses,
108 "initializing network for timely instance",
109 );
110 let sockets = loop {
111 match create_sockets(process, &addresses).await {
112 Ok(sockets) => break sockets,
113 Err(error) if error.is_fatal() => bail!("failed to set up Timely sockets: {error}"),
114 Err(error) => info!("creating sockets failed: {error}; retrying"),
115 }
116 };
117
118 if sockets
119 .iter()
120 .filter_map(|s| s.as_ref())
121 .all(|s| s.is_tcp())
122 {
123 let sockets = sockets
124 .into_iter()
125 .map(|s| s.map(|s| s.unwrap_tcp().into_std()).transpose())
126 .collect::<Result<Vec<_>, _>>()
127 .map_err(anyhow::Error::from)
128 .context("failed to get standard sockets from tokio sockets")?;
129 initialize_networking_inner::<_, P, _>(sockets, process, workers, refill, builder_fn)
130 } else if sockets
131 .iter()
132 .filter_map(|s| s.as_ref())
133 .all(|s| s.is_unix())
134 {
135 let sockets = sockets
136 .into_iter()
137 .map(|s| s.map(|s| s.unwrap_unix().into_std()).transpose())
138 .collect::<Result<Vec<_>, _>>()
139 .map_err(anyhow::Error::from)
140 .context("failed to get standard sockets from tokio sockets")?;
141 initialize_networking_inner::<_, P, _>(sockets, process, workers, refill, builder_fn)
142 } else {
143 anyhow::bail!("cannot mix TCP and Unix streams");
144 }
145}
146
147fn initialize_networking_inner<S, P, PF>(
148 sockets: Vec<Option<S>>,
149 process: usize,
150 workers: usize,
151 refill: BytesRefill,
152 builder_fn: PF,
153) -> Result<(Vec<GenericBuilder>, Box<dyn Any + Send>), anyhow::Error>
154where
155 S: timely::communication::allocator::zero_copy::stream::Stream + 'static,
156 P: PeerBuilder,
157 PF: Fn(TcpBuilder<P::Peer>) -> GenericBuilder,
158{
159 for s in &sockets {
160 if let Some(s) = s {
161 s.set_nonblocking(false)
162 .context("failed to set socket to non-blocking")?;
163 }
164 }
165
166 match initialize_networking_from_sockets::<_, P>(
167 sockets,
168 process,
169 workers,
170 refill,
171 Arc::new(|_| None),
172 ) {
173 Ok((stuff, guard)) => {
174 info!(process = process, "successfully initialized network");
175 Ok((stuff.into_iter().map(builder_fn).collect(), Box::new(guard)))
176 }
177 Err(err) => {
178 warn!(process, "failed to initialize network: {err}");
179 Err(anyhow::Error::from(err).context("failed to initialize networking from sockets"))
180 }
181 }
182}
183
184#[derive(Debug)]
186pub(crate) enum CreateSocketsError {
187 Bind {
188 address: String,
189 error: std::io::Error,
190 },
191 EpochMismatch {
192 peer_index: usize,
193 peer_epoch: Epoch,
194 my_epoch: Epoch,
195 },
196 Reconnect {
197 peer_index: usize,
198 },
199}
200
201impl CreateSocketsError {
202 pub fn is_fatal(&self) -> bool {
204 matches!(self, Self::Bind { .. })
205 }
206}
207
208impl fmt::Display for CreateSocketsError {
209 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
210 match self {
211 Self::Bind { address, error } => write!(f, "failed to bind at {address}: {error}"),
212 Self::EpochMismatch {
213 peer_index,
214 peer_epoch,
215 my_epoch,
216 } => write!(
217 f,
218 "peer {peer_index} has greater epoch: {peer_epoch} > {my_epoch}"
219 ),
220 Self::Reconnect { peer_index } => {
221 write!(f, "observed second instance of peer {peer_index}")
222 }
223 }
224 }
225}
226
227impl std::error::Error for CreateSocketsError {}
228
229#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
237pub(crate) struct Epoch {
238 time: u64,
239 nonce: u64,
240}
241
242impl Epoch {
243 fn mint() -> Self {
244 let time = SystemTime::UNIX_EPOCH
245 .elapsed()
246 .expect("current time is after 1970")
247 .as_millis()
248 .try_into()
249 .expect("fits");
250 let nonce = rand::random();
251 Self { time, nonce }
252 }
253
254 async fn read(s: &mut Stream) -> std::io::Result<Self> {
255 let time = s.read_u64().await?;
256 let nonce = s.read_u64().await?;
257 Ok(Self { time, nonce })
258 }
259
260 async fn write(&self, s: &mut Stream) -> std::io::Result<()> {
261 s.write_u64(self.time).await?;
262 s.write_u64(self.nonce).await?;
263 Ok(())
264 }
265}
266
267impl fmt::Display for Epoch {
268 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
269 write!(f, "({}, {})", self.time, self.nonce)
270 }
271}
272
273pub(crate) async fn create_sockets(
278 my_index: usize,
279 addresses: &[String],
280) -> Result<Vec<Option<Stream>>, CreateSocketsError> {
281 let my_address = &addresses[my_index];
282
283 let port_re = Regex::new(r"(?<proto>\w+:)?(?<host>.*):(?<port>\d{1,5})$").unwrap();
286 let listen_address = match port_re.captures(my_address) {
287 Some(cap) => match cap.name("proto") {
288 Some(proto) => format!("{}0.0.0.0:{}", proto.as_str(), &cap["port"]),
289 None => format!("0.0.0.0:{}", &cap["port"]),
290 },
291 None => my_address.to_string(),
292 };
293
294 let listener = Retry::default()
295 .initial_backoff(Duration::from_secs(1))
296 .clamp_backoff(Duration::from_secs(1))
297 .max_tries(10)
298 .retry_async(|_| {
299 Listener::bind(&listen_address)
300 .inspect_err(|error| warn!(%listen_address, "failed to listen: {error}"))
301 })
302 .await
303 .map_err(|error| CreateSocketsError::Bind {
304 address: listen_address,
305 error,
306 })?;
307
308 let (my_epoch, sockets_lower) = match my_index {
309 0 => {
310 let epoch = Epoch::mint();
311 info!(my_index, "minted epoch: {epoch}");
312 (epoch, Vec::new())
313 }
314 _ => connect_lower(my_index, addresses).await?,
315 };
316
317 let n_peers = addresses.len();
318 let sockets_higher = accept_higher(my_index, my_epoch, n_peers, &listener).await?;
319
320 let connections_lower = sockets_lower.into_iter().map(Some);
321 let connections_higher = sockets_higher.into_iter().map(Some);
322 let connections = connections_lower
323 .chain([None])
324 .chain(connections_higher)
325 .collect();
326
327 Ok(connections)
328}
329
330async fn connect_lower(
335 my_index: usize,
336 addresses: &[String],
337) -> Result<(Epoch, Vec<Stream>), CreateSocketsError> {
338 assert!(my_index > 0);
339 assert!(my_index <= addresses.len());
340
341 async fn handshake(
342 my_index: usize,
343 my_epoch: Option<Epoch>,
344 address: &str,
345 ) -> anyhow::Result<(Epoch, Stream)> {
346 let mut s = Stream::connect(address).await?;
347
348 timeout(s.write_u64(u64::cast_from(my_index))).await?;
352 let peer_epoch = timeout(Epoch::read(&mut s)).await?;
353 let my_epoch = my_epoch.unwrap_or(peer_epoch);
354 timeout(my_epoch.write(&mut s)).await?;
355
356 Ok((peer_epoch, s))
357 }
358
359 let mut my_epoch = None;
360 let mut sockets = Vec::new();
361
362 while sockets.len() < my_index {
363 let index = sockets.len();
364 let address = &addresses[index];
365
366 info!(my_index, "connecting to peer {index} at address: {address}");
367
368 let (peer_epoch, sock) = Retry::default()
369 .initial_backoff(Duration::from_secs(1))
370 .clamp_backoff(Duration::from_secs(1))
371 .retry_async(|_| {
372 handshake(my_index, my_epoch, address).inspect_err(|error| {
373 info!(my_index, "error connecting to peer {index}: {error}")
374 })
375 })
376 .await
377 .expect("retries forever");
378
379 if let Some(epoch) = my_epoch {
380 match peer_epoch.cmp(&epoch) {
381 Ordering::Less => {
382 info!(
383 my_index,
384 "refusing connection to peer {index} with smaller epoch: \
385 {peer_epoch} < {epoch}",
386 );
387 continue;
388 }
389 Ordering::Greater => {
390 return Err(CreateSocketsError::EpochMismatch {
391 peer_index: index,
392 peer_epoch,
393 my_epoch: epoch,
394 });
395 }
396 Ordering::Equal => info!(my_index, "connected to peer {index}"),
397 }
398 } else {
399 info!(my_index, "received epoch from peer {index}: {peer_epoch}");
400 my_epoch = Some(peer_epoch);
401 }
402
403 sockets.push(sock);
404 }
405
406 let my_epoch = my_epoch.expect("must exist");
407 Ok((my_epoch, sockets))
408}
409
410async fn accept_higher(
415 my_index: usize,
416 my_epoch: Epoch,
417 n_peers: usize,
418 listener: &Listener,
419) -> Result<Vec<Stream>, CreateSocketsError> {
420 assert!(my_index < n_peers);
421
422 async fn accept(listener: &Listener) -> anyhow::Result<(usize, Stream)> {
423 let (mut s, _) = listener.accept().await?;
424
425 let peer_index = timeout(s.read_u64()).await?;
428 let peer_index = usize::cast_from(peer_index);
429 Ok((peer_index, s))
430 }
431
432 async fn exchange_epochs(my_epoch: Epoch, s: &mut Stream) -> anyhow::Result<Epoch> {
433 timeout(my_epoch.write(s)).await?;
437 let peer_epoch = timeout(Epoch::read(s)).await?;
438 Ok(peer_epoch)
439 }
440
441 let offset = my_index + 1;
442 let mut sockets: Vec<_> = (offset..n_peers).map(|_| None).collect();
443
444 while sockets.iter().any(|s| s.is_none()) {
445 info!(my_index, "accepting connection from peer");
446
447 let (index, mut sock) = match accept(listener).await {
448 Ok(result) => result,
449 Err(error) => {
450 info!(my_index, "error accepting connection: {error}");
451 continue;
452 }
453 };
454
455 if sockets[index - offset].is_some() {
456 return Err(CreateSocketsError::Reconnect { peer_index: index });
457 }
458
459 let peer_epoch = match exchange_epochs(my_epoch, &mut sock).await {
460 Ok(result) => result,
461 Err(error) => {
462 info!(my_index, "error exchanging epochs: {error}");
463 continue;
464 }
465 };
466
467 match peer_epoch.cmp(&my_epoch) {
468 Ordering::Less => {
469 info!(
470 my_index,
471 "refusing connection from peer {index} with smaller epoch: \
472 {peer_epoch} < {my_epoch}",
473 );
474 continue;
475 }
476 Ordering::Greater => {
477 return Err(CreateSocketsError::EpochMismatch {
478 peer_index: index,
479 peer_epoch,
480 my_epoch,
481 });
482 }
483 Ordering::Equal => info!(my_index, "connected to peer {index}"),
484 }
485
486 sockets[index - offset] = Some(sock);
487 }
488
489 Ok(sockets.into_iter().map(|s| s.unwrap()).collect())
490}
491
492async fn timeout<F, R>(fut: F) -> anyhow::Result<R>
499where
500 F: Future<Output = std::io::Result<R>>,
501{
502 let timeout = Duration::from_secs(1);
503 let result = mz_ore::future::timeout(timeout, fut).await?;
504 Ok(result)
505}
506
507#[cfg(test)]
508mod turmoil_tests {
509 use rand::rngs::SmallRng;
510 use rand::{Rng, SeedableRng};
511 use tokio::sync::{mpsc, watch};
512 use tokio::time::timeout;
513
514 use super::*;
515
516 #[test] #[cfg_attr(miri, ignore)] fn test_create_sockets() {
525 const NUM_PROCESSES: usize = 10;
526 const NUM_CRASHES: usize = 3;
527
528 configure_tracing_for_turmoil();
529
530 let seed = std::env::var("SEED")
531 .ok()
532 .and_then(|x| x.parse().ok())
533 .unwrap_or_else(rand::random);
534
535 info!("initializing rng with seed {seed}");
536 let mut rng = SmallRng::seed_from_u64(seed);
537
538 let mut sim = turmoil::Builder::new()
539 .enable_random_order()
540 .build_with_rng(Box::new(rng.clone()));
541
542 let processes: Vec<_> = (0..NUM_PROCESSES).map(|i| format!("process-{i}")).collect();
543 let addresses: Vec<_> = processes
544 .iter()
545 .map(|n| format!("turmoil:{n}:7777"))
546 .collect();
547
548 let (ready_tx, mut ready_rx) = mpsc::unbounded_channel();
550
551 let (stable_tx, stable_rx) = watch::channel(false);
555
556 for (index, name) in processes.iter().enumerate() {
557 let addresses = addresses.clone();
558 let ready_tx = ready_tx.clone();
559 let stable_rx = stable_rx.clone();
560 sim.host(&name[..], move || {
561 let addresses = addresses.clone();
562 let ready_tx = ready_tx.clone();
563 let mut stable_rx = stable_rx.clone();
564 async move {
565 'protocol: loop {
566 let mut sockets = match create_sockets(index, &addresses).await {
567 Ok(sockets) => sockets,
568 Err(error) if error.is_fatal() => Err(error)?,
569 Err(error) => {
570 info!("creating sockets failed: {error}; retrying protocol");
571 continue 'protocol;
572 }
573 };
574
575 let _ = stable_rx.wait_for(|stable| *stable).await;
586
587 info!("sockets created; checking connections");
588 for sock in sockets.iter_mut().filter_map(|s| s.as_mut()) {
589 if let Err(error) = sock.write_u8(111).await {
590 info!("error pinging socket: {error}; retrying protocol");
591 continue 'protocol;
592 }
593 }
594 for sock in sockets.iter_mut().filter_map(|s| s.as_mut()) {
595 info!("waiting for ping from {sock:?}");
596 match timeout(Duration::from_secs(2), sock.read_u8()).await {
597 Ok(Ok(ping)) => assert_eq!(ping, 111),
598 Ok(Err(error)) => {
599 info!("error waiting for ping: {error}; retrying protocol");
600 continue 'protocol;
601 }
602 Err(_) => {
603 info!("timed out waiting for ping; retrying protocol");
604 continue 'protocol;
605 }
606 }
607 }
608
609 let _ = ready_tx.send(index);
610
611 std::mem::forget(sockets);
612 return Ok(());
613 }
614 }
615 });
616 }
617
618 for _ in 0..NUM_CRASHES {
620 let steps = rng.gen_range(1..100);
621 for _ in 0..steps {
622 sim.step().unwrap();
623 }
624
625 let i = rng.gen_range(0..NUM_PROCESSES);
626 info!("bouncing process {i}");
627 sim.bounce(format!("process-{i}"));
628 }
629
630 stable_tx.send(true).unwrap();
631
632 let mut num_ready = 0;
634 loop {
635 while let Ok(index) = ready_rx.try_recv() {
636 info!("process {index} is ready");
637 num_ready += 1;
638 }
639 if num_ready == NUM_PROCESSES {
640 break;
641 }
642
643 sim.step().unwrap();
644 if sim.elapsed() > Duration::from_secs(60) {
645 panic!("simulation not finished after 60s");
646 }
647 }
648 }
649
650 #[test] #[ignore = "runs forever"]
653 fn fuzz_create_sockets() {
654 loop {
655 test_create_sockets();
656 }
657 }
658
659 fn configure_tracing_for_turmoil() {
663 use std::sync::Once;
664 use tracing::level_filters::LevelFilter;
665 use tracing_subscriber::fmt::time::FormatTime;
666
667 #[derive(Clone)]
668 struct SimElapsedTime;
669
670 impl FormatTime for SimElapsedTime {
671 fn format_time(
672 &self,
673 w: &mut tracing_subscriber::fmt::format::Writer<'_>,
674 ) -> std::fmt::Result {
675 tracing_subscriber::fmt::time().format_time(w)?;
676 if let Some(sim_elapsed) = turmoil::sim_elapsed() {
677 write!(w, " [{:?}]", sim_elapsed)?;
678 }
679 Ok(())
680 }
681 }
682
683 static INIT_TRACING: Once = Once::new();
684 INIT_TRACING.call_once(|| {
685 let env_filter = tracing_subscriber::EnvFilter::builder()
686 .with_default_directive(LevelFilter::INFO.into())
687 .from_env_lossy();
688 let subscriber = tracing_subscriber::fmt()
689 .with_test_writer()
690 .with_env_filter(env_filter)
691 .with_timer(SimElapsedTime)
692 .finish();
693
694 tracing::subscriber::set_global_default(subscriber).unwrap();
695 });
696 }
697}