1use std::{marker::PhantomData, ptr};
43
44use crate::error::ErrorStack;
45use crate::hash::MessageDigest;
46use crate::pkey::{HasPrivate, HasPublic, PKeyRef};
47use crate::rsa::Padding;
48use crate::{cvt, cvt_p};
49use foreign_types::ForeignTypeRef;
50use openssl_macros::corresponds;
51
52pub struct Encrypter<'a> {
54 pctx: *mut ffi::EVP_PKEY_CTX,
55 _p: PhantomData<&'a ()>,
56}
57
58unsafe impl Sync for Encrypter<'_> {}
59unsafe impl Send for Encrypter<'_> {}
60
61impl Drop for Encrypter<'_> {
62 fn drop(&mut self) {
63 unsafe {
64 ffi::EVP_PKEY_CTX_free(self.pctx);
65 }
66 }
67}
68
69impl<'a> Encrypter<'a> {
70 #[corresponds(EVP_PKEY_encrypt_init)]
72 pub fn new<T>(pkey: &'a PKeyRef<T>) -> Result<Encrypter<'a>, ErrorStack>
73 where
74 T: HasPublic,
75 {
76 unsafe {
77 ffi::init();
78
79 let pctx = cvt_p(ffi::EVP_PKEY_CTX_new(pkey.as_ptr(), ptr::null_mut()))?;
80 let r = ffi::EVP_PKEY_encrypt_init(pctx);
81 if r != 1 {
82 ffi::EVP_PKEY_CTX_free(pctx);
83 return Err(ErrorStack::get());
84 }
85
86 Ok(Encrypter {
87 pctx,
88 _p: PhantomData,
89 })
90 }
91 }
92
93 pub fn rsa_padding(&self) -> Result<Padding, ErrorStack> {
99 unsafe {
100 let mut pad = 0;
101 cvt(ffi::EVP_PKEY_CTX_get_rsa_padding(self.pctx, &mut pad))
102 .map(|_| Padding::from_raw(pad))
103 }
104 }
105
106 #[corresponds(EVP_PKEY_CTX_set_rsa_padding)]
110 pub fn set_rsa_padding(&mut self, padding: Padding) -> Result<(), ErrorStack> {
111 unsafe {
112 cvt(ffi::EVP_PKEY_CTX_set_rsa_padding(
113 self.pctx,
114 padding.as_raw(),
115 ))
116 .map(|_| ())
117 }
118 }
119
120 #[corresponds(EVP_PKEY_CTX_set_rsa_mgf1_md)]
124 pub fn set_rsa_mgf1_md(&mut self, md: MessageDigest) -> Result<(), ErrorStack> {
125 unsafe {
126 cvt(ffi::EVP_PKEY_CTX_set_rsa_mgf1_md(
127 self.pctx,
128 md.as_ptr() as *mut _,
129 ))
130 .map(|_| ())
131 }
132 }
133
134 #[corresponds(EVP_PKEY_CTX_set_rsa_oaep_md)]
138 pub fn set_rsa_oaep_md(&mut self, md: MessageDigest) -> Result<(), ErrorStack> {
139 unsafe {
140 cvt(ffi::EVP_PKEY_CTX_set_rsa_oaep_md(
141 self.pctx,
142 md.as_ptr() as *mut _,
143 ))
144 .map(|_| ())
145 }
146 }
147
148 #[corresponds(EVP_PKEY_CTX_set0_rsa_oaep_label)]
152 pub fn set_rsa_oaep_label(&mut self, label: &[u8]) -> Result<(), ErrorStack> {
153 unsafe {
154 let p = cvt_p(ffi::OPENSSL_malloc(label.len() as _))?;
155 ptr::copy_nonoverlapping(label.as_ptr(), p as *mut u8, label.len());
156
157 cvt(ffi::EVP_PKEY_CTX_set0_rsa_oaep_label(
158 self.pctx,
159 p.cast(),
160 label.len() as _,
161 ))
162 .map(|_| ())
163 .map_err(|e| {
164 ffi::OPENSSL_free(p);
165 e
166 })
167 }
168 }
169
170 #[corresponds(EVP_PKEY_encrypt)]
201 pub fn encrypt(&self, from: &[u8], to: &mut [u8]) -> Result<usize, ErrorStack> {
202 let mut written = to.len();
203 unsafe {
204 cvt(ffi::EVP_PKEY_encrypt(
205 self.pctx,
206 to.as_mut_ptr(),
207 &mut written,
208 from.as_ptr(),
209 from.len(),
210 ))?;
211 }
212
213 Ok(written)
214 }
215
216 #[corresponds(EVP_PKEY_encrypt)]
220 pub fn encrypt_len(&self, from: &[u8]) -> Result<usize, ErrorStack> {
221 let mut written = 0;
222 unsafe {
223 cvt(ffi::EVP_PKEY_encrypt(
224 self.pctx,
225 ptr::null_mut(),
226 &mut written,
227 from.as_ptr(),
228 from.len(),
229 ))?;
230 }
231
232 Ok(written)
233 }
234}
235
236pub struct Decrypter<'a> {
238 pctx: *mut ffi::EVP_PKEY_CTX,
239 _p: PhantomData<&'a ()>,
240}
241
242unsafe impl Sync for Decrypter<'_> {}
243unsafe impl Send for Decrypter<'_> {}
244
245impl Drop for Decrypter<'_> {
246 fn drop(&mut self) {
247 unsafe {
248 ffi::EVP_PKEY_CTX_free(self.pctx);
249 }
250 }
251}
252
253impl<'a> Decrypter<'a> {
254 #[corresponds(EVP_PKEY_decrypt_init)]
256 pub fn new<T>(pkey: &'a PKeyRef<T>) -> Result<Decrypter<'a>, ErrorStack>
257 where
258 T: HasPrivate,
259 {
260 unsafe {
261 ffi::init();
262
263 let pctx = cvt_p(ffi::EVP_PKEY_CTX_new(pkey.as_ptr(), ptr::null_mut()))?;
264 let r = ffi::EVP_PKEY_decrypt_init(pctx);
265 if r != 1 {
266 ffi::EVP_PKEY_CTX_free(pctx);
267 return Err(ErrorStack::get());
268 }
269
270 Ok(Decrypter {
271 pctx,
272 _p: PhantomData,
273 })
274 }
275 }
276
277 pub fn rsa_padding(&self) -> Result<Padding, ErrorStack> {
283 unsafe {
284 let mut pad = 0;
285 cvt(ffi::EVP_PKEY_CTX_get_rsa_padding(self.pctx, &mut pad))
286 .map(|_| Padding::from_raw(pad))
287 }
288 }
289
290 #[corresponds(EVP_PKEY_CTX_set_rsa_padding)]
294 pub fn set_rsa_padding(&mut self, padding: Padding) -> Result<(), ErrorStack> {
295 unsafe {
296 cvt(ffi::EVP_PKEY_CTX_set_rsa_padding(
297 self.pctx,
298 padding.as_raw(),
299 ))
300 .map(|_| ())
301 }
302 }
303
304 #[corresponds(EVP_PKEY_CTX_set_rsa_mgf1_md)]
308 pub fn set_rsa_mgf1_md(&mut self, md: MessageDigest) -> Result<(), ErrorStack> {
309 unsafe {
310 cvt(ffi::EVP_PKEY_CTX_set_rsa_mgf1_md(
311 self.pctx,
312 md.as_ptr() as *mut _,
313 ))
314 .map(|_| ())
315 }
316 }
317
318 #[corresponds(EVP_PKEY_CTX_set_rsa_oaep_md)]
322 pub fn set_rsa_oaep_md(&mut self, md: MessageDigest) -> Result<(), ErrorStack> {
323 unsafe {
324 cvt(ffi::EVP_PKEY_CTX_set_rsa_oaep_md(
325 self.pctx,
326 md.as_ptr() as *mut _,
327 ))
328 .map(|_| ())
329 }
330 }
331
332 #[corresponds(EVP_PKEY_CTX_set0_rsa_oaep_label)]
336 pub fn set_rsa_oaep_label(&mut self, label: &[u8]) -> Result<(), ErrorStack> {
337 unsafe {
338 let p = cvt_p(ffi::OPENSSL_malloc(label.len() as _))?;
339 ptr::copy_nonoverlapping(label.as_ptr(), p as *mut u8, label.len());
340
341 cvt(ffi::EVP_PKEY_CTX_set0_rsa_oaep_label(
342 self.pctx,
343 p.cast(),
344 label.len() as _,
345 ))
346 .map(|_| ())
347 .map_err(|e| {
348 ffi::OPENSSL_free(p);
349 e
350 })
351 }
352 }
353
354 #[corresponds(EVP_PKEY_decrypt)]
400 pub fn decrypt(&self, from: &[u8], to: &mut [u8]) -> Result<usize, ErrorStack> {
401 let mut written = to.len();
402 unsafe {
403 cvt(ffi::EVP_PKEY_decrypt(
404 self.pctx,
405 to.as_mut_ptr(),
406 &mut written,
407 from.as_ptr(),
408 from.len(),
409 ))?;
410 }
411
412 Ok(written)
413 }
414
415 #[corresponds(EVP_PKEY_decrypt)]
419 pub fn decrypt_len(&self, from: &[u8]) -> Result<usize, ErrorStack> {
420 let mut written = 0;
421 unsafe {
422 cvt(ffi::EVP_PKEY_decrypt(
423 self.pctx,
424 ptr::null_mut(),
425 &mut written,
426 from.as_ptr(),
427 from.len(),
428 ))?;
429 }
430
431 Ok(written)
432 }
433}
434
435#[cfg(test)]
436mod test {
437 use hex::FromHex;
438
439 use crate::encrypt::{Decrypter, Encrypter};
440 use crate::hash::MessageDigest;
441 use crate::pkey::PKey;
442 use crate::rsa::{Padding, Rsa};
443
444 const INPUT: &str =
445 "65794a68624763694f694a53557a49314e694a392e65794a7063334d694f694a71623255694c41304b49434a6c\
446 654841694f6a457a4d4441344d546b7a4f44417344516f67496d6830644841364c79396c654746746347786c4c\
447 6d4e76625339706331397962323930496a7030636e566c6651";
448
449 #[test]
450 fn rsa_encrypt_decrypt() {
451 let key = include_bytes!("../test/rsa.pem");
452 let private_key = Rsa::private_key_from_pem(key).unwrap();
453 let pkey = PKey::from_rsa(private_key).unwrap();
454
455 let mut encrypter = Encrypter::new(&pkey).unwrap();
456 encrypter.set_rsa_padding(Padding::PKCS1).unwrap();
457 let input = Vec::from_hex(INPUT).unwrap();
458 let buffer_len = encrypter.encrypt_len(&input).unwrap();
459 let mut encoded = vec![0u8; buffer_len];
460 let encoded_len = encrypter.encrypt(&input, &mut encoded).unwrap();
461 let encoded = &encoded[..encoded_len];
462
463 let mut decrypter = Decrypter::new(&pkey).unwrap();
464 decrypter.set_rsa_padding(Padding::PKCS1).unwrap();
465 let buffer_len = decrypter.decrypt_len(encoded).unwrap();
466 let mut decoded = vec![0u8; buffer_len];
467 let decoded_len = decrypter.decrypt(encoded, &mut decoded).unwrap();
468 let decoded = &decoded[..decoded_len];
469
470 assert_eq!(decoded, &*input);
471 }
472
473 #[test]
474 fn rsa_encrypt_decrypt_with_sha256() {
475 let key = include_bytes!("../test/rsa.pem");
476 let private_key = Rsa::private_key_from_pem(key).unwrap();
477 let pkey = PKey::from_rsa(private_key).unwrap();
478
479 let md = MessageDigest::sha256();
480
481 let mut encrypter = Encrypter::new(&pkey).unwrap();
482 encrypter.set_rsa_padding(Padding::PKCS1_OAEP).unwrap();
483 encrypter.set_rsa_oaep_md(md).unwrap();
484 encrypter.set_rsa_mgf1_md(md).unwrap();
485 let input = Vec::from_hex(INPUT).unwrap();
486 let buffer_len = encrypter.encrypt_len(&input).unwrap();
487 let mut encoded = vec![0u8; buffer_len];
488 let encoded_len = encrypter.encrypt(&input, &mut encoded).unwrap();
489 let encoded = &encoded[..encoded_len];
490
491 let mut decrypter = Decrypter::new(&pkey).unwrap();
492 decrypter.set_rsa_padding(Padding::PKCS1_OAEP).unwrap();
493 decrypter.set_rsa_oaep_md(md).unwrap();
494 decrypter.set_rsa_mgf1_md(md).unwrap();
495 let buffer_len = decrypter.decrypt_len(encoded).unwrap();
496 let mut decoded = vec![0u8; buffer_len];
497 let decoded_len = decrypter.decrypt(encoded, &mut decoded).unwrap();
498 let decoded = &decoded[..decoded_len];
499
500 assert_eq!(decoded, &*input);
501 }
502
503 #[test]
504 fn rsa_encrypt_decrypt_oaep_label() {
505 let key = include_bytes!("../test/rsa.pem");
506 let private_key = Rsa::private_key_from_pem(key).unwrap();
507 let pkey = PKey::from_rsa(private_key).unwrap();
508
509 let mut encrypter = Encrypter::new(&pkey).unwrap();
510 encrypter.set_rsa_padding(Padding::PKCS1_OAEP).unwrap();
511 encrypter.set_rsa_oaep_label(b"test_oaep_label").unwrap();
512 let input = Vec::from_hex(INPUT).unwrap();
513 let buffer_len = encrypter.encrypt_len(&input).unwrap();
514 let mut encoded = vec![0u8; buffer_len];
515 let encoded_len = encrypter.encrypt(&input, &mut encoded).unwrap();
516 let encoded = &encoded[..encoded_len];
517
518 let mut decrypter = Decrypter::new(&pkey).unwrap();
519 decrypter.set_rsa_padding(Padding::PKCS1_OAEP).unwrap();
520 decrypter.set_rsa_oaep_label(b"test_oaep_label").unwrap();
521 let buffer_len = decrypter.decrypt_len(encoded).unwrap();
522 let mut decoded = vec![0u8; buffer_len];
523 let decoded_len = decrypter.decrypt(encoded, &mut decoded).unwrap();
524 let decoded = &decoded[..decoded_len];
525
526 assert_eq!(decoded, &*input);
527
528 decrypter.set_rsa_oaep_label(b"wrong_oaep_label").unwrap();
529 let buffer_len = decrypter.decrypt_len(encoded).unwrap();
530 let mut decoded = vec![0u8; buffer_len];
531
532 assert!(decrypter.decrypt(encoded, &mut decoded).is_err());
533 }
534}