Skip to content

Commit

Permalink
feat: Track the number of Diffie-Hellman ratchet advances in the Olm …
Browse files Browse the repository at this point in the history
…Session. This number is useful only for debugging purposes and will be included in the Debug output of the Olm `Session` (#134).
  • Loading branch information
poljar committed Mar 21, 2024
2 parents b768ecf + ab00f4b commit 6101be8
Show file tree
Hide file tree
Showing 3 changed files with 249 additions and 19 deletions.
215 changes: 211 additions & 4 deletions src/olm/session/double_ratchet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ impl DoubleRatchet {

let ratchet = ActiveDoubleRatchet {
parent_ratchet_key: None, // First chain in a session lacks parent ratchet key
ratchet_count: RatchetCount::new(),
active_ratchet: Ratchet::new(root_key),
symmetric_key_ratchet: chain_key,
};
Expand All @@ -91,15 +92,31 @@ impl DoubleRatchet {
Self {
inner: ActiveDoubleRatchet {
parent_ratchet_key: None, // libolm pickle did not record parent ratchet key
ratchet_count: RatchetCount::unknown(), // nor the ratchet count
active_ratchet: ratchet,
symmetric_key_ratchet: chain_key,
}
.into(),
}
}

pub fn inactive(root_key: RemoteRootKey, ratchet_key: RemoteRatchetKey) -> Self {
let ratchet = InactiveDoubleRatchet { root_key, ratchet_key };
pub fn inactive_from_prekey_data(
root_key: RemoteRootKey,
ratchet_key: RemoteRatchetKey,
) -> Self {
let ratchet_count = RatchetCount::new();
let ratchet = InactiveDoubleRatchet { root_key, ratchet_key, ratchet_count };

Self { inner: ratchet.into() }
}

#[cfg(feature = "libolm-compat")]
pub fn inactive_from_libolm_pickle(
root_key: RemoteRootKey,
ratchet_key: RemoteRatchetKey,
) -> Self {
let ratchet_count = RatchetCount::unknown();
let ratchet = InactiveDoubleRatchet { root_key, ratchet_key, ratchet_count };

Self { inner: ratchet.into() }
}
Expand Down Expand Up @@ -166,6 +183,15 @@ impl From<ActiveDoubleRatchet> for DoubleRatchetState {
struct InactiveDoubleRatchet {
root_key: RemoteRootKey,
ratchet_key: RemoteRatchetKey,

/// The number of times the ratchet has been advanced.
///
/// If `root_key` contains root key `R`<sub>`i`</sub>, this is `i`.
///
/// This is not required to implement the algorithm: it is maintained solely
/// for diagnostic output.
#[serde(default = "RatchetCount::unknown")]
ratchet_count: RatchetCount,
}

impl InactiveDoubleRatchet {
Expand All @@ -175,6 +201,7 @@ impl InactiveDoubleRatchet {

ActiveDoubleRatchet {
parent_ratchet_key: Some(self.ratchet_key),
ratchet_count: self.ratchet_count.advance(),
active_ratchet,
symmetric_key_ratchet: chain_key,
}
Expand All @@ -184,6 +211,7 @@ impl InactiveDoubleRatchet {
impl Debug for InactiveDoubleRatchet {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("InactiveDoubleRatchet")
.field("ratchet_count", &self.ratchet_count)
.field("ratchet_key", &self.ratchet_key)
.finish_non_exhaustive()
}
Expand Down Expand Up @@ -212,6 +240,12 @@ struct ActiveDoubleRatchet {
#[serde(default)]
parent_ratchet_key: Option<RemoteRatchetKey>,

/// The number of times the ratchet has been advanced.
///
/// If `active_ratchet` contains root key `R`<sub>`i`</sub>, this is `i`.
#[serde(default = "RatchetCount::unknown")]
ratchet_count: RatchetCount,

active_ratchet: Ratchet,
symmetric_key_ratchet: ChainKey,
}
Expand All @@ -220,8 +254,13 @@ impl ActiveDoubleRatchet {
fn advance(&self, ratchet_key: RemoteRatchetKey) -> (InactiveDoubleRatchet, ReceiverChain) {
let (root_key, remote_chain) = self.active_ratchet.advance(ratchet_key);

let ratchet = InactiveDoubleRatchet { root_key, ratchet_key };
let receiver_chain = ReceiverChain::new(ratchet_key, remote_chain);
let new_ratchet_count = self.ratchet_count.advance();
let ratchet = InactiveDoubleRatchet {
root_key,
ratchet_key,
ratchet_count: new_ratchet_count.clone(),
};
let receiver_chain = ReceiverChain::new(ratchet_key, remote_chain, new_ratchet_count);

(ratchet, receiver_chain)
}
Expand All @@ -239,9 +278,177 @@ impl Debug for ActiveDoubleRatchet {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let active_ratchet_public_key: RatchetPublicKey = self.active_ratchet.ratchet_key().into();
f.debug_struct("ActiveDoubleRatchet")
.field("ratchet_count", &self.ratchet_count)
.field("parent_ratchet_key", &self.parent_ratchet_key)
.field("ratchet_key", &active_ratchet_public_key)
.field("chain_index", &self.symmetric_key_ratchet.index())
.finish_non_exhaustive()
}
}

/// The number of times the ratchet has been advanced, `i`.
///
/// This starts at 0 for the first prekey messages from Alice to Bob,
/// increments to 1 when Bob replies, and then increments each time the
/// conversation changes direction.
///
/// It may be unknown, if the ratchet was restored from a pickle
/// which didn't track it.
#[derive(Serialize, Deserialize, Clone, PartialEq, Eq)]
pub enum RatchetCount {
Known(u64),
Unknown(()),
}

impl RatchetCount {
pub fn new() -> RatchetCount {
RatchetCount::Known(0)
}

pub fn unknown() -> RatchetCount {
RatchetCount::Unknown(())
}

pub fn advance(&self) -> RatchetCount {
match self {
RatchetCount::Known(count) => RatchetCount::Known(count + 1),
RatchetCount::Unknown(_) => RatchetCount::Unknown(()),
}
}
}

impl Debug for RatchetCount {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
RatchetCount::Known(count) => write!(f, "{count}"),
RatchetCount::Unknown(_) => write!(f, "<unknown>"),
}
}
}

#[cfg(test)]
mod test {
use assert_matches::assert_matches;

use super::{
ActiveDoubleRatchet, DoubleRatchet, DoubleRatchetState, InactiveDoubleRatchet, RatchetCount,
};
use crate::olm::{
session::test::session_and_libolm_pair, Account, OlmMessage, Session, SessionConfig,
};

fn create_session_pair(alice: &Account, bob: &mut Account) -> (Session, Session) {
let bob_otks = bob.generate_one_time_keys(1);
let bob_otk = bob_otks.created.first().expect("Couldn't get a one-time-key for bob");
let bob_identity_key = bob.identity_keys().curve25519;
let mut alice_session =
alice.create_outbound_session(SessionConfig::version_1(), bob_identity_key, *bob_otk);

let message = "It's a secret to everybody";
let olm_message = alice_session.encrypt(message);
let prekey_message = assert_matches!(olm_message, OlmMessage::PreKey(m) => m);

let alice_identity_key = alice.identity_keys().curve25519;
let bob_session_creation_result = bob
.create_inbound_session(alice_identity_key, &prekey_message)
.expect("Unable to create inbound session");
assert_eq!(bob_session_creation_result.plaintext, message.as_bytes());
(alice_session, bob_session_creation_result.session)
}

fn assert_active_ratchet(sending_ratchet: &DoubleRatchet) -> &ActiveDoubleRatchet {
match &sending_ratchet.inner {
DoubleRatchetState::Inactive(_) => panic!("Not an active ratchet"),
DoubleRatchetState::Active(s) => s,
}
}

fn assert_inactive_ratchet(sending_ratchet: &DoubleRatchet) -> &InactiveDoubleRatchet {
match &sending_ratchet.inner {
DoubleRatchetState::Active(_) => panic!("Not an inactive ratchet"),
DoubleRatchetState::Inactive(s) => s,
}
}

#[test]
fn ratchet_counts() {
let (mut alice_session, mut bob_session) =
create_session_pair(&Account::new(), &mut Account::new());

// Both ratchets should start with count 0.
assert_eq!(
assert_active_ratchet(&alice_session.sending_ratchet).ratchet_count,
RatchetCount::Known(0)
);
assert_eq!(
assert_inactive_ratchet(&bob_session.sending_ratchet).ratchet_count,
RatchetCount::Known(0)
);

// Once Bob replies, the ratchets should bump to 1.
let olm_message = bob_session.encrypt("sssh");
alice_session.decrypt(&olm_message).expect("Alice could not decrypt message from Bob");
assert_eq!(
assert_inactive_ratchet(&alice_session.sending_ratchet).ratchet_count,
RatchetCount::Known(1)
);
assert_eq!(
assert_active_ratchet(&bob_session.sending_ratchet).ratchet_count,
RatchetCount::Known(1)
);

// Now Alice replies again.
let olm_message = alice_session.encrypt("sssh");
bob_session.decrypt(&olm_message).expect("Bob could not decrypt message from Alice");
assert_eq!(
assert_active_ratchet(&alice_session.sending_ratchet).ratchet_count,
RatchetCount::Known(2)
);
assert_eq!(
assert_inactive_ratchet(&bob_session.sending_ratchet).ratchet_count,
RatchetCount::Known(2)
);
}

#[test]
fn ratchet_counts_for_imported_session() {
let (_, _, mut alice_session, bob_libolm_session) =
session_and_libolm_pair().expect("unable to create sessions");

// Import the libolm session into a proper Vodozmac session.
let key = b"DEFAULT_PICKLE_KEY";
let pickle =
bob_libolm_session.pickle(olm_rs::PicklingMode::Encrypted { key: key.to_vec() });
let mut bob_session =
Session::from_libolm_pickle(&pickle, key).expect("Should be able to unpickle session");

assert_eq!(
assert_inactive_ratchet(&bob_session.sending_ratchet).ratchet_count,
RatchetCount::Unknown(())
);

// Once Bob replies, Alice's count bumps to 1, but Bob's remains unknown.
let olm_message = bob_session.encrypt("sssh");
alice_session.decrypt(&olm_message).expect("Alice could not decrypt message from Bob");
assert_eq!(
assert_inactive_ratchet(&alice_session.sending_ratchet).ratchet_count,
RatchetCount::Known(1)
);
assert_eq!(
assert_active_ratchet(&bob_session.sending_ratchet).ratchet_count,
RatchetCount::Unknown(())
);

// Now Alice replies again.
let olm_message = alice_session.encrypt("sssh");
bob_session.decrypt(&olm_message).expect("Bob could not decrypt message from Alice");
assert_eq!(
assert_active_ratchet(&alice_session.sending_ratchet).ratchet_count,
RatchetCount::Known(2)
);
assert_eq!(
assert_inactive_ratchet(&bob_session.sending_ratchet).ratchet_count,
RatchetCount::Unknown(())
);
}
}
31 changes: 19 additions & 12 deletions src/olm/session/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@ use super::{
#[cfg(feature = "low-level-api")]
use crate::hazmat::olm::MessageKey;
use crate::{
olm::messages::{Message, OlmMessage, PreKeyMessage},
olm::{
messages::{Message, OlmMessage, PreKeyMessage},
session::double_ratchet::RatchetCount,
},
utilities::{pickle, unpickle},
Curve25519PublicKey, PickleError,
};
Expand Down Expand Up @@ -193,8 +196,9 @@ impl Session {
let root_key = RemoteRootKey::new(root_key);
let remote_chain_key = RemoteChainKey::new(remote_chain_key);

let local_ratchet = DoubleRatchet::inactive(root_key, remote_ratchet_key);
let remote_ratchet = ReceiverChain::new(remote_ratchet_key, remote_chain_key);
let local_ratchet = DoubleRatchet::inactive_from_prekey_data(root_key, remote_ratchet_key);
let remote_ratchet =
ReceiverChain::new(remote_ratchet_key, remote_chain_key, RatchetCount::new());

let mut ratchet_store = ChainStore::new();
ratchet_store.push(remote_ratchet);
Expand Down Expand Up @@ -360,7 +364,7 @@ impl Session {
chain.chain_key_index,
);

ReceiverChain::new(ratchet_key, chain_key)
ReceiverChain::new(ratchet_key, chain_key, RatchetCount::unknown())
}
}

Expand Down Expand Up @@ -444,7 +448,7 @@ impl Session {
config: SessionConfig::version_1(),
})
} else if let Some(chain) = receiving_chains.get(0) {
let sending_ratchet = DoubleRatchet::inactive(
let sending_ratchet = DoubleRatchet::inactive_from_libolm_pickle(
RemoteRootKey::new(pickle.root_key.clone()),
chain.ratchet_key(),
);
Expand Down Expand Up @@ -530,7 +534,10 @@ mod test {

const PICKLE_KEY: [u8; 32] = [0u8; 32];

fn sessions() -> Result<(Account, OlmAccount, Session, OlmSession)> {
/// Create a pair of accounts, one using vodozemac and one libolm.
///
/// Then, create a pair of sessions between the two.
pub fn session_and_libolm_pair() -> Result<(Account, OlmAccount, Session, OlmSession)> {
let alice = Account::new();
let bob = OlmAccount::new();
bob.generate_one_time_keys(1);
Expand Down Expand Up @@ -566,7 +573,7 @@ mod test {

#[test]
fn out_of_order_decryption() {
let (_, _, mut alice_session, bob_session) = sessions().unwrap();
let (_, _, mut alice_session, bob_session) = session_and_libolm_pair().unwrap();

let message_1 = bob_session.encrypt("Message 1").into();
let message_2 = bob_session.encrypt("Message 2").into();
Expand All @@ -588,7 +595,7 @@ mod test {

#[test]
fn more_out_of_order_decryption() {
let (_, _, mut alice_session, bob_session) = sessions().unwrap();
let (_, _, mut alice_session, bob_session) = session_and_libolm_pair().unwrap();

let message_1 = bob_session.encrypt("Message 1").into();
let message_2 = bob_session.encrypt("Message 2").into();
Expand Down Expand Up @@ -626,7 +633,7 @@ mod test {

#[test]
fn max_keys_out_of_order_decryption() {
let (_, _, mut alice_session, bob_session) = sessions().unwrap();
let (_, _, mut alice_session, bob_session) = session_and_libolm_pair().unwrap();

let mut messages: Vec<messages::OlmMessage> = Vec::new();
for i in 0..(MAX_MESSAGE_KEYS + 2) {
Expand Down Expand Up @@ -660,7 +667,7 @@ mod test {

#[test]
fn max_gap_out_of_order_decryption() {
let (_, _, mut alice_session, bob_session) = sessions().unwrap();
let (_, _, mut alice_session, bob_session) = session_and_libolm_pair().unwrap();

for i in 0..(MAX_MESSAGE_GAP + 1) {
bob_session.encrypt(format!("Message {}", i).as_str());
Expand All @@ -676,7 +683,7 @@ mod test {
#[test]
#[cfg(feature = "libolm-compat")]
fn libolm_unpickling() {
let (_, _, mut session, olm) = sessions().unwrap();
let (_, _, mut session, olm) = session_and_libolm_pair().unwrap();

let plaintext = "It's a secret to everybody";
let old_message = session.encrypt(plaintext);
Expand Down Expand Up @@ -713,7 +720,7 @@ mod test {

#[test]
fn session_pickling_roundtrip_is_identity() {
let (_, _, session, _) = sessions().unwrap();
let (_, _, session, _) = session_and_libolm_pair().unwrap();

let pickle = session.pickle().encrypt(&PICKLE_KEY);

Expand Down
Loading

0 comments on commit 6101be8

Please sign in to comment.