Skip to main content

mz_ccsr/
client.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::error::Error;
11use std::fmt;
12use std::hash::Hash;
13use std::sync::Arc;
14use std::time::Duration;
15
16use anyhow::{anyhow, bail};
17use proptest_derive::Arbitrary;
18use reqwest::{Method, Response, Url};
19use serde::de::DeserializeOwned;
20use serde::{Deserialize, Serialize};
21
22use crate::config::Auth;
23
24/// An API client for a Confluent-compatible schema registry.
25#[derive(Clone)]
26pub struct Client {
27    inner: reqwest::Client,
28    url: Arc<dyn Fn() -> Url + Send + Sync + 'static>,
29    auth: Option<Auth>,
30    timeout: Duration,
31}
32
33impl fmt::Debug for Client {
34    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35        f.debug_struct("Client")
36            .field("inner", &self.inner)
37            .field("url", &"...")
38            .field("auth", &self.auth)
39            .finish()
40    }
41}
42
43impl Client {
44    pub(crate) fn new(
45        inner: reqwest::Client,
46        url: Arc<dyn Fn() -> Url + Send + Sync + 'static>,
47        auth: Option<Auth>,
48        timeout: Duration,
49    ) -> Result<Self, anyhow::Error> {
50        if url().cannot_be_a_base() {
51            bail!("cannot construct a CCSR client with a cannot-be-a-base URL");
52        }
53        Ok(Client {
54            inner,
55            url,
56            auth,
57            timeout,
58        })
59    }
60
61    fn make_request<P>(&self, method: Method, path: P) -> reqwest::RequestBuilder
62    where
63        P: IntoIterator,
64        P::Item: AsRef<str>,
65    {
66        let mut url = (self.url)();
67        url.path_segments_mut()
68            .expect("constructor validated URL can be a base")
69            .clear()
70            .extend(path);
71
72        let mut request = self.inner.request(method, url);
73        if let Some(auth) = &self.auth {
74            request = request.basic_auth(&auth.username, auth.password.as_ref());
75        }
76        request
77    }
78
79    pub fn timeout(&self) -> Duration {
80        self.timeout
81    }
82
83    /// Gets the schema with the associated ID.
84    pub async fn get_schema_by_id(&self, id: i32) -> Result<Schema, GetByIdError> {
85        let req = self.make_request(Method::GET, &["schemas", "ids", &id.to_string()]);
86        let res: GetByIdResponse = send_request(req).await?;
87        Ok(Schema {
88            id,
89            raw: res.schema,
90        })
91    }
92
93    /// Gets the latest schema for the specified subject.
94    pub async fn get_schema_by_subject(&self, subject: &str) -> Result<Schema, GetBySubjectError> {
95        self.get_subject_latest(subject).await.map(|s| s.schema)
96    }
97
98    /// Gets the latest version of the specified subject.
99    pub async fn get_subject_latest(&self, subject: &str) -> Result<Subject, GetBySubjectError> {
100        let req = self.make_request(Method::GET, &["subjects", subject, "versions", "latest"]);
101        let res: GetBySubjectResponse = send_request(req).await?;
102        Ok(Subject {
103            schema: Schema {
104                id: res.id,
105                raw: res.schema,
106            },
107            version: res.version,
108            name: res.subject,
109        })
110    }
111
112    /// Gets the latest version of the specified subject along with its direct references.
113    /// Returns the subject and a list of subject names that this subject directly references.
114    pub async fn get_subject_with_references(
115        &self,
116        subject: &str,
117    ) -> Result<(Subject, Vec<SubjectVersion>), GetBySubjectError> {
118        let req = self.make_request(Method::GET, &["subjects", subject, "versions", "latest"]);
119        let res: GetBySubjectResponse = send_request(req).await?;
120        let referenced_subjects: Vec<_> = res
121            .references
122            .into_iter()
123            .map(|r| SubjectVersion {
124                subject: r.subject,
125                version: r.version,
126            })
127            .collect();
128        Ok((
129            Subject {
130                schema: Schema {
131                    id: res.id,
132                    raw: res.schema,
133                },
134                version: res.version,
135                name: res.subject,
136            },
137            referenced_subjects,
138        ))
139    }
140
141    /// Gets the config set for the specified subject
142    pub async fn get_subject_config(
143        &self,
144        subject: &str,
145    ) -> Result<SubjectConfig, GetSubjectConfigError> {
146        let req = self.make_request(Method::GET, &["config", subject]);
147        let res: SubjectConfig = send_request(req).await?;
148        Ok(res)
149    }
150
151    /// Gets the latest version of the specified subject as well as all other
152    /// subjects referenced by that subject (recursively).
153    ///
154    /// The dependencies are returned in dependency order, with dependencies first.
155    pub async fn get_subject_and_references(
156        &self,
157        subject: &str,
158    ) -> Result<(Subject, Vec<Subject>), GetBySubjectError> {
159        self.get_subject_and_references_by_version(subject, "latest".to_owned())
160            .await
161    }
162
163    /// Gets a subject and all other subjects referenced by that subject (recursively)
164    ///
165    /// The dependencies are returned in dependency order, with dependencies first.
166    #[allow(clippy::disallowed_types)]
167    async fn get_subject_and_references_by_version(
168        &self,
169        subject: &str,
170        version: String,
171    ) -> Result<(Subject, Vec<Subject>), GetBySubjectError> {
172        let mut subjects = vec![];
173        // HashMap are used as we strictly need lookup, not ordering.
174        let mut graph = std::collections::HashMap::new();
175        let mut subjects_queue = vec![(subject.to_owned(), version)];
176        while let Some((subject, version)) = subjects_queue.pop() {
177            let req = self.make_request(Method::GET, &["subjects", &subject, "versions", &version]);
178            let res: GetBySubjectResponse = send_request(req).await?;
179            subjects.push(Subject {
180                schema: Schema {
181                    id: res.id,
182                    raw: res.schema,
183                },
184                version: res.version,
185                name: res.subject.clone(),
186            });
187            let subject_key = SubjectVersion {
188                subject: res.subject,
189                version: res.version,
190            };
191
192            let dependents: Vec<_> = res
193                .references
194                .into_iter()
195                .map(|reference| SubjectVersion {
196                    subject: reference.subject,
197                    version: reference.version,
198                })
199                .collect();
200
201            graph
202                .entry(subject_key)
203                .or_insert_with(Vec::new)
204                .extend(dependents.iter().cloned());
205            subjects_queue.extend(
206                dependents
207                    .into_iter()
208                    // Add dependents into the graph before adding to the queue as the same
209                    // named type may be encountered multiple times, but we only want to look it up
210                    // once.
211                    .filter(|dep| match graph.entry(dep.clone()) {
212                        std::collections::hash_map::Entry::Occupied(_) => false,
213                        std::collections::hash_map::Entry::Vacant(vacant) => {
214                            vacant.insert(vec![]);
215                            true
216                        }
217                    })
218                    .map(|dep| (dep.subject, dep.version.to_string())),
219            );
220        }
221        assert!(subjects.len() > 0, "Request should error if no subjects");
222
223        let primary = subjects.remove(0);
224
225        let ordered =
226            topological_sort(&graph).map_err(|_| GetBySubjectError::SchemaReferenceCycle)?;
227
228        subjects.sort_by(|a, b| {
229            let a = SubjectVersion {
230                subject: a.name.clone(),
231                version: a.version,
232            };
233            let b = SubjectVersion {
234                subject: b.name.clone(),
235                version: b.version,
236            };
237            ordered
238                .get(&b)
239                .unwrap_or_else(|| panic!("b {b:?}"))
240                .cmp(ordered.get(&a).unwrap_or_else(|| panic!("a {a:?}")))
241        });
242
243        Ok((primary, subjects))
244    }
245
246    /// Publishes a new schema for the specified subject. The ID of the new
247    /// schema is returned.
248    ///
249    /// Note that if a schema that is identical to an existing schema for the
250    /// same subject is published, the ID of the existing schema will be
251    /// returned.
252    pub async fn publish_schema(
253        &self,
254        subject: &str,
255        schema: &str,
256        schema_type: SchemaType,
257        references: &[SchemaReference],
258    ) -> Result<i32, PublishError> {
259        let req = self.make_request(Method::POST, &["subjects", subject, "versions"]);
260        let req = req.json(&PublishRequest {
261            schema,
262            schema_type,
263            references,
264        });
265        let res: PublishResponse = send_request(req).await?;
266        Ok(res.id)
267    }
268
269    /// Sets the compatibility level for the specified subject.
270    pub async fn set_subject_compatibility_level(
271        &self,
272        subject: &str,
273        compatibility_level: CompatibilityLevel,
274    ) -> Result<(), SetCompatibilityLevelError> {
275        let req = self.make_request(Method::PUT, &["config", subject]);
276        let req = req.json(&CompatibilityLevelRequest {
277            compatibility: compatibility_level,
278        });
279        send_request_raw(req).await?;
280        Ok(())
281    }
282
283    /// Lists the names of all subjects that the schema registry is aware of.
284    pub async fn list_subjects(&self) -> Result<Vec<String>, ListError> {
285        let req = self.make_request(Method::GET, &["subjects"]);
286        Ok(send_request(req).await?)
287    }
288
289    /// Deletes all schema versions associated with the specified subject.
290    ///
291    /// This API is only intended to be used in development environments.
292    /// Deleting schemas only allows new, potentially incompatible schemas to
293    /// be registered under the same subject. It does not allow the schema ID
294    /// to be reused.
295    pub async fn delete_subject(&self, subject: &str) -> Result<(), DeleteError> {
296        let req = self.make_request(Method::DELETE, &["subjects", subject]);
297        send_request_raw(req).await?;
298        Ok(())
299    }
300
301    /// Gets the latest version of the first subject found associated with the scheme with
302    /// the given id, as well as all other subjects referenced by that subject (recursively).
303    ///
304    /// The dependencies are returned in dependency order, with dependencies first.
305    pub async fn get_subject_and_references_by_id(
306        &self,
307        id: i32,
308    ) -> Result<(Subject, Vec<Subject>), GetBySubjectError> {
309        let req = self.make_request(
310            Method::GET,
311            &["schemas", "ids", &id.to_string(), "versions"],
312        );
313        let res: Vec<SubjectVersion> = send_request(req).await?;
314
315        // NOTE NOTE NOTE
316        // We take the FIRST subject that matches this schema id. This could be DIFFERENT
317        // than the actual subject we are interested in (it could even be from a different test
318        // run), but we are trusting the schema registry to only output the same schema id for
319        // identical subjects.
320        // This was validated by publishing 2 empty schemas (i.e., identical), with different
321        // references (one empty, one with a random reference), and they were not linked to the
322        // same schema id.
323        //
324        // See https://docs.confluent.io/platform/current/schema-registry/develop/api.html#post--subjects-(string-%20subject)-versions
325        // for more info.
326        match res.as_slice() {
327            [first, ..] => {
328                self.get_subject_and_references_by_version(
329                    &first.subject,
330                    first.version.to_string(),
331                )
332                .await
333            }
334            _ => Err(GetBySubjectError::SubjectNotFound),
335        }
336    }
337}
338
339/// Generates a topological ordering for a DAG.  If a cycle is detected in any returns an error.
340/// This can operator on a disconnected graph containing multiple DAGs.
341#[allow(clippy::disallowed_types)]
342pub fn topological_sort<T: Hash + Eq>(
343    graph: &std::collections::HashMap<T, Vec<T>>,
344) -> Result<std::collections::HashMap<&T, i32>, anyhow::Error> {
345    let mut referenced_by: std::collections::HashMap<&T, std::collections::HashSet<&T>> =
346        std::collections::HashMap::new();
347    for (subject, references) in graph.iter() {
348        for reference in references {
349            referenced_by.entry(reference).or_default().insert(subject);
350        }
351    }
352
353    // Start with nodes that have no incoming edges (empty referenced_by sets).
354    // Also include nodes in graph that aren't in referenced_by at all (roots).
355    let mut queue: Vec<_> = graph
356        .keys()
357        .filter(|key| {
358            referenced_by
359                .get(*key)
360                .map_or(true, |subjects| subjects.is_empty())
361        })
362        .collect();
363
364    let mut ordered = std::collections::HashMap::new();
365    let mut n = 0;
366    while let Some(subj_ver) = queue.pop() {
367        if let Some(refs) = graph.get(subj_ver) {
368            for ref_ver in refs {
369                let Some(subjects) = referenced_by.get_mut(ref_ver) else {
370                    continue;
371                };
372                subjects.remove(&subj_ver);
373                if subjects.is_empty() {
374                    referenced_by.remove_entry(ref_ver);
375                    queue.push(ref_ver);
376                }
377            }
378        }
379        ordered.insert(subj_ver, n);
380        n += 1;
381    }
382
383    if referenced_by.is_empty() {
384        Ok(ordered)
385    } else {
386        Err(anyhow!("Cycled detected during topoligical sort"))
387    }
388}
389
390async fn send_request<T>(req: reqwest::RequestBuilder) -> Result<T, UnhandledError>
391where
392    T: DeserializeOwned,
393{
394    let res = send_request_raw(req).await?;
395    Ok(res.json().await?)
396}
397
398async fn send_request_raw(req: reqwest::RequestBuilder) -> Result<Response, UnhandledError> {
399    let res = req.send().await?;
400    let status = res.status();
401    if status.is_success() {
402        Ok(res)
403    } else {
404        match res.json::<ErrorResponse>().await {
405            Ok(err_res) => Err(UnhandledError::Api {
406                code: err_res.error_code,
407                message: err_res.message,
408            }),
409            Err(_) => Err(UnhandledError::Api {
410                code: i32::from(status.as_u16()),
411                message: "unable to decode error details".into(),
412            }),
413        }
414    }
415}
416
417/// The type of a schema stored by a schema registry.
418#[derive(Clone, Copy, Debug, Serialize)]
419#[serde(rename_all = "UPPERCASE")]
420pub enum SchemaType {
421    /// An Avro schema.
422    Avro,
423    /// A Protobuf schema.
424    Protobuf,
425    /// A JSON schema.
426    Json,
427}
428
429impl SchemaType {
430    fn is_default(&self) -> bool {
431        matches!(self, SchemaType::Avro)
432    }
433}
434
435/// A schema stored by a schema registry.
436#[derive(Debug, Eq, PartialEq)]
437pub struct Schema {
438    /// The ID of the schema.
439    pub id: i32,
440    /// The raw text representing the schema.
441    pub raw: String,
442}
443
444/// A subject stored by a schema registry.
445#[derive(Debug, Eq, PartialEq)]
446pub struct Subject {
447    /// The version of the schema.
448    pub version: i32,
449    /// The name of the schema.
450    pub name: String,
451    /// The schema of the `version` of the `Subject`.
452    pub schema: Schema,
453}
454
455/// A reference from one schema in a schema registry to another.
456#[derive(Debug, Serialize, Deserialize)]
457#[serde(rename_all = "camelCase")]
458pub struct SchemaReference {
459    /// The name of the reference.
460    pub name: String,
461    /// The subject under which the referenced schema is registered.
462    pub subject: String,
463    /// The version of the referenced schema.
464    pub version: i32,
465}
466
467#[derive(Debug, Deserialize)]
468struct GetByIdResponse {
469    schema: String,
470}
471
472/// Errors for schema lookups by ID.
473#[derive(Debug)]
474pub enum GetByIdError {
475    /// No schema with the requested ID exists.
476    SchemaNotFound,
477    /// The underlying HTTP transport failed.
478    Transport(reqwest::Error),
479    /// An internal server error occurred.
480    Server { code: i32, message: String },
481}
482
483#[derive(Debug, Deserialize, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
484pub struct SubjectVersion {
485    /// The name of the subject
486    pub subject: String,
487    /// The version of the schema
488    pub version: i32,
489}
490
491impl From<UnhandledError> for GetByIdError {
492    fn from(err: UnhandledError) -> GetByIdError {
493        match err {
494            UnhandledError::Transport(err) => GetByIdError::Transport(err),
495            UnhandledError::Api { code, message } => match code {
496                40403 => GetByIdError::SchemaNotFound,
497                _ => GetByIdError::Server { code, message },
498            },
499        }
500    }
501}
502
503impl Error for GetByIdError {
504    fn source(&self) -> Option<&(dyn Error + 'static)> {
505        match self {
506            GetByIdError::SchemaNotFound | GetByIdError::Server { .. } => None,
507            GetByIdError::Transport(err) => Some(err),
508        }
509    }
510}
511
512impl fmt::Display for GetByIdError {
513    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
514        match self {
515            GetByIdError::SchemaNotFound => write!(f, "schema not found"),
516            GetByIdError::Transport(err) => write!(f, "transport: {}", err),
517            GetByIdError::Server { code, message } => {
518                write!(f, "server error {}: {}", code, message)
519            }
520        }
521    }
522}
523
524#[derive(Debug, Deserialize)]
525#[serde(rename_all = "camelCase")]
526pub struct SubjectConfig {
527    pub compatibility_level: CompatibilityLevel,
528    // There are other fields to include if we need them.
529}
530
531/// Errors for schema lookups by subject.
532#[derive(Debug)]
533pub enum GetSubjectConfigError {
534    /// The requested subject does not exist.
535    SubjectNotFound,
536    /// The compatibility level for the subject has not been set.
537    SubjectCompatibilityLevelNotSet,
538    /// The underlying HTTP transport failed.
539    Transport(reqwest::Error),
540    /// An internal server error occurred.
541    Server { code: i32, message: String },
542}
543
544impl From<UnhandledError> for GetSubjectConfigError {
545    fn from(err: UnhandledError) -> GetSubjectConfigError {
546        match err {
547            UnhandledError::Transport(err) => GetSubjectConfigError::Transport(err),
548            UnhandledError::Api { code, message } => match code {
549                404 => GetSubjectConfigError::SubjectNotFound,
550                40408 => GetSubjectConfigError::SubjectCompatibilityLevelNotSet,
551                _ => GetSubjectConfigError::Server { code, message },
552            },
553        }
554    }
555}
556
557impl Error for GetSubjectConfigError {
558    fn source(&self) -> Option<&(dyn Error + 'static)> {
559        match self {
560            GetSubjectConfigError::SubjectNotFound
561            | GetSubjectConfigError::SubjectCompatibilityLevelNotSet
562            | GetSubjectConfigError::Server { .. } => None,
563            GetSubjectConfigError::Transport(err) => Some(err),
564        }
565    }
566}
567
568impl fmt::Display for GetSubjectConfigError {
569    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
570        match self {
571            GetSubjectConfigError::SubjectNotFound => write!(f, "subject not found"),
572            GetSubjectConfigError::SubjectCompatibilityLevelNotSet => {
573                write!(f, "subject level compatibility not set")
574            }
575            GetSubjectConfigError::Transport(err) => write!(f, "transport: {}", err),
576            GetSubjectConfigError::Server { code, message } => {
577                write!(f, "server error {}: {}", code, message)
578            }
579        }
580    }
581}
582
583#[derive(Debug, Deserialize)]
584#[serde(rename_all = "camelCase")]
585struct GetBySubjectResponse {
586    id: i32,
587    schema: String,
588    version: i32,
589    subject: String,
590    #[serde(default)]
591    references: Vec<SchemaReference>,
592}
593
594/// Errors for schema lookups by subject.
595#[derive(Debug)]
596pub enum GetBySubjectError {
597    /// The requested subject does not exist.
598    SubjectNotFound,
599    /// The requested version does not exist.
600    VersionNotFound(String),
601    /// The underlying HTTP transport failed.
602    Transport(reqwest::Error),
603    /// An internal server error occurred.
604    Server { code: i32, message: String },
605    /// Cycle detected in schemas
606    SchemaReferenceCycle,
607}
608
609impl From<UnhandledError> for GetBySubjectError {
610    fn from(err: UnhandledError) -> GetBySubjectError {
611        match err {
612            UnhandledError::Transport(err) => GetBySubjectError::Transport(err),
613            UnhandledError::Api { code, message } => match code {
614                40401 => GetBySubjectError::SubjectNotFound,
615                40402 => GetBySubjectError::VersionNotFound(message),
616                _ => GetBySubjectError::Server { code, message },
617            },
618        }
619    }
620}
621
622impl Error for GetBySubjectError {
623    fn source(&self) -> Option<&(dyn Error + 'static)> {
624        match self {
625            GetBySubjectError::SubjectNotFound
626            | GetBySubjectError::VersionNotFound(_)
627            | GetBySubjectError::Server { .. }
628            | GetBySubjectError::SchemaReferenceCycle => None,
629            GetBySubjectError::Transport(err) => Some(err),
630        }
631    }
632}
633
634impl fmt::Display for GetBySubjectError {
635    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
636        match self {
637            GetBySubjectError::SubjectNotFound => write!(f, "subject not found"),
638            GetBySubjectError::VersionNotFound(message) => {
639                write!(f, "version not found: {}", message)
640            }
641            GetBySubjectError::Transport(err) => write!(f, "transport: {}", err),
642            GetBySubjectError::Server { code, message } => {
643                write!(f, "server error {}: {}", code, message)
644            }
645            GetBySubjectError::SchemaReferenceCycle => {
646                write!(f, "cycle detected in schema references")
647            }
648        }
649    }
650}
651
652#[derive(Debug, Serialize)]
653#[serde(rename_all = "camelCase")]
654struct PublishRequest<'a> {
655    schema: &'a str,
656    // Omitting the following fields when they're set to their defaults provides
657    // compatibility with old versions of the schema registry that don't
658    // understand these fields.
659    #[serde(skip_serializing_if = "SchemaType::is_default")]
660    schema_type: SchemaType,
661    #[serde(skip_serializing_if = "<[_]>::is_empty")]
662    references: &'a [SchemaReference],
663}
664
665#[derive(Debug, Deserialize)]
666#[serde(rename_all = "camelCase")]
667struct PublishResponse {
668    id: i32,
669}
670
671#[derive(Debug, Serialize)]
672#[serde(rename_all = "camelCase")]
673struct CompatibilityLevelRequest {
674    compatibility: CompatibilityLevel,
675}
676
677#[derive(Arbitrary, Clone, Copy, Debug, Serialize, Deserialize, Eq, PartialEq)]
678#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
679pub enum CompatibilityLevel {
680    Backward,
681    BackwardTransitive,
682    Forward,
683    ForwardTransitive,
684    Full,
685    FullTransitive,
686    None,
687}
688
689impl TryFrom<&str> for CompatibilityLevel {
690    type Error = String;
691
692    fn try_from(value: &str) -> Result<Self, Self::Error> {
693        match value {
694            "BACKWARD" => Ok(CompatibilityLevel::Backward),
695            "BACKWARD_TRANSITIVE" => Ok(CompatibilityLevel::BackwardTransitive),
696            "FORWARD" => Ok(CompatibilityLevel::Forward),
697            "FORWARD_TRANSITIVE" => Ok(CompatibilityLevel::ForwardTransitive),
698            "FULL" => Ok(CompatibilityLevel::Full),
699            "FULL_TRANSITIVE" => Ok(CompatibilityLevel::FullTransitive),
700            "NONE" => Ok(CompatibilityLevel::None),
701            _ => Err(format!("invalid compatibility level: {}", value)),
702        }
703    }
704}
705
706impl fmt::Display for CompatibilityLevel {
707    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
708        match self {
709            CompatibilityLevel::Backward => write!(f, "BACKWARD"),
710            CompatibilityLevel::BackwardTransitive => write!(f, "BACKWARD_TRANSITIVE"),
711            CompatibilityLevel::Forward => write!(f, "FORWARD"),
712            CompatibilityLevel::ForwardTransitive => write!(f, "FORWARD_TRANSITIVE"),
713            CompatibilityLevel::Full => write!(f, "FULL"),
714            CompatibilityLevel::FullTransitive => write!(f, "FULL_TRANSITIVE"),
715            CompatibilityLevel::None => write!(f, "NONE"),
716        }
717    }
718}
719
720/// Errors for publish operations.
721#[derive(Debug)]
722pub enum PublishError {
723    /// The provided schema was not compatible with existing schemas for that
724    /// subject, according to the subject's forwards- or backwards-compatibility
725    /// requirements.
726    IncompatibleSchema,
727    /// The provided schema was invalid.
728    InvalidSchema { message: String },
729    /// The underlying HTTP transport failed.
730    Transport(reqwest::Error),
731    /// An internal server error occurred.
732    Server { code: i32, message: String },
733}
734
735impl From<UnhandledError> for PublishError {
736    fn from(err: UnhandledError) -> PublishError {
737        match err {
738            UnhandledError::Transport(err) => PublishError::Transport(err),
739            UnhandledError::Api { code, message } => match code {
740                409 => PublishError::IncompatibleSchema,
741                42201 => PublishError::InvalidSchema { message },
742                _ => PublishError::Server { code, message },
743            },
744        }
745    }
746}
747
748impl Error for PublishError {
749    fn source(&self) -> Option<&(dyn Error + 'static)> {
750        match self {
751            PublishError::IncompatibleSchema
752            | PublishError::InvalidSchema { .. }
753            | PublishError::Server { .. } => None,
754            PublishError::Transport(err) => Some(err),
755        }
756    }
757}
758
759impl fmt::Display for PublishError {
760    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
761        match self {
762            // The error descriptions for IncompatibleSchema and InvalidSchema
763            // are copied from the schema registry itself.
764            PublishError::IncompatibleSchema => write!(
765                f,
766                "schema being registered is incompatible with an earlier schema"
767            ),
768            PublishError::InvalidSchema { message } => write!(f, "{}", message),
769            PublishError::Transport(err) => write!(f, "transport: {}", err),
770            PublishError::Server { code, message } => {
771                write!(f, "server error {}: {}", code, message)
772            }
773        }
774    }
775}
776
777/// Errors for list operations.
778#[derive(Debug)]
779pub enum ListError {
780    /// The underlying HTTP transport failed.
781    Transport(reqwest::Error),
782    /// An internal server error occurred.
783    Server { code: i32, message: String },
784}
785
786impl From<UnhandledError> for ListError {
787    fn from(err: UnhandledError) -> ListError {
788        match err {
789            UnhandledError::Transport(err) => ListError::Transport(err),
790            UnhandledError::Api { code, message } => ListError::Server { code, message },
791        }
792    }
793}
794
795impl Error for ListError {
796    fn source(&self) -> Option<&(dyn Error + 'static)> {
797        match self {
798            ListError::Server { .. } => None,
799            ListError::Transport(err) => Some(err),
800        }
801    }
802}
803
804impl fmt::Display for ListError {
805    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
806        match self {
807            ListError::Transport(err) => write!(f, "transport: {}", err),
808            ListError::Server { code, message } => write!(f, "server error {}: {}", code, message),
809        }
810    }
811}
812
813/// Errors for delete operations.
814#[derive(Debug)]
815pub enum DeleteError {
816    /// The specified subject does not exist.
817    SubjectNotFound,
818    /// The underlying HTTP transport failed.
819    Transport(reqwest::Error),
820    /// An internal server error occurred.
821    Server { code: i32, message: String },
822}
823
824impl From<UnhandledError> for DeleteError {
825    fn from(err: UnhandledError) -> DeleteError {
826        match err {
827            UnhandledError::Transport(err) => DeleteError::Transport(err),
828            UnhandledError::Api { code, message } => match code {
829                40401 => DeleteError::SubjectNotFound,
830                _ => DeleteError::Server { code, message },
831            },
832        }
833    }
834}
835
836impl Error for DeleteError {
837    fn source(&self) -> Option<&(dyn Error + 'static)> {
838        match self {
839            DeleteError::SubjectNotFound | DeleteError::Server { .. } => None,
840            DeleteError::Transport(err) => Some(err),
841        }
842    }
843}
844
845impl fmt::Display for DeleteError {
846    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
847        match self {
848            DeleteError::SubjectNotFound => write!(f, "subject not found"),
849            DeleteError::Transport(err) => write!(f, "transport: {}", err),
850            DeleteError::Server { code, message } => {
851                write!(f, "server error {}: {}", code, message)
852            }
853        }
854    }
855}
856
857/// Errors for setting compatibility level operations.
858#[derive(Debug)]
859pub enum SetCompatibilityLevelError {
860    /// The compatibility level is invalid.
861    InvalidCompatibilityLevel,
862    /// The underlying HTTP transport failed.
863    Transport(reqwest::Error),
864    /// An internal server error occurred.
865    Server { code: i32, message: String },
866}
867
868impl From<UnhandledError> for SetCompatibilityLevelError {
869    fn from(err: UnhandledError) -> SetCompatibilityLevelError {
870        match err {
871            UnhandledError::Transport(err) => SetCompatibilityLevelError::Transport(err),
872            UnhandledError::Api { code, message } => match code {
873                42203 => SetCompatibilityLevelError::InvalidCompatibilityLevel,
874                _ => SetCompatibilityLevelError::Server { code, message },
875            },
876        }
877    }
878}
879
880impl Error for SetCompatibilityLevelError {
881    fn source(&self) -> Option<&(dyn Error + 'static)> {
882        match self {
883            SetCompatibilityLevelError::InvalidCompatibilityLevel
884            | SetCompatibilityLevelError::Server { .. } => None,
885            SetCompatibilityLevelError::Transport(err) => Some(err),
886        }
887    }
888}
889
890impl fmt::Display for SetCompatibilityLevelError {
891    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
892        match self {
893            SetCompatibilityLevelError::InvalidCompatibilityLevel => {
894                write!(f, "invalid compatibility level")
895            }
896            SetCompatibilityLevelError::Transport(err) => write!(f, "transport: {}", err),
897            SetCompatibilityLevelError::Server { code, message } => {
898                write!(f, "server error {}: {}", code, message)
899            }
900        }
901    }
902}
903
904#[derive(Debug, Deserialize)]
905struct ErrorResponse {
906    error_code: i32,
907    message: String,
908}
909
910#[derive(Debug)]
911enum UnhandledError {
912    Transport(reqwest::Error),
913    Api { code: i32, message: String },
914}
915
916impl From<reqwest::Error> for UnhandledError {
917    fn from(err: reqwest::Error) -> UnhandledError {
918        UnhandledError::Transport(err)
919    }
920}
921
922#[cfg(test)]
923mod tests {
924    #![allow(clippy::disallowed_types)]
925    use super::*;
926    use std::collections::HashMap;
927
928    /// Helper to create a SubjectVersion
929    fn sv(subject: &str, version: i32) -> SubjectVersion {
930        SubjectVersion {
931            subject: subject.to_string(),
932            version,
933        }
934    }
935
936    /// Helper to build a graph from a list of edges. Each edge is (from, to) meaning "from"
937    /// depends on "to".
938    ///
939    /// We take the root separately to build a graph without edges.
940    fn build_graph(
941        edges: &[(SubjectVersion, Option<SubjectVersion>)],
942    ) -> HashMap<SubjectVersion, Vec<SubjectVersion>> {
943        let mut graph: HashMap<SubjectVersion, Vec<SubjectVersion>> = HashMap::new();
944
945        // Add edges: from depends on to
946        for (from, to) in edges {
947            let deps = graph.entry(from.clone()).or_default();
948            if let Some(to) = to {
949                deps.push(to.clone());
950            }
951        }
952
953        graph
954    }
955
956    /// Verify that all edges are respected in the ordering.
957    /// For edge (from, to) where 'from' depends on 'to':
958    /// - 'from' should be processed before 'to' (lower order number)
959    /// - This is because the algorithm processes from roots (nothing depends on them)
960    ///   towards leaves (they don't depend on anything)
961    fn verify_edge_ordering(
962        ordered: &HashMap<&SubjectVersion, i32>,
963        edges: &[(SubjectVersion, Option<SubjectVersion>)],
964    ) {
965        for (from, to) in edges {
966            if let Some(to) = to {
967                let from_order = ordered.get(from).expect("from node should be in ordering");
968                let to_order = ordered.get(to).expect("to node should be in ordering");
969
970                // 'from' depends on 'to', so 'from' is processed first (lower order number)
971                // The algorithm starts at roots (nodes nothing depends on) and works toward leaves
972                assert!(
973                    from_order < to_order,
974                    "{:?} (order {}) depends on {:?} (order {}), so should be processed first",
975                    from,
976                    from_order,
977                    to,
978                    to_order
979                );
980            }
981        }
982    }
983
984    #[mz_ore::test]
985    fn test_topological_sort_empty_graph() {
986        let graph: HashMap<SubjectVersion, Vec<SubjectVersion>> = HashMap::new();
987
988        let ordered = topological_sort(&graph).unwrap();
989
990        assert!(ordered.is_empty());
991    }
992
993    #[mz_ore::test]
994    fn test_topological_sort_single_node() {
995        // Single node with no dependencies
996        let a = sv("a", 1);
997        let graph = build_graph(&[(a.clone(), None)]);
998
999        let ordered = topological_sort(&graph).unwrap();
1000
1001        assert_eq!(ordered.len(), 1);
1002        assert!(ordered.contains_key(&a));
1003    }
1004
1005    #[mz_ore::test]
1006    fn test_topological_sort_linear_chain() {
1007        // A -> B -> C -> D
1008        let a = sv("a", 1);
1009        let b = sv("b", 1);
1010        let c = sv("c", 1);
1011        let d = sv("d", 1);
1012
1013        let edges = vec![
1014            (a.clone(), Some(b.clone())),
1015            (b.clone(), Some(c.clone())),
1016            (c.clone(), Some(d.clone())),
1017            (d.clone(), None),
1018        ];
1019        let graph = build_graph(&edges);
1020
1021        let ordered = topological_sort(&graph).unwrap();
1022
1023        assert_eq!(ordered.len(), 4);
1024        verify_edge_ordering(&ordered, &edges);
1025    }
1026
1027    #[mz_ore::test]
1028    fn test_topological_sort_diamond() {
1029        // Classic diamond pattern:
1030        //     A
1031        //    / \
1032        //   B   C
1033        //    \ /
1034        //     D
1035        // A depends on B and C, both B and C depend on D
1036        let a = sv("a", 1);
1037        let b = sv("b", 1);
1038        let c = sv("c", 1);
1039        let d = sv("d", 1);
1040
1041        let edges = vec![
1042            (a.clone(), Some(b.clone())),
1043            (a.clone(), Some(c.clone())),
1044            (b.clone(), Some(d.clone())),
1045            (c.clone(), Some(d.clone())),
1046            (d.clone(), None),
1047        ];
1048        let graph = build_graph(&edges);
1049
1050        let ordered = topological_sort(&graph).unwrap();
1051
1052        assert_eq!(ordered.len(), 4);
1053        verify_edge_ordering(&ordered, &edges);
1054
1055        // A must come before B and C, and B and C must come before D
1056        assert!(ordered[&a] < ordered[&d]);
1057    }
1058
1059    #[mz_ore::test]
1060    fn test_topological_sort_wide_graph() {
1061        // Wide graph: A depends on B, C, D, E, F (many dependencies)
1062        let a = sv("a", 1);
1063        let b = sv("b", 1);
1064        let c = sv("c", 1);
1065        let d = sv("d", 1);
1066        let e = sv("e", 1);
1067        let f = sv("f", 1);
1068
1069        let edges = vec![
1070            (a.clone(), Some(b.clone())),
1071            (a.clone(), Some(c.clone())),
1072            (a.clone(), Some(d.clone())),
1073            (a.clone(), Some(e.clone())),
1074            (a.clone(), Some(f.clone())),
1075            (b.clone(), None),
1076            (c.clone(), None),
1077            (d.clone(), None),
1078            (e.clone(), None),
1079            (f.clone(), None),
1080        ];
1081        let graph = build_graph(&edges);
1082
1083        let ordered = topological_sort(&graph).unwrap();
1084
1085        assert_eq!(ordered.len(), 6);
1086        verify_edge_ordering(&ordered, &edges);
1087    }
1088
1089    #[mz_ore::test]
1090    fn test_topological_sort_complex_dag() {
1091        // Complex DAG:
1092        //       A
1093        //      /|\
1094        //     B C D
1095        //     |/| |
1096        //     E F G
1097        //      \|/
1098        //       H
1099        // A -> B, C, D
1100        // B -> E
1101        // C -> E, F
1102        // D -> G
1103        // E -> H
1104        // F -> H
1105        // G -> H
1106        let a = sv("a", 1);
1107        let b = sv("b", 1);
1108        let c = sv("c", 1);
1109        let d = sv("d", 1);
1110        let e = sv("e", 1);
1111        let f = sv("f", 1);
1112        let g = sv("g", 1);
1113        let h = sv("h", 1);
1114
1115        let edges = vec![
1116            (a.clone(), Some(b.clone())),
1117            (a.clone(), Some(c.clone())),
1118            (a.clone(), Some(d.clone())),
1119            (b.clone(), Some(e.clone())),
1120            (c.clone(), Some(e.clone())),
1121            (c.clone(), Some(f.clone())),
1122            (d.clone(), Some(g.clone())),
1123            (e.clone(), Some(h.clone())),
1124            (f.clone(), Some(h.clone())),
1125            (g.clone(), Some(h.clone())),
1126            (h.clone(), None),
1127        ];
1128        let graph = build_graph(&edges);
1129
1130        let ordered = topological_sort(&graph).unwrap();
1131
1132        assert_eq!(ordered.len(), 8);
1133        verify_edge_ordering(&ordered, &edges);
1134    }
1135
1136    #[mz_ore::test]
1137    fn test_topological_sort_with_versions() {
1138        // Same subject, different versions
1139        //        A2
1140        //       / |
1141        //      A1 |
1142        //       \ |
1143        //        B1
1144        let a_v1 = sv("a", 1);
1145        let a_v2 = sv("a", 2);
1146        let b_v1 = sv("b", 1);
1147
1148        // a_v2 depends on a_v1, and both depend on b_v1
1149        let edges = vec![
1150            (a_v2.clone(), Some(a_v1.clone())),
1151            (a_v2.clone(), Some(b_v1.clone())),
1152            (a_v1.clone(), Some(b_v1.clone())),
1153            (b_v1.clone(), None),
1154        ];
1155        let graph = build_graph(&edges);
1156
1157        let ordered = topological_sort(&graph).unwrap();
1158
1159        assert_eq!(ordered.len(), 3);
1160        verify_edge_ordering(&ordered, &edges);
1161    }
1162
1163    #[mz_ore::test]
1164    fn test_topological_sort_multi_level_diamond() {
1165        // Double diamond:
1166        //       A
1167        //      / \
1168        //     B   C
1169        //      \ /
1170        //       D
1171        //      / \
1172        //     E   F
1173        //      \ /
1174        //       G
1175        let a = sv("a", 1);
1176        let b = sv("b", 1);
1177        let c = sv("c", 1);
1178        let d = sv("d", 1);
1179        let e = sv("e", 1);
1180        let f = sv("f", 1);
1181        let g = sv("g", 1);
1182
1183        let edges = vec![
1184            (a.clone(), Some(b.clone())),
1185            (a.clone(), Some(c.clone())),
1186            (b.clone(), Some(d.clone())),
1187            (c.clone(), Some(d.clone())),
1188            (d.clone(), Some(e.clone())),
1189            (d.clone(), Some(f.clone())),
1190            (e.clone(), Some(g.clone())),
1191            (f.clone(), Some(g.clone())),
1192            (g.clone(), None),
1193        ];
1194        let graph = build_graph(&edges);
1195
1196        let ordered = topological_sort(&graph).unwrap();
1197
1198        assert_eq!(ordered.len(), 7);
1199        verify_edge_ordering(&ordered, &edges);
1200    }
1201
1202    #[mz_ore::test]
1203    fn test_topological_sort_shared_dependency_at_multiple_levels() {
1204        // Shared dependency accessed at multiple levels:
1205        //     A
1206        //    /|\
1207        //   B C |
1208        //   |/  |
1209        //   D   |
1210        //    \ /
1211        //     E
1212        //     |
1213        //     F
1214        let a = sv("a", 1);
1215        let b = sv("b", 1);
1216        let c = sv("c", 1);
1217        let d = sv("d", 1);
1218        let e = sv("e", 1);
1219        let f = sv("f", 1);
1220
1221        let edges = vec![
1222            (a.clone(), Some(b.clone())),
1223            (a.clone(), Some(c.clone())),
1224            (a.clone(), Some(e.clone())),
1225            (b.clone(), Some(d.clone())),
1226            (c.clone(), Some(d.clone())),
1227            (d.clone(), Some(e.clone())),
1228            (e.clone(), Some(f.clone())),
1229            (f.clone(), None),
1230        ];
1231        let graph = build_graph(&edges);
1232
1233        let ordered = topological_sort(&graph).expect("no cycle");
1234
1235        assert_eq!(ordered.len(), 6);
1236        verify_edge_ordering(&ordered, &edges);
1237    }
1238
1239    #[mz_ore::test]
1240    fn test_topological_sort_lattice_structure() {
1241        // Can you tell I had Claude help make up tests?
1242        // Lattice structure (more complex than diamond):
1243        //       A
1244        //      /|\
1245        //     B C D
1246        //     |\|/|
1247        //     | X |
1248        //     |/|\|
1249        //     E F G
1250        //      \|/
1251        //       H
1252        // A -> B, C, D
1253        // B -> E, F
1254        // C -> E, F, G
1255        // D -> F, G
1256        // E -> H
1257        // F -> H
1258        // G -> H
1259        let a = sv("a", 1);
1260        let b = sv("b", 1);
1261        let c = sv("c", 1);
1262        let d = sv("d", 1);
1263        let e = sv("e", 1);
1264        let f = sv("f", 1);
1265        let g = sv("g", 1);
1266        let h = sv("h", 1);
1267
1268        let edges = vec![
1269            (a.clone(), Some(b.clone())),
1270            (a.clone(), Some(c.clone())),
1271            (a.clone(), Some(d.clone())),
1272            (b.clone(), Some(e.clone())),
1273            (b.clone(), Some(f.clone())),
1274            (c.clone(), Some(e.clone())),
1275            (c.clone(), Some(f.clone())),
1276            (c.clone(), Some(g.clone())),
1277            (d.clone(), Some(f.clone())),
1278            (d.clone(), Some(g.clone())),
1279            (e.clone(), Some(h.clone())),
1280            (f.clone(), Some(h.clone())),
1281            (g.clone(), Some(h.clone())),
1282            (h.clone(), None),
1283        ];
1284        let graph = build_graph(&edges);
1285
1286        let ordered = topological_sort(&graph).unwrap();
1287
1288        assert_eq!(ordered.len(), 8);
1289        verify_edge_ordering(&ordered, &edges);
1290    }
1291
1292    #[mz_ore::test]
1293    fn test_topological_sort_binary_tree() {
1294        // Full binary tree structure (inverted, dependencies flow down):
1295        //        A
1296        //       / \
1297        //      B   C
1298        //     /|   |\
1299        //    D E   F G
1300        let a = sv("a", 1);
1301        let b = sv("b", 1);
1302        let c = sv("c", 1);
1303        let d = sv("d", 1);
1304        let e = sv("e", 1);
1305        let f = sv("f", 1);
1306        let g = sv("g", 1);
1307
1308        let edges = vec![
1309            (a.clone(), Some(b.clone())),
1310            (a.clone(), Some(c.clone())),
1311            (b.clone(), Some(d.clone())),
1312            (b.clone(), Some(e.clone())),
1313            (c.clone(), Some(f.clone())),
1314            (c.clone(), Some(g.clone())),
1315            (d.clone(), None),
1316            (e.clone(), None),
1317            (f.clone(), None),
1318            (g.clone(), None),
1319        ];
1320        let graph = build_graph(&edges);
1321
1322        let ordered = topological_sort(&graph).unwrap();
1323
1324        assert_eq!(ordered.len(), 7);
1325        verify_edge_ordering(&ordered, &edges);
1326    }
1327
1328    #[mz_ore::test]
1329    fn test_topological_sort_simple_cycle() {
1330        // Simple cycle: A -> B -> A
1331        let a = sv("a", 1);
1332        let b = sv("b", 1);
1333
1334        let mut graph: HashMap<SubjectVersion, Vec<SubjectVersion>> = HashMap::new();
1335        graph.insert(a.clone(), vec![b.clone()]);
1336        graph.insert(b.clone(), vec![a.clone()]);
1337
1338        let sort_result = topological_sort(&graph);
1339
1340        assert!(sort_result.is_err(), "Expected sort to detect cycle");
1341    }
1342
1343    #[mz_ore::test]
1344    fn test_topological_sort_cycle_with_entry_point() {
1345        // Cycle with entry point: A -> B -> C -> B (B and C form a cycle)
1346        let a = sv("a", 1);
1347        let b = sv("b", 1);
1348        let c = sv("c", 1);
1349
1350        let mut graph: HashMap<SubjectVersion, Vec<SubjectVersion>> = HashMap::new();
1351        graph.insert(a.clone(), vec![b.clone()]);
1352        graph.insert(b.clone(), vec![c.clone()]);
1353        graph.insert(c.clone(), vec![b.clone()]); // C points back to B
1354
1355        let sort_result = topological_sort(&graph);
1356        assert!(sort_result.is_err(), "Expected sort to detect cycle");
1357    }
1358
1359    #[mz_ore::test]
1360    fn test_topological_sort_self_reference() {
1361        // Self-referencing node: A -> A
1362        let a = sv("a", 1);
1363
1364        let mut graph: HashMap<SubjectVersion, Vec<SubjectVersion>> = HashMap::new();
1365        graph.insert(a.clone(), vec![a.clone()]);
1366
1367        let sort_result = topological_sort(&graph);
1368        assert!(sort_result.is_err(), "Expected sort to detect cycle");
1369    }
1370
1371    #[mz_ore::test]
1372    fn test_topological_sort_three_node_cycle() {
1373        // Three node cycle: A -> B -> C -> A
1374        let a = sv("a", 1);
1375        let b = sv("b", 1);
1376        let c = sv("c", 1);
1377
1378        let mut graph: HashMap<SubjectVersion, Vec<SubjectVersion>> = HashMap::new();
1379        graph.insert(a.clone(), vec![b.clone()]);
1380        graph.insert(b.clone(), vec![c.clone()]);
1381        graph.insert(c.clone(), vec![a.clone()]);
1382
1383        let sort_result = topological_sort(&graph);
1384        assert!(sort_result.is_err(), "Expected sort to detect cycle");
1385    }
1386}