criterion/
connection.rs

1use crate::report::BenchmarkId as InternalBenchmarkId;
2use crate::Throughput;
3use std::cell::RefCell;
4use std::convert::TryFrom;
5use std::io::{Read, Write};
6use std::mem::size_of;
7use std::net::TcpStream;
8
9#[derive(Debug)]
10pub enum MessageError {
11    Deserialization(ciborium::de::Error<std::io::Error>),
12    Serialization(ciborium::ser::Error<std::io::Error>),
13    Io(std::io::Error),
14}
15impl From<ciborium::de::Error<std::io::Error>> for MessageError {
16    fn from(other: ciborium::de::Error<std::io::Error>) -> Self {
17        MessageError::Deserialization(other)
18    }
19}
20impl From<ciborium::ser::Error<std::io::Error>> for MessageError {
21    fn from(other: ciborium::ser::Error<std::io::Error>) -> Self {
22        MessageError::Serialization(other)
23    }
24}
25impl From<std::io::Error> for MessageError {
26    fn from(other: std::io::Error) -> Self {
27        MessageError::Io(other)
28    }
29}
30impl std::fmt::Display for MessageError {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        match self {
33            MessageError::Deserialization(error) => write!(
34                f,
35                "Failed to deserialize message to Criterion.rs benchmark:\n{}",
36                error
37            ),
38            MessageError::Serialization(error) => write!(
39                f,
40                "Failed to serialize message to Criterion.rs benchmark:\n{}",
41                error
42            ),
43            MessageError::Io(error) => write!(
44                f,
45                "Failed to read or write message to Criterion.rs benchmark:\n{}",
46                error
47            ),
48        }
49    }
50}
51impl std::error::Error for MessageError {
52    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
53        match self {
54            MessageError::Deserialization(err) => Some(err),
55            MessageError::Serialization(err) => Some(err),
56            MessageError::Io(err) => Some(err),
57        }
58    }
59}
60
61// Use str::len as a const fn once we bump MSRV over 1.39.
62const RUNNER_MAGIC_NUMBER: &str = "cargo-criterion";
63const RUNNER_HELLO_SIZE: usize = 15 //RUNNER_MAGIC_NUMBER.len() // magic number
64    + (size_of::<u8>() * 3); // version number
65
66const BENCHMARK_MAGIC_NUMBER: &str = "Criterion";
67const BENCHMARK_HELLO_SIZE: usize = 9 //BENCHMARK_MAGIC_NUMBER.len() // magic number
68    + (size_of::<u8>() * 3) // version number
69    + size_of::<u16>() // protocol version
70    + size_of::<u16>(); // protocol format
71const PROTOCOL_VERSION: u16 = 1;
72const PROTOCOL_FORMAT: u16 = 1;
73
74#[derive(Debug)]
75struct InnerConnection {
76    socket: TcpStream,
77    receive_buffer: Vec<u8>,
78    send_buffer: Vec<u8>,
79    // runner_version: [u8; 3],
80}
81impl InnerConnection {
82    pub fn new(mut socket: TcpStream) -> Result<Self, std::io::Error> {
83        // read the runner-hello
84        let mut hello_buf = [0u8; RUNNER_HELLO_SIZE];
85        socket.read_exact(&mut hello_buf)?;
86        assert_eq!(
87            &hello_buf[0..RUNNER_MAGIC_NUMBER.len()],
88            RUNNER_MAGIC_NUMBER.as_bytes(),
89            "Not connected to cargo-criterion."
90        );
91
92        let i = RUNNER_MAGIC_NUMBER.len();
93        let runner_version = [hello_buf[i], hello_buf[i + 1], hello_buf[i + 2]];
94
95        info!("Runner version: {:?}", runner_version);
96
97        // now send the benchmark-hello
98        let mut hello_buf = [0u8; BENCHMARK_HELLO_SIZE];
99        hello_buf[0..BENCHMARK_MAGIC_NUMBER.len()]
100            .copy_from_slice(BENCHMARK_MAGIC_NUMBER.as_bytes());
101        let mut i = BENCHMARK_MAGIC_NUMBER.len();
102        hello_buf[i] = env!("CARGO_PKG_VERSION_MAJOR").parse().unwrap();
103        hello_buf[i + 1] = env!("CARGO_PKG_VERSION_MINOR").parse().unwrap();
104        hello_buf[i + 2] = env!("CARGO_PKG_VERSION_PATCH").parse().unwrap();
105        i += 3;
106        hello_buf[i..i + 2].clone_from_slice(&PROTOCOL_VERSION.to_be_bytes());
107        i += 2;
108        hello_buf[i..i + 2].clone_from_slice(&PROTOCOL_FORMAT.to_be_bytes());
109
110        socket.write_all(&hello_buf)?;
111
112        Ok(InnerConnection {
113            socket,
114            receive_buffer: vec![],
115            send_buffer: vec![],
116            // runner_version,
117        })
118    }
119
120    #[allow(dead_code)]
121    pub fn recv(&mut self) -> Result<IncomingMessage, MessageError> {
122        let mut length_buf = [0u8; 4];
123        self.socket.read_exact(&mut length_buf)?;
124        let length = u32::from_be_bytes(length_buf);
125        self.receive_buffer.resize(length as usize, 0u8);
126        self.socket.read_exact(&mut self.receive_buffer)?;
127        let value = ciborium::de::from_reader(&self.receive_buffer[..])?;
128        Ok(value)
129    }
130
131    pub fn send(&mut self, message: &OutgoingMessage) -> Result<(), MessageError> {
132        self.send_buffer.truncate(0);
133        ciborium::ser::into_writer(message, &mut self.send_buffer)?;
134        let size = u32::try_from(self.send_buffer.len()).unwrap();
135        let length_buf = size.to_be_bytes();
136        self.socket.write_all(&length_buf)?;
137        self.socket.write_all(&self.send_buffer)?;
138        Ok(())
139    }
140}
141
142/// This is really just a holder to allow us to send messages through a shared reference to the
143/// connection.
144#[derive(Debug)]
145pub struct Connection {
146    inner: RefCell<InnerConnection>,
147}
148impl Connection {
149    pub fn new(socket: TcpStream) -> Result<Self, std::io::Error> {
150        Ok(Connection {
151            inner: RefCell::new(InnerConnection::new(socket)?),
152        })
153    }
154
155    #[allow(dead_code)]
156    pub fn recv(&self) -> Result<IncomingMessage, MessageError> {
157        self.inner.borrow_mut().recv()
158    }
159
160    pub fn send(&self, message: &OutgoingMessage) -> Result<(), MessageError> {
161        self.inner.borrow_mut().send(message)
162    }
163
164    pub fn serve_value_formatter(
165        &self,
166        formatter: &dyn crate::measurement::ValueFormatter,
167    ) -> Result<(), MessageError> {
168        loop {
169            let response = match self.recv()? {
170                IncomingMessage::FormatValue { value } => OutgoingMessage::FormattedValue {
171                    value: formatter.format_value(value),
172                },
173                IncomingMessage::FormatThroughput { value, throughput } => {
174                    OutgoingMessage::FormattedValue {
175                        value: formatter.format_throughput(&throughput, value),
176                    }
177                }
178                IncomingMessage::ScaleValues {
179                    typical_value,
180                    mut values,
181                } => {
182                    let unit = formatter.scale_values(typical_value, &mut values);
183                    OutgoingMessage::ScaledValues {
184                        unit,
185                        scaled_values: values,
186                    }
187                }
188                IncomingMessage::ScaleThroughputs {
189                    typical_value,
190                    throughput,
191                    mut values,
192                } => {
193                    let unit = formatter.scale_throughputs(typical_value, &throughput, &mut values);
194                    OutgoingMessage::ScaledValues {
195                        unit,
196                        scaled_values: values,
197                    }
198                }
199                IncomingMessage::ScaleForMachines { mut values } => {
200                    let unit = formatter.scale_for_machines(&mut values);
201                    OutgoingMessage::ScaledValues {
202                        unit,
203                        scaled_values: values,
204                    }
205                }
206                IncomingMessage::Continue => break,
207                _ => panic!(),
208            };
209            self.send(&response)?;
210        }
211        Ok(())
212    }
213}
214
215/// Enum defining the messages we can receive
216#[derive(Debug, Deserialize)]
217pub enum IncomingMessage {
218    // Value formatter requests
219    FormatValue {
220        value: f64,
221    },
222    FormatThroughput {
223        value: f64,
224        throughput: Throughput,
225    },
226    ScaleValues {
227        typical_value: f64,
228        values: Vec<f64>,
229    },
230    ScaleThroughputs {
231        typical_value: f64,
232        values: Vec<f64>,
233        throughput: Throughput,
234    },
235    ScaleForMachines {
236        values: Vec<f64>,
237    },
238    Continue,
239
240    __Other,
241}
242
243/// Enum defining the messages we can send
244#[derive(Debug, Serialize)]
245pub enum OutgoingMessage<'a> {
246    BeginningBenchmarkGroup {
247        group: &'a str,
248    },
249    FinishedBenchmarkGroup {
250        group: &'a str,
251    },
252    BeginningBenchmark {
253        id: RawBenchmarkId,
254    },
255    SkippingBenchmark {
256        id: RawBenchmarkId,
257    },
258    Warmup {
259        id: RawBenchmarkId,
260        nanos: f64,
261    },
262    MeasurementStart {
263        id: RawBenchmarkId,
264        sample_count: u64,
265        estimate_ns: f64,
266        iter_count: u64,
267    },
268    MeasurementComplete {
269        id: RawBenchmarkId,
270        iters: &'a [f64],
271        times: &'a [f64],
272        plot_config: PlotConfiguration,
273        sampling_method: SamplingMethod,
274        benchmark_config: BenchmarkConfig,
275    },
276    // value formatter responses
277    FormattedValue {
278        value: String,
279    },
280    ScaledValues {
281        scaled_values: Vec<f64>,
282        unit: &'a str,
283    },
284}
285
286// Also define serializable variants of certain things, either to avoid leaking
287// serializability into the public interface or because the serialized form
288// is a bit different from the regular one.
289
290#[derive(Debug, Serialize)]
291pub struct RawBenchmarkId {
292    group_id: String,
293    function_id: Option<String>,
294    value_str: Option<String>,
295    throughput: Vec<Throughput>,
296}
297impl From<&InternalBenchmarkId> for RawBenchmarkId {
298    fn from(other: &InternalBenchmarkId) -> RawBenchmarkId {
299        RawBenchmarkId {
300            group_id: other.group_id.clone(),
301            function_id: other.function_id.clone(),
302            value_str: other.value_str.clone(),
303            throughput: other.throughput.iter().cloned().collect(),
304        }
305    }
306}
307
308#[derive(Debug, Serialize)]
309pub enum AxisScale {
310    Linear,
311    Logarithmic,
312}
313impl From<crate::AxisScale> for AxisScale {
314    fn from(other: crate::AxisScale) -> Self {
315        match other {
316            crate::AxisScale::Linear => AxisScale::Linear,
317            crate::AxisScale::Logarithmic => AxisScale::Logarithmic,
318        }
319    }
320}
321
322#[derive(Debug, Serialize)]
323pub struct PlotConfiguration {
324    summary_scale: AxisScale,
325}
326impl From<&crate::PlotConfiguration> for PlotConfiguration {
327    fn from(other: &crate::PlotConfiguration) -> Self {
328        PlotConfiguration {
329            summary_scale: other.summary_scale.into(),
330        }
331    }
332}
333
334#[derive(Debug, Serialize)]
335struct Duration {
336    secs: u64,
337    nanos: u32,
338}
339impl From<std::time::Duration> for Duration {
340    fn from(other: std::time::Duration) -> Self {
341        Duration {
342            secs: other.as_secs(),
343            nanos: other.subsec_nanos(),
344        }
345    }
346}
347
348#[derive(Debug, Serialize)]
349pub struct BenchmarkConfig {
350    confidence_level: f64,
351    measurement_time: Duration,
352    noise_threshold: f64,
353    nresamples: usize,
354    sample_size: usize,
355    significance_level: f64,
356    warm_up_time: Duration,
357}
358impl From<&crate::benchmark::BenchmarkConfig> for BenchmarkConfig {
359    fn from(other: &crate::benchmark::BenchmarkConfig) -> Self {
360        BenchmarkConfig {
361            confidence_level: other.confidence_level,
362            measurement_time: other.measurement_time.into(),
363            noise_threshold: other.noise_threshold,
364            nresamples: other.nresamples,
365            sample_size: other.sample_size,
366            significance_level: other.significance_level,
367            warm_up_time: other.warm_up_time.into(),
368        }
369    }
370}
371
372/// Currently not used; defined for forwards compatibility with cargo-criterion.
373#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
374pub enum SamplingMethod {
375    Linear,
376    Flat,
377}
378impl From<crate::ActualSamplingMode> for SamplingMethod {
379    fn from(other: crate::ActualSamplingMode) -> Self {
380        match other {
381            crate::ActualSamplingMode::Flat => SamplingMethod::Flat,
382            crate::ActualSamplingMode::Linear => SamplingMethod::Linear,
383        }
384    }
385}