1use std::any::Any;
76use std::cmp::Ordering;
77use std::fmt;
78use std::sync::Arc;
79use std::time::{Duration, SystemTime};
80
81use anyhow::{Context, bail};
82use futures::TryFutureExt;
83use mz_ore::cast::CastFrom;
84use mz_ore::netio::{Listener, Stream};
85use mz_ore::retry::Retry;
86use regex::Regex;
87use timely::communication::allocator::ProcessBuilder;
88use timely::communication::allocator::generic::AllocatorBuilder;
89use timely::communication::allocator::zero_copy::bytes_slab::BytesRefill;
90use timely::communication::allocator::zero_copy::initialize::initialize_networking_from_sockets;
91use tokio::io::{AsyncReadExt, AsyncWriteExt};
92use tracing::{info, warn};
93
94pub async fn initialize_networking(
100 workers: usize,
101 process: usize,
102 addresses: Vec<String>,
103 refill: BytesRefill,
104 enable_zero_copy_binary: bool,
105) -> Result<(Vec<AllocatorBuilder>, Box<dyn Any + Send>), anyhow::Error> {
106 info!(
107 process,
108 ?addresses,
109 "initializing network for timely instance",
110 );
111 let sockets = loop {
112 match create_sockets(process, &addresses).await {
113 Ok(sockets) => break sockets,
114 Err(error) if error.is_fatal() => bail!("failed to set up Timely sockets: {error}"),
115 Err(error) => info!("creating sockets failed: {error}; retrying"),
116 }
117 };
118
119 if sockets
120 .iter()
121 .filter_map(|s| s.as_ref())
122 .all(|s| s.is_tcp())
123 {
124 let sockets = sockets
125 .into_iter()
126 .map(|s| s.map(|s| s.unwrap_tcp().into_std()).transpose())
127 .collect::<Result<Vec<_>, _>>()
128 .map_err(anyhow::Error::from)
129 .context("failed to get standard sockets from tokio sockets")?;
130 initialize_networking_inner(sockets, process, workers, refill, enable_zero_copy_binary)
131 } else if sockets
132 .iter()
133 .filter_map(|s| s.as_ref())
134 .all(|s| s.is_unix())
135 {
136 let sockets = sockets
137 .into_iter()
138 .map(|s| s.map(|s| s.unwrap_unix().into_std()).transpose())
139 .collect::<Result<Vec<_>, _>>()
140 .map_err(anyhow::Error::from)
141 .context("failed to get standard sockets from tokio sockets")?;
142 initialize_networking_inner(sockets, process, workers, refill, enable_zero_copy_binary)
143 } else {
144 anyhow::bail!("cannot mix TCP and Unix streams");
145 }
146}
147
148fn initialize_networking_inner<S>(
149 sockets: Vec<Option<S>>,
150 process: usize,
151 workers: usize,
152 refill: BytesRefill,
153 enable_zero_copy_binary: bool,
154) -> Result<(Vec<AllocatorBuilder>, Box<dyn Any + Send>), anyhow::Error>
155where
156 S: timely::communication::allocator::zero_copy::stream::Stream + 'static,
157{
158 for s in &sockets {
159 if let Some(s) = s {
160 s.set_nonblocking(false)
161 .context("failed to set socket to non-blocking")?;
162 }
163 }
164
165 let process_allocators = if enable_zero_copy_binary {
167 ProcessBuilder::new_bytes_vector(workers, refill.clone())
168 } else {
169 ProcessBuilder::new_typed_vector(workers, refill.clone())
170 };
171
172 match initialize_networking_from_sockets(
173 process_allocators,
174 sockets,
175 process,
176 workers,
177 refill,
178 Arc::new(|_| None),
179 ) {
180 Ok((tcp_builders, guard)) => {
181 info!(process = process, "successfully initialized network");
182 let builders = tcp_builders
183 .into_iter()
184 .map(AllocatorBuilder::Tcp)
185 .collect();
186 Ok((builders, Box::new(guard)))
187 }
188 Err(err) => {
189 warn!(process, "failed to initialize network: {err}");
190 Err(anyhow::Error::from(err).context("failed to initialize networking from sockets"))
191 }
192 }
193}
194
195#[derive(Debug)]
197pub(crate) enum CreateSocketsError {
198 Bind {
199 address: String,
200 error: std::io::Error,
201 },
202 EpochMismatch {
203 peer_index: usize,
204 peer_epoch: Epoch,
205 my_epoch: Epoch,
206 },
207 Reconnect {
208 peer_index: usize,
209 },
210}
211
212impl CreateSocketsError {
213 pub fn is_fatal(&self) -> bool {
215 matches!(self, Self::Bind { .. })
216 }
217}
218
219impl fmt::Display for CreateSocketsError {
220 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
221 match self {
222 Self::Bind { address, error } => write!(f, "failed to bind at {address}: {error}"),
223 Self::EpochMismatch {
224 peer_index,
225 peer_epoch,
226 my_epoch,
227 } => write!(
228 f,
229 "peer {peer_index} has greater epoch: {peer_epoch} > {my_epoch}"
230 ),
231 Self::Reconnect { peer_index } => {
232 write!(f, "observed second instance of peer {peer_index}")
233 }
234 }
235 }
236}
237
238impl std::error::Error for CreateSocketsError {}
239
240#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
248pub(crate) struct Epoch {
249 time: u64,
250 nonce: u64,
251}
252
253impl Epoch {
254 fn mint() -> Self {
255 let time = SystemTime::UNIX_EPOCH
256 .elapsed()
257 .expect("current time is after 1970")
258 .as_millis()
259 .try_into()
260 .expect("fits");
261 let nonce = rand::random();
262 Self { time, nonce }
263 }
264
265 async fn read(s: &mut Stream) -> std::io::Result<Self> {
266 let time = s.read_u64().await?;
267 let nonce = s.read_u64().await?;
268 Ok(Self { time, nonce })
269 }
270
271 async fn write(&self, s: &mut Stream) -> std::io::Result<()> {
272 s.write_u64(self.time).await?;
273 s.write_u64(self.nonce).await?;
274 Ok(())
275 }
276}
277
278impl fmt::Display for Epoch {
279 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
280 write!(f, "({}, {})", self.time, self.nonce)
281 }
282}
283
284pub(crate) async fn create_sockets(
289 my_index: usize,
290 addresses: &[String],
291) -> Result<Vec<Option<Stream>>, CreateSocketsError> {
292 let my_address = &addresses[my_index];
293
294 let port_re = Regex::new(r"(?<proto>\w+:)?(?<host>.*):(?<port>\d{1,5})$").unwrap();
297 let listen_address = match port_re.captures(my_address) {
298 Some(cap) => match cap.name("proto") {
299 Some(proto) => format!("{}0.0.0.0:{}", proto.as_str(), &cap["port"]),
300 None => format!("0.0.0.0:{}", &cap["port"]),
301 },
302 None => my_address.to_string(),
303 };
304
305 let listener = Retry::default()
306 .initial_backoff(Duration::from_secs(1))
307 .clamp_backoff(Duration::from_secs(1))
308 .max_tries(10)
309 .retry_async(|_| {
310 Listener::bind(&listen_address)
311 .inspect_err(|error| warn!(%listen_address, "failed to listen: {error}"))
312 })
313 .await
314 .map_err(|error| CreateSocketsError::Bind {
315 address: listen_address,
316 error,
317 })?;
318
319 let (my_epoch, sockets_lower) = match my_index {
320 0 => {
321 let epoch = Epoch::mint();
322 info!(my_index, "minted epoch: {epoch}");
323 (epoch, Vec::new())
324 }
325 _ => connect_lower(my_index, addresses).await?,
326 };
327
328 let n_peers = addresses.len();
329 let sockets_higher = accept_higher(my_index, my_epoch, n_peers, &listener).await?;
330
331 let connections_lower = sockets_lower.into_iter().map(Some);
332 let connections_higher = sockets_higher.into_iter().map(Some);
333 let connections = connections_lower
334 .chain([None])
335 .chain(connections_higher)
336 .collect();
337
338 Ok(connections)
339}
340
341async fn connect_lower(
346 my_index: usize,
347 addresses: &[String],
348) -> Result<(Epoch, Vec<Stream>), CreateSocketsError> {
349 assert!(my_index > 0);
350 assert!(my_index <= addresses.len());
351
352 async fn handshake(
353 my_index: usize,
354 my_epoch: Option<Epoch>,
355 address: &str,
356 ) -> anyhow::Result<(Epoch, Stream)> {
357 let mut s = Stream::connect(address).await?;
358
359 timeout(s.write_u64(u64::cast_from(my_index))).await?;
363 let peer_epoch = timeout(Epoch::read(&mut s)).await?;
364 let my_epoch = my_epoch.unwrap_or(peer_epoch);
365 timeout(my_epoch.write(&mut s)).await?;
366
367 Ok((peer_epoch, s))
368 }
369
370 let mut my_epoch = None;
371 let mut sockets = Vec::new();
372
373 while sockets.len() < my_index {
374 let index = sockets.len();
375 let address = &addresses[index];
376
377 info!(my_index, "connecting to peer {index} at address: {address}");
378
379 let (peer_epoch, sock) = Retry::default()
380 .initial_backoff(Duration::from_secs(1))
381 .clamp_backoff(Duration::from_secs(1))
382 .retry_async(|_| {
383 handshake(my_index, my_epoch, address).inspect_err(|error| {
384 info!(my_index, "error connecting to peer {index}: {error}")
385 })
386 })
387 .await
388 .expect("retries forever");
389
390 if let Some(epoch) = my_epoch {
391 match peer_epoch.cmp(&epoch) {
392 Ordering::Less => {
393 info!(
394 my_index,
395 "refusing connection to peer {index} with smaller epoch: \
396 {peer_epoch} < {epoch}",
397 );
398 continue;
399 }
400 Ordering::Greater => {
401 return Err(CreateSocketsError::EpochMismatch {
402 peer_index: index,
403 peer_epoch,
404 my_epoch: epoch,
405 });
406 }
407 Ordering::Equal => info!(my_index, "connected to peer {index}"),
408 }
409 } else {
410 info!(my_index, "received epoch from peer {index}: {peer_epoch}");
411 my_epoch = Some(peer_epoch);
412 }
413
414 sockets.push(sock);
415 }
416
417 let my_epoch = my_epoch.expect("must exist");
418 Ok((my_epoch, sockets))
419}
420
421async fn accept_higher(
426 my_index: usize,
427 my_epoch: Epoch,
428 n_peers: usize,
429 listener: &Listener,
430) -> Result<Vec<Stream>, CreateSocketsError> {
431 assert!(my_index < n_peers);
432
433 async fn accept(listener: &Listener) -> anyhow::Result<(usize, Stream)> {
434 let (mut s, _) = listener.accept().await?;
435
436 let peer_index = timeout(s.read_u64()).await?;
439 let peer_index = usize::cast_from(peer_index);
440 Ok((peer_index, s))
441 }
442
443 async fn exchange_epochs(my_epoch: Epoch, s: &mut Stream) -> anyhow::Result<Epoch> {
444 timeout(my_epoch.write(s)).await?;
448 let peer_epoch = timeout(Epoch::read(s)).await?;
449 Ok(peer_epoch)
450 }
451
452 let offset = my_index + 1;
453 let mut sockets: Vec<_> = (offset..n_peers).map(|_| None).collect();
454
455 while sockets.iter().any(|s| s.is_none()) {
456 info!(my_index, "accepting connection from peer");
457
458 let (index, mut sock) = match accept(listener).await {
459 Ok(result) => result,
460 Err(error) => {
461 info!(my_index, "error accepting connection: {error}");
462 continue;
463 }
464 };
465
466 if sockets[index - offset].is_some() {
467 return Err(CreateSocketsError::Reconnect { peer_index: index });
468 }
469
470 let peer_epoch = match exchange_epochs(my_epoch, &mut sock).await {
471 Ok(result) => result,
472 Err(error) => {
473 info!(my_index, "error exchanging epochs: {error}");
474 continue;
475 }
476 };
477
478 match peer_epoch.cmp(&my_epoch) {
479 Ordering::Less => {
480 info!(
481 my_index,
482 "refusing connection from peer {index} with smaller epoch: \
483 {peer_epoch} < {my_epoch}",
484 );
485 continue;
486 }
487 Ordering::Greater => {
488 return Err(CreateSocketsError::EpochMismatch {
489 peer_index: index,
490 peer_epoch,
491 my_epoch,
492 });
493 }
494 Ordering::Equal => info!(my_index, "connected to peer {index}"),
495 }
496
497 sockets[index - offset] = Some(sock);
498 }
499
500 Ok(sockets.into_iter().map(|s| s.unwrap()).collect())
501}
502
503async fn timeout<F, R>(fut: F) -> anyhow::Result<R>
510where
511 F: Future<Output = std::io::Result<R>>,
512{
513 let timeout = Duration::from_secs(1);
514 let result = mz_ore::future::timeout(timeout, fut).await?;
515 Ok(result)
516}
517
518#[cfg(test)]
519mod turmoil_tests {
520 use rand::rngs::SmallRng;
521 use rand::{Rng, SeedableRng};
522 use tokio::sync::{mpsc, watch};
523 use tokio::time::timeout;
524
525 use super::*;
526
527 #[test] #[cfg_attr(miri, ignore)] fn test_create_sockets() {
536 const NUM_PROCESSES: usize = 10;
537 const NUM_CRASHES: usize = 3;
538
539 configure_tracing_for_turmoil();
540
541 let seed = std::env::var("SEED")
542 .ok()
543 .and_then(|x| x.parse().ok())
544 .unwrap_or_else(rand::random);
545
546 info!("initializing rng with seed {seed}");
547 let mut rng = SmallRng::seed_from_u64(seed);
548
549 let mut sim = turmoil::Builder::new()
550 .enable_random_order()
551 .rng_seed(rng.random())
552 .build();
553
554 let processes: Vec<_> = (0..NUM_PROCESSES).map(|i| format!("process-{i}")).collect();
555 let addresses: Vec<_> = processes
556 .iter()
557 .map(|n| format!("turmoil:{n}:7777"))
558 .collect();
559
560 let (ready_tx, mut ready_rx) = mpsc::unbounded_channel();
562
563 let (stable_tx, stable_rx) = watch::channel(false);
567
568 for (index, name) in processes.iter().enumerate() {
569 let addresses = addresses.clone();
570 let ready_tx = ready_tx.clone();
571 let stable_rx = stable_rx.clone();
572 sim.host(&name[..], move || {
573 let addresses = addresses.clone();
574 let ready_tx = ready_tx.clone();
575 let mut stable_rx = stable_rx.clone();
576 async move {
577 'protocol: loop {
578 let mut sockets = match create_sockets(index, &addresses).await {
579 Ok(sockets) => sockets,
580 Err(error) if error.is_fatal() => Err(error)?,
581 Err(error) => {
582 info!("creating sockets failed: {error}; retrying protocol");
583 continue 'protocol;
584 }
585 };
586
587 let _ = stable_rx.wait_for(|stable| *stable).await;
598
599 info!("sockets created; checking connections");
600 for sock in sockets.iter_mut().filter_map(|s| s.as_mut()) {
601 if let Err(error) = sock.write_u8(111).await {
602 info!("error pinging socket: {error}; retrying protocol");
603 continue 'protocol;
604 }
605 }
606 for sock in sockets.iter_mut().filter_map(|s| s.as_mut()) {
607 info!("waiting for ping from {sock:?}");
608 match timeout(Duration::from_secs(2), sock.read_u8()).await {
609 Ok(Ok(ping)) => assert_eq!(ping, 111),
610 Ok(Err(error)) => {
611 info!("error waiting for ping: {error}; retrying protocol");
612 continue 'protocol;
613 }
614 Err(_) => {
615 info!("timed out waiting for ping; retrying protocol");
616 continue 'protocol;
617 }
618 }
619 }
620
621 let _ = ready_tx.send(index);
622
623 std::mem::forget(sockets);
624 return Ok(());
625 }
626 }
627 });
628 }
629
630 for _ in 0..NUM_CRASHES {
632 let steps = rng.random_range(1..100);
633 for _ in 0..steps {
634 sim.step().unwrap();
635 }
636
637 let i = rng.random_range(0..NUM_PROCESSES);
638 info!("bouncing process {i}");
639 sim.bounce(format!("process-{i}"));
640 }
641
642 stable_tx.send(true).unwrap();
643
644 let mut num_ready = 0;
646 loop {
647 while let Ok(index) = ready_rx.try_recv() {
648 info!("process {index} is ready");
649 num_ready += 1;
650 }
651 if num_ready == NUM_PROCESSES {
652 break;
653 }
654
655 sim.step().unwrap();
656 if sim.elapsed() > Duration::from_secs(60) {
657 panic!("simulation not finished after 60s");
658 }
659 }
660 }
661
662 #[test] #[ignore = "runs forever"]
665 fn fuzz_create_sockets() {
666 loop {
667 test_create_sockets();
668 }
669 }
670
671 fn configure_tracing_for_turmoil() {
675 use std::sync::Once;
676 use tracing::level_filters::LevelFilter;
677 use tracing_subscriber::fmt::time::FormatTime;
678
679 #[derive(Clone)]
680 struct SimElapsedTime;
681
682 impl FormatTime for SimElapsedTime {
683 fn format_time(
684 &self,
685 w: &mut tracing_subscriber::fmt::format::Writer<'_>,
686 ) -> std::fmt::Result {
687 tracing_subscriber::fmt::time().format_time(w)?;
688 if let Some(sim_elapsed) = turmoil::sim_elapsed() {
689 write!(w, " [{:?}]", sim_elapsed)?;
690 }
691 Ok(())
692 }
693 }
694
695 static INIT_TRACING: Once = Once::new();
696 INIT_TRACING.call_once(|| {
697 let env_filter = tracing_subscriber::EnvFilter::builder()
698 .with_default_directive(LevelFilter::INFO.into())
699 .from_env_lossy();
700 let subscriber = tracing_subscriber::fmt()
701 .with_test_writer()
702 .with_env_filter(env_filter)
703 .with_timer(SimElapsedTime)
704 .finish();
705
706 tracing::subscriber::set_global_default(subscriber).unwrap();
707 });
708 }
709}