1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
// Copyright Materialize, Inc. and contributors. All rights reserved.
//
// Use of this software is governed by the Business Source License
// included in the LICENSE file.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0.

use std::cmp;
use std::time::Duration;

use anyhow::{bail, Context};
use aws_sdk_kinesis::model::{ScalingType, StreamStatus};

use mz_ore::retry::Retry;

use crate::action::{ControlFlow, State};
use crate::parser::BuiltinCommand;

pub async fn run_update_shards(
    mut cmd: BuiltinCommand,
    state: &mut State,
) -> Result<ControlFlow, anyhow::Error> {
    let stream_name = format!("testdrive-{}", cmd.args.string("stream")?);
    let target_shard_count = cmd.args.parse("shards")?;
    cmd.args.done()?;

    let stream_name = format!("{}-{}", stream_name, state.seed);
    println!(
        "Updating Kinesis stream {} to have {} shards",
        stream_name, target_shard_count
    );

    state
        .kinesis_client
        .update_shard_count()
        .scaling_type(ScalingType::UniformScaling)
        .stream_name(&stream_name)
        .target_shard_count(target_shard_count)
        .send()
        .await
        .with_context(|| format!("adding shards to stream {}", &stream_name))?;

    // Verify the current shard count.
    Retry::default()
        .max_duration(cmp::max(state.default_timeout, Duration::from_secs(60)))
        .retry_async_canceling(|_| async {
            // Wait for shards to stop updating.
            let description = state
                .kinesis_client
                .describe_stream()
                .stream_name(&stream_name)
                .send()
                .await
                .context("getting current shard count")?
                .stream_description
                .unwrap();
            if description.stream_status != Some(StreamStatus::Active) {
                bail!(
                    "stream {} is not active, is {:?}",
                    stream_name,
                    description.stream_status
                );
            }

            let active_shards_len = i32::try_from(
                description
                    .shards
                    .unwrap()
                    .iter()
                    .filter(|shard| {
                        shard
                            .sequence_number_range
                            .as_ref()
                            .unwrap()
                            .ending_sequence_number
                            .is_none()
                    })
                    .count(),
            )
            .context("converting shard length to i32: {}")?;
            if active_shards_len != target_shard_count {
                bail!(
                    "expected {} shards, found {}",
                    target_shard_count,
                    active_shards_len
                );
            }
            Ok(())
        })
        .await?;
    Ok(ControlFlow::Continue)
}