1use std::convert::TryFrom;
15use std::error::Error;
16use std::ffi::{CStr, CString};
17use std::mem::{self, ManuallyDrop};
18use std::net::{SocketAddr, ToSocketAddrs};
19use std::os::raw::{c_char, c_void};
20use std::ptr::{self, NonNull};
21use std::string::ToString;
22use std::sync::Arc;
23use std::{io, slice};
24
25use libc::addrinfo;
26use rdkafka_sys as rdsys;
27use rdkafka_sys::types::*;
28
29use crate::config::{ClientConfig, NativeClientConfig, RDKafkaLogLevel};
30use crate::consumer::RebalanceProtocol;
31use crate::error::{IsError, KafkaError, KafkaResult, RDKafkaError};
32use crate::groups::GroupList;
33use crate::log::{debug, error, info, trace, warn};
34use crate::metadata::Metadata;
35use crate::statistics::Statistics;
36use crate::util::{self, ErrBuf, KafkaDrop, NativePtr, Timeout};
37
38pub trait ClientContext: Send + Sync {
52 const ENABLE_REFRESH_OAUTH_TOKEN: bool = false;
61
62 fn log(&self, level: RDKafkaLogLevel, fac: &str, log_message: &str) {
70 match level {
71 RDKafkaLogLevel::Emerg
72 | RDKafkaLogLevel::Alert
73 | RDKafkaLogLevel::Critical
74 | RDKafkaLogLevel::Error => {
75 error!(target: "librdkafka", "librdkafka: {} {}", fac, log_message)
76 }
77 RDKafkaLogLevel::Warning => {
78 warn!(target: "librdkafka", "librdkafka: {} {}", fac, log_message)
79 }
80 RDKafkaLogLevel::Notice => {
81 info!(target: "librdkafka", "librdkafka: {} {}", fac, log_message)
82 }
83 RDKafkaLogLevel::Info => {
84 info!(target: "librdkafka", "librdkafka: {} {}", fac, log_message)
85 }
86 RDKafkaLogLevel::Debug => {
87 debug!(target: "librdkafka", "librdkafka: {} {}", fac, log_message)
88 }
89 }
90 }
91
92 fn stats(&self, statistics: Statistics) {
97 info!("Client stats: {:?}", statistics);
98 }
99
100 fn stats_raw(&self, statistics: &[u8]) {
107 match serde_json::from_slice(&statistics) {
108 Ok(stats) => self.stats(stats),
109 Err(e) => error!("Could not parse statistics JSON: {}", e),
110 }
111 }
112
113 fn error(&self, error: KafkaError, reason: &str) {
117 error!("librdkafka: {}: {}", error, reason);
118 }
119
120 fn resolve_broker_addr(&self, host: &str, port: u16) -> Result<Vec<SocketAddr>, io::Error> {
127 (host, port).to_socket_addrs().map(|addrs| addrs.collect())
128 }
129
130 fn generate_oauth_token(
142 &self,
143 _oauthbearer_config: Option<&str>,
144 ) -> Result<OAuthToken, Box<dyn Error>> {
145 Err("Default implementation of generate_oauth_token must be overridden".into())
146 }
147
148 }
153
154#[derive(Clone, Debug, Default)]
159pub struct DefaultClientContext;
160
161impl ClientContext for DefaultClientContext {}
162
163pub struct NativeClient {
171 ptr: NonNull<RDKafka>,
172}
173
174unsafe impl Sync for NativeClient {}
176unsafe impl Send for NativeClient {}
177
178impl NativeClient {
179 pub(crate) unsafe fn from_ptr(ptr: *mut RDKafka) -> NativeClient {
181 NativeClient {
182 ptr: NonNull::new(ptr).unwrap(),
183 }
184 }
185
186 pub fn ptr(&self) -> *mut RDKafka {
188 self.ptr.as_ptr()
189 }
190
191 pub(crate) fn rebalance_protocol(&self) -> RebalanceProtocol {
192 let protocol = unsafe { rdsys::rd_kafka_rebalance_protocol(self.ptr()) };
193 if protocol.is_null() {
194 RebalanceProtocol::None
195 } else {
196 let protocol = unsafe { CStr::from_ptr(protocol) };
197 match protocol.to_bytes() {
198 b"NONE" => RebalanceProtocol::None,
199 b"EAGER" => RebalanceProtocol::Eager,
200 b"COOPERATIVE" => RebalanceProtocol::Cooperative,
201 _ => unreachable!(),
202 }
203 }
204 }
205}
206
207pub struct Client<C: ClientContext + 'static = DefaultClientContext> {
219 native: NativeClient,
220 context: Arc<C>,
221}
222
223impl<C: ClientContext + 'static> Client<C> {
224 pub fn new(
226 config: &ClientConfig,
227 native_config: NativeClientConfig,
228 rd_kafka_type: RDKafkaType,
229 context: C,
230 ) -> KafkaResult<Client<C>> {
231 let mut err_buf = ErrBuf::new();
232 let context = Arc::new(context);
233 unsafe {
234 rdsys::rd_kafka_conf_set_opaque(
235 native_config.ptr(),
236 Arc::as_ptr(&context) as *mut c_void,
237 )
238 };
239 unsafe {
240 rdsys::rd_kafka_conf_set_log_cb(native_config.ptr(), Some(native_log_cb::<C>));
241 rdsys::rd_kafka_conf_set_stats_cb(native_config.ptr(), Some(native_stats_cb::<C>));
242 rdsys::rd_kafka_conf_set_error_cb(native_config.ptr(), Some(native_error_cb::<C>));
243 rdsys::rd_kafka_conf_set_resolve_cb(native_config.ptr(), Some(native_resolve_cb::<C>));
244 }
245 if C::ENABLE_REFRESH_OAUTH_TOKEN {
246 unsafe {
247 rdsys::rd_kafka_conf_set_oauthbearer_token_refresh_cb(
248 native_config.ptr(),
249 Some(native_oauth_refresh_cb::<C>),
250 );
251 rdkafka_sys::rd_kafka_conf_enable_sasl_queue(native_config.ptr(), 1);
252 };
253 }
254
255 let client_ptr = unsafe {
256 let native_config = ManuallyDrop::new(native_config);
257 rdsys::rd_kafka_new(
258 rd_kafka_type,
259 native_config.ptr(),
260 err_buf.as_mut_ptr(),
261 err_buf.capacity(),
262 )
263 };
264 trace!("Create new librdkafka client {:p}", client_ptr);
265
266 if client_ptr.is_null() {
267 return Err(KafkaError::ClientCreation(err_buf.to_string()));
268 }
269
270 unsafe { rdsys::rd_kafka_set_log_level(client_ptr, config.log_level as i32) };
271
272 let sasl_mechanism = config
273 .get("sasl.mechanisms")
274 .or_else(|| config.get("sasl.mechanism"));
275 if sasl_mechanism.map_or(false, |m| m.eq_ignore_ascii_case("OAUTHBEARER")) {
276 let ret = unsafe {
277 RDKafkaError::from_ptr(rdsys::rd_kafka_sasl_background_callbacks_enable(client_ptr))
278 };
279 if ret.is_error() {
280 return Err(KafkaError::OAuthConfig(ret));
281 }
282 }
283
284 Ok(Client {
285 native: unsafe { NativeClient::from_ptr(client_ptr) },
286 context,
287 })
288 }
289
290 pub fn native_client(&self) -> &NativeClient {
292 &self.native
293 }
294
295 pub fn native_ptr(&self) -> *mut RDKafka {
297 self.native.ptr()
298 }
299
300 pub fn context(&self) -> &Arc<C> {
302 &self.context
303 }
304
305 pub fn fetch_metadata<T: Into<Timeout>>(
308 &self,
309 topic: Option<&str>,
310 timeout: T,
311 ) -> KafkaResult<Metadata> {
312 let mut metadata_ptr: *const RDKafkaMetadata = ptr::null_mut();
313 let (flag, native_topic) = if let Some(topic_name) = topic {
314 (0, Some(self.native_topic(topic_name)?))
315 } else {
316 (1, None)
317 };
318 trace!("Starting metadata fetch");
319 let ret = unsafe {
320 rdsys::rd_kafka_metadata(
321 self.native_ptr(),
322 flag,
323 native_topic.map(|t| t.ptr()).unwrap_or_else(ptr::null_mut),
324 &mut metadata_ptr as *mut *const RDKafkaMetadata,
325 timeout.into().as_millis(),
326 )
327 };
328 trace!("Metadata fetch completed");
329 if ret.is_error() {
330 return Err(KafkaError::MetadataFetch(ret.into()));
331 }
332
333 Ok(unsafe { Metadata::from_ptr(metadata_ptr) })
334 }
335
336 pub fn fetch_watermarks<T: Into<Timeout>>(
338 &self,
339 topic: &str,
340 partition: i32,
341 timeout: T,
342 ) -> KafkaResult<(i64, i64)> {
343 let mut low = -1;
344 let mut high = -1;
345 let topic_c = CString::new(topic.to_string())?;
346 let ret = unsafe {
347 rdsys::rd_kafka_query_watermark_offsets(
348 self.native_ptr(),
349 topic_c.as_ptr(),
350 partition,
351 &mut low as *mut i64,
352 &mut high as *mut i64,
353 timeout.into().as_millis(),
354 )
355 };
356 if ret.is_error() {
357 return Err(KafkaError::MetadataFetch(ret.into()));
358 }
359 Ok((low, high))
360 }
361
362 pub fn fetch_cluster_id<T: Into<Timeout>>(&self, timeout: T) -> Option<String> {
364 let cluster_id =
365 unsafe { rdsys::rd_kafka_clusterid(self.native_ptr(), timeout.into().as_millis()) };
366 if cluster_id.is_null() {
367 return None;
368 }
369 let buf = unsafe { CStr::from_ptr(cluster_id).to_bytes() };
370 String::from_utf8(buf.to_vec()).ok()
371 }
372
373 pub fn fetch_group_list<T: Into<Timeout>>(
376 &self,
377 group: Option<&str>,
378 timeout: T,
379 ) -> KafkaResult<GroupList> {
380 let group_c = CString::new(group.map_or("".to_string(), ToString::to_string))?;
382 let group_c_ptr = if group.is_some() {
383 group_c.as_ptr()
384 } else {
385 ptr::null_mut()
386 };
387 let mut group_list_ptr: *const RDKafkaGroupList = ptr::null_mut();
388 trace!("Starting group list fetch");
389 let ret = unsafe {
390 rdsys::rd_kafka_list_groups(
391 self.native_ptr(),
392 group_c_ptr,
393 &mut group_list_ptr as *mut *const RDKafkaGroupList,
394 timeout.into().as_millis(),
395 )
396 };
397 trace!("Group list fetch completed");
398 if ret.is_error() {
399 return Err(KafkaError::GroupListFetch(ret.into()));
400 }
401
402 Ok(unsafe { GroupList::from_ptr(group_list_ptr) })
403 }
404
405 pub fn fatal_error(&self) -> Option<(RDKafkaErrorCode, String)> {
411 let mut err_buf = ErrBuf::new();
412 let code = unsafe {
413 rdsys::rd_kafka_fatal_error(self.native_ptr(), err_buf.as_mut_ptr(), err_buf.capacity())
414 };
415 if code == RDKafkaRespErr::RD_KAFKA_RESP_ERR_NO_ERROR {
416 None
417 } else {
418 Some((code.into(), err_buf.to_string()))
419 }
420 }
421
422 pub(crate) fn native_topic(&self, topic: &str) -> KafkaResult<NativeTopic> {
425 let topic_c = CString::new(topic.to_string())?;
426 Ok(unsafe {
427 NativeTopic::from_ptr(rdsys::rd_kafka_topic_new(
428 self.native_ptr(),
429 topic_c.as_ptr(),
430 ptr::null_mut(),
431 ))
432 .unwrap()
433 })
434 }
435
436 pub(crate) fn new_native_queue(&self) -> NativeQueue {
439 unsafe { NativeQueue::from_ptr(rdsys::rd_kafka_queue_new(self.native_ptr())).unwrap() }
440 }
441
442 pub(crate) fn consumer_queue(&self) -> Option<NativeQueue> {
443 unsafe { NativeQueue::from_ptr(rdsys::rd_kafka_queue_get_consumer(self.native_ptr())) }
444 }
445}
446
447impl<C: ClientContext + 'static> Drop for Client<C> {
448 fn drop(&mut self) {
449 let context = Arc::clone(&self.context);
457 let ptr = self.native_ptr() as usize;
458 std::thread::spawn(move || {
459 unsafe { rdsys::rd_kafka_destroy(ptr as *mut RDKafka) }
460 drop(context);
464 });
465 }
466}
467
468pub(crate) type NativeTopic = NativePtr<RDKafkaTopic>;
469
470unsafe impl KafkaDrop for RDKafkaTopic {
471 const TYPE: &'static str = "native topic";
472 const DROP: unsafe extern "C" fn(*mut Self) = rdsys::rd_kafka_topic_destroy;
473}
474
475unsafe impl Send for NativeTopic {}
476unsafe impl Sync for NativeTopic {}
477
478pub(crate) type NativeQueue = NativePtr<RDKafkaQueue>;
479
480unsafe impl KafkaDrop for RDKafkaQueue {
481 const TYPE: &'static str = "queue";
482 const DROP: unsafe extern "C" fn(*mut Self) = rdsys::rd_kafka_queue_destroy;
483}
484
485unsafe impl Sync for NativeQueue {}
487unsafe impl Send for NativeQueue {}
488
489impl NativeQueue {
490 pub fn poll<T: Into<Timeout>>(&self, t: T) -> *mut RDKafkaEvent {
491 unsafe { rdsys::rd_kafka_queue_poll(self.ptr(), t.into().as_millis()) }
492 }
493}
494
495pub(crate) unsafe extern "C" fn native_log_cb<C: ClientContext>(
496 client: *const RDKafka,
497 level: i32,
498 fac: *const c_char,
499 buf: *const c_char,
500) {
501 let fac = CStr::from_ptr(fac).to_string_lossy();
502 let log_message = CStr::from_ptr(buf).to_string_lossy();
503
504 let context = &mut *(rdsys::rd_kafka_opaque(client) as *mut C);
505 context.log(
506 RDKafkaLogLevel::from_int(level),
507 fac.trim(),
508 log_message.trim(),
509 );
510}
511
512pub(crate) unsafe extern "C" fn native_stats_cb<C: ClientContext>(
513 _conf: *mut RDKafka,
514 json: *mut c_char,
515 json_len: usize,
516 opaque: *mut c_void,
517) -> i32 {
518 let context = &mut *(opaque as *mut C);
519 context.stats_raw(slice::from_raw_parts(json as *mut u8, json_len));
520 0 }
522
523pub(crate) unsafe extern "C" fn native_error_cb<C: ClientContext>(
524 _client: *mut RDKafka,
525 err: i32,
526 reason: *const c_char,
527 opaque: *mut c_void,
528) {
529 let err = RDKafkaRespErr::try_from(err).expect("global error not an rd_kafka_resp_err_t");
530 let error = KafkaError::Global(err.into());
531 let reason = CStr::from_ptr(reason).to_string_lossy();
532
533 let context = &mut *(opaque as *mut C);
534 context.error(error, reason.trim());
535}
536
537pub(crate) unsafe extern "C" fn native_resolve_cb<C: ClientContext>(
538 node: *const c_char,
539 service: *const c_char,
540 hints: *const addrinfo,
541 res: *mut *mut addrinfo,
542 opaque: *mut c_void,
543) -> i32 {
544 if node.is_null() {
545 assert!(service.is_null());
548 assert!(hints.is_null());
549 libc::free(*res as *mut libc::c_void);
550 return 0; }
552
553 let Ok(host) = CStr::from_ptr(node).to_str() else {
555 return libc::EAI_FAIL;
556 };
557 let Ok(port) = CStr::from_ptr(service).to_str() else {
558 return libc::EAI_FAIL;
559 };
560 let Ok(port) = port.parse() else {
561 return libc::EAI_SERVICE;
562 };
563
564 debug!("resolving {host}:{port}");
565
566 let context = &mut *(opaque as *mut C);
568 match context.resolve_broker_addr(host, port) {
569 Ok(addrs) => {
570 debug!("dns resolution succeeded for {host}:{port}: {addrs:?}");
571
572 #[repr(C)]
592 union CSocketAddr {
593 in4: libc::sockaddr_in,
594 in6: libc::sockaddr_in6,
595 }
596
597 #[repr(C)]
598 struct AddrInfoBuf {
599 addr_info: libc::addrinfo,
600 socket_addr: CSocketAddr,
601 }
602
603 let out = libc::calloc(addrs.len(), mem::size_of::<AddrInfoBuf>());
604 let out = out as *mut AddrInfoBuf;
605
606 for (i, addr) in addrs.iter().enumerate() {
607 let ptr = out.add(i);
608 (*ptr).addr_info = libc::addrinfo {
609 ai_addr: &mut (*ptr).socket_addr as *mut _ as *mut libc::sockaddr,
610 ai_addrlen: match addr {
611 SocketAddr::V4(_) => mem::size_of::<libc::sockaddr_in>() as libc::socklen_t,
612 SocketAddr::V6(_) => {
613 mem::size_of::<libc::sockaddr_in6>() as libc::socklen_t
614 }
615 },
616 ai_canonname: ptr::null_mut(),
617 ai_family: match addr {
618 SocketAddr::V4(_) => libc::AF_INET,
619 SocketAddr::V6(_) => libc::AF_INET6,
620 },
621 ai_flags: 0,
622 ai_protocol: libc::IPPROTO_TCP,
623 ai_socktype: libc::SOCK_STREAM,
624 ai_next: if i < (addrs.len() - 1) {
625 out.add(i + 1) as *mut libc::addrinfo
626 } else {
627 ptr::null_mut()
628 },
629 };
630 match addr {
631 SocketAddr::V4(addr) => {
632 (*ptr).socket_addr.in4.sin_family = libc::AF_INET as libc::sa_family_t;
633 (*ptr).socket_addr.in4.sin_port = addr.port().to_be();
634 (*ptr).socket_addr.in4.sin_addr.s_addr =
635 u32::from_ne_bytes(addr.ip().octets());
636 }
637 SocketAddr::V6(addr) => {
638 (*ptr).socket_addr.in6.sin6_family = libc::AF_INET6 as libc::sa_family_t;
639 (*ptr).socket_addr.in6.sin6_port = addr.port().to_be();
640 (*ptr).socket_addr.in6.sin6_addr.s6_addr = addr.ip().octets();
641 }
642 };
643 }
644
645 *res = out as *mut libc::addrinfo;
646
647 0
648 }
649 Err(e) => {
650 debug!("dns resolution failed for {host}:{port}: {e}");
651
652 let message = e.to_string();
657 for code in [libc::EAI_NODATA, libc::EAI_NONAME, libc::EAI_AGAIN] {
658 if let Ok(code_str) = CStr::from_ptr(libc::gai_strerror(code)).to_str() {
659 if message.ends_with(code_str) {
660 return code;
661 }
662 }
663 }
664 libc::EAI_FAIL
665 }
666 }
667}
668
669pub struct OAuthToken {
677 pub token: String,
679 pub principal_name: String,
681 pub lifetime_ms: i64,
683}
684
685pub(crate) unsafe extern "C" fn native_oauth_refresh_cb<C: ClientContext>(
686 client: *mut RDKafka,
687 oauthbearer_config: *const c_char,
688 opaque: *mut c_void,
689) {
690 let res: Result<_, Box<dyn Error>> = (|| {
691 let context = &mut *(opaque as *mut C);
692 let oauthbearer_config = match oauthbearer_config.is_null() {
693 true => None,
694 false => Some(util::cstr_to_owned(oauthbearer_config)),
695 };
696 let token_info = context.generate_oauth_token(oauthbearer_config.as_deref())?;
697 let token = CString::new(token_info.token)?;
698 let principal_name = CString::new(token_info.principal_name)?;
699 Ok((token, principal_name, token_info.lifetime_ms))
700 })();
701 match res {
702 Ok((token, principal_name, lifetime_ms)) => {
703 let mut err_buf = ErrBuf::new();
704 let code = rdkafka_sys::rd_kafka_oauthbearer_set_token(
705 client,
706 token.as_ptr(),
707 lifetime_ms,
708 principal_name.as_ptr(),
709 ptr::null_mut(),
710 0,
711 err_buf.as_mut_ptr(),
712 err_buf.capacity(),
713 );
714 if code == RDKafkaRespErr::RD_KAFKA_RESP_ERR_NO_ERROR {
715 debug!("successfully set refreshed OAuth token");
716 } else {
717 debug!(
718 "failed to set refreshed OAuth token (code {:?}): {}",
719 code, err_buf
720 );
721 rdkafka_sys::rd_kafka_oauthbearer_set_token_failure(client, err_buf.as_mut_ptr());
722 }
723 }
724 Err(e) => {
725 debug!("failed to refresh OAuth token: {}", e);
726 let message = match CString::new(e.to_string()) {
727 Ok(message) => message,
728 Err(e) => {
729 error!("error message generated while refreshing OAuth token has embedded null character: {}", e);
730 CString::new("error while refreshing OAuth token has embedded null character")
731 .expect("known to be a valid CString")
732 }
733 };
734 rdkafka_sys::rd_kafka_oauthbearer_set_token_failure(client, message.as_ptr());
735 }
736 }
737}
738
739#[cfg(test)]
740mod tests {
741 use super::*;
745 use crate::config::ClientConfig;
746
747 #[test]
748 fn test_client() {
749 let config = ClientConfig::new();
750 let native_config = config.create_native_config().unwrap();
751 let client = Client::new(
752 &config,
753 native_config,
754 RDKafkaType::RD_KAFKA_PRODUCER,
755 DefaultClientContext,
756 )
757 .unwrap();
758 assert!(!client.native_ptr().is_null());
759 }
760}