Skip to content

Commit

Permalink
Record the number of times the Olm ratchet is advanced
Browse files Browse the repository at this point in the history
  • Loading branch information
richvdh committed Mar 19, 2024
1 parent 7a1887f commit 8662b17
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 12 deletions.
88 changes: 84 additions & 4 deletions src/olm/session/double_ratchet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ impl DoubleRatchet {

let ratchet = ActiveDoubleRatchet {
parent_ratchet_key: None, // First chain in a session lacks parent ratchet key
ratchet_count: RatchetCount::new(),
active_ratchet: Ratchet::new(root_key),
symmetric_key_ratchet: chain_key,
};
Expand All @@ -91,15 +92,31 @@ impl DoubleRatchet {
Self {
inner: ActiveDoubleRatchet {
parent_ratchet_key: None, // libolm pickle did not record parent ratchet key
ratchet_count: RatchetCount::unknown(), // nor the ratchet count
active_ratchet: ratchet,
symmetric_key_ratchet: chain_key,
}
.into(),
}
}

pub fn inactive(root_key: RemoteRootKey, ratchet_key: RemoteRatchetKey) -> Self {
let ratchet = InactiveDoubleRatchet { root_key, ratchet_key };
pub fn inactive_from_prekey_data(
root_key: RemoteRootKey,
ratchet_key: RemoteRatchetKey,
) -> Self {
let ratchet_count = RatchetCount::new();
let ratchet = InactiveDoubleRatchet { root_key, ratchet_key, ratchet_count };

Self { inner: ratchet.into() }
}

#[cfg(feature = "libolm-compat")]
pub fn inactive_from_libolm_pickle(
root_key: RemoteRootKey,
ratchet_key: RemoteRatchetKey,
) -> Self {
let ratchet_count = RatchetCount::unknown();
let ratchet = InactiveDoubleRatchet { root_key, ratchet_key, ratchet_count };

Self { inner: ratchet.into() }
}
Expand Down Expand Up @@ -166,6 +183,15 @@ impl From<ActiveDoubleRatchet> for DoubleRatchetState {
struct InactiveDoubleRatchet {
root_key: RemoteRootKey,
ratchet_key: RemoteRatchetKey,

/// The number of times the ratchet has been advanced.
///
/// If `root_key` contains root key `R`<sub>`i`</sub>, this is `i`.
///
/// This is not required to implement the algorithm: it is maintained solely
/// for diagnostic output.
#[serde(default = "RatchetCount::unknown")]
ratchet_count: RatchetCount,
}

impl InactiveDoubleRatchet {
Expand All @@ -175,6 +201,7 @@ impl InactiveDoubleRatchet {

ActiveDoubleRatchet {
parent_ratchet_key: Some(self.ratchet_key),
ratchet_count: self.ratchet_count.advance(),
active_ratchet,
symmetric_key_ratchet: chain_key,
}
Expand All @@ -184,6 +211,7 @@ impl InactiveDoubleRatchet {
impl Debug for InactiveDoubleRatchet {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("InactiveDoubleRatchet")
.field("ratchet_count", &self.ratchet_count)
.field("ratchet_key", &self.ratchet_key)
.finish_non_exhaustive()
}
Expand Down Expand Up @@ -212,6 +240,12 @@ struct ActiveDoubleRatchet {
#[serde(default)]
parent_ratchet_key: Option<RemoteRatchetKey>,

/// The number of times the ratchet has been advanced.
///
/// If `active_ratchet` contains root key `R`<sub>`i`</sub>, this is `i`.
#[serde(default = "RatchetCount::unknown")]
ratchet_count: RatchetCount,

active_ratchet: Ratchet,
symmetric_key_ratchet: ChainKey,
}
Expand All @@ -220,8 +254,13 @@ impl ActiveDoubleRatchet {
fn advance(&self, ratchet_key: RemoteRatchetKey) -> (InactiveDoubleRatchet, ReceiverChain) {
let (root_key, remote_chain) = self.active_ratchet.advance(ratchet_key);

let ratchet = InactiveDoubleRatchet { root_key, ratchet_key };
let receiver_chain = ReceiverChain::new(ratchet_key, remote_chain);
let new_ratchet_count = self.ratchet_count.advance();
let ratchet = InactiveDoubleRatchet {
root_key,
ratchet_key,
ratchet_count: new_ratchet_count.clone(),
};
let receiver_chain = ReceiverChain::new(ratchet_key, remote_chain, new_ratchet_count);

(ratchet, receiver_chain)
}
Expand All @@ -239,9 +278,50 @@ impl Debug for ActiveDoubleRatchet {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let active_ratchet_public_key: RatchetPublicKey = self.active_ratchet.ratchet_key().into();
f.debug_struct("ActiveDoubleRatchet")
.field("ratchet_count", &self.ratchet_count)
.field("parent_ratchet_key", &self.parent_ratchet_key)
.field("ratchet_key", &active_ratchet_public_key)
.field("chain_index", &self.symmetric_key_ratchet.index())
.finish_non_exhaustive()
}
}

/// The number of times the ratchet has been advanced, `i`.
///
/// This starts at 0 for the first prekey messages from Alice to Bob,
/// increments to 1 when Bob replies, and then increments each time the
/// conversation changes direction.
///
/// It may be unknown, if the ratchet was restored from a pickle
/// which didn't track it.
#[derive(Serialize, Deserialize, Clone)]
pub enum RatchetCount {
Known(u64),
Unknown(()),
}

impl RatchetCount {
pub fn new() -> RatchetCount {
RatchetCount::Known(0)
}

pub fn unknown() -> RatchetCount {
RatchetCount::Unknown(())
}

pub fn advance(&self) -> RatchetCount {
match self {
RatchetCount::Known(count) => RatchetCount::Known(count + 1),
RatchetCount::Unknown(_) => RatchetCount::Unknown(()),
}
}
}

impl Debug for RatchetCount {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
RatchetCount::Known(count) => f.write_fmt(format_args!("{}", count)),
RatchetCount::Unknown(_) => f.write_str("<unknown>"),
}
}
}
14 changes: 9 additions & 5 deletions src/olm/session/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@ use super::{
#[cfg(feature = "low-level-api")]
use crate::hazmat::olm::MessageKey;
use crate::{
olm::messages::{Message, OlmMessage, PreKeyMessage},
olm::{
messages::{Message, OlmMessage, PreKeyMessage},
session::double_ratchet::RatchetCount,
},
utilities::{pickle, unpickle},
Curve25519PublicKey, PickleError,
};
Expand Down Expand Up @@ -193,8 +196,9 @@ impl Session {
let root_key = RemoteRootKey::new(root_key);
let remote_chain_key = RemoteChainKey::new(remote_chain_key);

let local_ratchet = DoubleRatchet::inactive(root_key, remote_ratchet_key);
let remote_ratchet = ReceiverChain::new(remote_ratchet_key, remote_chain_key);
let local_ratchet = DoubleRatchet::inactive_from_prekey_data(root_key, remote_ratchet_key);
let remote_ratchet =
ReceiverChain::new(remote_ratchet_key, remote_chain_key, RatchetCount::new());

let mut ratchet_store = ChainStore::new();
ratchet_store.push(remote_ratchet);
Expand Down Expand Up @@ -360,7 +364,7 @@ impl Session {
chain.chain_key_index,
);

ReceiverChain::new(ratchet_key, chain_key)
ReceiverChain::new(ratchet_key, chain_key, RatchetCount::unknown())
}
}

Expand Down Expand Up @@ -444,7 +448,7 @@ impl Session {
config: SessionConfig::version_1(),
})
} else if let Some(chain) = receiving_chains.get(0) {
let sending_ratchet = DoubleRatchet::inactive(
let sending_ratchet = DoubleRatchet::inactive_from_libolm_pickle(
RemoteRootKey::new(pickle.root_key.clone()),
chain.ratchet_key(),
);
Expand Down
22 changes: 19 additions & 3 deletions src/olm/session/receiver_chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ use super::{
chain_key::RemoteChainKey, message_key::RemoteMessageKey, ratchet::RemoteRatchetKey,
DecryptionError,
};
use crate::olm::{messages::Message, session_config::Version, SessionConfig};
use crate::olm::{
messages::Message, session::double_ratchet::RatchetCount, session_config::Version,
SessionConfig,
};

pub(crate) const MAX_MESSAGE_GAP: u64 = 2000;
pub(crate) const MAX_MESSAGE_KEYS: usize = 40;
Expand Down Expand Up @@ -111,13 +114,21 @@ pub(super) struct ReceiverChain {
///
/// This allows us to handle out-of-order messages.
skipped_message_keys: MessageKeyStore,

/// The number of times `i` the ratchet was advanced before this chain.
///
/// This is not required to implement the algorithm: it is maintained solely
/// for diagnostic output.
#[serde(default = "RatchetCount::unknown")]
ratchet_count: RatchetCount,
}

impl Debug for ReceiverChain {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let Self { ratchet_key, hkdf_ratchet, skipped_message_keys } = self;
let Self { ratchet_count, ratchet_key, hkdf_ratchet, skipped_message_keys } = self;

f.debug_struct("ReceiverChain")
.field("ratchet_count", &self.ratchet_count)
.field("ratchet_key", &ratchet_key)
.field("chain_index", &hkdf_ratchet.chain_index())
.field("skipped_message_keys", &skipped_message_keys.inner)
Expand All @@ -126,11 +137,16 @@ impl Debug for ReceiverChain {
}

impl ReceiverChain {
pub fn new(ratchet_key: RemoteRatchetKey, chain_key: RemoteChainKey) -> Self {
pub fn new(
ratchet_key: RemoteRatchetKey,
chain_key: RemoteChainKey,
ratchet_count: RatchetCount,
) -> Self {
ReceiverChain {
ratchet_key,
hkdf_ratchet: chain_key,
skipped_message_keys: Default::default(),
ratchet_count,
}
}

Expand Down

0 comments on commit 8662b17

Please sign in to comment.