Skip to main content

mz_ccsr/
client.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::error::Error;
11use std::fmt;
12use std::hash::Hash;
13use std::sync::Arc;
14use std::time::Duration;
15
16use anyhow::{anyhow, bail};
17use proptest_derive::Arbitrary;
18use reqwest::{Method, Response, Url};
19use serde::de::DeserializeOwned;
20use serde::{Deserialize, Serialize};
21
22use crate::config::Auth;
23
24/// An API client for a Confluent-compatible schema registry.
25#[derive(Clone)]
26pub struct Client {
27    inner: reqwest::Client,
28    url: Arc<dyn Fn() -> Url + Send + Sync + 'static>,
29    auth: Option<Auth>,
30    timeout: Duration,
31}
32
33impl fmt::Debug for Client {
34    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35        f.debug_struct("Client")
36            .field("inner", &self.inner)
37            .field("url", &"...")
38            .field("auth", &self.auth)
39            .finish()
40    }
41}
42
43impl Client {
44    pub(crate) fn new(
45        inner: reqwest::Client,
46        url: Arc<dyn Fn() -> Url + Send + Sync + 'static>,
47        auth: Option<Auth>,
48        timeout: Duration,
49    ) -> Result<Self, anyhow::Error> {
50        if url().cannot_be_a_base() {
51            bail!("cannot construct a CCSR client with a cannot-be-a-base URL");
52        }
53        Ok(Client {
54            inner,
55            url,
56            auth,
57            timeout,
58        })
59    }
60
61    fn make_request<P>(&self, method: Method, path: P) -> reqwest::RequestBuilder
62    where
63        P: IntoIterator,
64        P::Item: AsRef<str>,
65    {
66        let mut url = (self.url)();
67        url.path_segments_mut()
68            .expect("constructor validated URL can be a base")
69            .clear()
70            .extend(path);
71
72        let mut request = self.inner.request(method, url);
73        if let Some(auth) = &self.auth {
74            request = request.basic_auth(&auth.username, auth.password.as_ref());
75        }
76        request
77    }
78
79    pub fn timeout(&self) -> Duration {
80        self.timeout
81    }
82
83    /// Gets the schema with the associated ID.
84    pub async fn get_schema_by_id(&self, id: i32) -> Result<Schema, GetByIdError> {
85        let req = self.make_request(Method::GET, &["schemas", "ids", &id.to_string()]);
86        let res: GetByIdResponse = send_request(req).await?;
87        Ok(Schema {
88            id,
89            raw: res.schema,
90        })
91    }
92
93    /// Gets the latest schema for the specified subject.
94    pub async fn get_schema_by_subject(&self, subject: &str) -> Result<Schema, GetBySubjectError> {
95        self.get_subject_latest(subject).await.map(|s| s.schema)
96    }
97
98    /// Gets the latest version of the specified subject.
99    pub async fn get_subject_latest(&self, subject: &str) -> Result<Subject, GetBySubjectError> {
100        let req = self.make_request(Method::GET, &["subjects", subject, "versions", "latest"]);
101        let res: GetBySubjectResponse = send_request(req).await?;
102        Ok(Subject {
103            schema: Schema {
104                id: res.id,
105                raw: res.schema,
106            },
107            version: res.version,
108            name: res.subject,
109        })
110    }
111
112    /// Gets the latest version of the specified subject along with its direct references.
113    /// Returns the subject and a list of subject names that this subject directly references.
114    pub async fn get_subject_with_references(
115        &self,
116        subject: &str,
117    ) -> Result<(Subject, Vec<SubjectVersion>), GetBySubjectError> {
118        let req = self.make_request(Method::GET, &["subjects", subject, "versions", "latest"]);
119        let res: GetBySubjectResponse = send_request(req).await?;
120        let referenced_subjects: Vec<_> = res
121            .references
122            .into_iter()
123            .map(|r| SubjectVersion {
124                subject: r.subject,
125                version: r.version,
126            })
127            .collect();
128        Ok((
129            Subject {
130                schema: Schema {
131                    id: res.id,
132                    raw: res.schema,
133                },
134                version: res.version,
135                name: res.subject,
136            },
137            referenced_subjects,
138        ))
139    }
140
141    /// Gets the config set for the specified subject
142    pub async fn get_subject_config(
143        &self,
144        subject: &str,
145    ) -> Result<SubjectConfig, GetSubjectConfigError> {
146        let req = self.make_request(Method::GET, &["config", subject]);
147        let res: SubjectConfig = send_request(req).await?;
148        Ok(res)
149    }
150
151    /// Gets the latest version of the specified subject as well as all other
152    /// subjects referenced by that subject (recursively).
153    ///
154    /// The dependencies are returned in dependency order, with dependencies first.
155    pub async fn get_subject_and_references(
156        &self,
157        subject: &str,
158    ) -> Result<(Subject, Vec<Subject>), GetBySubjectError> {
159        self.get_subject_and_references_by_version(subject, "latest".to_owned())
160            .await
161    }
162
163    /// Gets a subject and all other subjects referenced by that subject (recursively)
164    ///
165    /// The dependencies are returned in dependency order, with dependencies first.
166    #[allow(clippy::disallowed_types)]
167    async fn get_subject_and_references_by_version(
168        &self,
169        subject: &str,
170        version: String,
171    ) -> Result<(Subject, Vec<Subject>), GetBySubjectError> {
172        let mut subjects = vec![];
173        // HashMap are used as we strictly need lookup, not ordering.
174        let mut graph = std::collections::HashMap::new();
175        let mut subjects_queue = vec![(subject.to_owned(), version)];
176        while let Some((subject, version)) = subjects_queue.pop() {
177            let req = self.make_request(Method::GET, &["subjects", &subject, "versions", &version]);
178            let res: GetBySubjectResponse = send_request(req).await?;
179            subjects.push(Subject {
180                schema: Schema {
181                    id: res.id,
182                    raw: res.schema,
183                },
184                version: res.version,
185                name: res.subject.clone(),
186            });
187            let subject_key = SubjectVersion {
188                subject: res.subject,
189                version: res.version,
190            };
191
192            let dependents: Vec<_> = res
193                .references
194                .into_iter()
195                .map(|reference| SubjectVersion {
196                    subject: reference.subject,
197                    version: reference.version,
198                })
199                .collect();
200
201            graph
202                .entry(subject_key)
203                .or_insert_with(Vec::new)
204                .extend(dependents.iter().cloned());
205            subjects_queue.extend(
206                dependents
207                    .into_iter()
208                    // Add dependents into the graph before adding to the queue as the same
209                    // named type may be encountered multiple times, but we only want to look it up
210                    // once.
211                    .filter(|dep| match graph.entry(dep.clone()) {
212                        std::collections::hash_map::Entry::Occupied(_) => false,
213                        std::collections::hash_map::Entry::Vacant(vacant) => {
214                            vacant.insert(vec![]);
215                            true
216                        }
217                    })
218                    .map(|dep| (dep.subject, dep.version.to_string())),
219            );
220        }
221        assert!(subjects.len() > 0, "Request should error if no subjects");
222
223        let primary = subjects.remove(0);
224
225        let ordered =
226            topological_sort(&graph).map_err(|_| GetBySubjectError::SchemaReferenceCycle)?;
227
228        subjects.sort_by(|a, b| {
229            let a = SubjectVersion {
230                subject: a.name.clone(),
231                version: a.version,
232            };
233            let b = SubjectVersion {
234                subject: b.name.clone(),
235                version: b.version,
236            };
237            ordered
238                .get(&b)
239                .unwrap_or_else(|| panic!("b {b:?}"))
240                .cmp(ordered.get(&a).unwrap_or_else(|| panic!("a {a:?}")))
241        });
242
243        Ok((primary, subjects))
244    }
245
246    /// Publishes a new schema for the specified subject. The ID of the new
247    /// schema is returned.
248    ///
249    /// Note that if a schema that is identical to an existing schema for the
250    /// same subject is published, the ID of the existing schema will be
251    /// returned.
252    pub async fn publish_schema(
253        &self,
254        subject: &str,
255        schema: &str,
256        schema_type: SchemaType,
257        references: &[SchemaReference],
258    ) -> Result<i32, PublishError> {
259        let req = self.make_request(Method::POST, &["subjects", subject, "versions"]);
260        let req = req.json(&PublishRequest {
261            schema,
262            schema_type,
263            references,
264        });
265        let res: PublishResponse = send_request(req).await?;
266        Ok(res.id)
267    }
268
269    /// Sets the compatibility level for the specified subject.
270    pub async fn set_subject_compatibility_level(
271        &self,
272        subject: &str,
273        compatibility_level: CompatibilityLevel,
274    ) -> Result<(), SetCompatibilityLevelError> {
275        let req = self.make_request(Method::PUT, &["config", subject]);
276        let req = req.json(&CompatibilityLevelRequest {
277            compatibility: compatibility_level,
278        });
279        send_request_raw(req).await?;
280        Ok(())
281    }
282
283    /// Lists the names of all subjects that the schema registry is aware of.
284    pub async fn list_subjects(&self) -> Result<Vec<String>, ListError> {
285        let req = self.make_request(Method::GET, &["subjects"]);
286        Ok(send_request(req).await?)
287    }
288
289    /// Deletes all schema versions associated with the specified subject.
290    ///
291    /// This API is only intended to be used in development environments.
292    /// Deleting schemas only allows new, potentially incompatible schemas to
293    /// be registered under the same subject. It does not allow the schema ID
294    /// to be reused.
295    pub async fn delete_subject(&self, subject: &str) -> Result<(), DeleteError> {
296        let req = self.make_request(Method::DELETE, &["subjects", subject]);
297        send_request_raw(req).await?;
298        Ok(())
299    }
300
301    /// Gets the latest version of the first subject found associated with the scheme with
302    /// the given id, as well as all other subjects referenced by that subject (recursively).
303    ///
304    /// The dependencies are returned in dependency order, with dependencies first.
305    pub async fn get_subject_and_references_by_id(
306        &self,
307        id: i32,
308    ) -> Result<(Subject, Vec<Subject>), GetBySubjectError> {
309        let req = self.make_request(
310            Method::GET,
311            &["schemas", "ids", &id.to_string(), "versions"],
312        );
313        let res: Vec<SubjectVersion> = send_request(req).await?;
314
315        // NOTE NOTE NOTE
316        // We take the FIRST subject that matches this schema id. This could be DIFFERENT
317        // than the actual subject we are interested in (it could even be from a different test
318        // run), but we are trusting the schema registry to only output the same schema id for
319        // identical subjects.
320        // This was validated by publishing 2 empty schemas (i.e., identical), with different
321        // references (one empty, one with a random reference), and they were not linked to the
322        // same schema id.
323        //
324        // See https://docs.confluent.io/platform/current/schema-registry/develop/api.html#post--subjects-(string-%20subject)-versions
325        // for more info.
326        match res.as_slice() {
327            [first, ..] => {
328                self.get_subject_and_references_by_version(
329                    &first.subject,
330                    first.version.to_string(),
331                )
332                .await
333            }
334            _ => Err(GetBySubjectError::SubjectNotFound),
335        }
336    }
337}
338
339/// Generates a topological ordering for a DAG.  If a cycle is detected in any returns an error.
340/// This can operator on a disconnected graph containing multiple DAGs.
341#[allow(clippy::disallowed_types)]
342pub fn topological_sort<T: Hash + Eq>(
343    graph: &std::collections::HashMap<T, Vec<T>>,
344) -> Result<std::collections::HashMap<&T, i32>, anyhow::Error> {
345    let mut referenced_by: std::collections::HashMap<&T, std::collections::HashSet<&T>> =
346        std::collections::HashMap::new();
347    for (subject, references) in graph.iter() {
348        for reference in references {
349            referenced_by.entry(reference).or_default().insert(subject);
350        }
351    }
352
353    // Start with nodes that have no incoming edges (empty referenced_by sets).
354    // Also include nodes in graph that aren't in referenced_by at all (roots).
355    let mut queue: Vec<_> = graph
356        .keys()
357        .filter(|key| {
358            referenced_by
359                .get(*key)
360                .map_or(true, |subjects| subjects.is_empty())
361        })
362        .collect();
363
364    let mut ordered = std::collections::HashMap::new();
365    let mut n = 0;
366    while let Some(subj_ver) = queue.pop() {
367        if let Some(refs) = graph.get(subj_ver) {
368            for ref_ver in refs {
369                let Some(subjects) = referenced_by.get_mut(ref_ver) else {
370                    continue;
371                };
372                subjects.remove(&subj_ver);
373                if subjects.is_empty() {
374                    referenced_by.remove_entry(ref_ver);
375                    queue.push(ref_ver);
376                }
377            }
378        }
379        ordered.insert(subj_ver, n);
380        n += 1;
381    }
382
383    if referenced_by.is_empty() {
384        Ok(ordered)
385    } else {
386        Err(anyhow!("Cycled detected during topoligical sort"))
387    }
388}
389
390async fn send_request<T>(req: reqwest::RequestBuilder) -> Result<T, UnhandledError>
391where
392    T: DeserializeOwned,
393{
394    let res = send_request_raw(req).await?;
395    Ok(res.json().await?)
396}
397
398async fn send_request_raw(req: reqwest::RequestBuilder) -> Result<Response, UnhandledError> {
399    let res = req.send().await?;
400    let status = res.status();
401    if status.is_success() {
402        Ok(res)
403    } else {
404        match res.json::<ErrorResponse>().await {
405            Ok(err_res) => Err(UnhandledError::Api {
406                code: err_res.error_code,
407                message: err_res.message,
408            }),
409            Err(_) => Err(UnhandledError::Api {
410                code: i32::from(status.as_u16()),
411                message: "unable to decode error details".into(),
412            }),
413        }
414    }
415}
416
417/// The type of a schema stored by a schema registry.
418#[derive(Clone, Copy, Debug, Serialize)]
419#[serde(rename_all = "UPPERCASE")]
420pub enum SchemaType {
421    /// An Avro schema.
422    Avro,
423    /// A Protobuf schema.
424    Protobuf,
425    /// A JSON schema.
426    Json,
427}
428
429impl SchemaType {
430    fn is_default(&self) -> bool {
431        matches!(self, SchemaType::Avro)
432    }
433}
434
435/// A schema stored by a schema registry.
436#[derive(Debug, Eq, PartialEq)]
437pub struct Schema {
438    /// The ID of the schema.
439    pub id: i32,
440    /// The raw text representing the schema.
441    pub raw: String,
442}
443
444/// A subject stored by a schema registry.
445#[derive(Debug, Eq, PartialEq)]
446pub struct Subject {
447    /// The version of the schema.
448    pub version: i32,
449    /// The name of the schema.
450    pub name: String,
451    /// The schema of the `version` of the `Subject`.
452    pub schema: Schema,
453}
454
455/// A reference from one schema in a schema registry to another.
456#[derive(Debug, Serialize, Deserialize)]
457#[serde(rename_all = "camelCase")]
458pub struct SchemaReference {
459    /// The name of the reference.
460    pub name: String,
461    /// The subject under which the referenced schema is registered.
462    pub subject: String,
463    /// The version of the referenced schema.
464    pub version: i32,
465}
466
467#[derive(Debug, Deserialize)]
468struct GetByIdResponse {
469    schema: String,
470}
471
472/// Errors for schema lookups by ID.
473#[derive(Debug)]
474pub enum GetByIdError {
475    /// No schema with the requested ID exists.
476    SchemaNotFound,
477    /// The underlying HTTP transport failed.
478    Transport(reqwest::Error),
479    /// An internal server error occurred.
480    Server { code: i32, message: String },
481}
482
483#[derive(Debug, Deserialize, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
484pub struct SubjectVersion {
485    /// The name of the subject
486    pub subject: String,
487    /// The version of the schema
488    pub version: i32,
489}
490
491impl From<UnhandledError> for GetByIdError {
492    fn from(err: UnhandledError) -> GetByIdError {
493        match err {
494            UnhandledError::Transport(err) => GetByIdError::Transport(err),
495            UnhandledError::Api { code, message } => match code {
496                40403 => GetByIdError::SchemaNotFound,
497                _ => GetByIdError::Server { code, message },
498            },
499        }
500    }
501}
502
503impl Error for GetByIdError {
504    fn source(&self) -> Option<&(dyn Error + 'static)> {
505        match self {
506            GetByIdError::SchemaNotFound | GetByIdError::Server { .. } => None,
507            GetByIdError::Transport(err) => Some(err),
508        }
509    }
510}
511
512impl fmt::Display for GetByIdError {
513    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
514        match self {
515            GetByIdError::SchemaNotFound => write!(f, "schema not found"),
516            GetByIdError::Transport(err) => write!(f, "transport: {}", err),
517            GetByIdError::Server { code, message } => {
518                write!(f, "server error {}: {}", code, message)
519            }
520        }
521    }
522}
523
524#[derive(Debug, Deserialize)]
525#[serde(rename_all = "camelCase")]
526pub struct SubjectConfig {
527    pub compatibility_level: CompatibilityLevel,
528    // There are other fields to include if we need them.
529}
530
531/// Errors for schema lookups by subject.
532#[derive(Debug)]
533pub enum GetSubjectConfigError {
534    /// The requested subject does not exist.
535    SubjectNotFound,
536    /// The compatibility level for the subject has not been set.
537    SubjectCompatibilityLevelNotSet,
538    /// The underlying HTTP transport failed.
539    Transport(reqwest::Error),
540    /// An internal server error occurred.
541    Server { code: i32, message: String },
542}
543
544impl From<UnhandledError> for GetSubjectConfigError {
545    fn from(err: UnhandledError) -> GetSubjectConfigError {
546        match err {
547            UnhandledError::Transport(err) => GetSubjectConfigError::Transport(err),
548            UnhandledError::Api { code, message } => match code {
549                404 => GetSubjectConfigError::SubjectNotFound,
550                40408 => GetSubjectConfigError::SubjectCompatibilityLevelNotSet,
551                _ => GetSubjectConfigError::Server { code, message },
552            },
553        }
554    }
555}
556
557impl Error for GetSubjectConfigError {
558    fn source(&self) -> Option<&(dyn Error + 'static)> {
559        match self {
560            GetSubjectConfigError::SubjectNotFound
561            | GetSubjectConfigError::SubjectCompatibilityLevelNotSet
562            | GetSubjectConfigError::Server { .. } => None,
563            GetSubjectConfigError::Transport(err) => Some(err),
564        }
565    }
566}
567
568impl fmt::Display for GetSubjectConfigError {
569    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
570        match self {
571            GetSubjectConfigError::SubjectNotFound => write!(f, "subject not found"),
572            GetSubjectConfigError::SubjectCompatibilityLevelNotSet => {
573                write!(f, "subject level compatibility not set")
574            }
575            GetSubjectConfigError::Transport(err) => write!(f, "transport: {}", err),
576            GetSubjectConfigError::Server { code, message } => {
577                write!(f, "server error {}: {}", code, message)
578            }
579        }
580    }
581}
582
583#[derive(Debug, Deserialize)]
584#[serde(rename_all = "camelCase")]
585struct GetBySubjectResponse {
586    id: i32,
587    schema: String,
588    version: i32,
589    subject: String,
590    #[serde(default)]
591    references: Vec<SchemaReference>,
592}
593
594/// Errors for schema lookups by subject.
595#[derive(Debug)]
596pub enum GetBySubjectError {
597    /// The requested subject does not exist.
598    SubjectNotFound,
599    /// The requested version does not exist.
600    VersionNotFound(String),
601    /// The underlying HTTP transport failed.
602    Transport(reqwest::Error),
603    /// An internal server error occurred.
604    Server { code: i32, message: String },
605    /// Cycle detected in schemas
606    SchemaReferenceCycle,
607}
608
609impl From<UnhandledError> for GetBySubjectError {
610    fn from(err: UnhandledError) -> GetBySubjectError {
611        match err {
612            UnhandledError::Transport(err) => GetBySubjectError::Transport(err),
613            UnhandledError::Api { code, message } => match code {
614                40401 => GetBySubjectError::SubjectNotFound,
615                40402 => GetBySubjectError::VersionNotFound(message),
616                _ => GetBySubjectError::Server { code, message },
617            },
618        }
619    }
620}
621
622impl Error for GetBySubjectError {
623    fn source(&self) -> Option<&(dyn Error + 'static)> {
624        match self {
625            GetBySubjectError::SubjectNotFound
626            | GetBySubjectError::VersionNotFound(_)
627            | GetBySubjectError::Server { .. }
628            | GetBySubjectError::SchemaReferenceCycle => None,
629            GetBySubjectError::Transport(err) => Some(err),
630        }
631    }
632}
633
634impl fmt::Display for GetBySubjectError {
635    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
636        match self {
637            GetBySubjectError::SubjectNotFound => write!(f, "subject not found"),
638            GetBySubjectError::VersionNotFound(message) => {
639                write!(f, "version not found: {}", message)
640            }
641            GetBySubjectError::Transport(err) => write!(f, "transport: {}", err),
642            GetBySubjectError::Server { code, message } => {
643                write!(f, "server error {}: {}", code, message)
644            }
645            GetBySubjectError::SchemaReferenceCycle => {
646                write!(f, "cycle detected in schema references")
647            }
648        }
649    }
650}
651
652#[derive(Debug, Serialize)]
653#[serde(rename_all = "camelCase")]
654struct PublishRequest<'a> {
655    schema: &'a str,
656    // Omitting the following fields when they're set to their defaults provides
657    // compatibility with old versions of the schema registry that don't
658    // understand these fields.
659    #[serde(skip_serializing_if = "SchemaType::is_default")]
660    schema_type: SchemaType,
661    #[serde(skip_serializing_if = "<[_]>::is_empty")]
662    references: &'a [SchemaReference],
663}
664
665#[derive(Debug, Deserialize)]
666#[serde(rename_all = "camelCase")]
667struct PublishResponse {
668    id: i32,
669}
670
671#[derive(Debug, Serialize)]
672#[serde(rename_all = "camelCase")]
673struct CompatibilityLevelRequest {
674    compatibility: CompatibilityLevel,
675}
676
677#[derive(Arbitrary, Clone, Copy, Debug, Serialize, Deserialize, Eq, PartialEq)]
678#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
679pub enum CompatibilityLevel {
680    Backward,
681    BackwardTransitive,
682    Forward,
683    ForwardTransitive,
684    Full,
685    FullTransitive,
686    None,
687}
688
689impl TryFrom<&str> for CompatibilityLevel {
690    type Error = String;
691
692    fn try_from(value: &str) -> Result<Self, Self::Error> {
693        match value {
694            "BACKWARD" => Ok(CompatibilityLevel::Backward),
695            "BACKWARD_TRANSITIVE" => Ok(CompatibilityLevel::BackwardTransitive),
696            "FORWARD" => Ok(CompatibilityLevel::Forward),
697            "FORWARD_TRANSITIVE" => Ok(CompatibilityLevel::ForwardTransitive),
698            "FULL" => Ok(CompatibilityLevel::Full),
699            "FULL_TRANSITIVE" => Ok(CompatibilityLevel::FullTransitive),
700            "NONE" => Ok(CompatibilityLevel::None),
701            _ => Err(format!("invalid compatibility level: {}", value)),
702        }
703    }
704}
705
706impl fmt::Display for CompatibilityLevel {
707    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
708        match self {
709            CompatibilityLevel::Backward => write!(f, "BACKWARD"),
710            CompatibilityLevel::BackwardTransitive => write!(f, "BACKWARD_TRANSITIVE"),
711            CompatibilityLevel::Forward => write!(f, "FORWARD"),
712            CompatibilityLevel::ForwardTransitive => write!(f, "FORWARD_TRANSITIVE"),
713            CompatibilityLevel::Full => write!(f, "FULL"),
714            CompatibilityLevel::FullTransitive => write!(f, "FULL_TRANSITIVE"),
715            CompatibilityLevel::None => write!(f, "NONE"),
716        }
717    }
718}
719
720/// Errors for publish operations.
721#[derive(Debug)]
722pub enum PublishError {
723    /// The provided schema was not compatible with existing schemas for that
724    /// subject, according to the subject's forwards- or backwards-compatibility
725    /// requirements.
726    IncompatibleSchema,
727    /// The provided schema was invalid.
728    InvalidSchema { message: String },
729    /// The underlying HTTP transport failed.
730    Transport(reqwest::Error),
731    /// An internal server error occurred.
732    Server { code: i32, message: String },
733}
734
735impl From<UnhandledError> for PublishError {
736    fn from(err: UnhandledError) -> PublishError {
737        match err {
738            UnhandledError::Transport(err) => PublishError::Transport(err),
739            UnhandledError::Api { code, message } => match code {
740                // Confluent Schema Registry 8.0+ returns the more specific
741                // 40901 subcode in addition to the legacy 409.
742                409 | 40901 => PublishError::IncompatibleSchema,
743                42201 => PublishError::InvalidSchema { message },
744                _ => PublishError::Server { code, message },
745            },
746        }
747    }
748}
749
750impl Error for PublishError {
751    fn source(&self) -> Option<&(dyn Error + 'static)> {
752        match self {
753            PublishError::IncompatibleSchema
754            | PublishError::InvalidSchema { .. }
755            | PublishError::Server { .. } => None,
756            PublishError::Transport(err) => Some(err),
757        }
758    }
759}
760
761impl fmt::Display for PublishError {
762    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
763        match self {
764            // The error descriptions for IncompatibleSchema and InvalidSchema
765            // are copied from the schema registry itself.
766            PublishError::IncompatibleSchema => write!(
767                f,
768                "schema being registered is incompatible with an earlier schema"
769            ),
770            PublishError::InvalidSchema { message } => write!(f, "{}", message),
771            PublishError::Transport(err) => write!(f, "transport: {}", err),
772            PublishError::Server { code, message } => {
773                write!(f, "server error {}: {}", code, message)
774            }
775        }
776    }
777}
778
779/// Errors for list operations.
780#[derive(Debug)]
781pub enum ListError {
782    /// The underlying HTTP transport failed.
783    Transport(reqwest::Error),
784    /// An internal server error occurred.
785    Server { code: i32, message: String },
786}
787
788impl From<UnhandledError> for ListError {
789    fn from(err: UnhandledError) -> ListError {
790        match err {
791            UnhandledError::Transport(err) => ListError::Transport(err),
792            UnhandledError::Api { code, message } => ListError::Server { code, message },
793        }
794    }
795}
796
797impl Error for ListError {
798    fn source(&self) -> Option<&(dyn Error + 'static)> {
799        match self {
800            ListError::Server { .. } => None,
801            ListError::Transport(err) => Some(err),
802        }
803    }
804}
805
806impl fmt::Display for ListError {
807    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
808        match self {
809            ListError::Transport(err) => write!(f, "transport: {}", err),
810            ListError::Server { code, message } => write!(f, "server error {}: {}", code, message),
811        }
812    }
813}
814
815/// Errors for delete operations.
816#[derive(Debug)]
817pub enum DeleteError {
818    /// The specified subject does not exist.
819    SubjectNotFound,
820    /// The underlying HTTP transport failed.
821    Transport(reqwest::Error),
822    /// An internal server error occurred.
823    Server { code: i32, message: String },
824}
825
826impl From<UnhandledError> for DeleteError {
827    fn from(err: UnhandledError) -> DeleteError {
828        match err {
829            UnhandledError::Transport(err) => DeleteError::Transport(err),
830            UnhandledError::Api { code, message } => match code {
831                40401 => DeleteError::SubjectNotFound,
832                _ => DeleteError::Server { code, message },
833            },
834        }
835    }
836}
837
838impl Error for DeleteError {
839    fn source(&self) -> Option<&(dyn Error + 'static)> {
840        match self {
841            DeleteError::SubjectNotFound | DeleteError::Server { .. } => None,
842            DeleteError::Transport(err) => Some(err),
843        }
844    }
845}
846
847impl fmt::Display for DeleteError {
848    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
849        match self {
850            DeleteError::SubjectNotFound => write!(f, "subject not found"),
851            DeleteError::Transport(err) => write!(f, "transport: {}", err),
852            DeleteError::Server { code, message } => {
853                write!(f, "server error {}: {}", code, message)
854            }
855        }
856    }
857}
858
859/// Errors for setting compatibility level operations.
860#[derive(Debug)]
861pub enum SetCompatibilityLevelError {
862    /// The compatibility level is invalid.
863    InvalidCompatibilityLevel,
864    /// The underlying HTTP transport failed.
865    Transport(reqwest::Error),
866    /// An internal server error occurred.
867    Server { code: i32, message: String },
868}
869
870impl From<UnhandledError> for SetCompatibilityLevelError {
871    fn from(err: UnhandledError) -> SetCompatibilityLevelError {
872        match err {
873            UnhandledError::Transport(err) => SetCompatibilityLevelError::Transport(err),
874            UnhandledError::Api { code, message } => match code {
875                42203 => SetCompatibilityLevelError::InvalidCompatibilityLevel,
876                _ => SetCompatibilityLevelError::Server { code, message },
877            },
878        }
879    }
880}
881
882impl Error for SetCompatibilityLevelError {
883    fn source(&self) -> Option<&(dyn Error + 'static)> {
884        match self {
885            SetCompatibilityLevelError::InvalidCompatibilityLevel
886            | SetCompatibilityLevelError::Server { .. } => None,
887            SetCompatibilityLevelError::Transport(err) => Some(err),
888        }
889    }
890}
891
892impl fmt::Display for SetCompatibilityLevelError {
893    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
894        match self {
895            SetCompatibilityLevelError::InvalidCompatibilityLevel => {
896                write!(f, "invalid compatibility level")
897            }
898            SetCompatibilityLevelError::Transport(err) => write!(f, "transport: {}", err),
899            SetCompatibilityLevelError::Server { code, message } => {
900                write!(f, "server error {}: {}", code, message)
901            }
902        }
903    }
904}
905
906#[derive(Debug, Deserialize)]
907struct ErrorResponse {
908    error_code: i32,
909    message: String,
910}
911
912#[derive(Debug)]
913enum UnhandledError {
914    Transport(reqwest::Error),
915    Api { code: i32, message: String },
916}
917
918impl From<reqwest::Error> for UnhandledError {
919    fn from(err: reqwest::Error) -> UnhandledError {
920        UnhandledError::Transport(err)
921    }
922}
923
924#[cfg(test)]
925mod tests {
926    #![allow(clippy::disallowed_types)]
927    use super::*;
928    use std::collections::HashMap;
929
930    /// Helper to create a SubjectVersion
931    fn sv(subject: &str, version: i32) -> SubjectVersion {
932        SubjectVersion {
933            subject: subject.to_string(),
934            version,
935        }
936    }
937
938    /// Helper to build a graph from a list of edges. Each edge is (from, to) meaning "from"
939    /// depends on "to".
940    ///
941    /// We take the root separately to build a graph without edges.
942    fn build_graph(
943        edges: &[(SubjectVersion, Option<SubjectVersion>)],
944    ) -> HashMap<SubjectVersion, Vec<SubjectVersion>> {
945        let mut graph: HashMap<SubjectVersion, Vec<SubjectVersion>> = HashMap::new();
946
947        // Add edges: from depends on to
948        for (from, to) in edges {
949            let deps = graph.entry(from.clone()).or_default();
950            if let Some(to) = to {
951                deps.push(to.clone());
952            }
953        }
954
955        graph
956    }
957
958    /// Verify that all edges are respected in the ordering.
959    /// For edge (from, to) where 'from' depends on 'to':
960    /// - 'from' should be processed before 'to' (lower order number)
961    /// - This is because the algorithm processes from roots (nothing depends on them)
962    ///   towards leaves (they don't depend on anything)
963    fn verify_edge_ordering(
964        ordered: &HashMap<&SubjectVersion, i32>,
965        edges: &[(SubjectVersion, Option<SubjectVersion>)],
966    ) {
967        for (from, to) in edges {
968            if let Some(to) = to {
969                let from_order = ordered.get(from).expect("from node should be in ordering");
970                let to_order = ordered.get(to).expect("to node should be in ordering");
971
972                // 'from' depends on 'to', so 'from' is processed first (lower order number)
973                // The algorithm starts at roots (nodes nothing depends on) and works toward leaves
974                assert!(
975                    from_order < to_order,
976                    "{:?} (order {}) depends on {:?} (order {}), so should be processed first",
977                    from,
978                    from_order,
979                    to,
980                    to_order
981                );
982            }
983        }
984    }
985
986    #[mz_ore::test]
987    fn test_topological_sort_empty_graph() {
988        let graph: HashMap<SubjectVersion, Vec<SubjectVersion>> = HashMap::new();
989
990        let ordered = topological_sort(&graph).unwrap();
991
992        assert!(ordered.is_empty());
993    }
994
995    #[mz_ore::test]
996    fn test_topological_sort_single_node() {
997        // Single node with no dependencies
998        let a = sv("a", 1);
999        let graph = build_graph(&[(a.clone(), None)]);
1000
1001        let ordered = topological_sort(&graph).unwrap();
1002
1003        assert_eq!(ordered.len(), 1);
1004        assert!(ordered.contains_key(&a));
1005    }
1006
1007    #[mz_ore::test]
1008    fn test_topological_sort_linear_chain() {
1009        // A -> B -> C -> D
1010        let a = sv("a", 1);
1011        let b = sv("b", 1);
1012        let c = sv("c", 1);
1013        let d = sv("d", 1);
1014
1015        let edges = vec![
1016            (a.clone(), Some(b.clone())),
1017            (b.clone(), Some(c.clone())),
1018            (c.clone(), Some(d.clone())),
1019            (d.clone(), None),
1020        ];
1021        let graph = build_graph(&edges);
1022
1023        let ordered = topological_sort(&graph).unwrap();
1024
1025        assert_eq!(ordered.len(), 4);
1026        verify_edge_ordering(&ordered, &edges);
1027    }
1028
1029    #[mz_ore::test]
1030    fn test_topological_sort_diamond() {
1031        // Classic diamond pattern:
1032        //     A
1033        //    / \
1034        //   B   C
1035        //    \ /
1036        //     D
1037        // A depends on B and C, both B and C depend on D
1038        let a = sv("a", 1);
1039        let b = sv("b", 1);
1040        let c = sv("c", 1);
1041        let d = sv("d", 1);
1042
1043        let edges = vec![
1044            (a.clone(), Some(b.clone())),
1045            (a.clone(), Some(c.clone())),
1046            (b.clone(), Some(d.clone())),
1047            (c.clone(), Some(d.clone())),
1048            (d.clone(), None),
1049        ];
1050        let graph = build_graph(&edges);
1051
1052        let ordered = topological_sort(&graph).unwrap();
1053
1054        assert_eq!(ordered.len(), 4);
1055        verify_edge_ordering(&ordered, &edges);
1056
1057        // A must come before B and C, and B and C must come before D
1058        assert!(ordered[&a] < ordered[&d]);
1059    }
1060
1061    #[mz_ore::test]
1062    fn test_topological_sort_wide_graph() {
1063        // Wide graph: A depends on B, C, D, E, F (many dependencies)
1064        let a = sv("a", 1);
1065        let b = sv("b", 1);
1066        let c = sv("c", 1);
1067        let d = sv("d", 1);
1068        let e = sv("e", 1);
1069        let f = sv("f", 1);
1070
1071        let edges = vec![
1072            (a.clone(), Some(b.clone())),
1073            (a.clone(), Some(c.clone())),
1074            (a.clone(), Some(d.clone())),
1075            (a.clone(), Some(e.clone())),
1076            (a.clone(), Some(f.clone())),
1077            (b.clone(), None),
1078            (c.clone(), None),
1079            (d.clone(), None),
1080            (e.clone(), None),
1081            (f.clone(), None),
1082        ];
1083        let graph = build_graph(&edges);
1084
1085        let ordered = topological_sort(&graph).unwrap();
1086
1087        assert_eq!(ordered.len(), 6);
1088        verify_edge_ordering(&ordered, &edges);
1089    }
1090
1091    #[mz_ore::test]
1092    fn test_topological_sort_complex_dag() {
1093        // Complex DAG:
1094        //       A
1095        //      /|\
1096        //     B C D
1097        //     |/| |
1098        //     E F G
1099        //      \|/
1100        //       H
1101        // A -> B, C, D
1102        // B -> E
1103        // C -> E, F
1104        // D -> G
1105        // E -> H
1106        // F -> H
1107        // G -> H
1108        let a = sv("a", 1);
1109        let b = sv("b", 1);
1110        let c = sv("c", 1);
1111        let d = sv("d", 1);
1112        let e = sv("e", 1);
1113        let f = sv("f", 1);
1114        let g = sv("g", 1);
1115        let h = sv("h", 1);
1116
1117        let edges = vec![
1118            (a.clone(), Some(b.clone())),
1119            (a.clone(), Some(c.clone())),
1120            (a.clone(), Some(d.clone())),
1121            (b.clone(), Some(e.clone())),
1122            (c.clone(), Some(e.clone())),
1123            (c.clone(), Some(f.clone())),
1124            (d.clone(), Some(g.clone())),
1125            (e.clone(), Some(h.clone())),
1126            (f.clone(), Some(h.clone())),
1127            (g.clone(), Some(h.clone())),
1128            (h.clone(), None),
1129        ];
1130        let graph = build_graph(&edges);
1131
1132        let ordered = topological_sort(&graph).unwrap();
1133
1134        assert_eq!(ordered.len(), 8);
1135        verify_edge_ordering(&ordered, &edges);
1136    }
1137
1138    #[mz_ore::test]
1139    fn test_topological_sort_with_versions() {
1140        // Same subject, different versions
1141        //        A2
1142        //       / |
1143        //      A1 |
1144        //       \ |
1145        //        B1
1146        let a_v1 = sv("a", 1);
1147        let a_v2 = sv("a", 2);
1148        let b_v1 = sv("b", 1);
1149
1150        // a_v2 depends on a_v1, and both depend on b_v1
1151        let edges = vec![
1152            (a_v2.clone(), Some(a_v1.clone())),
1153            (a_v2.clone(), Some(b_v1.clone())),
1154            (a_v1.clone(), Some(b_v1.clone())),
1155            (b_v1.clone(), None),
1156        ];
1157        let graph = build_graph(&edges);
1158
1159        let ordered = topological_sort(&graph).unwrap();
1160
1161        assert_eq!(ordered.len(), 3);
1162        verify_edge_ordering(&ordered, &edges);
1163    }
1164
1165    #[mz_ore::test]
1166    fn test_topological_sort_multi_level_diamond() {
1167        // Double diamond:
1168        //       A
1169        //      / \
1170        //     B   C
1171        //      \ /
1172        //       D
1173        //      / \
1174        //     E   F
1175        //      \ /
1176        //       G
1177        let a = sv("a", 1);
1178        let b = sv("b", 1);
1179        let c = sv("c", 1);
1180        let d = sv("d", 1);
1181        let e = sv("e", 1);
1182        let f = sv("f", 1);
1183        let g = sv("g", 1);
1184
1185        let edges = vec![
1186            (a.clone(), Some(b.clone())),
1187            (a.clone(), Some(c.clone())),
1188            (b.clone(), Some(d.clone())),
1189            (c.clone(), Some(d.clone())),
1190            (d.clone(), Some(e.clone())),
1191            (d.clone(), Some(f.clone())),
1192            (e.clone(), Some(g.clone())),
1193            (f.clone(), Some(g.clone())),
1194            (g.clone(), None),
1195        ];
1196        let graph = build_graph(&edges);
1197
1198        let ordered = topological_sort(&graph).unwrap();
1199
1200        assert_eq!(ordered.len(), 7);
1201        verify_edge_ordering(&ordered, &edges);
1202    }
1203
1204    #[mz_ore::test]
1205    fn test_topological_sort_shared_dependency_at_multiple_levels() {
1206        // Shared dependency accessed at multiple levels:
1207        //     A
1208        //    /|\
1209        //   B C |
1210        //   |/  |
1211        //   D   |
1212        //    \ /
1213        //     E
1214        //     |
1215        //     F
1216        let a = sv("a", 1);
1217        let b = sv("b", 1);
1218        let c = sv("c", 1);
1219        let d = sv("d", 1);
1220        let e = sv("e", 1);
1221        let f = sv("f", 1);
1222
1223        let edges = vec![
1224            (a.clone(), Some(b.clone())),
1225            (a.clone(), Some(c.clone())),
1226            (a.clone(), Some(e.clone())),
1227            (b.clone(), Some(d.clone())),
1228            (c.clone(), Some(d.clone())),
1229            (d.clone(), Some(e.clone())),
1230            (e.clone(), Some(f.clone())),
1231            (f.clone(), None),
1232        ];
1233        let graph = build_graph(&edges);
1234
1235        let ordered = topological_sort(&graph).expect("no cycle");
1236
1237        assert_eq!(ordered.len(), 6);
1238        verify_edge_ordering(&ordered, &edges);
1239    }
1240
1241    #[mz_ore::test]
1242    fn test_topological_sort_lattice_structure() {
1243        // Can you tell I had Claude help make up tests?
1244        // Lattice structure (more complex than diamond):
1245        //       A
1246        //      /|\
1247        //     B C D
1248        //     |\|/|
1249        //     | X |
1250        //     |/|\|
1251        //     E F G
1252        //      \|/
1253        //       H
1254        // A -> B, C, D
1255        // B -> E, F
1256        // C -> E, F, G
1257        // D -> F, G
1258        // E -> H
1259        // F -> H
1260        // G -> H
1261        let a = sv("a", 1);
1262        let b = sv("b", 1);
1263        let c = sv("c", 1);
1264        let d = sv("d", 1);
1265        let e = sv("e", 1);
1266        let f = sv("f", 1);
1267        let g = sv("g", 1);
1268        let h = sv("h", 1);
1269
1270        let edges = vec![
1271            (a.clone(), Some(b.clone())),
1272            (a.clone(), Some(c.clone())),
1273            (a.clone(), Some(d.clone())),
1274            (b.clone(), Some(e.clone())),
1275            (b.clone(), Some(f.clone())),
1276            (c.clone(), Some(e.clone())),
1277            (c.clone(), Some(f.clone())),
1278            (c.clone(), Some(g.clone())),
1279            (d.clone(), Some(f.clone())),
1280            (d.clone(), Some(g.clone())),
1281            (e.clone(), Some(h.clone())),
1282            (f.clone(), Some(h.clone())),
1283            (g.clone(), Some(h.clone())),
1284            (h.clone(), None),
1285        ];
1286        let graph = build_graph(&edges);
1287
1288        let ordered = topological_sort(&graph).unwrap();
1289
1290        assert_eq!(ordered.len(), 8);
1291        verify_edge_ordering(&ordered, &edges);
1292    }
1293
1294    #[mz_ore::test]
1295    fn test_topological_sort_binary_tree() {
1296        // Full binary tree structure (inverted, dependencies flow down):
1297        //        A
1298        //       / \
1299        //      B   C
1300        //     /|   |\
1301        //    D E   F G
1302        let a = sv("a", 1);
1303        let b = sv("b", 1);
1304        let c = sv("c", 1);
1305        let d = sv("d", 1);
1306        let e = sv("e", 1);
1307        let f = sv("f", 1);
1308        let g = sv("g", 1);
1309
1310        let edges = vec![
1311            (a.clone(), Some(b.clone())),
1312            (a.clone(), Some(c.clone())),
1313            (b.clone(), Some(d.clone())),
1314            (b.clone(), Some(e.clone())),
1315            (c.clone(), Some(f.clone())),
1316            (c.clone(), Some(g.clone())),
1317            (d.clone(), None),
1318            (e.clone(), None),
1319            (f.clone(), None),
1320            (g.clone(), None),
1321        ];
1322        let graph = build_graph(&edges);
1323
1324        let ordered = topological_sort(&graph).unwrap();
1325
1326        assert_eq!(ordered.len(), 7);
1327        verify_edge_ordering(&ordered, &edges);
1328    }
1329
1330    #[mz_ore::test]
1331    fn test_topological_sort_simple_cycle() {
1332        // Simple cycle: A -> B -> A
1333        let a = sv("a", 1);
1334        let b = sv("b", 1);
1335
1336        let mut graph: HashMap<SubjectVersion, Vec<SubjectVersion>> = HashMap::new();
1337        graph.insert(a.clone(), vec![b.clone()]);
1338        graph.insert(b.clone(), vec![a.clone()]);
1339
1340        let sort_result = topological_sort(&graph);
1341
1342        assert!(sort_result.is_err(), "Expected sort to detect cycle");
1343    }
1344
1345    #[mz_ore::test]
1346    fn test_topological_sort_cycle_with_entry_point() {
1347        // Cycle with entry point: A -> B -> C -> B (B and C form a cycle)
1348        let a = sv("a", 1);
1349        let b = sv("b", 1);
1350        let c = sv("c", 1);
1351
1352        let mut graph: HashMap<SubjectVersion, Vec<SubjectVersion>> = HashMap::new();
1353        graph.insert(a.clone(), vec![b.clone()]);
1354        graph.insert(b.clone(), vec![c.clone()]);
1355        graph.insert(c.clone(), vec![b.clone()]); // C points back to B
1356
1357        let sort_result = topological_sort(&graph);
1358        assert!(sort_result.is_err(), "Expected sort to detect cycle");
1359    }
1360
1361    #[mz_ore::test]
1362    fn test_topological_sort_self_reference() {
1363        // Self-referencing node: A -> A
1364        let a = sv("a", 1);
1365
1366        let mut graph: HashMap<SubjectVersion, Vec<SubjectVersion>> = HashMap::new();
1367        graph.insert(a.clone(), vec![a.clone()]);
1368
1369        let sort_result = topological_sort(&graph);
1370        assert!(sort_result.is_err(), "Expected sort to detect cycle");
1371    }
1372
1373    #[mz_ore::test]
1374    fn test_topological_sort_three_node_cycle() {
1375        // Three node cycle: A -> B -> C -> A
1376        let a = sv("a", 1);
1377        let b = sv("b", 1);
1378        let c = sv("c", 1);
1379
1380        let mut graph: HashMap<SubjectVersion, Vec<SubjectVersion>> = HashMap::new();
1381        graph.insert(a.clone(), vec![b.clone()]);
1382        graph.insert(b.clone(), vec![c.clone()]);
1383        graph.insert(c.clone(), vec![a.clone()]);
1384
1385        let sort_result = topological_sort(&graph);
1386        assert!(sort_result.is_err(), "Expected sort to detect cycle");
1387    }
1388}