diff --git a/src/olm/session/double_ratchet.rs b/src/olm/session/double_ratchet.rs index 3d6e45fb..0e613144 100644 --- a/src/olm/session/double_ratchet.rs +++ b/src/olm/session/double_ratchet.rs @@ -79,6 +79,7 @@ impl DoubleRatchet { let ratchet = ActiveDoubleRatchet { parent_ratchet_key: None, // First chain in a session lacks parent ratchet key + ratchet_count: RatchetCount::new(), active_ratchet: Ratchet::new(root_key), symmetric_key_ratchet: chain_key, }; @@ -91,6 +92,7 @@ impl DoubleRatchet { Self { inner: ActiveDoubleRatchet { parent_ratchet_key: None, // libolm pickle did not record parent ratchet key + ratchet_count: RatchetCount::unknown(), // nor the ratchet count active_ratchet: ratchet, symmetric_key_ratchet: chain_key, } @@ -98,8 +100,23 @@ impl DoubleRatchet { } } - pub fn inactive(root_key: RemoteRootKey, ratchet_key: RemoteRatchetKey) -> Self { - let ratchet = InactiveDoubleRatchet { root_key, ratchet_key }; + pub fn inactive_from_prekey_data( + root_key: RemoteRootKey, + ratchet_key: RemoteRatchetKey, + ) -> Self { + let ratchet_count = RatchetCount::new(); + let ratchet = InactiveDoubleRatchet { root_key, ratchet_key, ratchet_count }; + + Self { inner: ratchet.into() } + } + + #[cfg(feature = "libolm-compat")] + pub fn inactive_from_libolm_pickle( + root_key: RemoteRootKey, + ratchet_key: RemoteRatchetKey, + ) -> Self { + let ratchet_count = RatchetCount::unknown(); + let ratchet = InactiveDoubleRatchet { root_key, ratchet_key, ratchet_count }; Self { inner: ratchet.into() } } @@ -166,6 +183,15 @@ impl From for DoubleRatchetState { struct InactiveDoubleRatchet { root_key: RemoteRootKey, ratchet_key: RemoteRatchetKey, + + /// The number of times the ratchet has been advanced. + /// + /// If `root_key` contains root key `R``i`, this is `i`. + /// + /// This is not required to implement the algorithm: it is maintained solely + /// for diagnostic output. + #[serde(default = "RatchetCount::unknown")] + ratchet_count: RatchetCount, } impl InactiveDoubleRatchet { @@ -175,6 +201,7 @@ impl InactiveDoubleRatchet { ActiveDoubleRatchet { parent_ratchet_key: Some(self.ratchet_key), + ratchet_count: self.ratchet_count.advance(), active_ratchet, symmetric_key_ratchet: chain_key, } @@ -184,6 +211,7 @@ impl InactiveDoubleRatchet { impl Debug for InactiveDoubleRatchet { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.debug_struct("InactiveDoubleRatchet") + .field("ratchet_count", &self.ratchet_count) .field("ratchet_key", &self.ratchet_key) .finish_non_exhaustive() } @@ -212,6 +240,12 @@ struct ActiveDoubleRatchet { #[serde(default)] parent_ratchet_key: Option, + /// The number of times the ratchet has been advanced. + /// + /// If `active_ratchet` contains root key `R``i`, this is `i`. + #[serde(default = "RatchetCount::unknown")] + ratchet_count: RatchetCount, + active_ratchet: Ratchet, symmetric_key_ratchet: ChainKey, } @@ -220,8 +254,13 @@ impl ActiveDoubleRatchet { fn advance(&self, ratchet_key: RemoteRatchetKey) -> (InactiveDoubleRatchet, ReceiverChain) { let (root_key, remote_chain) = self.active_ratchet.advance(ratchet_key); - let ratchet = InactiveDoubleRatchet { root_key, ratchet_key }; - let receiver_chain = ReceiverChain::new(ratchet_key, remote_chain); + let new_ratchet_count = self.ratchet_count.advance(); + let ratchet = InactiveDoubleRatchet { + root_key, + ratchet_key, + ratchet_count: new_ratchet_count.clone(), + }; + let receiver_chain = ReceiverChain::new(ratchet_key, remote_chain, new_ratchet_count); (ratchet, receiver_chain) } @@ -239,9 +278,177 @@ impl Debug for ActiveDoubleRatchet { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { let active_ratchet_public_key: RatchetPublicKey = self.active_ratchet.ratchet_key().into(); f.debug_struct("ActiveDoubleRatchet") + .field("ratchet_count", &self.ratchet_count) .field("parent_ratchet_key", &self.parent_ratchet_key) .field("ratchet_key", &active_ratchet_public_key) .field("chain_index", &self.symmetric_key_ratchet.index()) .finish_non_exhaustive() } } + +/// The number of times the ratchet has been advanced, `i`. +/// +/// This starts at 0 for the first prekey messages from Alice to Bob, +/// increments to 1 when Bob replies, and then increments each time the +/// conversation changes direction. +/// +/// It may be unknown, if the ratchet was restored from a pickle +/// which didn't track it. +#[derive(Serialize, Deserialize, Clone, PartialEq, Eq)] +pub enum RatchetCount { + Known(u64), + Unknown(()), +} + +impl RatchetCount { + pub fn new() -> RatchetCount { + RatchetCount::Known(0) + } + + pub fn unknown() -> RatchetCount { + RatchetCount::Unknown(()) + } + + pub fn advance(&self) -> RatchetCount { + match self { + RatchetCount::Known(count) => RatchetCount::Known(count + 1), + RatchetCount::Unknown(_) => RatchetCount::Unknown(()), + } + } +} + +impl Debug for RatchetCount { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + RatchetCount::Known(count) => write!(f, "{count}"), + RatchetCount::Unknown(_) => write!(f, ""), + } + } +} + +#[cfg(test)] +mod test { + use assert_matches::assert_matches; + + use super::{ + ActiveDoubleRatchet, DoubleRatchet, DoubleRatchetState, InactiveDoubleRatchet, RatchetCount, + }; + use crate::olm::{ + session::test::session_and_libolm_pair, Account, OlmMessage, Session, SessionConfig, + }; + + fn create_session_pair(alice: &Account, bob: &mut Account) -> (Session, Session) { + let bob_otks = bob.generate_one_time_keys(1); + let bob_otk = bob_otks.created.first().expect("Couldn't get a one-time-key for bob"); + let bob_identity_key = bob.identity_keys().curve25519; + let mut alice_session = + alice.create_outbound_session(SessionConfig::version_1(), bob_identity_key, *bob_otk); + + let message = "It's a secret to everybody"; + let olm_message = alice_session.encrypt(message); + let prekey_message = assert_matches!(olm_message, OlmMessage::PreKey(m) => m); + + let alice_identity_key = alice.identity_keys().curve25519; + let bob_session_creation_result = bob + .create_inbound_session(alice_identity_key, &prekey_message) + .expect("Unable to create inbound session"); + assert_eq!(bob_session_creation_result.plaintext, message.as_bytes()); + (alice_session, bob_session_creation_result.session) + } + + fn assert_active_ratchet(sending_ratchet: &DoubleRatchet) -> &ActiveDoubleRatchet { + match &sending_ratchet.inner { + DoubleRatchetState::Inactive(_) => panic!("Not an active ratchet"), + DoubleRatchetState::Active(s) => s, + } + } + + fn assert_inactive_ratchet(sending_ratchet: &DoubleRatchet) -> &InactiveDoubleRatchet { + match &sending_ratchet.inner { + DoubleRatchetState::Active(_) => panic!("Not an inactive ratchet"), + DoubleRatchetState::Inactive(s) => s, + } + } + + #[test] + fn ratchet_counts() { + let (mut alice_session, mut bob_session) = + create_session_pair(&Account::new(), &mut Account::new()); + + // Both ratchets should start with count 0. + assert_eq!( + assert_active_ratchet(&alice_session.sending_ratchet).ratchet_count, + RatchetCount::Known(0) + ); + assert_eq!( + assert_inactive_ratchet(&bob_session.sending_ratchet).ratchet_count, + RatchetCount::Known(0) + ); + + // Once Bob replies, the ratchets should bump to 1. + let olm_message = bob_session.encrypt("sssh"); + alice_session.decrypt(&olm_message).expect("Alice could not decrypt message from Bob"); + assert_eq!( + assert_inactive_ratchet(&alice_session.sending_ratchet).ratchet_count, + RatchetCount::Known(1) + ); + assert_eq!( + assert_active_ratchet(&bob_session.sending_ratchet).ratchet_count, + RatchetCount::Known(1) + ); + + // Now Alice replies again. + let olm_message = alice_session.encrypt("sssh"); + bob_session.decrypt(&olm_message).expect("Bob could not decrypt message from Alice"); + assert_eq!( + assert_active_ratchet(&alice_session.sending_ratchet).ratchet_count, + RatchetCount::Known(2) + ); + assert_eq!( + assert_inactive_ratchet(&bob_session.sending_ratchet).ratchet_count, + RatchetCount::Known(2) + ); + } + + #[test] + fn ratchet_counts_for_imported_session() { + let (_, _, mut alice_session, bob_libolm_session) = + session_and_libolm_pair().expect("unable to create sessions"); + + // Import the libolm session into a proper Vodozmac session. + let key = b"DEFAULT_PICKLE_KEY"; + let pickle = + bob_libolm_session.pickle(olm_rs::PicklingMode::Encrypted { key: key.to_vec() }); + let mut bob_session = + Session::from_libolm_pickle(&pickle, key).expect("Should be able to unpickle session"); + + assert_eq!( + assert_inactive_ratchet(&bob_session.sending_ratchet).ratchet_count, + RatchetCount::Unknown(()) + ); + + // Once Bob replies, Alice's count bumps to 1, but Bob's remains unknown. + let olm_message = bob_session.encrypt("sssh"); + alice_session.decrypt(&olm_message).expect("Alice could not decrypt message from Bob"); + assert_eq!( + assert_inactive_ratchet(&alice_session.sending_ratchet).ratchet_count, + RatchetCount::Known(1) + ); + assert_eq!( + assert_active_ratchet(&bob_session.sending_ratchet).ratchet_count, + RatchetCount::Unknown(()) + ); + + // Now Alice replies again. + let olm_message = alice_session.encrypt("sssh"); + bob_session.decrypt(&olm_message).expect("Bob could not decrypt message from Alice"); + assert_eq!( + assert_active_ratchet(&alice_session.sending_ratchet).ratchet_count, + RatchetCount::Known(2) + ); + assert_eq!( + assert_inactive_ratchet(&bob_session.sending_ratchet).ratchet_count, + RatchetCount::Unknown(()) + ); + } +} diff --git a/src/olm/session/mod.rs b/src/olm/session/mod.rs index f47f1199..3d55da0c 100644 --- a/src/olm/session/mod.rs +++ b/src/olm/session/mod.rs @@ -43,7 +43,10 @@ use super::{ #[cfg(feature = "low-level-api")] use crate::hazmat::olm::MessageKey; use crate::{ - olm::messages::{Message, OlmMessage, PreKeyMessage}, + olm::{ + messages::{Message, OlmMessage, PreKeyMessage}, + session::double_ratchet::RatchetCount, + }, utilities::{pickle, unpickle}, Curve25519PublicKey, PickleError, }; @@ -193,8 +196,9 @@ impl Session { let root_key = RemoteRootKey::new(root_key); let remote_chain_key = RemoteChainKey::new(remote_chain_key); - let local_ratchet = DoubleRatchet::inactive(root_key, remote_ratchet_key); - let remote_ratchet = ReceiverChain::new(remote_ratchet_key, remote_chain_key); + let local_ratchet = DoubleRatchet::inactive_from_prekey_data(root_key, remote_ratchet_key); + let remote_ratchet = + ReceiverChain::new(remote_ratchet_key, remote_chain_key, RatchetCount::new()); let mut ratchet_store = ChainStore::new(); ratchet_store.push(remote_ratchet); @@ -360,7 +364,7 @@ impl Session { chain.chain_key_index, ); - ReceiverChain::new(ratchet_key, chain_key) + ReceiverChain::new(ratchet_key, chain_key, RatchetCount::unknown()) } } @@ -444,7 +448,7 @@ impl Session { config: SessionConfig::version_1(), }) } else if let Some(chain) = receiving_chains.get(0) { - let sending_ratchet = DoubleRatchet::inactive( + let sending_ratchet = DoubleRatchet::inactive_from_libolm_pickle( RemoteRootKey::new(pickle.root_key.clone()), chain.ratchet_key(), ); @@ -530,7 +534,10 @@ mod test { const PICKLE_KEY: [u8; 32] = [0u8; 32]; - fn sessions() -> Result<(Account, OlmAccount, Session, OlmSession)> { + /// Create a pair of accounts, one using vodozemac and one libolm. + /// + /// Then, create a pair of sessions between the two. + pub fn session_and_libolm_pair() -> Result<(Account, OlmAccount, Session, OlmSession)> { let alice = Account::new(); let bob = OlmAccount::new(); bob.generate_one_time_keys(1); @@ -566,7 +573,7 @@ mod test { #[test] fn out_of_order_decryption() { - let (_, _, mut alice_session, bob_session) = sessions().unwrap(); + let (_, _, mut alice_session, bob_session) = session_and_libolm_pair().unwrap(); let message_1 = bob_session.encrypt("Message 1").into(); let message_2 = bob_session.encrypt("Message 2").into(); @@ -588,7 +595,7 @@ mod test { #[test] fn more_out_of_order_decryption() { - let (_, _, mut alice_session, bob_session) = sessions().unwrap(); + let (_, _, mut alice_session, bob_session) = session_and_libolm_pair().unwrap(); let message_1 = bob_session.encrypt("Message 1").into(); let message_2 = bob_session.encrypt("Message 2").into(); @@ -626,7 +633,7 @@ mod test { #[test] fn max_keys_out_of_order_decryption() { - let (_, _, mut alice_session, bob_session) = sessions().unwrap(); + let (_, _, mut alice_session, bob_session) = session_and_libolm_pair().unwrap(); let mut messages: Vec = Vec::new(); for i in 0..(MAX_MESSAGE_KEYS + 2) { @@ -660,7 +667,7 @@ mod test { #[test] fn max_gap_out_of_order_decryption() { - let (_, _, mut alice_session, bob_session) = sessions().unwrap(); + let (_, _, mut alice_session, bob_session) = session_and_libolm_pair().unwrap(); for i in 0..(MAX_MESSAGE_GAP + 1) { bob_session.encrypt(format!("Message {}", i).as_str()); @@ -676,7 +683,7 @@ mod test { #[test] #[cfg(feature = "libolm-compat")] fn libolm_unpickling() { - let (_, _, mut session, olm) = sessions().unwrap(); + let (_, _, mut session, olm) = session_and_libolm_pair().unwrap(); let plaintext = "It's a secret to everybody"; let old_message = session.encrypt(plaintext); @@ -713,7 +720,7 @@ mod test { #[test] fn session_pickling_roundtrip_is_identity() { - let (_, _, session, _) = sessions().unwrap(); + let (_, _, session, _) = session_and_libolm_pair().unwrap(); let pickle = session.pickle().encrypt(&PICKLE_KEY); diff --git a/src/olm/session/receiver_chain.rs b/src/olm/session/receiver_chain.rs index 58f56aac..7d43928e 100644 --- a/src/olm/session/receiver_chain.rs +++ b/src/olm/session/receiver_chain.rs @@ -21,7 +21,10 @@ use super::{ chain_key::RemoteChainKey, message_key::RemoteMessageKey, ratchet::RemoteRatchetKey, DecryptionError, }; -use crate::olm::{messages::Message, session_config::Version, SessionConfig}; +use crate::olm::{ + messages::Message, session::double_ratchet::RatchetCount, session_config::Version, + SessionConfig, +}; pub(crate) const MAX_MESSAGE_GAP: u64 = 2000; pub(crate) const MAX_MESSAGE_KEYS: usize = 40; @@ -111,13 +114,21 @@ pub(super) struct ReceiverChain { /// /// This allows us to handle out-of-order messages. skipped_message_keys: MessageKeyStore, + + /// The number of times `i` the ratchet was advanced before this chain. + /// + /// This is not required to implement the algorithm: it is maintained solely + /// for diagnostic output. + #[serde(default = "RatchetCount::unknown")] + ratchet_count: RatchetCount, } impl Debug for ReceiverChain { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let Self { ratchet_key, hkdf_ratchet, skipped_message_keys } = self; + let Self { ratchet_count, ratchet_key, hkdf_ratchet, skipped_message_keys } = self; f.debug_struct("ReceiverChain") + .field("ratchet_count", &ratchet_count) .field("ratchet_key", &ratchet_key) .field("chain_index", &hkdf_ratchet.chain_index()) .field("skipped_message_keys", &skipped_message_keys.inner) @@ -126,11 +137,16 @@ impl Debug for ReceiverChain { } impl ReceiverChain { - pub fn new(ratchet_key: RemoteRatchetKey, chain_key: RemoteChainKey) -> Self { + pub fn new( + ratchet_key: RemoteRatchetKey, + chain_key: RemoteChainKey, + ratchet_count: RatchetCount, + ) -> Self { ReceiverChain { ratchet_key, hkdf_ratchet: chain_key, skipped_message_keys: Default::default(), + ratchet_count, } }