diff --git a/src/olm/session/mod.rs b/src/olm/session/mod.rs index 7125bc5d..ad682cde 100644 --- a/src/olm/session/mod.rs +++ b/src/olm/session/mod.rs @@ -512,14 +512,19 @@ impl From for Session { #[cfg(test)] mod test { use anyhow::{bail, Result}; + use assert_matches::assert_matches; use olm_rs::{ account::OlmAccount, session::{OlmMessage, OlmSession}, }; - use super::Session; + use super::{DecryptionError, Session}; use crate::{ - olm::{Account, SessionConfig, SessionPickle}, + olm::{ + messages, + session::receiver_chain::{MAX_MESSAGE_GAP, MAX_MESSAGE_KEYS}, + Account, SessionConfig, SessionPickle, + }, Curve25519PublicKey, }; @@ -560,49 +565,118 @@ mod test { } #[test] - fn out_of_order_decryption() -> Result<()> { - let (_, _, mut alice_session, bob_session) = sessions()?; + fn out_of_order_decryption() { + let (_, _, mut alice_session, bob_session) = sessions().unwrap(); let message_1 = bob_session.encrypt("Message 1").into(); let message_2 = bob_session.encrypt("Message 2").into(); let message_3 = bob_session.encrypt("Message 3").into(); - assert_eq!("Message 3".as_bytes(), alice_session.decrypt(&message_3)?); - assert_eq!("Message 2".as_bytes(), alice_session.decrypt(&message_2)?); - assert_eq!("Message 1".as_bytes(), alice_session.decrypt(&message_1)?); - - Ok(()) + assert_eq!( + "Message 3".as_bytes(), + alice_session.decrypt(&message_3).expect("Should be able to decrypt message 3") + ); + assert_eq!( + "Message 2".as_bytes(), + alice_session.decrypt(&message_2).expect("Should be able to decrypt message 2") + ); + assert_eq!( + "Message 1".as_bytes(), + alice_session.decrypt(&message_1).expect("Should be able to decrypt message 1") + ); } #[test] - fn more_out_of_order_decryption() -> Result<()> { - let (_, _, mut alice_session, bob_session) = sessions()?; + fn more_out_of_order_decryption() { + let (_, _, mut alice_session, bob_session) = sessions().unwrap(); let message_1 = bob_session.encrypt("Message 1").into(); let message_2 = bob_session.encrypt("Message 2").into(); let message_3 = bob_session.encrypt("Message 3").into(); - assert_eq!("Message 1".as_bytes(), alice_session.decrypt(&message_1)?); + assert_eq!( + "Message 1".as_bytes(), + alice_session.decrypt(&message_1).expect("Should be able to decrypt message 1") + ); assert_eq!(alice_session.receiving_chains.len(), 1); let message_4 = alice_session.encrypt("Message 4").into(); - assert_eq!("Message 4", bob_session.decrypt(message_4)?); + assert_eq!( + "Message 4", + bob_session.decrypt(message_4).expect("Should be able to decrypt message 4") + ); let message_5 = bob_session.encrypt("Message 5").into(); - assert_eq!("Message 5".as_bytes(), alice_session.decrypt(&message_5)?); - assert_eq!("Message 3".as_bytes(), alice_session.decrypt(&message_3)?); - assert_eq!("Message 2".as_bytes(), alice_session.decrypt(&message_2)?); + assert_eq!( + "Message 5".as_bytes(), + alice_session.decrypt(&message_5).expect("Should be able to decrypt message 5") + ); + assert_eq!( + "Message 3".as_bytes(), + alice_session.decrypt(&message_3).expect("Should be able to decrypt message 3") + ); + assert_eq!( + "Message 2".as_bytes(), + alice_session.decrypt(&message_2).expect("Should be able to decrypt message 2") + ); assert_eq!(alice_session.receiving_chains.len(), 2); + } + + #[test] + fn max_keys_out_of_order_decryption() { + let (_, _, mut alice_session, bob_session) = sessions().unwrap(); - Ok(()) + let mut messages: Vec = Vec::new(); + for i in 0..(MAX_MESSAGE_KEYS + 2) { + messages.push(bob_session.encrypt(format!("Message {}", i).as_str()).into()); + } + + // Decrypt last message + assert_eq!( + format!("Message {}", MAX_MESSAGE_KEYS + 1).as_bytes(), + alice_session + .decrypt(&messages[MAX_MESSAGE_KEYS + 1]) + .expect("Should be able to decrypt last message") + ); + + // Cannot decrypt first message because it is more than MAX_MESSAGE_KEYS ago + assert_matches!( + alice_session.decrypt(&messages[0]), + Err(DecryptionError::MissingMessageKey(_)) + ); + + // Can decrypt all other messages + for (i, message) in messages.iter().enumerate().skip(1).take(MAX_MESSAGE_KEYS) { + assert_eq!( + format!("Message {}", i).as_bytes(), + alice_session + .decrypt(message) + .expect("Should be able to decrypt remaining messages") + ); + } + } + + #[test] + fn max_gap_out_of_order_decryption() { + let (_, _, mut alice_session, bob_session) = sessions().unwrap(); + + for i in 0..(MAX_MESSAGE_GAP + 1) { + bob_session.encrypt(format!("Message {}", i).as_str()); + } + + let message = bob_session.encrypt("Message").into(); + assert_matches!( + alice_session.decrypt(&message), + Err(DecryptionError::TooBigMessageGap(_, _)) + ); } #[test] #[cfg(feature = "libolm-compat")] - fn libolm_unpickling() -> Result<()> { - let (_, _, mut session, olm) = sessions()?; + fn libolm_unpickling() { + let (_, _, mut session, olm) = sessions().unwrap(); let plaintext = "It's a secret to everybody"; let old_message = session.encrypt(plaintext); @@ -612,42 +686,49 @@ mod test { } let message = session.encrypt("Hello"); - olm.decrypt(message.into())?; + olm.decrypt(message.into()).expect("Should be able to decrypt message"); let key = b"DEFAULT_PICKLE_KEY"; let pickle = olm.pickle(olm_rs::PicklingMode::Encrypted { key: key.to_vec() }); - let mut unpickled = Session::from_libolm_pickle(&pickle, key)?; + let mut unpickled = + Session::from_libolm_pickle(&pickle, key).expect("Should be able to unpickle session"); assert_eq!(olm.session_id(), unpickled.session_id()); - assert_eq!(unpickled.decrypt(&old_message)?, plaintext.as_bytes()); + assert_eq!( + unpickled + .decrypt(&old_message) + .expect("Should be able to decrypt old message with unpickled session"), + plaintext.as_bytes() + ); let message = unpickled.encrypt(plaintext); - assert_eq!(session.decrypt(&message)?, plaintext.as_bytes()); - - Ok(()) + assert_eq!( + session.decrypt(&message).expect("Should be able to decrypt re-encrypted message"), + plaintext.as_bytes() + ); } #[test] - fn session_pickling_roundtrip_is_identity() -> Result<()> { - let (_, _, session, _) = sessions()?; + fn session_pickling_roundtrip_is_identity() { + let (_, _, session, _) = sessions().unwrap(); let pickle = session.pickle().encrypt(&PICKLE_KEY); - let decrypted_pickle = SessionPickle::from_encrypted(&pickle, &PICKLE_KEY)?; + let decrypted_pickle = SessionPickle::from_encrypted(&pickle, &PICKLE_KEY) + .expect("Should be able to decrypt encrypted pickle"); let unpickled_group_session = Session::from_pickle(decrypted_pickle); let repickle = unpickled_group_session.pickle(); assert_eq!(session.session_id(), unpickled_group_session.session_id()); - let decrypted_pickle = SessionPickle::from_encrypted(&pickle, &PICKLE_KEY)?; - let pickle = serde_json::to_value(decrypted_pickle)?; - let repickle = serde_json::to_value(repickle)?; + let decrypted_pickle = SessionPickle::from_encrypted(&pickle, &PICKLE_KEY) + .expect("Should be able to decrypt encrypted pickle"); + let pickle = serde_json::to_value(decrypted_pickle).unwrap(); + let repickle = serde_json::to_value(repickle).unwrap(); assert_eq!(pickle, repickle); - - Ok(()) } } diff --git a/src/olm/session/receiver_chain.rs b/src/olm/session/receiver_chain.rs index 96b2c1f4..61b4429d 100644 --- a/src/olm/session/receiver_chain.rs +++ b/src/olm/session/receiver_chain.rs @@ -23,8 +23,8 @@ use super::{ }; use crate::olm::{messages::Message, session_config::Version, SessionConfig}; -const MAX_MESSAGE_GAP: u64 = 2000; -const MAX_MESSAGE_KEYS: usize = 40; +pub(crate) const MAX_MESSAGE_GAP: u64 = 2000; +pub(crate) const MAX_MESSAGE_KEYS: usize = 40; #[derive(Serialize, Deserialize, Clone)] struct MessageKeyStore { @@ -187,3 +187,49 @@ impl ReceiverChain { &self.ratchet_key == ratchet_key } } + +#[cfg(test)] +mod test { + use assert_matches::assert_matches; + + use super::MessageKeyStore; + use crate::olm::session::message_key::RemoteMessageKey; + + #[test] + fn push_and_remove() { + let mut store = MessageKeyStore::new(); + let key_bytes = *b"11111111111111111111111111111111"; + let chain_index: u64 = 1; + let key = RemoteMessageKey::new(Box::new(key_bytes), chain_index); + assert_matches!(store.get_message_key(chain_index), None); + store.push(key); + assert_matches!(store.get_message_key(chain_index), Some(key) if *(key.key) == key_bytes && key.index == chain_index); + store.remove_message_key(chain_index); + assert_matches!(store.get_message_key(chain_index), None); + } + + #[test] + fn merge() { + let mut store1 = MessageKeyStore::new(); + let key_bytes1 = *b"11111111111111111111111111111111"; + let chain_index1: u64 = 1; + let key1 = RemoteMessageKey::new(Box::new(key_bytes1), chain_index1); + store1.push(key1); + + let mut store2 = MessageKeyStore::new(); + let key_bytes2 = *b"22222222222222222222222222222222"; + let chain_index2: u64 = 2; + let key2 = RemoteMessageKey::new(Box::new(key_bytes2), chain_index2); + store2.push(key2); + + assert_matches!(store1.get_message_key(chain_index1), Some(_)); + assert_matches!(store1.get_message_key(chain_index2), None); + assert_matches!(store2.get_message_key(chain_index1), None); + assert_matches!(store2.get_message_key(chain_index2), Some(_)); + + store1.merge(store2); + + assert_matches!(store1.get_message_key(chain_index1), Some(_)); + assert_matches!(store1.get_message_key(chain_index2), Some(_)); + } +}