1use std::error::Error;
11use std::fmt;
12use std::hash::Hash;
13use std::sync::Arc;
14use std::time::Duration;
15
16use anyhow::{anyhow, bail};
17use proptest_derive::Arbitrary;
18use reqwest::{Method, Response, Url};
19use serde::de::DeserializeOwned;
20use serde::{Deserialize, Serialize};
21
22use crate::config::Auth;
23
24#[derive(Clone)]
26pub struct Client {
27 inner: reqwest::Client,
28 url: Arc<dyn Fn() -> Url + Send + Sync + 'static>,
29 auth: Option<Auth>,
30 timeout: Duration,
31}
32
33impl fmt::Debug for Client {
34 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35 f.debug_struct("Client")
36 .field("inner", &self.inner)
37 .field("url", &"...")
38 .field("auth", &self.auth)
39 .finish()
40 }
41}
42
43impl Client {
44 pub(crate) fn new(
45 inner: reqwest::Client,
46 url: Arc<dyn Fn() -> Url + Send + Sync + 'static>,
47 auth: Option<Auth>,
48 timeout: Duration,
49 ) -> Result<Self, anyhow::Error> {
50 if url().cannot_be_a_base() {
51 bail!("cannot construct a CCSR client with a cannot-be-a-base URL");
52 }
53 Ok(Client {
54 inner,
55 url,
56 auth,
57 timeout,
58 })
59 }
60
61 fn make_request<P>(&self, method: Method, path: P) -> reqwest::RequestBuilder
62 where
63 P: IntoIterator,
64 P::Item: AsRef<str>,
65 {
66 let mut url = (self.url)();
67 url.path_segments_mut()
68 .expect("constructor validated URL can be a base")
69 .clear()
70 .extend(path);
71
72 let mut request = self.inner.request(method, url);
73 if let Some(auth) = &self.auth {
74 request = request.basic_auth(&auth.username, auth.password.as_ref());
75 }
76 request
77 }
78
79 pub fn timeout(&self) -> Duration {
80 self.timeout
81 }
82
83 pub async fn get_schema_by_id(&self, id: i32) -> Result<Schema, GetByIdError> {
85 let req = self.make_request(Method::GET, &["schemas", "ids", &id.to_string()]);
86 let res: GetByIdResponse = send_request(req).await?;
87 Ok(Schema {
88 id,
89 raw: res.schema,
90 })
91 }
92
93 pub async fn get_schema_by_subject(&self, subject: &str) -> Result<Schema, GetBySubjectError> {
95 self.get_subject_latest(subject).await.map(|s| s.schema)
96 }
97
98 pub async fn get_subject_latest(&self, subject: &str) -> Result<Subject, GetBySubjectError> {
100 let req = self.make_request(Method::GET, &["subjects", subject, "versions", "latest"]);
101 let res: GetBySubjectResponse = send_request(req).await?;
102 Ok(Subject {
103 schema: Schema {
104 id: res.id,
105 raw: res.schema,
106 },
107 version: res.version,
108 name: res.subject,
109 })
110 }
111
112 pub async fn get_subject_with_references(
115 &self,
116 subject: &str,
117 ) -> Result<(Subject, Vec<SubjectVersion>), GetBySubjectError> {
118 let req = self.make_request(Method::GET, &["subjects", subject, "versions", "latest"]);
119 let res: GetBySubjectResponse = send_request(req).await?;
120 let referenced_subjects: Vec<_> = res
121 .references
122 .into_iter()
123 .map(|r| SubjectVersion {
124 subject: r.subject,
125 version: r.version,
126 })
127 .collect();
128 Ok((
129 Subject {
130 schema: Schema {
131 id: res.id,
132 raw: res.schema,
133 },
134 version: res.version,
135 name: res.subject,
136 },
137 referenced_subjects,
138 ))
139 }
140
141 pub async fn get_subject_config(
143 &self,
144 subject: &str,
145 ) -> Result<SubjectConfig, GetSubjectConfigError> {
146 let req = self.make_request(Method::GET, &["config", subject]);
147 let res: SubjectConfig = send_request(req).await?;
148 Ok(res)
149 }
150
151 pub async fn get_subject_and_references(
156 &self,
157 subject: &str,
158 ) -> Result<(Subject, Vec<Subject>), GetBySubjectError> {
159 self.get_subject_and_references_by_version(subject, "latest".to_owned())
160 .await
161 }
162
163 #[allow(clippy::disallowed_types)]
167 async fn get_subject_and_references_by_version(
168 &self,
169 subject: &str,
170 version: String,
171 ) -> Result<(Subject, Vec<Subject>), GetBySubjectError> {
172 let mut subjects = vec![];
173 let mut graph = std::collections::HashMap::new();
175 let mut subjects_queue = vec![(subject.to_owned(), version)];
176 while let Some((subject, version)) = subjects_queue.pop() {
177 let req = self.make_request(Method::GET, &["subjects", &subject, "versions", &version]);
178 let res: GetBySubjectResponse = send_request(req).await?;
179 subjects.push(Subject {
180 schema: Schema {
181 id: res.id,
182 raw: res.schema,
183 },
184 version: res.version,
185 name: res.subject.clone(),
186 });
187 let subject_key = SubjectVersion {
188 subject: res.subject,
189 version: res.version,
190 };
191
192 let dependents: Vec<_> = res
193 .references
194 .into_iter()
195 .map(|reference| SubjectVersion {
196 subject: reference.subject,
197 version: reference.version,
198 })
199 .collect();
200
201 graph
202 .entry(subject_key)
203 .or_insert_with(Vec::new)
204 .extend(dependents.iter().cloned());
205 subjects_queue.extend(
206 dependents
207 .into_iter()
208 .filter(|dep| match graph.entry(dep.clone()) {
212 std::collections::hash_map::Entry::Occupied(_) => false,
213 std::collections::hash_map::Entry::Vacant(vacant) => {
214 vacant.insert(vec![]);
215 true
216 }
217 })
218 .map(|dep| (dep.subject, dep.version.to_string())),
219 );
220 }
221 assert!(subjects.len() > 0, "Request should error if no subjects");
222
223 let primary = subjects.remove(0);
224
225 let ordered =
226 topological_sort(&graph).map_err(|_| GetBySubjectError::SchemaReferenceCycle)?;
227
228 subjects.sort_by(|a, b| {
229 let a = SubjectVersion {
230 subject: a.name.clone(),
231 version: a.version,
232 };
233 let b = SubjectVersion {
234 subject: b.name.clone(),
235 version: b.version,
236 };
237 ordered
238 .get(&b)
239 .unwrap_or_else(|| panic!("b {b:?}"))
240 .cmp(ordered.get(&a).unwrap_or_else(|| panic!("a {a:?}")))
241 });
242
243 Ok((primary, subjects))
244 }
245
246 pub async fn publish_schema(
253 &self,
254 subject: &str,
255 schema: &str,
256 schema_type: SchemaType,
257 references: &[SchemaReference],
258 ) -> Result<i32, PublishError> {
259 let req = self.make_request(Method::POST, &["subjects", subject, "versions"]);
260 let req = req.json(&PublishRequest {
261 schema,
262 schema_type,
263 references,
264 });
265 let res: PublishResponse = send_request(req).await?;
266 Ok(res.id)
267 }
268
269 pub async fn set_subject_compatibility_level(
271 &self,
272 subject: &str,
273 compatibility_level: CompatibilityLevel,
274 ) -> Result<(), SetCompatibilityLevelError> {
275 let req = self.make_request(Method::PUT, &["config", subject]);
276 let req = req.json(&CompatibilityLevelRequest {
277 compatibility: compatibility_level,
278 });
279 send_request_raw(req).await?;
280 Ok(())
281 }
282
283 pub async fn list_subjects(&self) -> Result<Vec<String>, ListError> {
285 let req = self.make_request(Method::GET, &["subjects"]);
286 Ok(send_request(req).await?)
287 }
288
289 pub async fn delete_subject(&self, subject: &str) -> Result<(), DeleteError> {
296 let req = self.make_request(Method::DELETE, &["subjects", subject]);
297 send_request_raw(req).await?;
298 Ok(())
299 }
300
301 pub async fn get_subject_and_references_by_id(
306 &self,
307 id: i32,
308 ) -> Result<(Subject, Vec<Subject>), GetBySubjectError> {
309 let req = self.make_request(
310 Method::GET,
311 &["schemas", "ids", &id.to_string(), "versions"],
312 );
313 let res: Vec<SubjectVersion> = send_request(req).await?;
314
315 match res.as_slice() {
327 [first, ..] => {
328 self.get_subject_and_references_by_version(
329 &first.subject,
330 first.version.to_string(),
331 )
332 .await
333 }
334 _ => Err(GetBySubjectError::SubjectNotFound),
335 }
336 }
337}
338
339#[allow(clippy::disallowed_types)]
342pub fn topological_sort<T: Hash + Eq>(
343 graph: &std::collections::HashMap<T, Vec<T>>,
344) -> Result<std::collections::HashMap<&T, i32>, anyhow::Error> {
345 let mut referenced_by: std::collections::HashMap<&T, std::collections::HashSet<&T>> =
346 std::collections::HashMap::new();
347 for (subject, references) in graph.iter() {
348 for reference in references {
349 referenced_by.entry(reference).or_default().insert(subject);
350 }
351 }
352
353 let mut queue: Vec<_> = graph
356 .keys()
357 .filter(|key| {
358 referenced_by
359 .get(*key)
360 .map_or(true, |subjects| subjects.is_empty())
361 })
362 .collect();
363
364 let mut ordered = std::collections::HashMap::new();
365 let mut n = 0;
366 while let Some(subj_ver) = queue.pop() {
367 if let Some(refs) = graph.get(subj_ver) {
368 for ref_ver in refs {
369 let Some(subjects) = referenced_by.get_mut(ref_ver) else {
370 continue;
371 };
372 subjects.remove(&subj_ver);
373 if subjects.is_empty() {
374 referenced_by.remove_entry(ref_ver);
375 queue.push(ref_ver);
376 }
377 }
378 }
379 ordered.insert(subj_ver, n);
380 n += 1;
381 }
382
383 if referenced_by.is_empty() {
384 Ok(ordered)
385 } else {
386 Err(anyhow!("Cycled detected during topoligical sort"))
387 }
388}
389
390async fn send_request<T>(req: reqwest::RequestBuilder) -> Result<T, UnhandledError>
391where
392 T: DeserializeOwned,
393{
394 let res = send_request_raw(req).await?;
395 Ok(res.json().await?)
396}
397
398async fn send_request_raw(req: reqwest::RequestBuilder) -> Result<Response, UnhandledError> {
399 let res = req.send().await?;
400 let status = res.status();
401 if status.is_success() {
402 Ok(res)
403 } else {
404 match res.json::<ErrorResponse>().await {
405 Ok(err_res) => Err(UnhandledError::Api {
406 code: err_res.error_code,
407 message: err_res.message,
408 }),
409 Err(_) => Err(UnhandledError::Api {
410 code: i32::from(status.as_u16()),
411 message: "unable to decode error details".into(),
412 }),
413 }
414 }
415}
416
417#[derive(Clone, Copy, Debug, Serialize)]
419#[serde(rename_all = "UPPERCASE")]
420pub enum SchemaType {
421 Avro,
423 Protobuf,
425 Json,
427}
428
429impl SchemaType {
430 fn is_default(&self) -> bool {
431 matches!(self, SchemaType::Avro)
432 }
433}
434
435#[derive(Debug, Eq, PartialEq)]
437pub struct Schema {
438 pub id: i32,
440 pub raw: String,
442}
443
444#[derive(Debug, Eq, PartialEq)]
446pub struct Subject {
447 pub version: i32,
449 pub name: String,
451 pub schema: Schema,
453}
454
455#[derive(Debug, Serialize, Deserialize)]
457#[serde(rename_all = "camelCase")]
458pub struct SchemaReference {
459 pub name: String,
461 pub subject: String,
463 pub version: i32,
465}
466
467#[derive(Debug, Deserialize)]
468struct GetByIdResponse {
469 schema: String,
470}
471
472#[derive(Debug)]
474pub enum GetByIdError {
475 SchemaNotFound,
477 Transport(reqwest::Error),
479 Server { code: i32, message: String },
481}
482
483#[derive(Debug, Deserialize, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
484pub struct SubjectVersion {
485 pub subject: String,
487 pub version: i32,
489}
490
491impl From<UnhandledError> for GetByIdError {
492 fn from(err: UnhandledError) -> GetByIdError {
493 match err {
494 UnhandledError::Transport(err) => GetByIdError::Transport(err),
495 UnhandledError::Api { code, message } => match code {
496 40403 => GetByIdError::SchemaNotFound,
497 _ => GetByIdError::Server { code, message },
498 },
499 }
500 }
501}
502
503impl Error for GetByIdError {
504 fn source(&self) -> Option<&(dyn Error + 'static)> {
505 match self {
506 GetByIdError::SchemaNotFound | GetByIdError::Server { .. } => None,
507 GetByIdError::Transport(err) => Some(err),
508 }
509 }
510}
511
512impl fmt::Display for GetByIdError {
513 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
514 match self {
515 GetByIdError::SchemaNotFound => write!(f, "schema not found"),
516 GetByIdError::Transport(err) => write!(f, "transport: {}", err),
517 GetByIdError::Server { code, message } => {
518 write!(f, "server error {}: {}", code, message)
519 }
520 }
521 }
522}
523
524#[derive(Debug, Deserialize)]
525#[serde(rename_all = "camelCase")]
526pub struct SubjectConfig {
527 pub compatibility_level: CompatibilityLevel,
528 }
530
531#[derive(Debug)]
533pub enum GetSubjectConfigError {
534 SubjectNotFound,
536 SubjectCompatibilityLevelNotSet,
538 Transport(reqwest::Error),
540 Server { code: i32, message: String },
542}
543
544impl From<UnhandledError> for GetSubjectConfigError {
545 fn from(err: UnhandledError) -> GetSubjectConfigError {
546 match err {
547 UnhandledError::Transport(err) => GetSubjectConfigError::Transport(err),
548 UnhandledError::Api { code, message } => match code {
549 404 => GetSubjectConfigError::SubjectNotFound,
550 40408 => GetSubjectConfigError::SubjectCompatibilityLevelNotSet,
551 _ => GetSubjectConfigError::Server { code, message },
552 },
553 }
554 }
555}
556
557impl Error for GetSubjectConfigError {
558 fn source(&self) -> Option<&(dyn Error + 'static)> {
559 match self {
560 GetSubjectConfigError::SubjectNotFound
561 | GetSubjectConfigError::SubjectCompatibilityLevelNotSet
562 | GetSubjectConfigError::Server { .. } => None,
563 GetSubjectConfigError::Transport(err) => Some(err),
564 }
565 }
566}
567
568impl fmt::Display for GetSubjectConfigError {
569 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
570 match self {
571 GetSubjectConfigError::SubjectNotFound => write!(f, "subject not found"),
572 GetSubjectConfigError::SubjectCompatibilityLevelNotSet => {
573 write!(f, "subject level compatibility not set")
574 }
575 GetSubjectConfigError::Transport(err) => write!(f, "transport: {}", err),
576 GetSubjectConfigError::Server { code, message } => {
577 write!(f, "server error {}: {}", code, message)
578 }
579 }
580 }
581}
582
583#[derive(Debug, Deserialize)]
584#[serde(rename_all = "camelCase")]
585struct GetBySubjectResponse {
586 id: i32,
587 schema: String,
588 version: i32,
589 subject: String,
590 #[serde(default)]
591 references: Vec<SchemaReference>,
592}
593
594#[derive(Debug)]
596pub enum GetBySubjectError {
597 SubjectNotFound,
599 VersionNotFound(String),
601 Transport(reqwest::Error),
603 Server { code: i32, message: String },
605 SchemaReferenceCycle,
607}
608
609impl From<UnhandledError> for GetBySubjectError {
610 fn from(err: UnhandledError) -> GetBySubjectError {
611 match err {
612 UnhandledError::Transport(err) => GetBySubjectError::Transport(err),
613 UnhandledError::Api { code, message } => match code {
614 40401 => GetBySubjectError::SubjectNotFound,
615 40402 => GetBySubjectError::VersionNotFound(message),
616 _ => GetBySubjectError::Server { code, message },
617 },
618 }
619 }
620}
621
622impl Error for GetBySubjectError {
623 fn source(&self) -> Option<&(dyn Error + 'static)> {
624 match self {
625 GetBySubjectError::SubjectNotFound
626 | GetBySubjectError::VersionNotFound(_)
627 | GetBySubjectError::Server { .. }
628 | GetBySubjectError::SchemaReferenceCycle => None,
629 GetBySubjectError::Transport(err) => Some(err),
630 }
631 }
632}
633
634impl fmt::Display for GetBySubjectError {
635 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
636 match self {
637 GetBySubjectError::SubjectNotFound => write!(f, "subject not found"),
638 GetBySubjectError::VersionNotFound(message) => {
639 write!(f, "version not found: {}", message)
640 }
641 GetBySubjectError::Transport(err) => write!(f, "transport: {}", err),
642 GetBySubjectError::Server { code, message } => {
643 write!(f, "server error {}: {}", code, message)
644 }
645 GetBySubjectError::SchemaReferenceCycle => {
646 write!(f, "cycle detected in schema references")
647 }
648 }
649 }
650}
651
652#[derive(Debug, Serialize)]
653#[serde(rename_all = "camelCase")]
654struct PublishRequest<'a> {
655 schema: &'a str,
656 #[serde(skip_serializing_if = "SchemaType::is_default")]
660 schema_type: SchemaType,
661 #[serde(skip_serializing_if = "<[_]>::is_empty")]
662 references: &'a [SchemaReference],
663}
664
665#[derive(Debug, Deserialize)]
666#[serde(rename_all = "camelCase")]
667struct PublishResponse {
668 id: i32,
669}
670
671#[derive(Debug, Serialize)]
672#[serde(rename_all = "camelCase")]
673struct CompatibilityLevelRequest {
674 compatibility: CompatibilityLevel,
675}
676
677#[derive(Arbitrary, Clone, Copy, Debug, Serialize, Deserialize, Eq, PartialEq)]
678#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
679pub enum CompatibilityLevel {
680 Backward,
681 BackwardTransitive,
682 Forward,
683 ForwardTransitive,
684 Full,
685 FullTransitive,
686 None,
687}
688
689impl TryFrom<&str> for CompatibilityLevel {
690 type Error = String;
691
692 fn try_from(value: &str) -> Result<Self, Self::Error> {
693 match value {
694 "BACKWARD" => Ok(CompatibilityLevel::Backward),
695 "BACKWARD_TRANSITIVE" => Ok(CompatibilityLevel::BackwardTransitive),
696 "FORWARD" => Ok(CompatibilityLevel::Forward),
697 "FORWARD_TRANSITIVE" => Ok(CompatibilityLevel::ForwardTransitive),
698 "FULL" => Ok(CompatibilityLevel::Full),
699 "FULL_TRANSITIVE" => Ok(CompatibilityLevel::FullTransitive),
700 "NONE" => Ok(CompatibilityLevel::None),
701 _ => Err(format!("invalid compatibility level: {}", value)),
702 }
703 }
704}
705
706impl fmt::Display for CompatibilityLevel {
707 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
708 match self {
709 CompatibilityLevel::Backward => write!(f, "BACKWARD"),
710 CompatibilityLevel::BackwardTransitive => write!(f, "BACKWARD_TRANSITIVE"),
711 CompatibilityLevel::Forward => write!(f, "FORWARD"),
712 CompatibilityLevel::ForwardTransitive => write!(f, "FORWARD_TRANSITIVE"),
713 CompatibilityLevel::Full => write!(f, "FULL"),
714 CompatibilityLevel::FullTransitive => write!(f, "FULL_TRANSITIVE"),
715 CompatibilityLevel::None => write!(f, "NONE"),
716 }
717 }
718}
719
720#[derive(Debug)]
722pub enum PublishError {
723 IncompatibleSchema,
727 InvalidSchema { message: String },
729 Transport(reqwest::Error),
731 Server { code: i32, message: String },
733}
734
735impl From<UnhandledError> for PublishError {
736 fn from(err: UnhandledError) -> PublishError {
737 match err {
738 UnhandledError::Transport(err) => PublishError::Transport(err),
739 UnhandledError::Api { code, message } => match code {
740 409 => PublishError::IncompatibleSchema,
741 42201 => PublishError::InvalidSchema { message },
742 _ => PublishError::Server { code, message },
743 },
744 }
745 }
746}
747
748impl Error for PublishError {
749 fn source(&self) -> Option<&(dyn Error + 'static)> {
750 match self {
751 PublishError::IncompatibleSchema
752 | PublishError::InvalidSchema { .. }
753 | PublishError::Server { .. } => None,
754 PublishError::Transport(err) => Some(err),
755 }
756 }
757}
758
759impl fmt::Display for PublishError {
760 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
761 match self {
762 PublishError::IncompatibleSchema => write!(
765 f,
766 "schema being registered is incompatible with an earlier schema"
767 ),
768 PublishError::InvalidSchema { message } => write!(f, "{}", message),
769 PublishError::Transport(err) => write!(f, "transport: {}", err),
770 PublishError::Server { code, message } => {
771 write!(f, "server error {}: {}", code, message)
772 }
773 }
774 }
775}
776
777#[derive(Debug)]
779pub enum ListError {
780 Transport(reqwest::Error),
782 Server { code: i32, message: String },
784}
785
786impl From<UnhandledError> for ListError {
787 fn from(err: UnhandledError) -> ListError {
788 match err {
789 UnhandledError::Transport(err) => ListError::Transport(err),
790 UnhandledError::Api { code, message } => ListError::Server { code, message },
791 }
792 }
793}
794
795impl Error for ListError {
796 fn source(&self) -> Option<&(dyn Error + 'static)> {
797 match self {
798 ListError::Server { .. } => None,
799 ListError::Transport(err) => Some(err),
800 }
801 }
802}
803
804impl fmt::Display for ListError {
805 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
806 match self {
807 ListError::Transport(err) => write!(f, "transport: {}", err),
808 ListError::Server { code, message } => write!(f, "server error {}: {}", code, message),
809 }
810 }
811}
812
813#[derive(Debug)]
815pub enum DeleteError {
816 SubjectNotFound,
818 Transport(reqwest::Error),
820 Server { code: i32, message: String },
822}
823
824impl From<UnhandledError> for DeleteError {
825 fn from(err: UnhandledError) -> DeleteError {
826 match err {
827 UnhandledError::Transport(err) => DeleteError::Transport(err),
828 UnhandledError::Api { code, message } => match code {
829 40401 => DeleteError::SubjectNotFound,
830 _ => DeleteError::Server { code, message },
831 },
832 }
833 }
834}
835
836impl Error for DeleteError {
837 fn source(&self) -> Option<&(dyn Error + 'static)> {
838 match self {
839 DeleteError::SubjectNotFound | DeleteError::Server { .. } => None,
840 DeleteError::Transport(err) => Some(err),
841 }
842 }
843}
844
845impl fmt::Display for DeleteError {
846 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
847 match self {
848 DeleteError::SubjectNotFound => write!(f, "subject not found"),
849 DeleteError::Transport(err) => write!(f, "transport: {}", err),
850 DeleteError::Server { code, message } => {
851 write!(f, "server error {}: {}", code, message)
852 }
853 }
854 }
855}
856
857#[derive(Debug)]
859pub enum SetCompatibilityLevelError {
860 InvalidCompatibilityLevel,
862 Transport(reqwest::Error),
864 Server { code: i32, message: String },
866}
867
868impl From<UnhandledError> for SetCompatibilityLevelError {
869 fn from(err: UnhandledError) -> SetCompatibilityLevelError {
870 match err {
871 UnhandledError::Transport(err) => SetCompatibilityLevelError::Transport(err),
872 UnhandledError::Api { code, message } => match code {
873 42203 => SetCompatibilityLevelError::InvalidCompatibilityLevel,
874 _ => SetCompatibilityLevelError::Server { code, message },
875 },
876 }
877 }
878}
879
880impl Error for SetCompatibilityLevelError {
881 fn source(&self) -> Option<&(dyn Error + 'static)> {
882 match self {
883 SetCompatibilityLevelError::InvalidCompatibilityLevel
884 | SetCompatibilityLevelError::Server { .. } => None,
885 SetCompatibilityLevelError::Transport(err) => Some(err),
886 }
887 }
888}
889
890impl fmt::Display for SetCompatibilityLevelError {
891 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
892 match self {
893 SetCompatibilityLevelError::InvalidCompatibilityLevel => {
894 write!(f, "invalid compatibility level")
895 }
896 SetCompatibilityLevelError::Transport(err) => write!(f, "transport: {}", err),
897 SetCompatibilityLevelError::Server { code, message } => {
898 write!(f, "server error {}: {}", code, message)
899 }
900 }
901 }
902}
903
904#[derive(Debug, Deserialize)]
905struct ErrorResponse {
906 error_code: i32,
907 message: String,
908}
909
910#[derive(Debug)]
911enum UnhandledError {
912 Transport(reqwest::Error),
913 Api { code: i32, message: String },
914}
915
916impl From<reqwest::Error> for UnhandledError {
917 fn from(err: reqwest::Error) -> UnhandledError {
918 UnhandledError::Transport(err)
919 }
920}
921
922#[cfg(test)]
923mod tests {
924 #![allow(clippy::disallowed_types)]
925 use super::*;
926 use std::collections::HashMap;
927
928 fn sv(subject: &str, version: i32) -> SubjectVersion {
930 SubjectVersion {
931 subject: subject.to_string(),
932 version,
933 }
934 }
935
936 fn build_graph(
941 edges: &[(SubjectVersion, Option<SubjectVersion>)],
942 ) -> HashMap<SubjectVersion, Vec<SubjectVersion>> {
943 let mut graph: HashMap<SubjectVersion, Vec<SubjectVersion>> = HashMap::new();
944
945 for (from, to) in edges {
947 let deps = graph.entry(from.clone()).or_default();
948 if let Some(to) = to {
949 deps.push(to.clone());
950 }
951 }
952
953 graph
954 }
955
956 fn verify_edge_ordering(
962 ordered: &HashMap<&SubjectVersion, i32>,
963 edges: &[(SubjectVersion, Option<SubjectVersion>)],
964 ) {
965 for (from, to) in edges {
966 if let Some(to) = to {
967 let from_order = ordered.get(from).expect("from node should be in ordering");
968 let to_order = ordered.get(to).expect("to node should be in ordering");
969
970 assert!(
973 from_order < to_order,
974 "{:?} (order {}) depends on {:?} (order {}), so should be processed first",
975 from,
976 from_order,
977 to,
978 to_order
979 );
980 }
981 }
982 }
983
984 #[mz_ore::test]
985 fn test_topological_sort_empty_graph() {
986 let graph: HashMap<SubjectVersion, Vec<SubjectVersion>> = HashMap::new();
987
988 let ordered = topological_sort(&graph).unwrap();
989
990 assert!(ordered.is_empty());
991 }
992
993 #[mz_ore::test]
994 fn test_topological_sort_single_node() {
995 let a = sv("a", 1);
997 let graph = build_graph(&[(a.clone(), None)]);
998
999 let ordered = topological_sort(&graph).unwrap();
1000
1001 assert_eq!(ordered.len(), 1);
1002 assert!(ordered.contains_key(&a));
1003 }
1004
1005 #[mz_ore::test]
1006 fn test_topological_sort_linear_chain() {
1007 let a = sv("a", 1);
1009 let b = sv("b", 1);
1010 let c = sv("c", 1);
1011 let d = sv("d", 1);
1012
1013 let edges = vec![
1014 (a.clone(), Some(b.clone())),
1015 (b.clone(), Some(c.clone())),
1016 (c.clone(), Some(d.clone())),
1017 (d.clone(), None),
1018 ];
1019 let graph = build_graph(&edges);
1020
1021 let ordered = topological_sort(&graph).unwrap();
1022
1023 assert_eq!(ordered.len(), 4);
1024 verify_edge_ordering(&ordered, &edges);
1025 }
1026
1027 #[mz_ore::test]
1028 fn test_topological_sort_diamond() {
1029 let a = sv("a", 1);
1037 let b = sv("b", 1);
1038 let c = sv("c", 1);
1039 let d = sv("d", 1);
1040
1041 let edges = vec![
1042 (a.clone(), Some(b.clone())),
1043 (a.clone(), Some(c.clone())),
1044 (b.clone(), Some(d.clone())),
1045 (c.clone(), Some(d.clone())),
1046 (d.clone(), None),
1047 ];
1048 let graph = build_graph(&edges);
1049
1050 let ordered = topological_sort(&graph).unwrap();
1051
1052 assert_eq!(ordered.len(), 4);
1053 verify_edge_ordering(&ordered, &edges);
1054
1055 assert!(ordered[&a] < ordered[&d]);
1057 }
1058
1059 #[mz_ore::test]
1060 fn test_topological_sort_wide_graph() {
1061 let a = sv("a", 1);
1063 let b = sv("b", 1);
1064 let c = sv("c", 1);
1065 let d = sv("d", 1);
1066 let e = sv("e", 1);
1067 let f = sv("f", 1);
1068
1069 let edges = vec![
1070 (a.clone(), Some(b.clone())),
1071 (a.clone(), Some(c.clone())),
1072 (a.clone(), Some(d.clone())),
1073 (a.clone(), Some(e.clone())),
1074 (a.clone(), Some(f.clone())),
1075 (b.clone(), None),
1076 (c.clone(), None),
1077 (d.clone(), None),
1078 (e.clone(), None),
1079 (f.clone(), None),
1080 ];
1081 let graph = build_graph(&edges);
1082
1083 let ordered = topological_sort(&graph).unwrap();
1084
1085 assert_eq!(ordered.len(), 6);
1086 verify_edge_ordering(&ordered, &edges);
1087 }
1088
1089 #[mz_ore::test]
1090 fn test_topological_sort_complex_dag() {
1091 let a = sv("a", 1);
1107 let b = sv("b", 1);
1108 let c = sv("c", 1);
1109 let d = sv("d", 1);
1110 let e = sv("e", 1);
1111 let f = sv("f", 1);
1112 let g = sv("g", 1);
1113 let h = sv("h", 1);
1114
1115 let edges = vec![
1116 (a.clone(), Some(b.clone())),
1117 (a.clone(), Some(c.clone())),
1118 (a.clone(), Some(d.clone())),
1119 (b.clone(), Some(e.clone())),
1120 (c.clone(), Some(e.clone())),
1121 (c.clone(), Some(f.clone())),
1122 (d.clone(), Some(g.clone())),
1123 (e.clone(), Some(h.clone())),
1124 (f.clone(), Some(h.clone())),
1125 (g.clone(), Some(h.clone())),
1126 (h.clone(), None),
1127 ];
1128 let graph = build_graph(&edges);
1129
1130 let ordered = topological_sort(&graph).unwrap();
1131
1132 assert_eq!(ordered.len(), 8);
1133 verify_edge_ordering(&ordered, &edges);
1134 }
1135
1136 #[mz_ore::test]
1137 fn test_topological_sort_with_versions() {
1138 let a_v1 = sv("a", 1);
1145 let a_v2 = sv("a", 2);
1146 let b_v1 = sv("b", 1);
1147
1148 let edges = vec![
1150 (a_v2.clone(), Some(a_v1.clone())),
1151 (a_v2.clone(), Some(b_v1.clone())),
1152 (a_v1.clone(), Some(b_v1.clone())),
1153 (b_v1.clone(), None),
1154 ];
1155 let graph = build_graph(&edges);
1156
1157 let ordered = topological_sort(&graph).unwrap();
1158
1159 assert_eq!(ordered.len(), 3);
1160 verify_edge_ordering(&ordered, &edges);
1161 }
1162
1163 #[mz_ore::test]
1164 fn test_topological_sort_multi_level_diamond() {
1165 let a = sv("a", 1);
1176 let b = sv("b", 1);
1177 let c = sv("c", 1);
1178 let d = sv("d", 1);
1179 let e = sv("e", 1);
1180 let f = sv("f", 1);
1181 let g = sv("g", 1);
1182
1183 let edges = vec![
1184 (a.clone(), Some(b.clone())),
1185 (a.clone(), Some(c.clone())),
1186 (b.clone(), Some(d.clone())),
1187 (c.clone(), Some(d.clone())),
1188 (d.clone(), Some(e.clone())),
1189 (d.clone(), Some(f.clone())),
1190 (e.clone(), Some(g.clone())),
1191 (f.clone(), Some(g.clone())),
1192 (g.clone(), None),
1193 ];
1194 let graph = build_graph(&edges);
1195
1196 let ordered = topological_sort(&graph).unwrap();
1197
1198 assert_eq!(ordered.len(), 7);
1199 verify_edge_ordering(&ordered, &edges);
1200 }
1201
1202 #[mz_ore::test]
1203 fn test_topological_sort_shared_dependency_at_multiple_levels() {
1204 let a = sv("a", 1);
1215 let b = sv("b", 1);
1216 let c = sv("c", 1);
1217 let d = sv("d", 1);
1218 let e = sv("e", 1);
1219 let f = sv("f", 1);
1220
1221 let edges = vec![
1222 (a.clone(), Some(b.clone())),
1223 (a.clone(), Some(c.clone())),
1224 (a.clone(), Some(e.clone())),
1225 (b.clone(), Some(d.clone())),
1226 (c.clone(), Some(d.clone())),
1227 (d.clone(), Some(e.clone())),
1228 (e.clone(), Some(f.clone())),
1229 (f.clone(), None),
1230 ];
1231 let graph = build_graph(&edges);
1232
1233 let ordered = topological_sort(&graph).expect("no cycle");
1234
1235 assert_eq!(ordered.len(), 6);
1236 verify_edge_ordering(&ordered, &edges);
1237 }
1238
1239 #[mz_ore::test]
1240 fn test_topological_sort_lattice_structure() {
1241 let a = sv("a", 1);
1260 let b = sv("b", 1);
1261 let c = sv("c", 1);
1262 let d = sv("d", 1);
1263 let e = sv("e", 1);
1264 let f = sv("f", 1);
1265 let g = sv("g", 1);
1266 let h = sv("h", 1);
1267
1268 let edges = vec![
1269 (a.clone(), Some(b.clone())),
1270 (a.clone(), Some(c.clone())),
1271 (a.clone(), Some(d.clone())),
1272 (b.clone(), Some(e.clone())),
1273 (b.clone(), Some(f.clone())),
1274 (c.clone(), Some(e.clone())),
1275 (c.clone(), Some(f.clone())),
1276 (c.clone(), Some(g.clone())),
1277 (d.clone(), Some(f.clone())),
1278 (d.clone(), Some(g.clone())),
1279 (e.clone(), Some(h.clone())),
1280 (f.clone(), Some(h.clone())),
1281 (g.clone(), Some(h.clone())),
1282 (h.clone(), None),
1283 ];
1284 let graph = build_graph(&edges);
1285
1286 let ordered = topological_sort(&graph).unwrap();
1287
1288 assert_eq!(ordered.len(), 8);
1289 verify_edge_ordering(&ordered, &edges);
1290 }
1291
1292 #[mz_ore::test]
1293 fn test_topological_sort_binary_tree() {
1294 let a = sv("a", 1);
1301 let b = sv("b", 1);
1302 let c = sv("c", 1);
1303 let d = sv("d", 1);
1304 let e = sv("e", 1);
1305 let f = sv("f", 1);
1306 let g = sv("g", 1);
1307
1308 let edges = vec![
1309 (a.clone(), Some(b.clone())),
1310 (a.clone(), Some(c.clone())),
1311 (b.clone(), Some(d.clone())),
1312 (b.clone(), Some(e.clone())),
1313 (c.clone(), Some(f.clone())),
1314 (c.clone(), Some(g.clone())),
1315 (d.clone(), None),
1316 (e.clone(), None),
1317 (f.clone(), None),
1318 (g.clone(), None),
1319 ];
1320 let graph = build_graph(&edges);
1321
1322 let ordered = topological_sort(&graph).unwrap();
1323
1324 assert_eq!(ordered.len(), 7);
1325 verify_edge_ordering(&ordered, &edges);
1326 }
1327
1328 #[mz_ore::test]
1329 fn test_topological_sort_simple_cycle() {
1330 let a = sv("a", 1);
1332 let b = sv("b", 1);
1333
1334 let mut graph: HashMap<SubjectVersion, Vec<SubjectVersion>> = HashMap::new();
1335 graph.insert(a.clone(), vec![b.clone()]);
1336 graph.insert(b.clone(), vec![a.clone()]);
1337
1338 let sort_result = topological_sort(&graph);
1339
1340 assert!(sort_result.is_err(), "Expected sort to detect cycle");
1341 }
1342
1343 #[mz_ore::test]
1344 fn test_topological_sort_cycle_with_entry_point() {
1345 let a = sv("a", 1);
1347 let b = sv("b", 1);
1348 let c = sv("c", 1);
1349
1350 let mut graph: HashMap<SubjectVersion, Vec<SubjectVersion>> = HashMap::new();
1351 graph.insert(a.clone(), vec![b.clone()]);
1352 graph.insert(b.clone(), vec![c.clone()]);
1353 graph.insert(c.clone(), vec![b.clone()]); let sort_result = topological_sort(&graph);
1356 assert!(sort_result.is_err(), "Expected sort to detect cycle");
1357 }
1358
1359 #[mz_ore::test]
1360 fn test_topological_sort_self_reference() {
1361 let a = sv("a", 1);
1363
1364 let mut graph: HashMap<SubjectVersion, Vec<SubjectVersion>> = HashMap::new();
1365 graph.insert(a.clone(), vec![a.clone()]);
1366
1367 let sort_result = topological_sort(&graph);
1368 assert!(sort_result.is_err(), "Expected sort to detect cycle");
1369 }
1370
1371 #[mz_ore::test]
1372 fn test_topological_sort_three_node_cycle() {
1373 let a = sv("a", 1);
1375 let b = sv("b", 1);
1376 let c = sv("c", 1);
1377
1378 let mut graph: HashMap<SubjectVersion, Vec<SubjectVersion>> = HashMap::new();
1379 graph.insert(a.clone(), vec![b.clone()]);
1380 graph.insert(b.clone(), vec![c.clone()]);
1381 graph.insert(c.clone(), vec![a.clone()]);
1382
1383 let sort_result = topological_sort(&graph);
1384 assert!(sort_result.is_err(), "Expected sort to detect cycle");
1385 }
1386}