1use foreign_types::ForeignTypeRef;
53use std::marker::PhantomData;
54use std::ptr;
55
56use crate::error::ErrorStack;
57use crate::pkey::{HasPrivate, HasPublic, PKeyRef};
58use crate::{cvt, cvt_p};
59use openssl_macros::corresponds;
60
61pub struct Deriver<'a>(*mut ffi::EVP_PKEY_CTX, PhantomData<&'a ()>);
63
64unsafe impl Sync for Deriver<'_> {}
65unsafe impl Send for Deriver<'_> {}
66
67#[allow(clippy::len_without_is_empty)]
68impl<'a> Deriver<'a> {
69 #[corresponds(EVP_PKEY_derive_init)]
71 pub fn new<T>(key: &'a PKeyRef<T>) -> Result<Deriver<'a>, ErrorStack>
72 where
73 T: HasPrivate,
74 {
75 unsafe {
76 cvt_p(ffi::EVP_PKEY_CTX_new(key.as_ptr(), ptr::null_mut()))
77 .map(|p| Deriver(p, PhantomData))
78 .and_then(|ctx| cvt(ffi::EVP_PKEY_derive_init(ctx.0)).map(|_| ctx))
79 }
80 }
81
82 #[corresponds(EVP_PKEY_derive_set_peer)]
84 pub fn set_peer<T>(&mut self, key: &'a PKeyRef<T>) -> Result<(), ErrorStack>
85 where
86 T: HasPublic,
87 {
88 unsafe { cvt(ffi::EVP_PKEY_derive_set_peer(self.0, key.as_ptr())).map(|_| ()) }
89 }
90
91 #[corresponds(EVP_PKEY_derive_set_peer_ex)]
95 #[cfg(ossl300)]
96 pub fn set_peer_ex<T>(
97 &mut self,
98 key: &'a PKeyRef<T>,
99 validate_peer: bool,
100 ) -> Result<(), ErrorStack>
101 where
102 T: HasPublic,
103 {
104 unsafe {
105 cvt(ffi::EVP_PKEY_derive_set_peer_ex(
106 self.0,
107 key.as_ptr(),
108 validate_peer as i32,
109 ))
110 .map(|_| ())
111 }
112 }
113
114 #[corresponds(EVP_PKEY_derive)]
122 pub fn len(&mut self) -> Result<usize, ErrorStack> {
123 unsafe {
124 let mut len = 0;
125 cvt(ffi::EVP_PKEY_derive(self.0, ptr::null_mut(), &mut len)).map(|_| len)
126 }
127 }
128
129 #[corresponds(EVP_PKEY_derive)]
133 pub fn derive(&mut self, buf: &mut [u8]) -> Result<usize, ErrorStack> {
134 #[cfg(any(all(ossl110, not(ossl300)), libressl))]
141 {
142 let required = self.len()?;
143 if required != usize::MAX && buf.len() < required {
144 let mut temp = vec![0u8; required];
145 let mut len = required;
146 unsafe {
147 cvt(ffi::EVP_PKEY_derive(self.0, temp.as_mut_ptr(), &mut len))?;
148 }
149 let copy_len = buf.len().min(len);
150 buf[..copy_len].copy_from_slice(&temp[..copy_len]);
151 return Ok(copy_len);
152 }
153 }
154 let mut len = buf.len();
155 unsafe {
156 cvt(ffi::EVP_PKEY_derive(
157 self.0,
158 buf.as_mut_ptr() as *mut _,
159 &mut len,
160 ))
161 .map(|_| len)
162 }
163 }
164
165 pub fn derive_to_vec(&mut self) -> Result<Vec<u8>, ErrorStack> {
172 let len = self.len()?;
173 let mut buf = vec![0; len];
174 let len = self.derive(&mut buf)?;
175 buf.truncate(len);
176 Ok(buf)
177 }
178}
179
180impl Drop for Deriver<'_> {
181 fn drop(&mut self) {
182 unsafe {
183 ffi::EVP_PKEY_CTX_free(self.0);
184 }
185 }
186}
187
188#[cfg(test)]
189mod test {
190 use super::*;
191
192 use crate::ec::{EcGroup, EcKey};
193 use crate::nid::Nid;
194 use crate::pkey::PKey;
195
196 #[test]
197 fn derive_without_peer() {
198 let group = EcGroup::from_curve_name(Nid::X9_62_PRIME256V1).unwrap();
199 let ec_key = EcKey::generate(&group).unwrap();
200 let pkey = PKey::from_ec_key(ec_key).unwrap();
201 let mut deriver = Deriver::new(&pkey).unwrap();
202 deriver.derive_to_vec().unwrap_err();
203 }
204
205 #[test]
206 fn test_ec_key_derive() {
207 let group = EcGroup::from_curve_name(Nid::X9_62_PRIME256V1).unwrap();
208 let ec_key = EcKey::generate(&group).unwrap();
209 let ec_key2 = EcKey::generate(&group).unwrap();
210 let pkey = PKey::from_ec_key(ec_key).unwrap();
211 let pkey2 = PKey::from_ec_key(ec_key2).unwrap();
212 let mut deriver = Deriver::new(&pkey).unwrap();
213 deriver.set_peer(&pkey2).unwrap();
214 let shared = deriver.derive_to_vec().unwrap();
215 assert!(!shared.is_empty());
216 }
217
218 #[test]
219 #[cfg(any(ossl111, libressl370))]
220 fn derive_undersized_buffer() {
221 let pkey = PKey::generate_x25519().unwrap();
227 let pkey2 = PKey::generate_x25519().unwrap();
228 let mut deriver = Deriver::new(&pkey).unwrap();
229 deriver.set_peer(&pkey2).unwrap();
230 let mut buf = [0u8; 4];
231 let result = deriver.derive(&mut buf);
232 #[cfg(any(all(ossl110, not(ossl300)), libressl))]
233 assert_eq!(result.unwrap(), 4);
234 #[cfg(all(ossl300, not(libressl)))]
235 assert!(result.is_err());
236 }
237
238 #[test]
239 #[cfg(ossl300)]
240 fn test_ec_key_derive_ex() {
241 let group = EcGroup::from_curve_name(Nid::X9_62_PRIME256V1).unwrap();
242 let ec_key = EcKey::generate(&group).unwrap();
243 let ec_key2 = EcKey::generate(&group).unwrap();
244 let pkey = PKey::from_ec_key(ec_key).unwrap();
245 let pkey2 = PKey::from_ec_key(ec_key2).unwrap();
246 let mut deriver = Deriver::new(&pkey).unwrap();
247 deriver.set_peer_ex(&pkey2, true).unwrap();
248 let shared = deriver.derive_to_vec().unwrap();
249 assert!(!shared.is_empty());
250 }
251}