1use std::error::Error;
11use std::fmt;
12use std::hash::Hash;
13use std::sync::Arc;
14use std::time::Duration;
15
16use anyhow::{anyhow, bail};
17use proptest_derive::Arbitrary;
18use reqwest::{Method, Response, Url};
19use serde::de::DeserializeOwned;
20use serde::{Deserialize, Serialize};
21
22use crate::config::Auth;
23
24#[derive(Clone)]
26pub struct Client {
27 inner: reqwest::Client,
28 url: Arc<dyn Fn() -> Url + Send + Sync + 'static>,
29 auth: Option<Auth>,
30 timeout: Duration,
31}
32
33impl fmt::Debug for Client {
34 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35 f.debug_struct("Client")
36 .field("inner", &self.inner)
37 .field("url", &"...")
38 .field("auth", &self.auth)
39 .finish()
40 }
41}
42
43impl Client {
44 pub(crate) fn new(
45 inner: reqwest::Client,
46 url: Arc<dyn Fn() -> Url + Send + Sync + 'static>,
47 auth: Option<Auth>,
48 timeout: Duration,
49 ) -> Result<Self, anyhow::Error> {
50 if url().cannot_be_a_base() {
51 bail!("cannot construct a CCSR client with a cannot-be-a-base URL");
52 }
53 Ok(Client {
54 inner,
55 url,
56 auth,
57 timeout,
58 })
59 }
60
61 fn make_request<P>(&self, method: Method, path: P) -> reqwest::RequestBuilder
62 where
63 P: IntoIterator,
64 P::Item: AsRef<str>,
65 {
66 let mut url = (self.url)();
67 url.path_segments_mut()
68 .expect("constructor validated URL can be a base")
69 .clear()
70 .extend(path);
71
72 let mut request = self.inner.request(method, url);
73 if let Some(auth) = &self.auth {
74 request = request.basic_auth(&auth.username, auth.password.as_ref());
75 }
76 request
77 }
78
79 pub fn timeout(&self) -> Duration {
80 self.timeout
81 }
82
83 pub async fn get_schema_by_id(&self, id: i32) -> Result<Schema, GetByIdError> {
85 let req = self.make_request(Method::GET, &["schemas", "ids", &id.to_string()]);
86 let res: GetByIdResponse = send_request(req).await?;
87 Ok(Schema {
88 id,
89 raw: res.schema,
90 })
91 }
92
93 pub async fn get_schema_by_subject(&self, subject: &str) -> Result<Schema, GetBySubjectError> {
95 self.get_subject_latest(subject).await.map(|s| s.schema)
96 }
97
98 pub async fn get_subject_latest(&self, subject: &str) -> Result<Subject, GetBySubjectError> {
100 let req = self.make_request(Method::GET, &["subjects", subject, "versions", "latest"]);
101 let res: GetBySubjectResponse = send_request(req).await?;
102 Ok(Subject {
103 schema: Schema {
104 id: res.id,
105 raw: res.schema,
106 },
107 version: res.version,
108 name: res.subject,
109 })
110 }
111
112 pub async fn get_subject_with_references(
115 &self,
116 subject: &str,
117 ) -> Result<(Subject, Vec<SubjectVersion>), GetBySubjectError> {
118 let req = self.make_request(Method::GET, &["subjects", subject, "versions", "latest"]);
119 let res: GetBySubjectResponse = send_request(req).await?;
120 let referenced_subjects: Vec<_> = res
121 .references
122 .into_iter()
123 .map(|r| SubjectVersion {
124 subject: r.subject,
125 version: r.version,
126 })
127 .collect();
128 Ok((
129 Subject {
130 schema: Schema {
131 id: res.id,
132 raw: res.schema,
133 },
134 version: res.version,
135 name: res.subject,
136 },
137 referenced_subjects,
138 ))
139 }
140
141 pub async fn get_subject_config(
143 &self,
144 subject: &str,
145 ) -> Result<SubjectConfig, GetSubjectConfigError> {
146 let req = self.make_request(Method::GET, &["config", subject]);
147 let res: SubjectConfig = send_request(req).await?;
148 Ok(res)
149 }
150
151 pub async fn get_subject_and_references(
156 &self,
157 subject: &str,
158 ) -> Result<(Subject, Vec<Subject>), GetBySubjectError> {
159 self.get_subject_and_references_by_version(subject, "latest".to_owned())
160 .await
161 }
162
163 #[allow(clippy::disallowed_types)]
167 async fn get_subject_and_references_by_version(
168 &self,
169 subject: &str,
170 version: String,
171 ) -> Result<(Subject, Vec<Subject>), GetBySubjectError> {
172 let mut subjects = vec![];
173 let mut graph = std::collections::HashMap::new();
175 let mut subjects_queue = vec![(subject.to_owned(), version)];
176 while let Some((subject, version)) = subjects_queue.pop() {
177 let req = self.make_request(Method::GET, &["subjects", &subject, "versions", &version]);
178 let res: GetBySubjectResponse = send_request(req).await?;
179 subjects.push(Subject {
180 schema: Schema {
181 id: res.id,
182 raw: res.schema,
183 },
184 version: res.version,
185 name: res.subject.clone(),
186 });
187 let subject_key = SubjectVersion {
188 subject: res.subject,
189 version: res.version,
190 };
191
192 let dependents: Vec<_> = res
193 .references
194 .into_iter()
195 .map(|reference| SubjectVersion {
196 subject: reference.subject,
197 version: reference.version,
198 })
199 .collect();
200
201 graph
202 .entry(subject_key)
203 .or_insert_with(Vec::new)
204 .extend(dependents.iter().cloned());
205 subjects_queue.extend(
206 dependents
207 .into_iter()
208 .filter(|dep| match graph.entry(dep.clone()) {
212 std::collections::hash_map::Entry::Occupied(_) => false,
213 std::collections::hash_map::Entry::Vacant(vacant) => {
214 vacant.insert(vec![]);
215 true
216 }
217 })
218 .map(|dep| (dep.subject, dep.version.to_string())),
219 );
220 }
221 assert!(subjects.len() > 0, "Request should error if no subjects");
222
223 let primary = subjects.remove(0);
224
225 let ordered =
226 topological_sort(&graph).map_err(|_| GetBySubjectError::SchemaReferenceCycle)?;
227
228 subjects.sort_by(|a, b| {
229 let a = SubjectVersion {
230 subject: a.name.clone(),
231 version: a.version,
232 };
233 let b = SubjectVersion {
234 subject: b.name.clone(),
235 version: b.version,
236 };
237 ordered
238 .get(&b)
239 .unwrap_or_else(|| panic!("b {b:?}"))
240 .cmp(ordered.get(&a).unwrap_or_else(|| panic!("a {a:?}")))
241 });
242
243 Ok((primary, subjects))
244 }
245
246 pub async fn publish_schema(
253 &self,
254 subject: &str,
255 schema: &str,
256 schema_type: SchemaType,
257 references: &[SchemaReference],
258 ) -> Result<i32, PublishError> {
259 let req = self.make_request(Method::POST, &["subjects", subject, "versions"]);
260 let req = req.json(&PublishRequest {
261 schema,
262 schema_type,
263 references,
264 });
265 let res: PublishResponse = send_request(req).await?;
266 Ok(res.id)
267 }
268
269 pub async fn set_subject_compatibility_level(
271 &self,
272 subject: &str,
273 compatibility_level: CompatibilityLevel,
274 ) -> Result<(), SetCompatibilityLevelError> {
275 let req = self.make_request(Method::PUT, &["config", subject]);
276 let req = req.json(&CompatibilityLevelRequest {
277 compatibility: compatibility_level,
278 });
279 send_request_raw(req).await?;
280 Ok(())
281 }
282
283 pub async fn list_subjects(&self) -> Result<Vec<String>, ListError> {
285 let req = self.make_request(Method::GET, &["subjects"]);
286 Ok(send_request(req).await?)
287 }
288
289 pub async fn delete_subject(&self, subject: &str) -> Result<(), DeleteError> {
296 let req = self.make_request(Method::DELETE, &["subjects", subject]);
297 send_request_raw(req).await?;
298 Ok(())
299 }
300
301 pub async fn get_subject_and_references_by_id(
306 &self,
307 id: i32,
308 ) -> Result<(Subject, Vec<Subject>), GetBySubjectError> {
309 let req = self.make_request(
310 Method::GET,
311 &["schemas", "ids", &id.to_string(), "versions"],
312 );
313 let res: Vec<SubjectVersion> = send_request(req).await?;
314
315 match res.as_slice() {
327 [first, ..] => {
328 self.get_subject_and_references_by_version(
329 &first.subject,
330 first.version.to_string(),
331 )
332 .await
333 }
334 _ => Err(GetBySubjectError::SubjectNotFound),
335 }
336 }
337}
338
339#[allow(clippy::disallowed_types)]
342pub fn topological_sort<T: Hash + Eq>(
343 graph: &std::collections::HashMap<T, Vec<T>>,
344) -> Result<std::collections::HashMap<&T, i32>, anyhow::Error> {
345 let mut referenced_by: std::collections::HashMap<&T, std::collections::HashSet<&T>> =
346 std::collections::HashMap::new();
347 for (subject, references) in graph.iter() {
348 for reference in references {
349 referenced_by.entry(reference).or_default().insert(subject);
350 }
351 }
352
353 let mut queue: Vec<_> = graph
356 .keys()
357 .filter(|key| {
358 referenced_by
359 .get(*key)
360 .map_or(true, |subjects| subjects.is_empty())
361 })
362 .collect();
363
364 let mut ordered = std::collections::HashMap::new();
365 let mut n = 0;
366 while let Some(subj_ver) = queue.pop() {
367 if let Some(refs) = graph.get(subj_ver) {
368 for ref_ver in refs {
369 let Some(subjects) = referenced_by.get_mut(ref_ver) else {
370 continue;
371 };
372 subjects.remove(&subj_ver);
373 if subjects.is_empty() {
374 referenced_by.remove_entry(ref_ver);
375 queue.push(ref_ver);
376 }
377 }
378 }
379 ordered.insert(subj_ver, n);
380 n += 1;
381 }
382
383 if referenced_by.is_empty() {
384 Ok(ordered)
385 } else {
386 Err(anyhow!("Cycled detected during topoligical sort"))
387 }
388}
389
390async fn send_request<T>(req: reqwest::RequestBuilder) -> Result<T, UnhandledError>
391where
392 T: DeserializeOwned,
393{
394 let res = send_request_raw(req).await?;
395 Ok(res.json().await?)
396}
397
398async fn send_request_raw(req: reqwest::RequestBuilder) -> Result<Response, UnhandledError> {
399 let res = req.send().await?;
400 let status = res.status();
401 if status.is_success() {
402 Ok(res)
403 } else {
404 match res.json::<ErrorResponse>().await {
405 Ok(err_res) => Err(UnhandledError::Api {
406 code: err_res.error_code,
407 message: err_res.message,
408 }),
409 Err(_) => Err(UnhandledError::Api {
410 code: i32::from(status.as_u16()),
411 message: "unable to decode error details".into(),
412 }),
413 }
414 }
415}
416
417#[derive(Clone, Copy, Debug, Serialize)]
419#[serde(rename_all = "UPPERCASE")]
420pub enum SchemaType {
421 Avro,
423 Protobuf,
425 Json,
427}
428
429impl SchemaType {
430 fn is_default(&self) -> bool {
431 matches!(self, SchemaType::Avro)
432 }
433}
434
435#[derive(Debug, Eq, PartialEq)]
437pub struct Schema {
438 pub id: i32,
440 pub raw: String,
442}
443
444#[derive(Debug, Eq, PartialEq)]
446pub struct Subject {
447 pub version: i32,
449 pub name: String,
451 pub schema: Schema,
453}
454
455#[derive(Debug, Serialize, Deserialize)]
457#[serde(rename_all = "camelCase")]
458pub struct SchemaReference {
459 pub name: String,
461 pub subject: String,
463 pub version: i32,
465}
466
467#[derive(Debug, Deserialize)]
468struct GetByIdResponse {
469 schema: String,
470}
471
472#[derive(Debug)]
474pub enum GetByIdError {
475 SchemaNotFound,
477 Transport(reqwest::Error),
479 Server { code: i32, message: String },
481}
482
483#[derive(Debug, Deserialize, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
484pub struct SubjectVersion {
485 pub subject: String,
487 pub version: i32,
489}
490
491impl From<UnhandledError> for GetByIdError {
492 fn from(err: UnhandledError) -> GetByIdError {
493 match err {
494 UnhandledError::Transport(err) => GetByIdError::Transport(err),
495 UnhandledError::Api { code, message } => match code {
496 40403 => GetByIdError::SchemaNotFound,
497 _ => GetByIdError::Server { code, message },
498 },
499 }
500 }
501}
502
503impl Error for GetByIdError {
504 fn source(&self) -> Option<&(dyn Error + 'static)> {
505 match self {
506 GetByIdError::SchemaNotFound | GetByIdError::Server { .. } => None,
507 GetByIdError::Transport(err) => Some(err),
508 }
509 }
510}
511
512impl fmt::Display for GetByIdError {
513 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
514 match self {
515 GetByIdError::SchemaNotFound => write!(f, "schema not found"),
516 GetByIdError::Transport(err) => write!(f, "transport: {}", err),
517 GetByIdError::Server { code, message } => {
518 write!(f, "server error {}: {}", code, message)
519 }
520 }
521 }
522}
523
524#[derive(Debug, Deserialize)]
525#[serde(rename_all = "camelCase")]
526pub struct SubjectConfig {
527 pub compatibility_level: CompatibilityLevel,
528 }
530
531#[derive(Debug)]
533pub enum GetSubjectConfigError {
534 SubjectNotFound,
536 SubjectCompatibilityLevelNotSet,
538 Transport(reqwest::Error),
540 Server { code: i32, message: String },
542}
543
544impl From<UnhandledError> for GetSubjectConfigError {
545 fn from(err: UnhandledError) -> GetSubjectConfigError {
546 match err {
547 UnhandledError::Transport(err) => GetSubjectConfigError::Transport(err),
548 UnhandledError::Api { code, message } => match code {
549 404 => GetSubjectConfigError::SubjectNotFound,
550 40408 => GetSubjectConfigError::SubjectCompatibilityLevelNotSet,
551 _ => GetSubjectConfigError::Server { code, message },
552 },
553 }
554 }
555}
556
557impl Error for GetSubjectConfigError {
558 fn source(&self) -> Option<&(dyn Error + 'static)> {
559 match self {
560 GetSubjectConfigError::SubjectNotFound
561 | GetSubjectConfigError::SubjectCompatibilityLevelNotSet
562 | GetSubjectConfigError::Server { .. } => None,
563 GetSubjectConfigError::Transport(err) => Some(err),
564 }
565 }
566}
567
568impl fmt::Display for GetSubjectConfigError {
569 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
570 match self {
571 GetSubjectConfigError::SubjectNotFound => write!(f, "subject not found"),
572 GetSubjectConfigError::SubjectCompatibilityLevelNotSet => {
573 write!(f, "subject level compatibility not set")
574 }
575 GetSubjectConfigError::Transport(err) => write!(f, "transport: {}", err),
576 GetSubjectConfigError::Server { code, message } => {
577 write!(f, "server error {}: {}", code, message)
578 }
579 }
580 }
581}
582
583#[derive(Debug, Deserialize)]
584#[serde(rename_all = "camelCase")]
585struct GetBySubjectResponse {
586 id: i32,
587 schema: String,
588 version: i32,
589 subject: String,
590 #[serde(default)]
591 references: Vec<SchemaReference>,
592}
593
594#[derive(Debug)]
596pub enum GetBySubjectError {
597 SubjectNotFound,
599 VersionNotFound(String),
601 Transport(reqwest::Error),
603 Server { code: i32, message: String },
605 SchemaReferenceCycle,
607}
608
609impl From<UnhandledError> for GetBySubjectError {
610 fn from(err: UnhandledError) -> GetBySubjectError {
611 match err {
612 UnhandledError::Transport(err) => GetBySubjectError::Transport(err),
613 UnhandledError::Api { code, message } => match code {
614 40401 => GetBySubjectError::SubjectNotFound,
615 40402 => GetBySubjectError::VersionNotFound(message),
616 _ => GetBySubjectError::Server { code, message },
617 },
618 }
619 }
620}
621
622impl Error for GetBySubjectError {
623 fn source(&self) -> Option<&(dyn Error + 'static)> {
624 match self {
625 GetBySubjectError::SubjectNotFound
626 | GetBySubjectError::VersionNotFound(_)
627 | GetBySubjectError::Server { .. }
628 | GetBySubjectError::SchemaReferenceCycle => None,
629 GetBySubjectError::Transport(err) => Some(err),
630 }
631 }
632}
633
634impl fmt::Display for GetBySubjectError {
635 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
636 match self {
637 GetBySubjectError::SubjectNotFound => write!(f, "subject not found"),
638 GetBySubjectError::VersionNotFound(message) => {
639 write!(f, "version not found: {}", message)
640 }
641 GetBySubjectError::Transport(err) => write!(f, "transport: {}", err),
642 GetBySubjectError::Server { code, message } => {
643 write!(f, "server error {}: {}", code, message)
644 }
645 GetBySubjectError::SchemaReferenceCycle => {
646 write!(f, "cycle detected in schema references")
647 }
648 }
649 }
650}
651
652#[derive(Debug, Serialize)]
653#[serde(rename_all = "camelCase")]
654struct PublishRequest<'a> {
655 schema: &'a str,
656 #[serde(skip_serializing_if = "SchemaType::is_default")]
660 schema_type: SchemaType,
661 #[serde(skip_serializing_if = "<[_]>::is_empty")]
662 references: &'a [SchemaReference],
663}
664
665#[derive(Debug, Deserialize)]
666#[serde(rename_all = "camelCase")]
667struct PublishResponse {
668 id: i32,
669}
670
671#[derive(Debug, Serialize)]
672#[serde(rename_all = "camelCase")]
673struct CompatibilityLevelRequest {
674 compatibility: CompatibilityLevel,
675}
676
677#[derive(Arbitrary, Clone, Copy, Debug, Serialize, Deserialize, Eq, PartialEq)]
678#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
679pub enum CompatibilityLevel {
680 Backward,
681 BackwardTransitive,
682 Forward,
683 ForwardTransitive,
684 Full,
685 FullTransitive,
686 None,
687}
688
689impl TryFrom<&str> for CompatibilityLevel {
690 type Error = String;
691
692 fn try_from(value: &str) -> Result<Self, Self::Error> {
693 match value {
694 "BACKWARD" => Ok(CompatibilityLevel::Backward),
695 "BACKWARD_TRANSITIVE" => Ok(CompatibilityLevel::BackwardTransitive),
696 "FORWARD" => Ok(CompatibilityLevel::Forward),
697 "FORWARD_TRANSITIVE" => Ok(CompatibilityLevel::ForwardTransitive),
698 "FULL" => Ok(CompatibilityLevel::Full),
699 "FULL_TRANSITIVE" => Ok(CompatibilityLevel::FullTransitive),
700 "NONE" => Ok(CompatibilityLevel::None),
701 _ => Err(format!("invalid compatibility level: {}", value)),
702 }
703 }
704}
705
706impl fmt::Display for CompatibilityLevel {
707 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
708 match self {
709 CompatibilityLevel::Backward => write!(f, "BACKWARD"),
710 CompatibilityLevel::BackwardTransitive => write!(f, "BACKWARD_TRANSITIVE"),
711 CompatibilityLevel::Forward => write!(f, "FORWARD"),
712 CompatibilityLevel::ForwardTransitive => write!(f, "FORWARD_TRANSITIVE"),
713 CompatibilityLevel::Full => write!(f, "FULL"),
714 CompatibilityLevel::FullTransitive => write!(f, "FULL_TRANSITIVE"),
715 CompatibilityLevel::None => write!(f, "NONE"),
716 }
717 }
718}
719
720#[derive(Debug)]
722pub enum PublishError {
723 IncompatibleSchema,
727 InvalidSchema { message: String },
729 Transport(reqwest::Error),
731 Server { code: i32, message: String },
733}
734
735impl From<UnhandledError> for PublishError {
736 fn from(err: UnhandledError) -> PublishError {
737 match err {
738 UnhandledError::Transport(err) => PublishError::Transport(err),
739 UnhandledError::Api { code, message } => match code {
740 409 | 40901 => PublishError::IncompatibleSchema,
743 42201 => PublishError::InvalidSchema { message },
744 _ => PublishError::Server { code, message },
745 },
746 }
747 }
748}
749
750impl Error for PublishError {
751 fn source(&self) -> Option<&(dyn Error + 'static)> {
752 match self {
753 PublishError::IncompatibleSchema
754 | PublishError::InvalidSchema { .. }
755 | PublishError::Server { .. } => None,
756 PublishError::Transport(err) => Some(err),
757 }
758 }
759}
760
761impl fmt::Display for PublishError {
762 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
763 match self {
764 PublishError::IncompatibleSchema => write!(
767 f,
768 "schema being registered is incompatible with an earlier schema"
769 ),
770 PublishError::InvalidSchema { message } => write!(f, "{}", message),
771 PublishError::Transport(err) => write!(f, "transport: {}", err),
772 PublishError::Server { code, message } => {
773 write!(f, "server error {}: {}", code, message)
774 }
775 }
776 }
777}
778
779#[derive(Debug)]
781pub enum ListError {
782 Transport(reqwest::Error),
784 Server { code: i32, message: String },
786}
787
788impl From<UnhandledError> for ListError {
789 fn from(err: UnhandledError) -> ListError {
790 match err {
791 UnhandledError::Transport(err) => ListError::Transport(err),
792 UnhandledError::Api { code, message } => ListError::Server { code, message },
793 }
794 }
795}
796
797impl Error for ListError {
798 fn source(&self) -> Option<&(dyn Error + 'static)> {
799 match self {
800 ListError::Server { .. } => None,
801 ListError::Transport(err) => Some(err),
802 }
803 }
804}
805
806impl fmt::Display for ListError {
807 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
808 match self {
809 ListError::Transport(err) => write!(f, "transport: {}", err),
810 ListError::Server { code, message } => write!(f, "server error {}: {}", code, message),
811 }
812 }
813}
814
815#[derive(Debug)]
817pub enum DeleteError {
818 SubjectNotFound,
820 Transport(reqwest::Error),
822 Server { code: i32, message: String },
824}
825
826impl From<UnhandledError> for DeleteError {
827 fn from(err: UnhandledError) -> DeleteError {
828 match err {
829 UnhandledError::Transport(err) => DeleteError::Transport(err),
830 UnhandledError::Api { code, message } => match code {
831 40401 => DeleteError::SubjectNotFound,
832 _ => DeleteError::Server { code, message },
833 },
834 }
835 }
836}
837
838impl Error for DeleteError {
839 fn source(&self) -> Option<&(dyn Error + 'static)> {
840 match self {
841 DeleteError::SubjectNotFound | DeleteError::Server { .. } => None,
842 DeleteError::Transport(err) => Some(err),
843 }
844 }
845}
846
847impl fmt::Display for DeleteError {
848 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
849 match self {
850 DeleteError::SubjectNotFound => write!(f, "subject not found"),
851 DeleteError::Transport(err) => write!(f, "transport: {}", err),
852 DeleteError::Server { code, message } => {
853 write!(f, "server error {}: {}", code, message)
854 }
855 }
856 }
857}
858
859#[derive(Debug)]
861pub enum SetCompatibilityLevelError {
862 InvalidCompatibilityLevel,
864 Transport(reqwest::Error),
866 Server { code: i32, message: String },
868}
869
870impl From<UnhandledError> for SetCompatibilityLevelError {
871 fn from(err: UnhandledError) -> SetCompatibilityLevelError {
872 match err {
873 UnhandledError::Transport(err) => SetCompatibilityLevelError::Transport(err),
874 UnhandledError::Api { code, message } => match code {
875 42203 => SetCompatibilityLevelError::InvalidCompatibilityLevel,
876 _ => SetCompatibilityLevelError::Server { code, message },
877 },
878 }
879 }
880}
881
882impl Error for SetCompatibilityLevelError {
883 fn source(&self) -> Option<&(dyn Error + 'static)> {
884 match self {
885 SetCompatibilityLevelError::InvalidCompatibilityLevel
886 | SetCompatibilityLevelError::Server { .. } => None,
887 SetCompatibilityLevelError::Transport(err) => Some(err),
888 }
889 }
890}
891
892impl fmt::Display for SetCompatibilityLevelError {
893 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
894 match self {
895 SetCompatibilityLevelError::InvalidCompatibilityLevel => {
896 write!(f, "invalid compatibility level")
897 }
898 SetCompatibilityLevelError::Transport(err) => write!(f, "transport: {}", err),
899 SetCompatibilityLevelError::Server { code, message } => {
900 write!(f, "server error {}: {}", code, message)
901 }
902 }
903 }
904}
905
906#[derive(Debug, Deserialize)]
907struct ErrorResponse {
908 error_code: i32,
909 message: String,
910}
911
912#[derive(Debug)]
913enum UnhandledError {
914 Transport(reqwest::Error),
915 Api { code: i32, message: String },
916}
917
918impl From<reqwest::Error> for UnhandledError {
919 fn from(err: reqwest::Error) -> UnhandledError {
920 UnhandledError::Transport(err)
921 }
922}
923
924#[cfg(test)]
925mod tests {
926 #![allow(clippy::disallowed_types)]
927 use super::*;
928 use std::collections::HashMap;
929
930 fn sv(subject: &str, version: i32) -> SubjectVersion {
932 SubjectVersion {
933 subject: subject.to_string(),
934 version,
935 }
936 }
937
938 fn build_graph(
943 edges: &[(SubjectVersion, Option<SubjectVersion>)],
944 ) -> HashMap<SubjectVersion, Vec<SubjectVersion>> {
945 let mut graph: HashMap<SubjectVersion, Vec<SubjectVersion>> = HashMap::new();
946
947 for (from, to) in edges {
949 let deps = graph.entry(from.clone()).or_default();
950 if let Some(to) = to {
951 deps.push(to.clone());
952 }
953 }
954
955 graph
956 }
957
958 fn verify_edge_ordering(
964 ordered: &HashMap<&SubjectVersion, i32>,
965 edges: &[(SubjectVersion, Option<SubjectVersion>)],
966 ) {
967 for (from, to) in edges {
968 if let Some(to) = to {
969 let from_order = ordered.get(from).expect("from node should be in ordering");
970 let to_order = ordered.get(to).expect("to node should be in ordering");
971
972 assert!(
975 from_order < to_order,
976 "{:?} (order {}) depends on {:?} (order {}), so should be processed first",
977 from,
978 from_order,
979 to,
980 to_order
981 );
982 }
983 }
984 }
985
986 #[mz_ore::test]
987 fn test_topological_sort_empty_graph() {
988 let graph: HashMap<SubjectVersion, Vec<SubjectVersion>> = HashMap::new();
989
990 let ordered = topological_sort(&graph).unwrap();
991
992 assert!(ordered.is_empty());
993 }
994
995 #[mz_ore::test]
996 fn test_topological_sort_single_node() {
997 let a = sv("a", 1);
999 let graph = build_graph(&[(a.clone(), None)]);
1000
1001 let ordered = topological_sort(&graph).unwrap();
1002
1003 assert_eq!(ordered.len(), 1);
1004 assert!(ordered.contains_key(&a));
1005 }
1006
1007 #[mz_ore::test]
1008 fn test_topological_sort_linear_chain() {
1009 let a = sv("a", 1);
1011 let b = sv("b", 1);
1012 let c = sv("c", 1);
1013 let d = sv("d", 1);
1014
1015 let edges = vec![
1016 (a.clone(), Some(b.clone())),
1017 (b.clone(), Some(c.clone())),
1018 (c.clone(), Some(d.clone())),
1019 (d.clone(), None),
1020 ];
1021 let graph = build_graph(&edges);
1022
1023 let ordered = topological_sort(&graph).unwrap();
1024
1025 assert_eq!(ordered.len(), 4);
1026 verify_edge_ordering(&ordered, &edges);
1027 }
1028
1029 #[mz_ore::test]
1030 fn test_topological_sort_diamond() {
1031 let a = sv("a", 1);
1039 let b = sv("b", 1);
1040 let c = sv("c", 1);
1041 let d = sv("d", 1);
1042
1043 let edges = vec![
1044 (a.clone(), Some(b.clone())),
1045 (a.clone(), Some(c.clone())),
1046 (b.clone(), Some(d.clone())),
1047 (c.clone(), Some(d.clone())),
1048 (d.clone(), None),
1049 ];
1050 let graph = build_graph(&edges);
1051
1052 let ordered = topological_sort(&graph).unwrap();
1053
1054 assert_eq!(ordered.len(), 4);
1055 verify_edge_ordering(&ordered, &edges);
1056
1057 assert!(ordered[&a] < ordered[&d]);
1059 }
1060
1061 #[mz_ore::test]
1062 fn test_topological_sort_wide_graph() {
1063 let a = sv("a", 1);
1065 let b = sv("b", 1);
1066 let c = sv("c", 1);
1067 let d = sv("d", 1);
1068 let e = sv("e", 1);
1069 let f = sv("f", 1);
1070
1071 let edges = vec![
1072 (a.clone(), Some(b.clone())),
1073 (a.clone(), Some(c.clone())),
1074 (a.clone(), Some(d.clone())),
1075 (a.clone(), Some(e.clone())),
1076 (a.clone(), Some(f.clone())),
1077 (b.clone(), None),
1078 (c.clone(), None),
1079 (d.clone(), None),
1080 (e.clone(), None),
1081 (f.clone(), None),
1082 ];
1083 let graph = build_graph(&edges);
1084
1085 let ordered = topological_sort(&graph).unwrap();
1086
1087 assert_eq!(ordered.len(), 6);
1088 verify_edge_ordering(&ordered, &edges);
1089 }
1090
1091 #[mz_ore::test]
1092 fn test_topological_sort_complex_dag() {
1093 let a = sv("a", 1);
1109 let b = sv("b", 1);
1110 let c = sv("c", 1);
1111 let d = sv("d", 1);
1112 let e = sv("e", 1);
1113 let f = sv("f", 1);
1114 let g = sv("g", 1);
1115 let h = sv("h", 1);
1116
1117 let edges = vec![
1118 (a.clone(), Some(b.clone())),
1119 (a.clone(), Some(c.clone())),
1120 (a.clone(), Some(d.clone())),
1121 (b.clone(), Some(e.clone())),
1122 (c.clone(), Some(e.clone())),
1123 (c.clone(), Some(f.clone())),
1124 (d.clone(), Some(g.clone())),
1125 (e.clone(), Some(h.clone())),
1126 (f.clone(), Some(h.clone())),
1127 (g.clone(), Some(h.clone())),
1128 (h.clone(), None),
1129 ];
1130 let graph = build_graph(&edges);
1131
1132 let ordered = topological_sort(&graph).unwrap();
1133
1134 assert_eq!(ordered.len(), 8);
1135 verify_edge_ordering(&ordered, &edges);
1136 }
1137
1138 #[mz_ore::test]
1139 fn test_topological_sort_with_versions() {
1140 let a_v1 = sv("a", 1);
1147 let a_v2 = sv("a", 2);
1148 let b_v1 = sv("b", 1);
1149
1150 let edges = vec![
1152 (a_v2.clone(), Some(a_v1.clone())),
1153 (a_v2.clone(), Some(b_v1.clone())),
1154 (a_v1.clone(), Some(b_v1.clone())),
1155 (b_v1.clone(), None),
1156 ];
1157 let graph = build_graph(&edges);
1158
1159 let ordered = topological_sort(&graph).unwrap();
1160
1161 assert_eq!(ordered.len(), 3);
1162 verify_edge_ordering(&ordered, &edges);
1163 }
1164
1165 #[mz_ore::test]
1166 fn test_topological_sort_multi_level_diamond() {
1167 let a = sv("a", 1);
1178 let b = sv("b", 1);
1179 let c = sv("c", 1);
1180 let d = sv("d", 1);
1181 let e = sv("e", 1);
1182 let f = sv("f", 1);
1183 let g = sv("g", 1);
1184
1185 let edges = vec![
1186 (a.clone(), Some(b.clone())),
1187 (a.clone(), Some(c.clone())),
1188 (b.clone(), Some(d.clone())),
1189 (c.clone(), Some(d.clone())),
1190 (d.clone(), Some(e.clone())),
1191 (d.clone(), Some(f.clone())),
1192 (e.clone(), Some(g.clone())),
1193 (f.clone(), Some(g.clone())),
1194 (g.clone(), None),
1195 ];
1196 let graph = build_graph(&edges);
1197
1198 let ordered = topological_sort(&graph).unwrap();
1199
1200 assert_eq!(ordered.len(), 7);
1201 verify_edge_ordering(&ordered, &edges);
1202 }
1203
1204 #[mz_ore::test]
1205 fn test_topological_sort_shared_dependency_at_multiple_levels() {
1206 let a = sv("a", 1);
1217 let b = sv("b", 1);
1218 let c = sv("c", 1);
1219 let d = sv("d", 1);
1220 let e = sv("e", 1);
1221 let f = sv("f", 1);
1222
1223 let edges = vec![
1224 (a.clone(), Some(b.clone())),
1225 (a.clone(), Some(c.clone())),
1226 (a.clone(), Some(e.clone())),
1227 (b.clone(), Some(d.clone())),
1228 (c.clone(), Some(d.clone())),
1229 (d.clone(), Some(e.clone())),
1230 (e.clone(), Some(f.clone())),
1231 (f.clone(), None),
1232 ];
1233 let graph = build_graph(&edges);
1234
1235 let ordered = topological_sort(&graph).expect("no cycle");
1236
1237 assert_eq!(ordered.len(), 6);
1238 verify_edge_ordering(&ordered, &edges);
1239 }
1240
1241 #[mz_ore::test]
1242 fn test_topological_sort_lattice_structure() {
1243 let a = sv("a", 1);
1262 let b = sv("b", 1);
1263 let c = sv("c", 1);
1264 let d = sv("d", 1);
1265 let e = sv("e", 1);
1266 let f = sv("f", 1);
1267 let g = sv("g", 1);
1268 let h = sv("h", 1);
1269
1270 let edges = vec![
1271 (a.clone(), Some(b.clone())),
1272 (a.clone(), Some(c.clone())),
1273 (a.clone(), Some(d.clone())),
1274 (b.clone(), Some(e.clone())),
1275 (b.clone(), Some(f.clone())),
1276 (c.clone(), Some(e.clone())),
1277 (c.clone(), Some(f.clone())),
1278 (c.clone(), Some(g.clone())),
1279 (d.clone(), Some(f.clone())),
1280 (d.clone(), Some(g.clone())),
1281 (e.clone(), Some(h.clone())),
1282 (f.clone(), Some(h.clone())),
1283 (g.clone(), Some(h.clone())),
1284 (h.clone(), None),
1285 ];
1286 let graph = build_graph(&edges);
1287
1288 let ordered = topological_sort(&graph).unwrap();
1289
1290 assert_eq!(ordered.len(), 8);
1291 verify_edge_ordering(&ordered, &edges);
1292 }
1293
1294 #[mz_ore::test]
1295 fn test_topological_sort_binary_tree() {
1296 let a = sv("a", 1);
1303 let b = sv("b", 1);
1304 let c = sv("c", 1);
1305 let d = sv("d", 1);
1306 let e = sv("e", 1);
1307 let f = sv("f", 1);
1308 let g = sv("g", 1);
1309
1310 let edges = vec![
1311 (a.clone(), Some(b.clone())),
1312 (a.clone(), Some(c.clone())),
1313 (b.clone(), Some(d.clone())),
1314 (b.clone(), Some(e.clone())),
1315 (c.clone(), Some(f.clone())),
1316 (c.clone(), Some(g.clone())),
1317 (d.clone(), None),
1318 (e.clone(), None),
1319 (f.clone(), None),
1320 (g.clone(), None),
1321 ];
1322 let graph = build_graph(&edges);
1323
1324 let ordered = topological_sort(&graph).unwrap();
1325
1326 assert_eq!(ordered.len(), 7);
1327 verify_edge_ordering(&ordered, &edges);
1328 }
1329
1330 #[mz_ore::test]
1331 fn test_topological_sort_simple_cycle() {
1332 let a = sv("a", 1);
1334 let b = sv("b", 1);
1335
1336 let mut graph: HashMap<SubjectVersion, Vec<SubjectVersion>> = HashMap::new();
1337 graph.insert(a.clone(), vec![b.clone()]);
1338 graph.insert(b.clone(), vec![a.clone()]);
1339
1340 let sort_result = topological_sort(&graph);
1341
1342 assert!(sort_result.is_err(), "Expected sort to detect cycle");
1343 }
1344
1345 #[mz_ore::test]
1346 fn test_topological_sort_cycle_with_entry_point() {
1347 let a = sv("a", 1);
1349 let b = sv("b", 1);
1350 let c = sv("c", 1);
1351
1352 let mut graph: HashMap<SubjectVersion, Vec<SubjectVersion>> = HashMap::new();
1353 graph.insert(a.clone(), vec![b.clone()]);
1354 graph.insert(b.clone(), vec![c.clone()]);
1355 graph.insert(c.clone(), vec![b.clone()]); let sort_result = topological_sort(&graph);
1358 assert!(sort_result.is_err(), "Expected sort to detect cycle");
1359 }
1360
1361 #[mz_ore::test]
1362 fn test_topological_sort_self_reference() {
1363 let a = sv("a", 1);
1365
1366 let mut graph: HashMap<SubjectVersion, Vec<SubjectVersion>> = HashMap::new();
1367 graph.insert(a.clone(), vec![a.clone()]);
1368
1369 let sort_result = topological_sort(&graph);
1370 assert!(sort_result.is_err(), "Expected sort to detect cycle");
1371 }
1372
1373 #[mz_ore::test]
1374 fn test_topological_sort_three_node_cycle() {
1375 let a = sv("a", 1);
1377 let b = sv("b", 1);
1378 let c = sv("c", 1);
1379
1380 let mut graph: HashMap<SubjectVersion, Vec<SubjectVersion>> = HashMap::new();
1381 graph.insert(a.clone(), vec![b.clone()]);
1382 graph.insert(b.clone(), vec![c.clone()]);
1383 graph.insert(c.clone(), vec![a.clone()]);
1384
1385 let sort_result = topological_sort(&graph);
1386 assert!(sort_result.is_err(), "Expected sort to detect cycle");
1387 }
1388}