1use std::any::Any;
76use std::cmp::Ordering;
77use std::fmt;
78use std::sync::Arc;
79use std::time::{Duration, SystemTime};
80
81use anyhow::{Context, bail};
82use futures::TryFutureExt;
83use mz_ore::cast::CastFrom;
84use mz_ore::netio::{Listener, Stream};
85use mz_ore::retry::Retry;
86use regex::Regex;
87use timely::communication::allocator::zero_copy::allocator::TcpBuilder;
88use timely::communication::allocator::zero_copy::bytes_slab::BytesRefill;
89use timely::communication::allocator::zero_copy::initialize::initialize_networking_from_sockets;
90use timely::communication::allocator::{GenericBuilder, PeerBuilder};
91use tokio::io::{AsyncReadExt, AsyncWriteExt};
92use tracing::{info, warn};
93
94pub async fn initialize_networking<P>(
96    workers: usize,
97    process: usize,
98    addresses: Vec<String>,
99    refill: BytesRefill,
100    builder_fn: impl Fn(TcpBuilder<P::Peer>) -> GenericBuilder,
101) -> Result<(Vec<GenericBuilder>, Box<dyn Any + Send>), anyhow::Error>
102where
103    P: PeerBuilder,
104{
105    info!(
106        process,
107        ?addresses,
108        "initializing network for timely instance",
109    );
110    let sockets = loop {
111        match create_sockets(process, &addresses).await {
112            Ok(sockets) => break sockets,
113            Err(error) if error.is_fatal() => bail!("failed to set up Timely sockets: {error}"),
114            Err(error) => info!("creating sockets failed: {error}; retrying"),
115        }
116    };
117
118    if sockets
119        .iter()
120        .filter_map(|s| s.as_ref())
121        .all(|s| s.is_tcp())
122    {
123        let sockets = sockets
124            .into_iter()
125            .map(|s| s.map(|s| s.unwrap_tcp().into_std()).transpose())
126            .collect::<Result<Vec<_>, _>>()
127            .map_err(anyhow::Error::from)
128            .context("failed to get standard sockets from tokio sockets")?;
129        initialize_networking_inner::<_, P, _>(sockets, process, workers, refill, builder_fn)
130    } else if sockets
131        .iter()
132        .filter_map(|s| s.as_ref())
133        .all(|s| s.is_unix())
134    {
135        let sockets = sockets
136            .into_iter()
137            .map(|s| s.map(|s| s.unwrap_unix().into_std()).transpose())
138            .collect::<Result<Vec<_>, _>>()
139            .map_err(anyhow::Error::from)
140            .context("failed to get standard sockets from tokio sockets")?;
141        initialize_networking_inner::<_, P, _>(sockets, process, workers, refill, builder_fn)
142    } else {
143        anyhow::bail!("cannot mix TCP and Unix streams");
144    }
145}
146
147fn initialize_networking_inner<S, P, PF>(
148    sockets: Vec<Option<S>>,
149    process: usize,
150    workers: usize,
151    refill: BytesRefill,
152    builder_fn: PF,
153) -> Result<(Vec<GenericBuilder>, Box<dyn Any + Send>), anyhow::Error>
154where
155    S: timely::communication::allocator::zero_copy::stream::Stream + 'static,
156    P: PeerBuilder,
157    PF: Fn(TcpBuilder<P::Peer>) -> GenericBuilder,
158{
159    for s in &sockets {
160        if let Some(s) = s {
161            s.set_nonblocking(false)
162                .context("failed to set socket to non-blocking")?;
163        }
164    }
165
166    match initialize_networking_from_sockets::<_, P>(
167        sockets,
168        process,
169        workers,
170        refill,
171        Arc::new(|_| None),
172    ) {
173        Ok((stuff, guard)) => {
174            info!(process = process, "successfully initialized network");
175            Ok((stuff.into_iter().map(builder_fn).collect(), Box::new(guard)))
176        }
177        Err(err) => {
178            warn!(process, "failed to initialize network: {err}");
179            Err(anyhow::Error::from(err).context("failed to initialize networking from sockets"))
180        }
181    }
182}
183
184#[derive(Debug)]
186pub(crate) enum CreateSocketsError {
187    Bind {
188        address: String,
189        error: std::io::Error,
190    },
191    EpochMismatch {
192        peer_index: usize,
193        peer_epoch: Epoch,
194        my_epoch: Epoch,
195    },
196    Reconnect {
197        peer_index: usize,
198    },
199}
200
201impl CreateSocketsError {
202    pub fn is_fatal(&self) -> bool {
204        matches!(self, Self::Bind { .. })
205    }
206}
207
208impl fmt::Display for CreateSocketsError {
209    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
210        match self {
211            Self::Bind { address, error } => write!(f, "failed to bind at {address}: {error}"),
212            Self::EpochMismatch {
213                peer_index,
214                peer_epoch,
215                my_epoch,
216            } => write!(
217                f,
218                "peer {peer_index} has greater epoch: {peer_epoch} > {my_epoch}"
219            ),
220            Self::Reconnect { peer_index } => {
221                write!(f, "observed second instance of peer {peer_index}")
222            }
223        }
224    }
225}
226
227impl std::error::Error for CreateSocketsError {}
228
229#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
237pub(crate) struct Epoch {
238    time: u64,
239    nonce: u64,
240}
241
242impl Epoch {
243    fn mint() -> Self {
244        let time = SystemTime::UNIX_EPOCH
245            .elapsed()
246            .expect("current time is after 1970")
247            .as_millis()
248            .try_into()
249            .expect("fits");
250        let nonce = rand::random();
251        Self { time, nonce }
252    }
253
254    async fn read(s: &mut Stream) -> std::io::Result<Self> {
255        let time = s.read_u64().await?;
256        let nonce = s.read_u64().await?;
257        Ok(Self { time, nonce })
258    }
259
260    async fn write(&self, s: &mut Stream) -> std::io::Result<()> {
261        s.write_u64(self.time).await?;
262        s.write_u64(self.nonce).await?;
263        Ok(())
264    }
265}
266
267impl fmt::Display for Epoch {
268    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
269        write!(f, "({}, {})", self.time, self.nonce)
270    }
271}
272
273pub(crate) async fn create_sockets(
278    my_index: usize,
279    addresses: &[String],
280) -> Result<Vec<Option<Stream>>, CreateSocketsError> {
281    let my_address = &addresses[my_index];
282
283    let port_re = Regex::new(r"(?<proto>\w+:)?(?<host>.*):(?<port>\d{1,5})$").unwrap();
286    let listen_address = match port_re.captures(my_address) {
287        Some(cap) => match cap.name("proto") {
288            Some(proto) => format!("{}0.0.0.0:{}", proto.as_str(), &cap["port"]),
289            None => format!("0.0.0.0:{}", &cap["port"]),
290        },
291        None => my_address.to_string(),
292    };
293
294    let listener = Retry::default()
295        .initial_backoff(Duration::from_secs(1))
296        .clamp_backoff(Duration::from_secs(1))
297        .max_tries(10)
298        .retry_async(|_| {
299            Listener::bind(&listen_address)
300                .inspect_err(|error| warn!(%listen_address, "failed to listen: {error}"))
301        })
302        .await
303        .map_err(|error| CreateSocketsError::Bind {
304            address: listen_address,
305            error,
306        })?;
307
308    let (my_epoch, sockets_lower) = match my_index {
309        0 => {
310            let epoch = Epoch::mint();
311            info!(my_index, "minted epoch: {epoch}");
312            (epoch, Vec::new())
313        }
314        _ => connect_lower(my_index, addresses).await?,
315    };
316
317    let n_peers = addresses.len();
318    let sockets_higher = accept_higher(my_index, my_epoch, n_peers, &listener).await?;
319
320    let connections_lower = sockets_lower.into_iter().map(Some);
321    let connections_higher = sockets_higher.into_iter().map(Some);
322    let connections = connections_lower
323        .chain([None])
324        .chain(connections_higher)
325        .collect();
326
327    Ok(connections)
328}
329
330async fn connect_lower(
335    my_index: usize,
336    addresses: &[String],
337) -> Result<(Epoch, Vec<Stream>), CreateSocketsError> {
338    assert!(my_index > 0);
339    assert!(my_index <= addresses.len());
340
341    async fn handshake(
342        my_index: usize,
343        my_epoch: Option<Epoch>,
344        address: &str,
345    ) -> anyhow::Result<(Epoch, Stream)> {
346        let mut s = Stream::connect(address).await?;
347
348        timeout(s.write_u64(u64::cast_from(my_index))).await?;
352        let peer_epoch = timeout(Epoch::read(&mut s)).await?;
353        let my_epoch = my_epoch.unwrap_or(peer_epoch);
354        timeout(my_epoch.write(&mut s)).await?;
355
356        Ok((peer_epoch, s))
357    }
358
359    let mut my_epoch = None;
360    let mut sockets = Vec::new();
361
362    while sockets.len() < my_index {
363        let index = sockets.len();
364        let address = &addresses[index];
365
366        info!(my_index, "connecting to peer {index} at address: {address}");
367
368        let (peer_epoch, sock) = Retry::default()
369            .initial_backoff(Duration::from_secs(1))
370            .clamp_backoff(Duration::from_secs(1))
371            .retry_async(|_| {
372                handshake(my_index, my_epoch, address).inspect_err(|error| {
373                    info!(my_index, "error connecting to peer {index}: {error}")
374                })
375            })
376            .await
377            .expect("retries forever");
378
379        if let Some(epoch) = my_epoch {
380            match peer_epoch.cmp(&epoch) {
381                Ordering::Less => {
382                    info!(
383                        my_index,
384                        "refusing connection to peer {index} with smaller epoch: \
385                         {peer_epoch} < {epoch}",
386                    );
387                    continue;
388                }
389                Ordering::Greater => {
390                    return Err(CreateSocketsError::EpochMismatch {
391                        peer_index: index,
392                        peer_epoch,
393                        my_epoch: epoch,
394                    });
395                }
396                Ordering::Equal => info!(my_index, "connected to peer {index}"),
397            }
398        } else {
399            info!(my_index, "received epoch from peer {index}: {peer_epoch}");
400            my_epoch = Some(peer_epoch);
401        }
402
403        sockets.push(sock);
404    }
405
406    let my_epoch = my_epoch.expect("must exist");
407    Ok((my_epoch, sockets))
408}
409
410async fn accept_higher(
415    my_index: usize,
416    my_epoch: Epoch,
417    n_peers: usize,
418    listener: &Listener,
419) -> Result<Vec<Stream>, CreateSocketsError> {
420    assert!(my_index < n_peers);
421
422    async fn accept(listener: &Listener) -> anyhow::Result<(usize, Stream)> {
423        let (mut s, _) = listener.accept().await?;
424
425        let peer_index = timeout(s.read_u64()).await?;
428        let peer_index = usize::cast_from(peer_index);
429        Ok((peer_index, s))
430    }
431
432    async fn exchange_epochs(my_epoch: Epoch, s: &mut Stream) -> anyhow::Result<Epoch> {
433        timeout(my_epoch.write(s)).await?;
437        let peer_epoch = timeout(Epoch::read(s)).await?;
438        Ok(peer_epoch)
439    }
440
441    let offset = my_index + 1;
442    let mut sockets: Vec<_> = (offset..n_peers).map(|_| None).collect();
443
444    while sockets.iter().any(|s| s.is_none()) {
445        info!(my_index, "accepting connection from peer");
446
447        let (index, mut sock) = match accept(listener).await {
448            Ok(result) => result,
449            Err(error) => {
450                info!(my_index, "error accepting connection: {error}");
451                continue;
452            }
453        };
454
455        if sockets[index - offset].is_some() {
456            return Err(CreateSocketsError::Reconnect { peer_index: index });
457        }
458
459        let peer_epoch = match exchange_epochs(my_epoch, &mut sock).await {
460            Ok(result) => result,
461            Err(error) => {
462                info!(my_index, "error exchanging epochs: {error}");
463                continue;
464            }
465        };
466
467        match peer_epoch.cmp(&my_epoch) {
468            Ordering::Less => {
469                info!(
470                    my_index,
471                    "refusing connection from peer {index} with smaller epoch: \
472                     {peer_epoch} < {my_epoch}",
473                );
474                continue;
475            }
476            Ordering::Greater => {
477                return Err(CreateSocketsError::EpochMismatch {
478                    peer_index: index,
479                    peer_epoch,
480                    my_epoch,
481                });
482            }
483            Ordering::Equal => info!(my_index, "connected to peer {index}"),
484        }
485
486        sockets[index - offset] = Some(sock);
487    }
488
489    Ok(sockets.into_iter().map(|s| s.unwrap()).collect())
490}
491
492async fn timeout<F, R>(fut: F) -> anyhow::Result<R>
499where
500    F: Future<Output = std::io::Result<R>>,
501{
502    let timeout = Duration::from_secs(1);
503    let result = mz_ore::future::timeout(timeout, fut).await?;
504    Ok(result)
505}
506
507#[cfg(test)]
508mod turmoil_tests {
509    use rand::rngs::SmallRng;
510    use rand::{Rng, SeedableRng};
511    use tokio::sync::{mpsc, watch};
512    use tokio::time::timeout;
513
514    use super::*;
515
516    #[test] #[cfg_attr(miri, ignore)] fn test_create_sockets() {
525        const NUM_PROCESSES: usize = 10;
526        const NUM_CRASHES: usize = 3;
527
528        configure_tracing_for_turmoil();
529
530        let seed = std::env::var("SEED")
531            .ok()
532            .and_then(|x| x.parse().ok())
533            .unwrap_or_else(rand::random);
534
535        info!("initializing rng with seed {seed}");
536        let mut rng = SmallRng::seed_from_u64(seed);
537
538        let mut sim = turmoil::Builder::new()
539            .enable_random_order()
540            .build_with_rng(Box::new(rng.clone()));
541
542        let processes: Vec<_> = (0..NUM_PROCESSES).map(|i| format!("process-{i}")).collect();
543        let addresses: Vec<_> = processes
544            .iter()
545            .map(|n| format!("turmoil:{n}:7777"))
546            .collect();
547
548        let (ready_tx, mut ready_rx) = mpsc::unbounded_channel();
550
551        let (stable_tx, stable_rx) = watch::channel(false);
555
556        for (index, name) in processes.iter().enumerate() {
557            let addresses = addresses.clone();
558            let ready_tx = ready_tx.clone();
559            let stable_rx = stable_rx.clone();
560            sim.host(&name[..], move || {
561                let addresses = addresses.clone();
562                let ready_tx = ready_tx.clone();
563                let mut stable_rx = stable_rx.clone();
564                async move {
565                    'protocol: loop {
566                        let mut sockets = match create_sockets(index, &addresses).await {
567                            Ok(sockets) => sockets,
568                            Err(error) if error.is_fatal() => Err(error)?,
569                            Err(error) => {
570                                info!("creating sockets failed: {error}; retrying protocol");
571                                continue 'protocol;
572                            }
573                        };
574
575                        let _ = stable_rx.wait_for(|stable| *stable).await;
586
587                        info!("sockets created; checking connections");
588                        for sock in sockets.iter_mut().filter_map(|s| s.as_mut()) {
589                            if let Err(error) = sock.write_u8(111).await {
590                                info!("error pinging socket: {error}; retrying protocol");
591                                continue 'protocol;
592                            }
593                        }
594                        for sock in sockets.iter_mut().filter_map(|s| s.as_mut()) {
595                            info!("waiting for ping from {sock:?}");
596                            match timeout(Duration::from_secs(2), sock.read_u8()).await {
597                                Ok(Ok(ping)) => assert_eq!(ping, 111),
598                                Ok(Err(error)) => {
599                                    info!("error waiting for ping: {error}; retrying protocol");
600                                    continue 'protocol;
601                                }
602                                Err(_) => {
603                                    info!("timed out waiting for ping; retrying protocol");
604                                    continue 'protocol;
605                                }
606                            }
607                        }
608
609                        let _ = ready_tx.send(index);
610
611                        std::mem::forget(sockets);
612                        return Ok(());
613                    }
614                }
615            });
616        }
617
618        for _ in 0..NUM_CRASHES {
620            let steps = rng.gen_range(1..100);
621            for _ in 0..steps {
622                sim.step().unwrap();
623            }
624
625            let i = rng.gen_range(0..NUM_PROCESSES);
626            info!("bouncing process {i}");
627            sim.bounce(format!("process-{i}"));
628        }
629
630        stable_tx.send(true).unwrap();
631
632        let mut num_ready = 0;
634        loop {
635            while let Ok(index) = ready_rx.try_recv() {
636                info!("process {index} is ready");
637                num_ready += 1;
638            }
639            if num_ready == NUM_PROCESSES {
640                break;
641            }
642
643            sim.step().unwrap();
644            if sim.elapsed() > Duration::from_secs(60) {
645                panic!("simulation not finished after 60s");
646            }
647        }
648    }
649
650    #[test] #[ignore = "runs forever"]
653    fn fuzz_create_sockets() {
654        loop {
655            test_create_sockets();
656        }
657    }
658
659    fn configure_tracing_for_turmoil() {
663        use std::sync::Once;
664        use tracing::level_filters::LevelFilter;
665        use tracing_subscriber::fmt::time::FormatTime;
666
667        #[derive(Clone)]
668        struct SimElapsedTime;
669
670        impl FormatTime for SimElapsedTime {
671            fn format_time(
672                &self,
673                w: &mut tracing_subscriber::fmt::format::Writer<'_>,
674            ) -> std::fmt::Result {
675                tracing_subscriber::fmt::time().format_time(w)?;
676                if let Some(sim_elapsed) = turmoil::sim_elapsed() {
677                    write!(w, " [{:?}]", sim_elapsed)?;
678                }
679                Ok(())
680            }
681        }
682
683        static INIT_TRACING: Once = Once::new();
684        INIT_TRACING.call_once(|| {
685            let env_filter = tracing_subscriber::EnvFilter::builder()
686                .with_default_directive(LevelFilter::INFO.into())
687                .from_env_lossy();
688            let subscriber = tracing_subscriber::fmt()
689                .with_test_writer()
690                .with_env_filter(env_filter)
691                .with_timer(SimElapsedTime)
692                .finish();
693
694            tracing::subscriber::set_global_default(subscriber).unwrap();
695        });
696    }
697}