Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(api)!: Remove base64 expectation for messages in the API #176

Merged
merged 1 commit into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/olm/account/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ impl Account {
}

/// Sign the given message using our Ed25519 fingerprint key.
pub fn sign(&self, message: &str) -> Ed25519Signature {
self.signing_key.sign(message.as_bytes())
pub fn sign(&self, message: impl AsRef<[u8]>) -> Ed25519Signature {
self.signing_key.sign(message.as_ref())
}

/// Get the maximum number of one-time keys the client should keep on the
Expand Down Expand Up @@ -1076,7 +1076,7 @@ mod test {
#[allow(clippy::redundant_clone)]
let signing_key_clone = account_with_expanded_key.signing_key.clone();
signing_key_clone.sign("You met with a terrible fate, haven’t you?".as_bytes());
account_with_expanded_key.sign("You met with a terrible fate, haven’t you?");
account_with_expanded_key.sign("You met with a terrible fate, haven’t you?".as_bytes());

Ok(())
}
Expand Down Expand Up @@ -1146,7 +1146,7 @@ mod test {
let vodozemac_pickle = account.to_libolm_pickle(key).unwrap();
let _ = Account::from_libolm_pickle(&vodozemac_pickle, key).unwrap();

let vodozemac_signature = account.sign(message);
let vodozemac_signature = account.sign(message.as_bytes());
let olm_signature = Ed25519Signature::from_base64(&olm_signature)
.expect("We should be able to parse a signature produced by libolm");
account
Expand Down
43 changes: 24 additions & 19 deletions src/olm/messages/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub use message::Message;
pub use pre_key::PreKeyMessage;
use serde::{Deserialize, Serialize};

use crate::DecodeError;
use crate::{base64_decode, base64_encode, DecodeError};

/// Enum over the different Olm message types.
///
Expand Down Expand Up @@ -67,9 +67,8 @@ impl Serialize for OlmMessage {
where
S: serde::Serializer,
{
let (message_type, ciphertext) = self.clone().to_parts();

let message = MessageSerdeHelper { message_type, ciphertext };
let (message_type, ciphertext) = self.to_parts();
let message = MessageSerdeHelper { message_type, ciphertext: base64_encode(ciphertext) };

message.serialize(serializer)
}
Expand All @@ -78,18 +77,19 @@ impl Serialize for OlmMessage {
impl<'de> Deserialize<'de> for OlmMessage {
fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
let value = MessageSerdeHelper::deserialize(d)?;
let ciphertext_bytes = base64_decode(value.ciphertext).map_err(serde::de::Error::custom)?;

OlmMessage::from_parts(value.message_type, &value.ciphertext)
OlmMessage::from_parts(value.message_type, ciphertext_bytes.as_slice())
.map_err(serde::de::Error::custom)
}
}

impl OlmMessage {
/// Create a `OlmMessage` from a message type and a ciphertext.
pub fn from_parts(message_type: usize, ciphertext: &str) -> Result<Self, DecodeError> {
pub fn from_parts(message_type: usize, ciphertext: &[u8]) -> Result<Self, DecodeError> {
match message_type {
0 => Ok(Self::PreKey(PreKeyMessage::try_from(ciphertext)?)),
1 => Ok(Self::Normal(Message::try_from(ciphertext)?)),
0 => Ok(Self::PreKey(PreKeyMessage::from_bytes(ciphertext)?)),
1 => Ok(Self::Normal(Message::from_bytes(ciphertext)?)),
m => Err(DecodeError::MessageType(m)),
}
}
Expand All @@ -110,14 +110,13 @@ impl OlmMessage {
}
}

/// Convert the `OlmMessage` into a message type, and base64 encoded message
/// tuple.
pub fn to_parts(self) -> (usize, String) {
/// Convert the `OlmMessage` into a message type, and message bytes tuple.
pub fn to_parts(&self) -> (usize, Vec<u8>) {
let message_type = self.message_type();

match self {
OlmMessage::Normal(m) => (message_type.into(), m.to_base64()),
OlmMessage::PreKey(m) => (message_type.into(), m.to_base64()),
OlmMessage::Normal(m) => (message_type.into(), m.to_bytes()),
OlmMessage::PreKey(m) => (message_type.into(), m.to_bytes()),
}
}
}
Expand Down Expand Up @@ -156,8 +155,10 @@ use olm_rs::session::OlmMessage as LibolmMessage;
impl From<LibolmMessage> for OlmMessage {
fn from(other: LibolmMessage) -> Self {
let (message_type, ciphertext) = other.to_tuple();
let ciphertext_bytes = base64_decode(ciphertext).expect("Can't decode base64");

Self::from_parts(message_type.into(), &ciphertext).expect("Can't decode a libolm message")
Self::from_parts(message_type.into(), ciphertext_bytes.as_slice())
.expect("Can't decode a libolm message")
}
}

Expand Down Expand Up @@ -247,27 +248,31 @@ mod tests {

#[test]
fn from_parts() -> Result<()> {
let message = OlmMessage::from_parts(0, PRE_KEY_MESSAGE)?;
let message = OlmMessage::from_parts(0, base64_decode(PRE_KEY_MESSAGE)?.as_slice())?;
assert_matches!(message, OlmMessage::PreKey(_));
assert_eq!(
message.message_type(),
MessageType::PreKey,
"Expected message to be recognized as a pre-key Olm message."
);
assert_eq!(message.message(), PRE_KEY_MESSAGE_CIPHERTEXT);
assert_eq!(message.to_parts(), (0, PRE_KEY_MESSAGE.to_string()), "Roundtrip not identity.");
assert_eq!(
message.to_parts(),
(0, base64_decode(PRE_KEY_MESSAGE)?),
"Roundtrip not identity."
);

let message = OlmMessage::from_parts(1, MESSAGE)?;
let message = OlmMessage::from_parts(1, base64_decode(MESSAGE)?.as_slice())?;
assert_matches!(message, OlmMessage::Normal(_));
assert_eq!(
message.message_type(),
MessageType::Normal,
"Expected message to be recognized as a normal Olm message."
);
assert_eq!(message.message(), MESSAGE_CIPHERTEXT);
assert_eq!(message.to_parts(), (1, MESSAGE.to_string()), "Roundtrip not identity.");
assert_eq!(message.to_parts(), (1, base64_decode(MESSAGE)?), "Roundtrip not identity.");

OlmMessage::from_parts(3, PRE_KEY_MESSAGE)
OlmMessage::from_parts(3, base64_decode(PRE_KEY_MESSAGE)?.as_slice())
.expect_err("Unknown message types can't be parsed");

Ok(())
Expand Down