1use std::any::Any;
76use std::cmp::Ordering;
77use std::fmt;
78use std::sync::Arc;
79use std::time::{Duration, SystemTime};
80
81use anyhow::{Context, bail};
82use futures::TryFutureExt;
83use mz_ore::cast::CastFrom;
84use mz_ore::netio::{Listener, Stream};
85use mz_ore::retry::Retry;
86use regex::Regex;
87use timely::communication::Hooks;
88use timely::communication::allocator::ProcessBuilder;
89use timely::communication::allocator::generic::AllocatorBuilder;
90use timely::communication::allocator::zero_copy::bytes_slab::BytesRefill;
91use timely::communication::allocator::zero_copy::initialize::initialize_networking_from_sockets;
92use tokio::io::{AsyncReadExt, AsyncWriteExt};
93use tracing::{info, warn};
94
95pub async fn initialize_networking(
101 workers: usize,
102 process: usize,
103 addresses: Vec<String>,
104 refill: BytesRefill,
105 enable_zero_copy_binary: bool,
106) -> Result<(Vec<AllocatorBuilder>, Box<dyn Any + Send>), anyhow::Error> {
107 info!(
108 process,
109 ?addresses,
110 "initializing network for timely instance",
111 );
112 let sockets = loop {
113 match create_sockets(process, &addresses).await {
114 Ok(sockets) => break sockets,
115 Err(error) if error.is_fatal() => bail!("failed to set up Timely sockets: {error}"),
116 Err(error) => info!("creating sockets failed: {error}; retrying"),
117 }
118 };
119
120 if sockets
121 .iter()
122 .filter_map(|s| s.as_ref())
123 .all(|s| s.is_tcp())
124 {
125 let sockets = sockets
126 .into_iter()
127 .map(|s| s.map(|s| s.unwrap_tcp().into_std()).transpose())
128 .collect::<Result<Vec<_>, _>>()
129 .map_err(anyhow::Error::from)
130 .context("failed to get standard sockets from tokio sockets")?;
131 initialize_networking_inner(sockets, process, workers, refill, enable_zero_copy_binary)
132 } else if sockets
133 .iter()
134 .filter_map(|s| s.as_ref())
135 .all(|s| s.is_unix())
136 {
137 let sockets = sockets
138 .into_iter()
139 .map(|s| s.map(|s| s.unwrap_unix().into_std()).transpose())
140 .collect::<Result<Vec<_>, _>>()
141 .map_err(anyhow::Error::from)
142 .context("failed to get standard sockets from tokio sockets")?;
143 initialize_networking_inner(sockets, process, workers, refill, enable_zero_copy_binary)
144 } else {
145 anyhow::bail!("cannot mix TCP and Unix streams");
146 }
147}
148
149fn initialize_networking_inner<S>(
150 sockets: Vec<Option<S>>,
151 process: usize,
152 workers: usize,
153 refill: BytesRefill,
154 enable_zero_copy_binary: bool,
155) -> Result<(Vec<AllocatorBuilder>, Box<dyn Any + Send>), anyhow::Error>
156where
157 S: timely::communication::allocator::zero_copy::stream::Stream + 'static,
158{
159 for s in &sockets {
160 if let Some(s) = s {
161 s.set_nonblocking(false)
162 .context("failed to set socket to non-blocking")?;
163 }
164 }
165
166 let process_allocators = if enable_zero_copy_binary {
170 ProcessBuilder::new_bytes_vector(workers, refill.clone(), None)
171 } else {
172 ProcessBuilder::new_typed_vector(workers, refill.clone(), None)
173 };
174
175 let hooks = Hooks {
179 log_fn: Arc::new(|_| None),
180 refill,
181 spill: None,
182 };
183
184 match initialize_networking_from_sockets(process_allocators, sockets, process, workers, hooks) {
185 Ok((tcp_builders, guard)) => {
186 info!(process = process, "successfully initialized network");
187 let builders = tcp_builders
188 .into_iter()
189 .map(AllocatorBuilder::Tcp)
190 .collect();
191 Ok((builders, Box::new(guard)))
192 }
193 Err(err) => {
194 warn!(process, "failed to initialize network: {err}");
195 Err(anyhow::Error::from(err).context("failed to initialize networking from sockets"))
196 }
197 }
198}
199
200#[derive(Debug)]
202pub(crate) enum CreateSocketsError {
203 Bind {
204 address: String,
205 error: std::io::Error,
206 },
207 EpochMismatch {
208 peer_index: usize,
209 peer_epoch: Epoch,
210 my_epoch: Epoch,
211 },
212 Reconnect {
213 peer_index: usize,
214 },
215}
216
217impl CreateSocketsError {
218 pub fn is_fatal(&self) -> bool {
220 matches!(self, Self::Bind { .. })
221 }
222}
223
224impl fmt::Display for CreateSocketsError {
225 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
226 match self {
227 Self::Bind { address, error } => write!(f, "failed to bind at {address}: {error}"),
228 Self::EpochMismatch {
229 peer_index,
230 peer_epoch,
231 my_epoch,
232 } => write!(
233 f,
234 "peer {peer_index} has greater epoch: {peer_epoch} > {my_epoch}"
235 ),
236 Self::Reconnect { peer_index } => {
237 write!(f, "observed second instance of peer {peer_index}")
238 }
239 }
240 }
241}
242
243impl std::error::Error for CreateSocketsError {}
244
245#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
253pub(crate) struct Epoch {
254 time: u64,
255 nonce: u64,
256}
257
258impl Epoch {
259 fn mint() -> Self {
260 let time = SystemTime::UNIX_EPOCH
261 .elapsed()
262 .expect("current time is after 1970")
263 .as_millis()
264 .try_into()
265 .expect("fits");
266 let nonce = rand::random();
267 Self { time, nonce }
268 }
269
270 async fn read(s: &mut Stream) -> std::io::Result<Self> {
271 let time = s.read_u64().await?;
272 let nonce = s.read_u64().await?;
273 Ok(Self { time, nonce })
274 }
275
276 async fn write(&self, s: &mut Stream) -> std::io::Result<()> {
277 s.write_u64(self.time).await?;
278 s.write_u64(self.nonce).await?;
279 Ok(())
280 }
281}
282
283impl fmt::Display for Epoch {
284 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
285 write!(f, "({}, {})", self.time, self.nonce)
286 }
287}
288
289pub(crate) async fn create_sockets(
294 my_index: usize,
295 addresses: &[String],
296) -> Result<Vec<Option<Stream>>, CreateSocketsError> {
297 let my_address = &addresses[my_index];
298
299 let port_re = Regex::new(r"(?<proto>\w+:)?(?<host>.*):(?<port>\d{1,5})$").unwrap();
302 let listen_address = match port_re.captures(my_address) {
303 Some(cap) => match cap.name("proto") {
304 Some(proto) => format!("{}0.0.0.0:{}", proto.as_str(), &cap["port"]),
305 None => format!("0.0.0.0:{}", &cap["port"]),
306 },
307 None => my_address.to_string(),
308 };
309
310 let listener = Retry::default()
311 .initial_backoff(Duration::from_secs(1))
312 .clamp_backoff(Duration::from_secs(1))
313 .max_tries(10)
314 .retry_async(|_| {
315 Listener::bind(&listen_address)
316 .inspect_err(|error| warn!(%listen_address, "failed to listen: {error}"))
317 })
318 .await
319 .map_err(|error| CreateSocketsError::Bind {
320 address: listen_address,
321 error,
322 })?;
323
324 let (my_epoch, sockets_lower) = match my_index {
325 0 => {
326 let epoch = Epoch::mint();
327 info!(my_index, "minted epoch: {epoch}");
328 (epoch, Vec::new())
329 }
330 _ => connect_lower(my_index, addresses).await?,
331 };
332
333 let n_peers = addresses.len();
334 let sockets_higher = accept_higher(my_index, my_epoch, n_peers, &listener).await?;
335
336 let connections_lower = sockets_lower.into_iter().map(Some);
337 let connections_higher = sockets_higher.into_iter().map(Some);
338 let connections = connections_lower
339 .chain([None])
340 .chain(connections_higher)
341 .collect();
342
343 Ok(connections)
344}
345
346async fn connect_lower(
351 my_index: usize,
352 addresses: &[String],
353) -> Result<(Epoch, Vec<Stream>), CreateSocketsError> {
354 assert!(my_index > 0);
355 assert!(my_index <= addresses.len());
356
357 async fn handshake(
358 my_index: usize,
359 my_epoch: Option<Epoch>,
360 address: &str,
361 ) -> anyhow::Result<(Epoch, Stream)> {
362 let mut s = Stream::connect(address).await?;
363
364 timeout(s.write_u64(u64::cast_from(my_index))).await?;
368 let peer_epoch = timeout(Epoch::read(&mut s)).await?;
369 let my_epoch = my_epoch.unwrap_or(peer_epoch);
370 timeout(my_epoch.write(&mut s)).await?;
371
372 Ok((peer_epoch, s))
373 }
374
375 let mut my_epoch = None;
376 let mut sockets = Vec::new();
377
378 while sockets.len() < my_index {
379 let index = sockets.len();
380 let address = &addresses[index];
381
382 info!(my_index, "connecting to peer {index} at address: {address}");
383
384 let (peer_epoch, sock) = Retry::default()
385 .initial_backoff(Duration::from_secs(1))
386 .clamp_backoff(Duration::from_secs(1))
387 .retry_async(|_| {
388 handshake(my_index, my_epoch, address).inspect_err(|error| {
389 info!(my_index, "error connecting to peer {index}: {error}")
390 })
391 })
392 .await
393 .expect("retries forever");
394
395 if let Some(epoch) = my_epoch {
396 match peer_epoch.cmp(&epoch) {
397 Ordering::Less => {
398 info!(
399 my_index,
400 "refusing connection to peer {index} with smaller epoch: \
401 {peer_epoch} < {epoch}",
402 );
403 continue;
404 }
405 Ordering::Greater => {
406 return Err(CreateSocketsError::EpochMismatch {
407 peer_index: index,
408 peer_epoch,
409 my_epoch: epoch,
410 });
411 }
412 Ordering::Equal => info!(my_index, "connected to peer {index}"),
413 }
414 } else {
415 info!(my_index, "received epoch from peer {index}: {peer_epoch}");
416 my_epoch = Some(peer_epoch);
417 }
418
419 sockets.push(sock);
420 }
421
422 let my_epoch = my_epoch.expect("must exist");
423 Ok((my_epoch, sockets))
424}
425
426async fn accept_higher(
431 my_index: usize,
432 my_epoch: Epoch,
433 n_peers: usize,
434 listener: &Listener,
435) -> Result<Vec<Stream>, CreateSocketsError> {
436 assert!(my_index < n_peers);
437
438 async fn accept(listener: &Listener) -> anyhow::Result<(usize, Stream)> {
439 let (mut s, _) = listener.accept().await?;
440
441 let peer_index = timeout(s.read_u64()).await?;
444 let peer_index = usize::cast_from(peer_index);
445 Ok((peer_index, s))
446 }
447
448 async fn exchange_epochs(my_epoch: Epoch, s: &mut Stream) -> anyhow::Result<Epoch> {
449 timeout(my_epoch.write(s)).await?;
453 let peer_epoch = timeout(Epoch::read(s)).await?;
454 Ok(peer_epoch)
455 }
456
457 let offset = my_index + 1;
458 let mut sockets: Vec<_> = (offset..n_peers).map(|_| None).collect();
459
460 while sockets.iter().any(|s| s.is_none()) {
461 info!(my_index, "accepting connection from peer");
462
463 let (index, mut sock) = match accept(listener).await {
464 Ok(result) => result,
465 Err(error) => {
466 info!(my_index, "error accepting connection: {error}");
467 continue;
468 }
469 };
470
471 if sockets[index - offset].is_some() {
472 return Err(CreateSocketsError::Reconnect { peer_index: index });
473 }
474
475 let peer_epoch = match exchange_epochs(my_epoch, &mut sock).await {
476 Ok(result) => result,
477 Err(error) => {
478 info!(my_index, "error exchanging epochs: {error}");
479 continue;
480 }
481 };
482
483 match peer_epoch.cmp(&my_epoch) {
484 Ordering::Less => {
485 info!(
486 my_index,
487 "refusing connection from peer {index} with smaller epoch: \
488 {peer_epoch} < {my_epoch}",
489 );
490 continue;
491 }
492 Ordering::Greater => {
493 return Err(CreateSocketsError::EpochMismatch {
494 peer_index: index,
495 peer_epoch,
496 my_epoch,
497 });
498 }
499 Ordering::Equal => info!(my_index, "connected to peer {index}"),
500 }
501
502 sockets[index - offset] = Some(sock);
503 }
504
505 Ok(sockets.into_iter().map(|s| s.unwrap()).collect())
506}
507
508async fn timeout<F, R>(fut: F) -> anyhow::Result<R>
515where
516 F: Future<Output = std::io::Result<R>>,
517{
518 let timeout = Duration::from_secs(1);
519 let result = mz_ore::future::timeout(timeout, fut).await?;
520 Ok(result)
521}
522
523#[cfg(test)]
524mod turmoil_tests {
525 use rand::rngs::SmallRng;
526 use rand::{Rng, SeedableRng};
527 use tokio::sync::{mpsc, watch};
528 use tokio::time::timeout;
529
530 use super::*;
531
532 #[test] #[cfg_attr(miri, ignore)] fn test_create_sockets() {
541 const NUM_PROCESSES: usize = 10;
542 const NUM_CRASHES: usize = 3;
543
544 configure_tracing_for_turmoil();
545
546 let seed = std::env::var("SEED")
547 .ok()
548 .and_then(|x| x.parse().ok())
549 .unwrap_or_else(rand::random);
550
551 info!("initializing rng with seed {seed}");
552 let mut rng = SmallRng::seed_from_u64(seed);
553
554 let mut sim = turmoil::Builder::new()
555 .enable_random_order()
556 .rng_seed(rng.random())
557 .build();
558
559 let processes: Vec<_> = (0..NUM_PROCESSES).map(|i| format!("process-{i}")).collect();
560 let addresses: Vec<_> = processes
561 .iter()
562 .map(|n| format!("turmoil:{n}:7777"))
563 .collect();
564
565 let (ready_tx, mut ready_rx) = mpsc::unbounded_channel();
567
568 let (stable_tx, stable_rx) = watch::channel(false);
572
573 for (index, name) in processes.iter().enumerate() {
574 let addresses = addresses.clone();
575 let ready_tx = ready_tx.clone();
576 let stable_rx = stable_rx.clone();
577 sim.host(&name[..], move || {
578 let addresses = addresses.clone();
579 let ready_tx = ready_tx.clone();
580 let mut stable_rx = stable_rx.clone();
581 async move {
582 'protocol: loop {
583 let mut sockets = match create_sockets(index, &addresses).await {
584 Ok(sockets) => sockets,
585 Err(error) if error.is_fatal() => Err(error)?,
586 Err(error) => {
587 info!("creating sockets failed: {error}; retrying protocol");
588 continue 'protocol;
589 }
590 };
591
592 let _ = stable_rx.wait_for(|stable| *stable).await;
603
604 info!("sockets created; checking connections");
605 for sock in sockets.iter_mut().filter_map(|s| s.as_mut()) {
606 if let Err(error) = sock.write_u8(111).await {
607 info!("error pinging socket: {error}; retrying protocol");
608 continue 'protocol;
609 }
610 }
611 for sock in sockets.iter_mut().filter_map(|s| s.as_mut()) {
612 info!("waiting for ping from {sock:?}");
613 match timeout(Duration::from_secs(2), sock.read_u8()).await {
614 Ok(Ok(ping)) => assert_eq!(ping, 111),
615 Ok(Err(error)) => {
616 info!("error waiting for ping: {error}; retrying protocol");
617 continue 'protocol;
618 }
619 Err(_) => {
620 info!("timed out waiting for ping; retrying protocol");
621 continue 'protocol;
622 }
623 }
624 }
625
626 let _ = ready_tx.send(index);
627
628 std::mem::forget(sockets);
629 return Ok(());
630 }
631 }
632 });
633 }
634
635 for _ in 0..NUM_CRASHES {
637 let steps = rng.random_range(1..100);
638 for _ in 0..steps {
639 sim.step().unwrap();
640 }
641
642 let i = rng.random_range(0..NUM_PROCESSES);
643 info!("bouncing process {i}");
644 sim.bounce(format!("process-{i}"));
645 }
646
647 stable_tx.send(true).unwrap();
648
649 let mut num_ready = 0;
651 loop {
652 while let Ok(index) = ready_rx.try_recv() {
653 info!("process {index} is ready");
654 num_ready += 1;
655 }
656 if num_ready == NUM_PROCESSES {
657 break;
658 }
659
660 sim.step().unwrap();
661 if sim.elapsed() > Duration::from_secs(60) {
662 panic!("simulation not finished after 60s");
663 }
664 }
665 }
666
667 #[test] #[ignore = "runs forever"]
670 fn fuzz_create_sockets() {
671 loop {
672 test_create_sockets();
673 }
674 }
675
676 fn configure_tracing_for_turmoil() {
680 use std::sync::Once;
681 use tracing::level_filters::LevelFilter;
682 use tracing_subscriber::fmt::time::FormatTime;
683
684 #[derive(Clone)]
685 struct SimElapsedTime;
686
687 impl FormatTime for SimElapsedTime {
688 fn format_time(
689 &self,
690 w: &mut tracing_subscriber::fmt::format::Writer<'_>,
691 ) -> std::fmt::Result {
692 tracing_subscriber::fmt::time().format_time(w)?;
693 if let Some(sim_elapsed) = turmoil::sim_elapsed() {
694 write!(w, " [{:?}]", sim_elapsed)?;
695 }
696 Ok(())
697 }
698 }
699
700 static INIT_TRACING: Once = Once::new();
701 INIT_TRACING.call_once(|| {
702 let env_filter = tracing_subscriber::EnvFilter::builder()
703 .with_default_directive(LevelFilter::INFO.into())
704 .from_env_lossy();
705 let subscriber = tracing_subscriber::fmt()
706 .with_test_writer()
707 .with_env_filter(env_filter)
708 .with_timer(SimElapsedTime)
709 .finish();
710
711 tracing::subscriber::set_global_default(subscriber).unwrap();
712 });
713 }
714}