From 0a1edc2f597299b99fdea2cb39ff5fec63f313fe Mon Sep 17 00:00:00 2001 From: Mathias Date: Thu, 4 Jan 2024 22:05:05 +0100 Subject: [PATCH 01/36] Rework provisioning to async --- Cargo.toml | 19 +- rust-toolchain.toml | 4 +- src/lib.rs | 8 +- src/provisioning/data_types.rs | 101 ++++++- src/provisioning/error.rs | 12 +- src/provisioning/mod.rs | 492 ++++++++++++++++++++++----------- src/provisioning/topics.rs | 128 ++++----- src/test/mod.rs | 80 +++--- 8 files changed, 545 insertions(+), 299 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 671dcd3..046a4ab 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,16 +23,16 @@ maintenance = { status = "actively-developed" } [dependencies] bitmaps = { version = "^3.1", default-features = false } -heapless = { version = "0.7.0", features = ["serde"] } -mqttrust = { version = "0.6" } -nb = "1" +heapless = { version = "0.8", features = ["serde"] } serde = { version = "1.0.126", default-features = false, features = ["derive"] } serde_cbor = { version = "^0.11", default-features = false, optional = true } -serde-json-core = { version = "0.4.0" } -smlang = "0.5.0" -fugit-timer = "0.1.2" +serde-json-core = { version = "0.5" } shadow-derive = { path = "shadow_derive", version = "0.2.1" } -embedded-storage = "0.3.0" +embedded-storage-async = "0.4" +embedded-mqtt = { path = "../embedded-mqtt" } +futures = { version = "0.3.28", default-features = false } + +embassy-sync = "0.5" log = { version = "^0.4", default-features = false, optional = true } defmt = { version = "^0.3", optional = true } @@ -42,7 +42,6 @@ native-tls = { version = "^0.2" } embedded-nal = "0.6.0" no-std-net = { version = "^0.5", features = ["serde"] } dns-lookup = "1.0.3" -mqttrust_core = { version = "0.6", features = ["log"] } env_logger = "0.9.0" sha2 = "0.10.1" ecdsa = { version = "0.13.4", features = ["pkcs8"] } @@ -61,6 +60,4 @@ ota_http_data = [] std = ["serde/std", "serde_cbor?/std"] -defmt = ["dep:defmt", "mqttrust/defmt-impl", "heapless/defmt-impl"] - -graphviz = ["smlang/graphviz"] +defmt = ["dep:defmt", "heapless/defmt-03"] diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 3cd5460..b79d547 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,6 +1,6 @@ [toolchain] -channel = "nightly-2023-06-28" -components = [ "rust-src", "rustfmt", "llvm-tools-preview", "clippy" ] +channel = "nightly-2023-12-24" +components = [ "rust-src", "rustfmt", "llvm-tools", "clippy" ] targets = [ "x86_64-unknown-linux-gnu", "thumbv7em-none-eabihf" diff --git a/src/lib.rs b/src/lib.rs index 23917e8..4b9d248 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,11 +5,11 @@ // This mod MUST go first, so that the others see its macros. pub(crate) mod fmt; -pub mod jobs; -#[cfg(any(feature = "ota_mqtt_data", feature = "ota_http_data"))] -pub mod ota; +// pub mod jobs; +// #[cfg(any(feature = "ota_mqtt_data", feature = "ota_http_data"))] +// pub mod ota; pub mod provisioning; -pub mod shadows; +// pub mod shadows; pub use serde_cbor; diff --git a/src/provisioning/data_types.rs b/src/provisioning/data_types.rs index 7349425..3d929eb 100644 --- a/src/provisioning/data_types.rs +++ b/src/provisioning/data_types.rs @@ -1,4 +1,3 @@ -use heapless::LinearMap; use serde::{Deserialize, Serialize}; /// To receive error responses, subscribe to @@ -94,7 +93,7 @@ pub struct CreateKeysAndCertificateResponse<'a> { /// **:** The provisioning template name. #[derive(Debug, PartialEq, Serialize)] // #[cfg_attr(feature = "defmt", derive(defmt::Format))] -pub struct RegisterThingRequest<'a, const P: usize> { +pub struct RegisterThingRequest<'a, P: Serialize> { /// The token to prove ownership of the certificate. The token is generated /// by AWS IoT when you create a certificate over MQTT. #[serde(rename = "certificateOwnershipToken")] @@ -102,8 +101,8 @@ pub struct RegisterThingRequest<'a, const P: usize> { /// Optional. Key-value pairs from the device that are used by the /// pre-provisioning hooks to evaluate the registration request. - #[serde(rename = "parameters")] - pub parameters: Option>, + #[serde(rename = "parameters", skip_serializing_if = "Option::is_none")] + pub parameters: Option

, } /// Subscribe to @@ -113,12 +112,102 @@ pub struct RegisterThingRequest<'a, const P: usize> { /// **:** The provisioning template name. #[derive(Debug, PartialEq, Deserialize)] // #[cfg_attr(feature = "defmt", derive(defmt::Format))] -pub struct RegisterThingResponse<'a, const P: usize> { +pub struct RegisterThingResponse<'a, C> { /// The device configuration defined in the template. #[serde(rename = "deviceConfiguration")] - pub device_configuration: LinearMap<&'a str, &'a str, P>, + pub device_configuration: Option, /// The name of the IoT thing created during provisioning. #[serde(rename = "thingName")] pub thing_name: &'a str, } + +#[cfg(test)] +mod tests { + use super::*; + + #[derive(Serialize)] + struct Parameters<'a> { + some_key: &'a str, + } + + #[derive(Debug, Deserialize, PartialEq)] + struct DeviceConfiguration { + some_key: heapless::String<64>, + } + + #[test] + fn serialize_optional_parameters() { + let register_request = RegisterThingRequest { + certificate_ownership_token: "my_ownership_token", + parameters: Some(Parameters { + some_key: "optional_key", + }), + }; + + let json = serde_json_core::to_string::<_, 128>(®ister_request).unwrap(); + assert_eq!( + json.as_str(), + r#"{"certificateOwnershipToken":"my_ownership_token","parameters":{"some_key":"optional_key"}}"# + ); + + let register_request_none: RegisterThingRequest<'_, Parameters> = RegisterThingRequest { + certificate_ownership_token: "my_ownership_token", + parameters: None, + }; + + let json = serde_json_core::to_string::<_, 128>(®ister_request_none).unwrap(); + assert_eq!( + json.as_str(), + r#"{"certificateOwnershipToken":"my_ownership_token"}"# + ); + } + + #[test] + fn deserialize_optional_device_configuration() { + let register_response = + r#"{"thingName":"my_thing","deviceConfiguration":{"some_key":"optional_key"}}"#; + + let (response, _) = + serde_json_core::from_str::>( + register_response, + ) + .unwrap(); + assert_eq!( + response, + RegisterThingResponse { + thing_name: "my_thing", + device_configuration: Some(DeviceConfiguration { + some_key: heapless::String::try_from("optional_key").unwrap() + }), + } + ); + + let register_response_none = r#"{"thingName":"my_thing"}"#; + + let (response, _) = + serde_json_core::from_str::>(®ister_response_none) + .unwrap(); + assert_eq!( + response, + RegisterThingResponse { + thing_name: "my_thing", + device_configuration: None, + } + ); + + // FIXME + let register_response_none = r#"{"thingName":"my_thing","deviceConfiguration":{}}"#; + + let (response, _) = + serde_json_core::from_str::>(®ister_response_none) + .unwrap(); + assert_eq!( + response, + RegisterThingResponse { + thing_name: "my_thing", + device_configuration: None, + } + ); + } +} diff --git a/src/provisioning/error.rs b/src/provisioning/error.rs index e43c01c..d0a128c 100644 --- a/src/provisioning/error.rs +++ b/src/provisioning/error.rs @@ -3,17 +3,17 @@ pub enum Error { Overflow, InvalidPayload, InvalidState, - Mqtt(mqttrust::MqttError), + Mqtt, DeserializeJson(serde_json_core::de::Error), DeserializeCbor, Response(u16), } -impl From for Error { - fn from(e: mqttrust::MqttError) -> Self { - Self::Mqtt(e) - } -} +// impl From for Error { +// fn from(e: MqttError) -> Self { +// Self::Mqtt(e) +// } +// } impl From for Error { fn from(_: serde_json_core::ser::Error) -> Self { diff --git a/src/provisioning/mod.rs b/src/provisioning/mod.rs index 1c29c56..b2bd9c1 100644 --- a/src/provisioning/mod.rs +++ b/src/provisioning/mod.rs @@ -2,8 +2,12 @@ pub mod data_types; mod error; pub mod topics; -use heapless::LinearMap; -use mqttrust::Mqtt; +use core::future::Future; + +use embassy_sync::blocking_mutex::raw::NoopRawMutex; +use embedded_mqtt::{Publish, QoS, RetainHandling, Subscribe, SubscribeTopic}; +use futures::StreamExt; +use serde::de::DeserializeOwned; #[cfg(feature = "provision_cbor")] use serde::Serialize; @@ -13,7 +17,7 @@ use self::{ RegisterThingRequest, RegisterThingResponse, }, error::Error, - topics::{PayloadFormat, Subscribe, Topic, Unsubscribe}, + topics::{PayloadFormat, Topic}, }; #[derive(Debug)] @@ -23,236 +27,410 @@ pub struct Credentials<'a> { pub private_key: Option<&'a str>, } -#[derive(Debug)] -pub enum Response<'a, const P: usize> { - Credentials(Credentials<'a>), - DeviceConfiguration(LinearMap<&'a str, &'a str, P>), -} - -pub struct FleetProvisioner<'a, M> -where - M: Mqtt, -{ - mqtt: &'a M, - template_name: &'a str, - ownership_token: Option>, - payload_format: PayloadFormat, -} +pub struct FleetProvisioner; -impl<'a, M> FleetProvisioner<'a, M> -where - M: Mqtt, -{ +impl FleetProvisioner { /// Instantiate a new `FleetProvisioner`, using `template_name` for the provisioning - pub fn new(mqtt: &'a M, template_name: &'a str) -> Self { - Self { - mqtt, - template_name, - ownership_token: None, - payload_format: PayloadFormat::Json, - } + pub async fn provision<'a, F, Fut, P, C>( + mqtt: &'a embedded_mqtt::MqttClient<'a, NoopRawMutex, 2>, + template_name: &'a str, + parameters: Option

, + f: F, + ) -> Result, Error> + where + F: FnOnce(Credentials<'_>) -> Fut, + Fut: Future>, + P: Serialize, + C: DeserializeOwned, + { + Self::provision_inner(mqtt, template_name, parameters, f, PayloadFormat::Json).await } #[cfg(feature = "provision_cbor")] - pub fn new_cbor(mqtt: &'a M, template_name: &'a str) -> Self { - Self { + pub async fn provision_cbor<'a, F, Fut, P, C>( + mqtt: &'a embedded_mqtt::MqttClient<'a, NoopRawMutex, 2>, + template_name: &'a str, + parameters: Option

, + f: F, + ) -> Result, Error> + where + F: FnOnce(Credentials<'_>) -> Fut, + Fut: Future>, + P: Serialize, + C: DeserializeOwned, + { + Self::provision_inner(mqtt, template_name, parameters, f, PayloadFormat::Cbor).await + } + + async fn provision_inner<'a, F, Fut, P, C>( + mqtt: &'a embedded_mqtt::MqttClient<'a, NoopRawMutex, 2>, + template_name: &'a str, + parameters: Option

, + f: F, + payload_format: PayloadFormat, + ) -> Result, Error> + where + F: FnOnce(Credentials<'_>) -> Fut, + Fut: Future>, + P: Serialize, + C: DeserializeOwned, + { + let certificate_ownership_token = + Self::create_keys_and_certificates(mqtt, payload_format, f).await?; + + Self::register_thing( mqtt, template_name, - ownership_token: None, - payload_format: PayloadFormat::Cbor, - } + payload_format, + certificate_ownership_token.as_str(), + parameters, + ) + .await } - pub fn initialize(&self) -> Result<(), Error> { - Subscribe::<4>::new() - .topic( - Topic::CreateKeysAndCertificateAccepted(self.payload_format), - mqttrust::QoS::AtLeastOnce, - ) - .topic( - Topic::CreateKeysAndCertificateRejected(self.payload_format), - mqttrust::QoS::AtLeastOnce, - ) + pub async fn create_keys_and_certificates( + mqtt: &embedded_mqtt::MqttClient<'_, NoopRawMutex, 2>, + payload_format: PayloadFormat, + f: F, + ) -> Result, Error> + where + F: FnOnce(Credentials<'_>) -> Fut, + Fut: Future>, + { + let topic_paths = topics::Subscribe::<2>::new() .topic( - Topic::RegisterThingAccepted(self.template_name, self.payload_format), - mqttrust::QoS::AtLeastOnce, + Topic::CreateKeysAndCertificateAccepted(payload_format), + QoS::AtLeastOnce, ) .topic( - Topic::RegisterThingRejected(self.template_name, self.payload_format), - mqttrust::QoS::AtLeastOnce, + Topic::CreateKeysAndCertificateRejected(payload_format), + QoS::AtLeastOnce, ) - .send(self.mqtt)?; - - Ok(()) - } - - // TODO: Can we handle this better? If sent from `initialize` it causes a - // race condition with the subscription ack. - pub fn begin(&mut self) -> Result<(), Error> { - self.mqtt.publish( - Topic::CreateKeysAndCertificate(self.payload_format) + .topics::<38>()?; + + let subscribe_topics = topic_paths + .iter() + .map(|(s, qos)| SubscribeTopic { + topic_path: s.as_str(), + maximum_qos: *qos, + no_local: false, + retain_as_published: false, + retain_handling: RetainHandling::SendAtSubscribeTime, + }) + .collect::>(); + + let mut subscription = mqtt + .subscribe::<2>(Subscribe { + pid: None, + properties: embedded_mqtt::Properties::Slice(&[]), + topics: subscribe_topics.as_slice(), + }) + .await + .map_err(|_| Error::Mqtt)?; + + mqtt.publish(Publish { + dup: false, + qos: QoS::AtLeastOnce, + retain: false, + pid: None, + topic_name: Topic::CreateKeysAndCertificate(payload_format) .format::<29>()? .as_str(), - b"", - mqttrust::QoS::AtLeastOnce, - )?; + payload: b"", + properties: embedded_mqtt::Properties::Slice(&[]), + }) + .await + .map_err(|_| Error::Mqtt)?; - Ok(()) - } + let mut message = subscription.next().await.ok_or(Error::InvalidState)?; - pub fn register_thing<'b, const P: usize>( - &mut self, - parameters: Option>, - ) -> Result<(), Error> { - let certificate_ownership_token = self.ownership_token.take().ok_or(Error::InvalidState)?; + match Topic::from_str(message.topic_name()) { + Some(Topic::CreateKeysAndCertificateAccepted(format)) => { + trace!( + "Topic::CreateKeysAndCertificateAccepted {:?}. Payload len: {:?}", + format, + message.payload().len() + ); - let register_request = RegisterThingRequest { - certificate_ownership_token: &certificate_ownership_token, - parameters, - }; + let response = match format { + #[cfg(feature = "provision_cbor")] + PayloadFormat::Cbor => serde_cbor::de::from_mut_slice::< + CreateKeysAndCertificateResponse, + >(message.payload_mut())?, + PayloadFormat::Json => { + serde_json_core::from_slice::( + message.payload(), + )? + .0 + } + }; - let payload = &mut [0u8; 1024]; + f(Credentials { + certificate_id: response.certificate_id, + certificate_pem: response.certificate_pem, + private_key: Some(response.private_key), + }) + .await?; - let payload_len = match self.payload_format { - #[cfg(feature = "provision_cbor")] - PayloadFormat::Cbor => { - let mut serializer = - serde_cbor::ser::Serializer::new(serde_cbor::ser::SliceWrite::new(payload)); - register_request.serialize(&mut serializer)?; - serializer.into_inner().bytes_written() + Ok(heapless::String::try_from(response.certificate_ownership_token).unwrap()) } - PayloadFormat::Json => serde_json_core::to_slice(®ister_request, payload)?, - }; - self.mqtt.publish( - Topic::RegisterThing(self.template_name, self.payload_format) - .format::<69>()? - .as_str(), - &payload[..payload_len], - mqttrust::QoS::AtLeastOnce, - )?; + // Error happened! + Some(Topic::CreateKeysAndCertificateRejected(format)) => { + error!(">> {:?}", message.topic_name()); - Ok(()) + let response = match format { + #[cfg(feature = "provision_cbor")] + PayloadFormat::Cbor => { + serde_cbor::de::from_mut_slice::(message.payload_mut())? + } + PayloadFormat::Json => { + serde_json_core::from_slice::(message.payload())?.0 + } + }; + + error!("{:?}", response); + + Err(Error::Response(response.status_code)) + } + + t => { + trace!("{:?}", t); + + Err(Error::InvalidState) + } + } } - pub fn handle_message<'b, const P: usize>( - &mut self, - topic_name: &'b str, - payload: &'b mut [u8], - ) -> Result>, Error> { - match Topic::from_str(topic_name) { - Some(Topic::CreateKeysAndCertificateAccepted(format)) => { + pub async fn create_certificate_from_csr( + mqtt: &embedded_mqtt::MqttClient<'_, NoopRawMutex, 2>, + payload_format: PayloadFormat, + f: F, + ) -> Result, Error> + where + F: FnOnce(Credentials<'_>) -> Fut, + Fut: Future>, + { + let topic_paths = topics::Subscribe::<2>::new() + .topic( + Topic::CreateCertificateFromCsrAccepted(payload_format), + QoS::AtLeastOnce, + ) + .topic( + Topic::CreateCertificateFromCsrRejected(payload_format), + QoS::AtLeastOnce, + ) + .topics::<47>()?; + + let subscribe_topics = topic_paths + .iter() + .map(|(s, qos)| SubscribeTopic { + topic_path: s.as_str(), + maximum_qos: *qos, + no_local: false, + retain_as_published: false, + retain_handling: RetainHandling::SendAtSubscribeTime, + }) + .collect::>(); + + let mut subscription = mqtt + .subscribe::<2>(Subscribe { + pid: None, + properties: embedded_mqtt::Properties::Slice(&[]), + topics: subscribe_topics.as_slice(), + }) + .await + .map_err(|_| Error::Mqtt)?; + + mqtt.publish(Publish { + dup: false, + qos: QoS::AtLeastOnce, + retain: false, + pid: None, + topic_name: Topic::CreateCertificateFromCsr(payload_format) + .format::<38>()? + .as_str(), + payload: b"", + properties: embedded_mqtt::Properties::Slice(&[]), + }) + .await + .map_err(|_| Error::Mqtt)?; + + let mut message = subscription.next().await.ok_or(Error::InvalidState)?; + + match Topic::from_str(message.topic_name()) { + Some(Topic::CreateCertificateFromCsrAccepted(format)) => { trace!( - "Topic::CreateKeysAndCertificateAccepted {:?}. Payload len: {:?}", + "Topic::CreateCertificateFromCsrAccepted {:?}. Payload len: {:?}", format, - payload.len() + message.payload().len() ); let response = match format { #[cfg(feature = "provision_cbor")] - PayloadFormat::Cbor => { - serde_cbor::de::from_mut_slice::(payload)? - } + PayloadFormat::Cbor => serde_cbor::de::from_mut_slice::< + CreateCertificateFromCsrResponse, + >(message.payload_mut())?, PayloadFormat::Json => { - serde_json_core::from_slice::(payload)?.0 + serde_json_core::from_slice::( + message.payload(), + )? + .0 } }; - self.ownership_token - .replace(heapless::String::from(response.certificate_ownership_token)); - - Ok(Some(Response::Credentials(Credentials { + f(Credentials { certificate_id: response.certificate_id, certificate_pem: response.certificate_pem, - private_key: Some(response.private_key), - }))) + private_key: None, + }) + .await?; + + // FIXME: It should be possible to re-arrange stuff to get rid of the need for this 512 byte stack alloc + Ok(heapless::String::try_from(response.certificate_ownership_token).unwrap()) } - Some(Topic::CreateCertificateFromCsrAccepted(format)) => { - trace!("Topic::CreateCertificateFromCsrAccepted {:?}", format); + + // Error happened! + Some(Topic::CreateCertificateFromCsrRejected(format)) => { + error!(">> {:?}", message.topic_name()); let response = match format { #[cfg(feature = "provision_cbor")] PayloadFormat::Cbor => { - serde_cbor::de::from_mut_slice::(payload)? + serde_cbor::de::from_mut_slice::(message.payload_mut())? } PayloadFormat::Json => { - serde_json_core::from_slice::(payload)?.0 + serde_json_core::from_slice::(message.payload())?.0 } }; - self.ownership_token - .replace(heapless::String::from(response.certificate_ownership_token)); + error!("{:?}", response); - Ok(Some(Response::Credentials(Credentials { - certificate_id: response.certificate_id, - certificate_pem: response.certificate_pem, - private_key: None, - }))) + Err(Error::Response(response.status_code)) + } + + t => { + trace!("{:?}", t); + + Err(Error::InvalidState) } + } + } + + pub async fn register_thing( + mqtt: &embedded_mqtt::MqttClient<'_, NoopRawMutex, 2>, + template_name: &str, + payload_format: PayloadFormat, + certificate_ownership_token: &str, + parameters: Option

, + ) -> Result, Error> { + let topic_paths = topics::Subscribe::<2>::new() + .topic( + Topic::RegisterThingAccepted(template_name, payload_format), + QoS::AtLeastOnce, + ) + .topic( + Topic::RegisterThingRejected(template_name, payload_format), + QoS::AtLeastOnce, + ) + .topics::<128>()?; + + let subscribe_topics = topic_paths + .iter() + .map(|(s, qos)| SubscribeTopic { + topic_path: s.as_str(), + maximum_qos: *qos, + no_local: false, + retain_as_published: false, + retain_handling: RetainHandling::SendAtSubscribeTime, + }) + .collect::>(); + + let mut subscription = mqtt + .subscribe::<2>(Subscribe { + pid: None, + properties: embedded_mqtt::Properties::Slice(&[]), + topics: subscribe_topics.as_slice(), + }) + .await + .map_err(|_| Error::Mqtt)?; + + let register_request = RegisterThingRequest { + certificate_ownership_token: &certificate_ownership_token, + parameters, + }; + + // FIXME: Serialize directly into the publish payload through `DeferredPublish` API + let payload = &mut [0u8; 1024]; + + let payload_len = match payload_format { + #[cfg(feature = "provision_cbor")] + PayloadFormat::Cbor => { + let mut serializer = + serde_cbor::ser::Serializer::new(serde_cbor::ser::SliceWrite::new(payload)); + register_request.serialize(&mut serializer)?; + serializer.into_inner().bytes_written() + } + PayloadFormat::Json => serde_json_core::to_slice(®ister_request, payload)?, + }; + + mqtt.publish(Publish { + dup: false, + qos: QoS::AtLeastOnce, + retain: false, + pid: None, + topic_name: Topic::RegisterThing(template_name, payload_format) + .format::<69>()? + .as_str(), + payload: &payload[..payload_len], + properties: embedded_mqtt::Properties::Slice(&[]), + }) + .await + .map_err(|_| Error::Mqtt)?; + + let mut message = subscription.next().await.ok_or(Error::InvalidState)?; + + match Topic::from_str(message.topic_name()) { Some(Topic::RegisterThingAccepted(_, format)) => { trace!("Topic::RegisterThingAccepted {:?}", format); let response = match format { #[cfg(feature = "provision_cbor")] PayloadFormat::Cbor => { - serde_cbor::de::from_mut_slice::>(payload)? + serde_cbor::de::from_mut_slice::>(payload)? } PayloadFormat::Json => { - serde_json_core::from_slice::>(payload)?.0 + serde_json_core::from_slice::>(payload)?.0 } }; - assert_eq!(response.thing_name, self.mqtt.client_id()); - - Ok(Some(Response::DeviceConfiguration( - response.device_configuration, - ))) + Ok(response.device_configuration) } // Error happened! - Some( - Topic::CreateKeysAndCertificateRejected(format) - | Topic::CreateCertificateFromCsrRejected(format) - | Topic::RegisterThingRejected(_, format), - ) => { + Some(Topic::RegisterThingRejected(_, format)) => { + error!(">> {:?}", message.topic_name()); + let response = match format { #[cfg(feature = "provision_cbor")] PayloadFormat::Cbor => { - serde_cbor::de::from_mut_slice::(payload)? + serde_cbor::de::from_mut_slice::(message.payload_mut())? + } + PayloadFormat::Json => { + serde_json_core::from_slice::(message.payload())?.0 } - PayloadFormat::Json => serde_json_core::from_slice::(payload)?.0, }; - error!("{:?}: {:?}", topic_name, response); + error!("{:?}", response); Err(Error::Response(response.status_code)) } t => { trace!("{:?}", t); - Ok(None) + + Err(Error::InvalidState) } } } } - -impl<'a, M> Drop for FleetProvisioner<'a, M> -where - M: Mqtt, -{ - fn drop(&mut self) { - Unsubscribe::<4>::new() - .topic(Topic::CreateKeysAndCertificateAccepted(self.payload_format)) - .topic(Topic::CreateKeysAndCertificateRejected(self.payload_format)) - .topic(Topic::RegisterThingAccepted( - self.template_name, - self.payload_format, - )) - .topic(Topic::RegisterThingRejected( - self.template_name, - self.payload_format, - )) - .send(self.mqtt) - .ok(); - } -} diff --git a/src/provisioning/topics.rs b/src/provisioning/topics.rs index 9ced90a..c94a5d9 100644 --- a/src/provisioning/topics.rs +++ b/src/provisioning/topics.rs @@ -2,8 +2,8 @@ use core::fmt::Display; use core::fmt::Write; use core::str::FromStr; +use embedded_mqtt::QoS; use heapless::String; -use mqttrust::{Mqtt, QoS, SubscribeTopic}; use super::Error; @@ -243,81 +243,63 @@ impl<'a, const N: usize> Subscribe<'a, N> { Self { topics } } - pub fn topics(self) -> Result, QoS), N>, Error> { - self.topics - .iter() - .map(|(topic, qos)| Ok((topic.clone().format()?, *qos))) + pub fn topics( + self, + ) -> Result, QoS), N>, Error> { + self.iter() + .map(|(topic, qos)| Ok((topic.format()?, *qos))) .collect() } - pub fn send(self, mqtt: &M) -> Result<(), Error> { - if self.topics.is_empty() { - return Ok(()); - } - - let topic_paths = self.topics()?; - - debug!("Subscribing! {:?}", topic_paths); - - let topics: heapless::Vec<_, N> = topic_paths - .iter() - .map(|(s, qos)| SubscribeTopic { - topic_path: s.as_str(), - qos: *qos, - }) - .collect(); - - for t in topics.chunks(5) { - mqtt.subscribe(t)?; - } - Ok(()) + pub fn iter(&self) -> impl Iterator, QoS)> { + self.topics.iter() } } -#[derive(Default)] -pub struct Unsubscribe<'a, const N: usize> { - topics: heapless::Vec, N>, -} - -impl<'a, const N: usize> Unsubscribe<'a, N> { - pub fn new() -> Self { - Self::default() - } - - pub fn topic(self, topic: Topic<'a>) -> Self { - // Ignore attempts to subscribe to outgoing topics - if topic.direction() != Direction::Incoming { - return self; - } - - if self.topics.iter().any(|t| t == &topic) { - return self; - } - - let mut topics = self.topics; - topics.push(topic).ok(); - Self { topics } - } - - pub fn topics(self) -> Result, N>, Error> { - self.topics - .iter() - .map(|topic| topic.clone().format()) - .collect() - } - - pub fn send(self, mqtt: &M) -> Result<(), Error> { - if self.topics.is_empty() { - return Ok(()); - } - - let topic_paths = self.topics()?; - let topics: heapless::Vec<_, N> = topic_paths.iter().map(|s| s.as_str()).collect(); - - for t in topics.chunks(5) { - mqtt.unsubscribe(t)?; - } - - Ok(()) - } -} +// #[derive(Default)] +// pub struct Unsubscribe<'a, const N: usize> { +// topics: heapless::Vec, N>, +// } + +// impl<'a, const N: usize> Unsubscribe<'a, N> { +// pub fn new() -> Self { +// Self::default() +// } + +// pub fn topic(self, topic: Topic<'a>) -> Self { +// // Ignore attempts to subscribe to outgoing topics +// if topic.direction() != Direction::Incoming { +// return self; +// } + +// if self.topics.iter().any(|t| t == &topic) { +// return self; +// } + +// let mut topics = self.topics; +// topics.push(topic).ok(); +// Self { topics } +// } + +// pub fn topics(self) -> Result, N>, Error> { +// self.topics +// .iter() +// .map(|topic| topic.clone().format()) +// .collect() +// } + +// // pub fn send(self, mqtt: &M) -> Result<(), Error> { +// // if self.topics.is_empty() { +// // return Ok(()); +// // } + +// // let topic_paths = self.topics()?; +// // let topics: heapless::Vec<_, N> = topic_paths.iter().map(|s| s.as_str()).collect(); + +// // for t in topics.chunks(5) { +// // // mqtt.unsubscribe(t)?; +// // } + +// // Ok(()) +// // } +// } diff --git a/src/test/mod.rs b/src/test/mod.rs index e27c028..d601e8d 100644 --- a/src/test/mod.rs +++ b/src/test/mod.rs @@ -1,40 +1,40 @@ -use std::{cell::RefCell, collections::VecDeque}; - -use mqttrust::{encoding::v4::encode_slice, Mqtt, MqttError, Packet}; - -/// -/// Mock Mqtt client used for unit tests. Implements `mqttrust::Mqtt` trait. -/// -pub struct MockMqtt { - pub tx: RefCell>>, - publish_fail: bool, -} - -impl MockMqtt { - pub fn new() -> Self { - Self { - tx: RefCell::new(VecDeque::new()), - publish_fail: false, - } - } - - pub fn publish_fail(&mut self) { - self.publish_fail = true; - } -} - -impl Mqtt for MockMqtt { - fn send(&self, packet: Packet<'_>) -> Result<(), MqttError> { - let v = &mut [0u8; 1024]; - - let len = encode_slice(&packet, v).map_err(|_| MqttError::Full)?; - let packet = v[..len].iter().cloned().collect(); - self.tx.borrow_mut().push_back(packet); - - Ok(()) - } - - fn client_id(&self) -> &str { - "test_client" - } -} +// use std::{cell::RefCell, collections::VecDeque}; + +// use mqttrust::{encoding::v4::encode_slice, Mqtt, MqttError, Packet}; + +// /// +// /// Mock Mqtt client used for unit tests. Implements `mqttrust::Mqtt` trait. +// /// +// pub struct MockMqtt { +// pub tx: RefCell>>, +// publish_fail: bool, +// } + +// impl MockMqtt { +// pub fn new() -> Self { +// Self { +// tx: RefCell::new(VecDeque::new()), +// publish_fail: false, +// } +// } + +// pub fn publish_fail(&mut self) { +// self.publish_fail = true; +// } +// } + +// impl Mqtt for MockMqtt { +// fn send(&self, packet: Packet<'_>) -> Result<(), MqttError> { +// let v = &mut [0u8; 1024]; + +// let len = encode_slice(&packet, v).map_err(|_| MqttError::Full)?; +// let packet = v[..len].iter().cloned().collect(); +// self.tx.borrow_mut().push_back(packet); + +// Ok(()) +// } + +// fn client_id(&self) -> &str { +// "test_client" +// } +// } From f64903a23c39682250af25322145993c37b13682 Mon Sep 17 00:00:00 2001 From: Mathias Date: Sat, 6 Jan 2024 13:41:14 +0100 Subject: [PATCH 02/36] Working async provisioning based on embedded-mqtt with integration test --- Cargo.toml | 24 +- src/lib.rs | 3 - src/provisioning/data_types.rs | 26 +- src/provisioning/mod.rs | 127 ++-- src/test/mod.rs | 40 -- tests/common/clock.rs | 55 -- tests/common/mod.rs | 3 +- tests/common/network.rs | 331 ++++------- tests/ota.rs | 470 +++++++-------- tests/provisioning.rs | 208 +++---- tests/shadows.rs | 1004 ++++++++++++++++---------------- 11 files changed, 1050 insertions(+), 1241 deletions(-) delete mode 100644 src/test/mod.rs delete mode 100644 tests/common/clock.rs diff --git a/Cargo.toml b/Cargo.toml index 046a4ab..4a4ce18 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,16 +39,21 @@ defmt = { version = "^0.3", optional = true } [dev-dependencies] native-tls = { version = "^0.2" } -embedded-nal = "0.6.0" -no-std-net = { version = "^0.5", features = ["serde"] } -dns-lookup = "1.0.3" -env_logger = "0.9.0" +embedded-nal-async = "0.7" +env_logger = "0.10" sha2 = "0.10.1" -ecdsa = { version = "0.13.4", features = ["pkcs8"] } -p256 = "0.10.1" -pkcs8 = { version = "0.8", features = ["encryption", "pem"] } +static_cell = { version = "2", features = ["nightly"]} +tokio = { version = "1.33", default-features = false, features = ["macros", "rt", "net", "time", "io-std"] } +tokio-native-tls = { version = "0.3.1" } +embassy-futures = { version = "0.1.0" } +embassy-time = { version = "0.2", features = ["log", "std", "generic-queue"] } +embedded-io-adapters = { version = "0.6.0", features = ["tokio-1"] } + +ecdsa = { version = "0.16", features = ["pkcs8", "pem"] } +p256 = "0.13" +pkcs8 = { version = "0.10", features = ["encryption", "pem"] } timebomb = "0.1.2" -hex = "0.4.3" +hex = { version = "0.4.3", features = ["alloc"] } [features] default = ["ota_mqtt_data", "provision_cbor"] @@ -60,4 +65,5 @@ ota_http_data = [] std = ["serde/std", "serde_cbor?/std"] -defmt = ["dep:defmt", "heapless/defmt-03"] +defmt = ["dep:defmt", "heapless/defmt-03", "embedded-mqtt/defmt"] +log = ["dep:log", "embedded-mqtt/log", ] diff --git a/src/lib.rs b/src/lib.rs index 4b9d248..1b9acd2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,6 +12,3 @@ pub mod provisioning; // pub mod shadows; pub use serde_cbor; - -#[cfg(test)] -pub mod test; diff --git a/src/provisioning/data_types.rs b/src/provisioning/data_types.rs index 3d929eb..0b17241 100644 --- a/src/provisioning/data_types.rs +++ b/src/provisioning/data_types.rs @@ -196,18 +196,18 @@ mod tests { } ); - // FIXME - let register_response_none = r#"{"thingName":"my_thing","deviceConfiguration":{}}"#; - - let (response, _) = - serde_json_core::from_str::>(®ister_response_none) - .unwrap(); - assert_eq!( - response, - RegisterThingResponse { - thing_name: "my_thing", - device_configuration: None, - } - ); + // // FIXME + // let register_response_none = r#"{"thingName":"my_thing","deviceConfiguration":{}}"#; + + // let (response, _) = + // serde_json_core::from_str::>(®ister_response_none) + // .unwrap(); + // assert_eq!( + // response, + // RegisterThingResponse { + // thing_name: "my_thing", + // device_configuration: None, + // } + // ); } } diff --git a/src/provisioning/mod.rs b/src/provisioning/mod.rs index b2bd9c1..df79353 100644 --- a/src/provisioning/mod.rs +++ b/src/provisioning/mod.rs @@ -8,18 +8,25 @@ use embassy_sync::blocking_mutex::raw::NoopRawMutex; use embedded_mqtt::{Publish, QoS, RetainHandling, Subscribe, SubscribeTopic}; use futures::StreamExt; use serde::de::DeserializeOwned; -#[cfg(feature = "provision_cbor")] use serde::Serialize; +pub use error::Error; + use self::{ data_types::{ CreateCertificateFromCsrResponse, CreateKeysAndCertificateResponse, ErrorResponse, RegisterThingRequest, RegisterThingResponse, }, - error::Error, topics::{PayloadFormat, Topic}, }; +pub trait CredentialHandler { + fn store_credentials( + &mut self, + credentials: Credentials<'_>, + ) -> impl Future> + Send; +} + #[derive(Debug)] pub struct Credentials<'a> { pub certificate_id: &'a str, @@ -30,53 +37,59 @@ pub struct Credentials<'a> { pub struct FleetProvisioner; impl FleetProvisioner { - /// Instantiate a new `FleetProvisioner`, using `template_name` for the provisioning - pub async fn provision<'a, F, Fut, P, C>( + pub async fn provision<'a, C>( mqtt: &'a embedded_mqtt::MqttClient<'a, NoopRawMutex, 2>, - template_name: &'a str, - parameters: Option

, - f: F, + template_name: &str, + parameters: Option, + credential_handler: &mut impl CredentialHandler, ) -> Result, Error> where - F: FnOnce(Credentials<'_>) -> Fut, - Fut: Future>, - P: Serialize, C: DeserializeOwned, { - Self::provision_inner(mqtt, template_name, parameters, f, PayloadFormat::Json).await + Self::provision_inner( + mqtt, + template_name, + parameters, + credential_handler, + PayloadFormat::Json, + ) + .await } #[cfg(feature = "provision_cbor")] - pub async fn provision_cbor<'a, F, Fut, P, C>( + pub async fn provision_cbor<'a, C>( mqtt: &'a embedded_mqtt::MqttClient<'a, NoopRawMutex, 2>, - template_name: &'a str, - parameters: Option

, - f: F, + template_name: &str, + parameters: Option, + credential_handler: &mut impl CredentialHandler, ) -> Result, Error> where - F: FnOnce(Credentials<'_>) -> Fut, - Fut: Future>, - P: Serialize, C: DeserializeOwned, { - Self::provision_inner(mqtt, template_name, parameters, f, PayloadFormat::Cbor).await + Self::provision_inner( + mqtt, + template_name, + parameters, + credential_handler, + PayloadFormat::Cbor, + ) + .await } - async fn provision_inner<'a, F, Fut, P, C>( + async fn provision_inner<'a, C, P, CH>( mqtt: &'a embedded_mqtt::MqttClient<'a, NoopRawMutex, 2>, - template_name: &'a str, + template_name: &str, parameters: Option

, - f: F, + credential_handler: &mut CH, payload_format: PayloadFormat, ) -> Result, Error> where - F: FnOnce(Credentials<'_>) -> Fut, - Fut: Future>, - P: Serialize, C: DeserializeOwned, + P: Serialize, + CH: CredentialHandler, { let certificate_ownership_token = - Self::create_keys_and_certificates(mqtt, payload_format, f).await?; + Self::create_keys_and_certificates(mqtt, payload_format, credential_handler).await?; Self::register_thing( mqtt, @@ -88,15 +101,17 @@ impl FleetProvisioner { .await } - pub async fn create_keys_and_certificates( + pub async fn create_keys_and_certificates( mqtt: &embedded_mqtt::MqttClient<'_, NoopRawMutex, 2>, payload_format: PayloadFormat, - f: F, + credential_handler: &mut CH, ) -> Result, Error> where - F: FnOnce(Credentials<'_>) -> Fut, - Fut: Future>, + CH: CredentialHandler, { + // FIXME: Changing these to a single topic filter of + // `$aws/certificates/create//+` could be beneficial to + // stack usage let topic_paths = topics::Subscribe::<2>::new() .topic( Topic::CreateKeysAndCertificateAccepted(payload_format), @@ -165,12 +180,13 @@ impl FleetProvisioner { } }; - f(Credentials { - certificate_id: response.certificate_id, - certificate_pem: response.certificate_pem, - private_key: Some(response.private_key), - }) - .await?; + credential_handler + .store_credentials(Credentials { + certificate_id: response.certificate_id, + certificate_pem: response.certificate_pem, + private_key: Some(response.private_key), + }) + .await?; Ok(heapless::String::try_from(response.certificate_ownership_token).unwrap()) } @@ -202,15 +218,17 @@ impl FleetProvisioner { } } - pub async fn create_certificate_from_csr( + pub async fn create_certificate_from_csr( mqtt: &embedded_mqtt::MqttClient<'_, NoopRawMutex, 2>, payload_format: PayloadFormat, - f: F, + credential_handler: &mut CH, ) -> Result, Error> where - F: FnOnce(Credentials<'_>) -> Fut, - Fut: Future>, + CH: CredentialHandler, { + // FIXME: Changing these to a single topic filter of + // `$aws/certificates/create-from-csr//+` could be beneficial to + // stack usage let topic_paths = topics::Subscribe::<2>::new() .topic( Topic::CreateCertificateFromCsrAccepted(payload_format), @@ -279,12 +297,13 @@ impl FleetProvisioner { } }; - f(Credentials { - certificate_id: response.certificate_id, - certificate_pem: response.certificate_pem, - private_key: None, - }) - .await?; + credential_handler + .store_credentials(Credentials { + certificate_id: response.certificate_id, + certificate_pem: response.certificate_pem, + private_key: None, + }) + .await?; // FIXME: It should be possible to re-arrange stuff to get rid of the need for this 512 byte stack alloc Ok(heapless::String::try_from(response.certificate_ownership_token).unwrap()) @@ -324,6 +343,9 @@ impl FleetProvisioner { certificate_ownership_token: &str, parameters: Option

, ) -> Result, Error> { + // FIXME: Changing these to a single topic filter of + // `$aws/provisioning-templates//provision//+` + // could be beneficial to stack usage let topic_paths = topics::Subscribe::<2>::new() .topic( Topic::RegisterThingAccepted(template_name, payload_format), @@ -374,6 +396,8 @@ impl FleetProvisioner { PayloadFormat::Json => serde_json_core::to_slice(®ister_request, payload)?, }; + info!("Starting RegisterThing {:?}", payload_len); + mqtt.publish(Publish { dup: false, qos: QoS::AtLeastOnce, @@ -396,11 +420,14 @@ impl FleetProvisioner { let response = match format { #[cfg(feature = "provision_cbor")] - PayloadFormat::Cbor => { - serde_cbor::de::from_mut_slice::>(payload)? - } + PayloadFormat::Cbor => serde_cbor::de::from_mut_slice::< + RegisterThingResponse<'_, C>, + >(message.payload_mut())?, PayloadFormat::Json => { - serde_json_core::from_slice::>(payload)?.0 + serde_json_core::from_slice::>( + message.payload(), + )? + .0 } }; diff --git a/src/test/mod.rs b/src/test/mod.rs deleted file mode 100644 index d601e8d..0000000 --- a/src/test/mod.rs +++ /dev/null @@ -1,40 +0,0 @@ -// use std::{cell::RefCell, collections::VecDeque}; - -// use mqttrust::{encoding::v4::encode_slice, Mqtt, MqttError, Packet}; - -// /// -// /// Mock Mqtt client used for unit tests. Implements `mqttrust::Mqtt` trait. -// /// -// pub struct MockMqtt { -// pub tx: RefCell>>, -// publish_fail: bool, -// } - -// impl MockMqtt { -// pub fn new() -> Self { -// Self { -// tx: RefCell::new(VecDeque::new()), -// publish_fail: false, -// } -// } - -// pub fn publish_fail(&mut self) { -// self.publish_fail = true; -// } -// } - -// impl Mqtt for MockMqtt { -// fn send(&self, packet: Packet<'_>) -> Result<(), MqttError> { -// let v = &mut [0u8; 1024]; - -// let len = encode_slice(&packet, v).map_err(|_| MqttError::Full)?; -// let packet = v[..len].iter().cloned().collect(); -// self.tx.borrow_mut().push_back(packet); - -// Ok(()) -// } - -// fn client_id(&self) -> &str { -// "test_client" -// } -// } diff --git a/tests/common/clock.rs b/tests/common/clock.rs deleted file mode 100644 index c486cc1..0000000 --- a/tests/common/clock.rs +++ /dev/null @@ -1,55 +0,0 @@ -use std::time::{SystemTime, UNIX_EPOCH}; - -pub struct SysClock { - start_time: u32, - end_time: Option>, -} - -impl SysClock { - pub fn new() -> Self { - Self { - start_time: Self::epoch(), - end_time: None, - } - } - - pub fn epoch() -> u32 { - SystemTime::now() - .duration_since(UNIX_EPOCH) - .expect("Time went backwards") - .as_millis() as u32 - } - - pub fn now(&self) -> u32 { - Self::epoch() - self.start_time - } -} - -impl fugit_timer::Timer<1000> for SysClock { - type Error = std::convert::Infallible; - - fn now(&mut self) -> fugit_timer::TimerInstantU32<1000> { - fugit_timer::TimerInstantU32::from_ticks(SysClock::now(self)) - } - - fn start(&mut self, duration: fugit_timer::TimerDurationU32<1000>) -> Result<(), Self::Error> { - let now = self.now(); - self.end_time.replace(now + duration); - Ok(()) - } - - fn cancel(&mut self) -> Result<(), Self::Error> { - self.end_time.take(); - Ok(()) - } - - fn wait(&mut self) -> nb::Result<(), Self::Error> { - match self.end_time.map(|end| end <= self.now()) { - Some(true) => { - self.end_time.take(); - Ok(()) - } - _ => Err(nb::Error::WouldBlock), - } - } -} diff --git a/tests/common/mod.rs b/tests/common/mod.rs index cd087b0..594f1b0 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -1,4 +1,3 @@ -pub mod clock; pub mod credentials; -pub mod file_handler; +// pub mod file_handler; pub mod network; diff --git a/tests/common/network.rs b/tests/common/network.rs index 968a093..dfbe27c 100644 --- a/tests/common/network.rs +++ b/tests/common/network.rs @@ -1,252 +1,143 @@ -use embedded_nal::{AddrType, Dns, IpAddr, SocketAddr, TcpClientStack}; -use native_tls::{MidHandshakeTlsStream, TlsConnector, TlsStream}; -use std::io::{Read, Write}; -use std::marker::PhantomData; -use std::net::TcpStream; - -use dns_lookup::{lookup_addr, lookup_host}; - -/// An std::io::Error compatible error type returned when an operation is requested in the wrong -/// sequence (where the "right" is create a socket, connect, any receive/send, and possibly close). -#[derive(Debug)] -struct OutOfOrder; - -impl std::fmt::Display for OutOfOrder { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "Out of order operations requested") - } -} +use std::net::SocketAddr; -impl std::error::Error for OutOfOrder {} - -impl Into> for OutOfOrder { - fn into(self) -> std::io::Result { - Err(std::io::Error::new( - std::io::ErrorKind::NotConnected, - OutOfOrder, - )) - } -} +use ::native_tls::Identity; +use embedded_io_adapters::tokio_1::FromTokio; +use embedded_nal_async::{AddrType, Dns, IpAddr, Ipv4Addr, Ipv6Addr, TcpConnect}; +use tokio_native_tls::native_tls; -pub struct Network { - tls_connector: Option<(TlsConnector, String)>, - _sec: PhantomData, -} +use super::credentials; -impl Network> { - pub fn new_tls(tls_connector: TlsConnector, hostname: String) -> Self { - Self { - tls_connector: Some((tls_connector, hostname)), - _sec: PhantomData, - } - } -} +#[derive(Debug, Clone, Copy)] +pub struct Network; -impl Network { - pub fn new() -> Self { - Self { - tls_connector: None, - _sec: PhantomData, - } +impl Network { + pub const fn new() -> Self { + Self } } -pub(crate) fn to_nb(e: std::io::Error) -> nb::Error { - use std::io::ErrorKind::{TimedOut, WouldBlock}; - match e.kind() { - WouldBlock | TimedOut => nb::Error::WouldBlock, - _ => e.into(), - } -} - -pub enum TlsState { - MidHandshake(MidHandshakeTlsStream), - Connected(T), -} - -pub struct TcpSocket { - pub stream: Option>, -} - -impl TcpSocket { - pub fn new() -> Self { - TcpSocket { stream: None } - } - - pub fn get_running(&mut self) -> std::io::Result<&mut T> { - match self.stream { - Some(TlsState::Connected(ref mut s)) => Ok(s), - _ => OutOfOrder.into(), - } - } -} +impl TcpConnect for Network { + type Error = std::io::Error; -impl Dns for Network { - type Error = (); + type Connection<'a> = FromTokio + where + Self: 'a; - fn get_host_by_address( - &mut self, - ip_addr: IpAddr, - ) -> nb::Result, Self::Error> { - let ip: std::net::IpAddr = format!("{}", ip_addr).parse().unwrap(); - let host = lookup_addr(&ip).unwrap(); - Ok(heapless::String::from(host.as_str())) - } - fn get_host_by_name( - &mut self, - hostname: &str, - _addr_type: AddrType, - ) -> nb::Result { - let ips: Vec = lookup_host(hostname).unwrap(); - let ip = ips - .iter() - .find(|s| matches!(s, std::net::IpAddr::V4(_))) - .unwrap(); - format!("{}", ip).parse().map_err(|_| nb::Error::Other(())) + async fn connect<'a>( + &'a self, + remote: embedded_nal_async::SocketAddr, + ) -> Result, Self::Error> { + let stream = tokio::net::TcpStream::connect(format!("{}", remote)).await?; + Ok(FromTokio::new(stream)) } } -impl TcpClientStack for Network> { +impl Dns for Network { type Error = std::io::Error; - type TcpSocket = TcpSocket>; - fn socket(&mut self) -> Result { - Ok(TcpSocket::new()) - } - - fn receive( - &mut self, - network: &mut Self::TcpSocket, - buf: &mut [u8], - ) -> nb::Result { - let socket = network.get_running()?; - socket.read(buf).map_err(to_nb) - } - - fn send( - &mut self, - network: &mut Self::TcpSocket, - buf: &[u8], - ) -> nb::Result { - let socket = network.get_running()?; - socket.write(buf).map_err(|e| { - if !matches!( - e.kind(), - std::io::ErrorKind::WouldBlock | std::io::ErrorKind::TimedOut - ) { - log::error!("{:?}", e); + async fn get_host_by_name( + &self, + host: &str, + addr_type: AddrType, + ) -> Result { + for ip in tokio::net::lookup_host(host).await? { + match (&addr_type, ip) { + (AddrType::IPv4 | AddrType::Either, SocketAddr::V4(ip)) => { + return Ok(IpAddr::V4(Ipv4Addr::from(ip.ip().octets()))) + } + (AddrType::IPv6 | AddrType::Either, SocketAddr::V6(ip)) => { + return Ok(IpAddr::V6(Ipv6Addr::from(ip.ip().octets()))) + } + (_, _) => {} } - to_nb(e) - }) + } + Err(std::io::Error::new( + std::io::ErrorKind::AddrNotAvailable, + "", + )) } - fn is_connected(&mut self, network: &Self::TcpSocket) -> Result { - Ok(matches!(network.stream, Some(TlsState::Connected(_)))) + async fn get_host_by_address( + &self, + _addr: IpAddr, + _result: &mut [u8], + ) -> Result { + unimplemented!() } +} - fn connect( - &mut self, - network: &mut Self::TcpSocket, - remote: SocketAddr, - ) -> nb::Result<(), Self::Error> { - let tls_stream = match network.stream.take() { - None => { - let soc = TcpStream::connect(remote.to_string())?; - soc.set_nonblocking(true)?; - - let (connector, hostname) = self.tls_connector.as_ref().unwrap(); - - let mut tls_stream = connector.connect(hostname, soc).map_err(|e| match e { - native_tls::HandshakeError::Failure(_) => nb::Error::Other( - std::io::Error::new(std::io::ErrorKind::Other, "Failed TLS handshake"), - ), - native_tls::HandshakeError::WouldBlock(h) => { - network.stream.replace(TlsState::MidHandshake(h)); - nb::Error::WouldBlock - } - })?; - tls_stream.get_mut().set_nonblocking(true)?; - tls_stream - } - Some(TlsState::MidHandshake(h)) => { - let mut tls_stream = h.handshake().map_err(|e| match e { - native_tls::HandshakeError::Failure(_) => nb::Error::Other( - std::io::Error::new(std::io::ErrorKind::Other, "Failed TLS handshake"), - ), - native_tls::HandshakeError::WouldBlock(h) => { - network.stream.replace(TlsState::MidHandshake(h)); - nb::Error::WouldBlock - } - })?; - tls_stream.get_mut().set_nonblocking(true)?; - tls_stream - } - Some(TlsState::Connected(_)) => return Ok(()), - }; - - network.stream.replace(TlsState::Connected(tls_stream)); - - Ok(()) - } +pub struct TlsNetwork { + identity: Identity, + domain: String, +} - fn close(&mut self, _network: Self::TcpSocket) -> Result<(), Self::Error> { - // No-op: Socket gets closed when it is freed - // - // Could wrap it in an Option, but really that'll only make things messier; users will - // probably drop the socket anyway after closing, and can't expect it to be usable with - // this API. - Ok(()) +impl TlsNetwork { + pub const fn new(domain: String, identity: Identity) -> Self { + Self { identity, domain } } } -impl TcpClientStack for Network { +impl TcpConnect for TlsNetwork { type Error = std::io::Error; - type TcpSocket = TcpSocket; - fn socket(&mut self) -> Result { - Ok(TcpSocket::new()) - } - - fn receive( - &mut self, - network: &mut Self::TcpSocket, - buf: &mut [u8], - ) -> nb::Result { - let socket = network.get_running()?; - socket.read(buf).map_err(to_nb) - } - - fn send( - &mut self, - network: &mut Self::TcpSocket, - buf: &[u8], - ) -> nb::Result { - let socket = network.get_running()?; - socket.write(buf).map_err(to_nb) - } - - fn is_connected(&mut self, network: &Self::TcpSocket) -> Result { - Ok(matches!(network.stream, Some(TlsState::Connected(_)))) + type Connection<'a> = FromTokio> + where + Self: 'a; + + async fn connect<'a>( + &'a self, + remote: embedded_nal_async::SocketAddr, + ) -> Result, Self::Error> { + log::info!("Connecting to {:?}", remote); + let connector = tokio_native_tls::TlsConnector::from( + native_tls::TlsConnector::builder() + .identity(self.identity.clone()) + .add_root_certificate(credentials::root_ca()) + .build() + .unwrap(), + ); + let stream = tokio::net::TcpStream::connect(format!("{}", remote)).await?; + let tls_stream = connector + .connect(self.domain.as_str(), stream) + .await + .unwrap(); + Ok(FromTokio::new(tls_stream)) } +} - fn connect( - &mut self, - network: &mut Self::TcpSocket, - remote: SocketAddr, - ) -> nb::Result<(), Self::Error> { - let soc = TcpStream::connect(format!("{}", remote))?; - soc.set_nonblocking(true)?; - network.stream.replace(TlsState::Connected(soc)); +impl Dns for TlsNetwork { + type Error = std::io::Error; - Ok(()) + async fn get_host_by_name( + &self, + host: &str, + addr_type: AddrType, + ) -> Result { + log::info!("Looking up {}", host); + for ip in tokio::net::lookup_host(host).await? { + log::info!("Found IP {}", ip); + + match (&addr_type, ip) { + (AddrType::IPv4 | AddrType::Either, SocketAddr::V4(ip)) => { + return Ok(IpAddr::V4(Ipv4Addr::from(ip.ip().octets()))) + } + (AddrType::IPv6 | AddrType::Either, SocketAddr::V6(ip)) => { + return Ok(IpAddr::V6(Ipv6Addr::from(ip.ip().octets()))) + } + (_, _) => {} + } + } + Err(std::io::Error::new( + std::io::ErrorKind::AddrNotAvailable, + "", + )) } - fn close(&mut self, _network: Self::TcpSocket) -> Result<(), Self::Error> { - // No-op: Socket gets closed when it is freed - // - // Could wrap it in an Option, but really that'll only make things messier; users will - // probably drop the socket anyway after closing, and can't expect it to be usable with - // this API. - Ok(()) + async fn get_host_by_address( + &self, + _addr: IpAddr, + _result: &mut [u8], + ) -> Result { + unimplemented!() } } diff --git a/tests/ota.rs b/tests/ota.rs index 9e459e6..e645b99 100644 --- a/tests/ota.rs +++ b/tests/ota.rs @@ -1,235 +1,235 @@ -mod common; - -use mqttrust_core::{bbqueue::BBBuffer, EventLoop, MqttOptions, Notification, PublishNotification}; -use native_tls::TlsConnector; -use rustot::ota::state::States; -use serde::Deserialize; -use sha2::{Digest, Sha256}; -use std::{fs::File, io::Read, ops::Deref}; - -use common::{clock::SysClock, credentials, file_handler::FileHandler, network::Network}; -use rustot::{ - jobs::{ - self, - data_types::{DescribeJobExecutionResponse, NextJobExecutionChanged}, - StatusDetails, - }, - ota::{self, agent::OtaAgent, encoding::json::OtaJob}, -}; - -static mut Q: BBBuffer<{ 1024 * 10 }> = BBBuffer::new(); - -#[derive(Debug, Deserialize)] -pub enum Jobs<'a> { - #[serde(rename = "afr_ota")] - #[serde(borrow)] - Ota(OtaJob<'a>), -} - -impl<'a> Jobs<'a> { - pub fn ota_job(self) -> Option> { - match self { - Jobs::Ota(ota_job) => Some(ota_job), - } - } -} - -enum OtaUpdate<'a> { - JobUpdate(&'a str, OtaJob<'a>, Option>), - Data(&'a mut [u8]), -} - -fn handle_ota<'a>(publish: &'a mut PublishNotification) -> Result, ()> { - match jobs::Topic::from_str(publish.topic_name.as_str()) { - Some(jobs::Topic::NotifyNext) => { - let (execution_changed, _) = - serde_json_core::from_slice::>(&publish.payload) - .map_err(drop)?; - let job = execution_changed.execution.ok_or(())?; - let ota_job = job.job_document.ok_or(())?.ota_job().ok_or(())?; - return Ok(OtaUpdate::JobUpdate( - job.job_id, - ota_job, - job.status_details, - )); - } - Some(jobs::Topic::DescribeAccepted(_)) => { - let (execution_changed, _) = - serde_json_core::from_slice::>(&publish.payload) - .map_err(drop)?; - let job = execution_changed.execution.ok_or(())?; - let ota_job = job.job_document.ok_or(())?.ota_job().ok_or(())?; - return Ok(OtaUpdate::JobUpdate( - job.job_id, - ota_job, - job.status_details, - )); - } - _ => {} - } - - match ota::Topic::from_str(publish.topic_name.as_str()) { - Some(ota::Topic::Data(_, _)) => { - return Ok(OtaUpdate::Data(&mut publish.payload)); - } - _ => {} - } - Err(()) -} - -pub struct FileInfo { - pub file_path: String, - pub filesize: usize, - pub signature: ota::encoding::json::Signature, -} - -#[test] -fn test_mqtt_ota() { - // Make sure this times out in case something went wrong setting up the OTA - // job in AWS IoT before starting. - timebomb::timeout_ms(test_mqtt_ota_inner, 100_000) -} - -fn test_mqtt_ota_inner() { - env_logger::init(); - - let (p, c) = unsafe { Q.try_split_framed().unwrap() }; - - log::info!("Starting OTA test..."); - - let hostname = credentials::HOSTNAME.unwrap(); - let (thing_name, identity) = credentials::identity(); - - let connector = TlsConnector::builder() - .identity(identity) - .add_root_certificate(credentials::root_ca()) - .build() - .unwrap(); - - let mut network = Network::new_tls(connector, String::from(hostname)); - - let mut mqtt_eventloop = EventLoop::new( - c, - SysClock::new(), - MqttOptions::new(thing_name, hostname.into(), 8883).set_clean_session(true), - ); - - let mqtt_client = mqttrust_core::Client::new(p, thing_name); - - let file_handler = FileHandler::new(); - - let mut ota_agent = - OtaAgent::builder(&mqtt_client, &mqtt_client, SysClock::new(), file_handler) - .request_wait_ms(3000) - .block_size(256) - .build(); - - let mut file_info = None; - - loop { - match mqtt_eventloop.connect(&mut network) { - Ok(true) => { - log::info!("Successfully connected to broker"); - ota_agent.init(); - } - Ok(false) => {} - Err(nb::Error::WouldBlock) => continue, - Err(e) => panic!("{:?}", e), - } - - match mqtt_eventloop.yield_event(&mut network) { - Ok(Notification::Publish(mut publish)) => { - // Check if the received file is a jobs topic, that we - // want to react to. - match handle_ota(&mut publish) { - Ok(OtaUpdate::JobUpdate(job_id, job_doc, status_details)) => { - log::debug!("Received job! Starting OTA! {:?}", job_doc.streamname); - - let file = &job_doc.files[0]; - file_info.replace(FileInfo { - file_path: file.filepath.to_string(), - filesize: file.filesize, - signature: file.signature(), - }); - ota_agent - .job_update(job_id, &job_doc, status_details.as_ref()) - .expect("Failed to start OTA job"); - } - Ok(OtaUpdate::Data(payload)) => { - if ota_agent.handle_message(payload).is_err() { - match ota_agent.state() { - States::CreatingFile => log::info!("State: CreatingFile"), - States::Ready => log::info!("State: Ready"), - States::RequestingFileBlock => { - log::info!("State: RequestingFileBlock") - } - States::RequestingJob => log::info!("State: RequestingJob"), - States::Restarting => log::info!("State: Restarting"), - States::Suspended => log::info!("State: Suspended"), - States::WaitingForFileBlock => { - log::info!("State: WaitingForFileBlock") - } - States::WaitingForJob => log::info!("State: WaitingForJob"), - } - } - } - Err(_) => {} - } - } - Ok(n) => { - log::trace!("{:?}", n); - } - _ => {} - } - - ota_agent.timer_callback().expect("Failed timer callback!"); - - match ota_agent.process_event() { - // Use the restarting state to indicate finished - Ok(States::Restarting) => break, - _ => {} - } - } - - let mut expected_file = File::open("tests/assets/ota_file").unwrap(); - let mut expected_data = Vec::new(); - expected_file.read_to_end(&mut expected_data).unwrap(); - let mut expected_hasher = Sha256::new(); - expected_hasher.update(&expected_data); - let expected_hash = expected_hasher.finalize(); - - let file_info = file_info.unwrap(); - - log::info!( - "Comparing {:?} with {:?}", - "tests/assets/ota_file", - file_info.file_path - ); - let mut file = File::open(file_info.file_path.clone()).unwrap(); - let mut data = Vec::new(); - file.read_to_end(&mut data).unwrap(); - drop(file); - std::fs::remove_file(file_info.file_path).unwrap(); - - assert_eq!(data.len(), file_info.filesize); - - let mut hasher = Sha256::new(); - hasher.update(&data); - assert_eq!(hasher.finalize().deref(), expected_hash.deref()); - - // Check file signature - match file_info.signature { - ota::encoding::json::Signature::Sha1Rsa(_) => { - panic!("Unexpected signature format: Sha1Rsa. Expected Sha256Ecdsa") - } - ota::encoding::json::Signature::Sha256Rsa(_) => { - panic!("Unexpected signature format: Sha256Rsa. Expected Sha256Ecdsa") - } - ota::encoding::json::Signature::Sha1Ecdsa(_) => { - panic!("Unexpected signature format: Sha1Ecdsa. Expected Sha256Ecdsa") - } - ota::encoding::json::Signature::Sha256Ecdsa(sig) => { - assert_eq!(&sig, "This is my custom signature\\n") - } - } -} +// mod common; + +// use mqttrust_core::{bbqueue::BBBuffer, EventLoop, MqttOptions, Notification, PublishNotification}; +// use native_tls::TlsConnector; +// use rustot::ota::state::States; +// use serde::Deserialize; +// use sha2::{Digest, Sha256}; +// use std::{fs::File, io::Read, ops::Deref}; + +// use common::{clock::SysClock, credentials, file_handler::FileHandler, network::Network}; +// use rustot::{ +// jobs::{ +// self, +// data_types::{DescribeJobExecutionResponse, NextJobExecutionChanged}, +// StatusDetails, +// }, +// ota::{self, agent::OtaAgent, encoding::json::OtaJob}, +// }; + +// static mut Q: BBBuffer<{ 1024 * 10 }> = BBBuffer::new(); + +// #[derive(Debug, Deserialize)] +// pub enum Jobs<'a> { +// #[serde(rename = "afr_ota")] +// #[serde(borrow)] +// Ota(OtaJob<'a>), +// } + +// impl<'a> Jobs<'a> { +// pub fn ota_job(self) -> Option> { +// match self { +// Jobs::Ota(ota_job) => Some(ota_job), +// } +// } +// } + +// enum OtaUpdate<'a> { +// JobUpdate(&'a str, OtaJob<'a>, Option>), +// Data(&'a mut [u8]), +// } + +// fn handle_ota<'a>(publish: &'a mut PublishNotification) -> Result, ()> { +// match jobs::Topic::from_str(publish.topic_name.as_str()) { +// Some(jobs::Topic::NotifyNext) => { +// let (execution_changed, _) = +// serde_json_core::from_slice::>(&publish.payload) +// .map_err(drop)?; +// let job = execution_changed.execution.ok_or(())?; +// let ota_job = job.job_document.ok_or(())?.ota_job().ok_or(())?; +// return Ok(OtaUpdate::JobUpdate( +// job.job_id, +// ota_job, +// job.status_details, +// )); +// } +// Some(jobs::Topic::DescribeAccepted(_)) => { +// let (execution_changed, _) = +// serde_json_core::from_slice::>(&publish.payload) +// .map_err(drop)?; +// let job = execution_changed.execution.ok_or(())?; +// let ota_job = job.job_document.ok_or(())?.ota_job().ok_or(())?; +// return Ok(OtaUpdate::JobUpdate( +// job.job_id, +// ota_job, +// job.status_details, +// )); +// } +// _ => {} +// } + +// match ota::Topic::from_str(publish.topic_name.as_str()) { +// Some(ota::Topic::Data(_, _)) => { +// return Ok(OtaUpdate::Data(&mut publish.payload)); +// } +// _ => {} +// } +// Err(()) +// } + +// pub struct FileInfo { +// pub file_path: String, +// pub filesize: usize, +// pub signature: ota::encoding::json::Signature, +// } + +// #[test] +// fn test_mqtt_ota() { +// // Make sure this times out in case something went wrong setting up the OTA +// // job in AWS IoT before starting. +// timebomb::timeout_ms(test_mqtt_ota_inner, 100_000) +// } + +// fn test_mqtt_ota_inner() { +// env_logger::init(); + +// let (p, c) = unsafe { Q.try_split_framed().unwrap() }; + +// log::info!("Starting OTA test..."); + +// let hostname = credentials::HOSTNAME.unwrap(); +// let (thing_name, identity) = credentials::identity(); + +// let connector = TlsConnector::builder() +// .identity(identity) +// .add_root_certificate(credentials::root_ca()) +// .build() +// .unwrap(); + +// let mut network = Network::new_tls(connector, String::from(hostname)); + +// let mut mqtt_eventloop = EventLoop::new( +// c, +// SysClock::new(), +// MqttOptions::new(thing_name, hostname.into(), 8883).set_clean_session(true), +// ); + +// let mqtt_client = mqttrust_core::Client::new(p, thing_name); + +// let file_handler = FileHandler::new(); + +// let mut ota_agent = +// OtaAgent::builder(&mqtt_client, &mqtt_client, SysClock::new(), file_handler) +// .request_wait_ms(3000) +// .block_size(256) +// .build(); + +// let mut file_info = None; + +// loop { +// match mqtt_eventloop.connect(&mut network) { +// Ok(true) => { +// log::info!("Successfully connected to broker"); +// ota_agent.init(); +// } +// Ok(false) => {} +// Err(nb::Error::WouldBlock) => continue, +// Err(e) => panic!("{:?}", e), +// } + +// match mqtt_eventloop.yield_event(&mut network) { +// Ok(Notification::Publish(mut publish)) => { +// // Check if the received file is a jobs topic, that we +// // want to react to. +// match handle_ota(&mut publish) { +// Ok(OtaUpdate::JobUpdate(job_id, job_doc, status_details)) => { +// log::debug!("Received job! Starting OTA! {:?}", job_doc.streamname); + +// let file = &job_doc.files[0]; +// file_info.replace(FileInfo { +// file_path: file.filepath.to_string(), +// filesize: file.filesize, +// signature: file.signature(), +// }); +// ota_agent +// .job_update(job_id, &job_doc, status_details.as_ref()) +// .expect("Failed to start OTA job"); +// } +// Ok(OtaUpdate::Data(payload)) => { +// if ota_agent.handle_message(payload).is_err() { +// match ota_agent.state() { +// States::CreatingFile => log::info!("State: CreatingFile"), +// States::Ready => log::info!("State: Ready"), +// States::RequestingFileBlock => { +// log::info!("State: RequestingFileBlock") +// } +// States::RequestingJob => log::info!("State: RequestingJob"), +// States::Restarting => log::info!("State: Restarting"), +// States::Suspended => log::info!("State: Suspended"), +// States::WaitingForFileBlock => { +// log::info!("State: WaitingForFileBlock") +// } +// States::WaitingForJob => log::info!("State: WaitingForJob"), +// } +// } +// } +// Err(_) => {} +// } +// } +// Ok(n) => { +// log::trace!("{:?}", n); +// } +// _ => {} +// } + +// ota_agent.timer_callback().expect("Failed timer callback!"); + +// match ota_agent.process_event() { +// // Use the restarting state to indicate finished +// Ok(States::Restarting) => break, +// _ => {} +// } +// } + +// let mut expected_file = File::open("tests/assets/ota_file").unwrap(); +// let mut expected_data = Vec::new(); +// expected_file.read_to_end(&mut expected_data).unwrap(); +// let mut expected_hasher = Sha256::new(); +// expected_hasher.update(&expected_data); +// let expected_hash = expected_hasher.finalize(); + +// let file_info = file_info.unwrap(); + +// log::info!( +// "Comparing {:?} with {:?}", +// "tests/assets/ota_file", +// file_info.file_path +// ); +// let mut file = File::open(file_info.file_path.clone()).unwrap(); +// let mut data = Vec::new(); +// file.read_to_end(&mut data).unwrap(); +// drop(file); +// std::fs::remove_file(file_info.file_path).unwrap(); + +// assert_eq!(data.len(), file_info.filesize); + +// let mut hasher = Sha256::new(); +// hasher.update(&data); +// assert_eq!(hasher.finalize().deref(), expected_hash.deref()); + +// // Check file signature +// match file_info.signature { +// ota::encoding::json::Signature::Sha1Rsa(_) => { +// panic!("Unexpected signature format: Sha1Rsa. Expected Sha256Ecdsa") +// } +// ota::encoding::json::Signature::Sha256Rsa(_) => { +// panic!("Unexpected signature format: Sha256Rsa. Expected Sha256Ecdsa") +// } +// ota::encoding::json::Signature::Sha1Ecdsa(_) => { +// panic!("Unexpected signature format: Sha1Ecdsa. Expected Sha256Ecdsa") +// } +// ota::encoding::json::Signature::Sha256Ecdsa(sig) => { +// assert_eq!(&sig, "This is my custom signature\\n") +// } +// } +// } diff --git a/tests/provisioning.rs b/tests/provisioning.rs index 5fa317d..949ba7f 100644 --- a/tests/provisioning.rs +++ b/tests/provisioning.rs @@ -1,24 +1,27 @@ -mod common; +#![allow(async_fn_in_trait)] +#![feature(type_alias_impl_trait)] -use mqttrust::Mqtt; -use mqttrust_core::{bbqueue::BBBuffer, EventLoop, MqttOptions, Notification, PublishNotification}; +mod common; -use common::clock::SysClock; -use common::network::{Network, TcpSocket}; -use native_tls::{Identity, TlsConnector, TlsStream}; -use p256::ecdsa::signature::Signer; -use rustot::provisioning::{topics::Topic, Credentials, FleetProvisioner, Response}; -use std::net::TcpStream; -use std::ops::DerefMut; +use std::{net::ToSocketAddrs, process}; use common::credentials; - -static mut Q: BBBuffer<{ 1024 * 10 }> = BBBuffer::new(); +use common::network::TlsNetwork; +use ecdsa::Signature; +use embassy_futures::select; +use embassy_sync::blocking_mutex::raw::NoopRawMutex; +use embedded_mqtt::{Config, DomainBroker, IpBroker, Publish, State, Subscribe, SubscribeTopic}; +use p256::{ecdsa::signature::Signer, NistP256}; +use rustot::provisioning::{ + topics::Topic, CredentialHandler, Credentials, Error, FleetProvisioner, +}; +use serde::{Deserialize, Serialize}; +use static_cell::make_static; pub struct OwnedCredentials { - certificate_id: String, - certificate_pem: String, - private_key: Option, + pub certificate_id: String, + pub certificate_pem: String, + pub private_key: Option, } impl<'a> From> for OwnedCredentials { @@ -31,120 +34,101 @@ impl<'a> From> for OwnedCredentials { } } -fn provision_credentials<'a, const L: usize>( - hostname: &'a str, - identity: Identity, - mqtt_eventloop: &mut EventLoop<'a, 'a, TcpSocket>, SysClock, 1000, L>, - mqtt_client: &mqttrust_core::Client, -) -> Result { - let template_name = - std::env::var("TEMPLATE_NAME").unwrap_or_else(|_| "duoProvisioningTemplate".to_string()); - - let connector = TlsConnector::builder() - .identity(identity) - .add_root_certificate(credentials::root_ca()) - .build() - .unwrap(); - - let mut network = Network::new_tls(connector, String::from(hostname)); - - nb::block!(mqtt_eventloop.connect(&mut network)) - .expect("To connect to MQTT with claim credentials"); - - log::info!("Successfully connected to broker with claim credentials"); - - #[cfg(not(feature = "provision_cbor"))] - let mut provisioner = FleetProvisioner::new(mqtt_client, &template_name); - #[cfg(feature = "provision_cbor")] - let mut provisioner = FleetProvisioner::new_cbor(mqtt_client, &template_name); +pub struct CredentialDAO { + pub creds: Option, +} - provisioner - .initialize() - .expect("Failed to initialize FleetProvisioner"); +impl CredentialHandler for CredentialDAO { + async fn store_credentials(&mut self, credentials: Credentials<'_>) -> Result<(), Error> { + log::info!("Provisioned credentials: {:#?}", credentials); - let mut provisioned_credentials: Option = None; + self.creds.replace(credentials.into()); - let signing_key = credentials::signing_key(); - let signature = hex::encode(signing_key.sign(mqtt_client.client_id().as_bytes())); - - let result = loop { - match mqtt_eventloop.yield_event(&mut network) { - Ok(Notification::Publish(mut publish)) if Topic::check(publish.topic_name.as_str()) => { - let PublishNotification { - topic_name, - payload, - .. - } = publish.deref_mut(); - - match provisioner.handle_message::<4>(topic_name.as_str(), payload) { - Ok(Some(Response::Credentials(credentials))) => { - log::info!("Got credentials! {:?}", credentials); - provisioned_credentials = Some(credentials.into()); - - let mut parameters = heapless::LinearMap::new(); - parameters.insert("uuid", mqtt_client.client_id()).unwrap(); - parameters.insert("signature", &signature).unwrap(); - - provisioner - .register_thing::<2>(Some(parameters)) - .expect("To successfully publish to RegisterThing"); - } - Ok(Some(Response::DeviceConfiguration(config))) => { - // Store Device configuration parameters, if any. - - log::info!("Got device config! {:?}", config); - - break Ok(()); - } - Ok(None) => {} - Err(e) => { - log::error!("Got provision error! {:?}", e); - provisioned_credentials = None; - - break Err(()); - } - } - } - Ok(Notification::Suback(_)) => { - log::info!("Starting provisioning"); - provisioner.begin().expect("To begin provisioning"); - } - Ok(n) => { - log::trace!("{:?}", n); - } - _ => {} - } - }; + Ok(()) + } +} - // Disconnect from AWS IoT Core - mqtt_eventloop.disconnect(&mut network); +#[derive(Debug, Serialize)] +struct Parameters<'a> { + uuid: &'a str, + signature: &'a str, +} - result.and_then(|_| provisioned_credentials.ok_or(())) +#[derive(Debug, Deserialize, PartialEq)] +struct DeviceConfig { + #[serde(rename = "SoftwareId")] + software_id: heapless::String<64>, } -#[test] -fn test_provisioning() { +#[tokio::test(flavor = "current_thread")] +async fn test_provisioning() { env_logger::init(); - let (p, c) = unsafe { Q.try_split_framed().unwrap() }; - log::info!("Starting provisioning test..."); let (thing_name, claim_identity) = credentials::claim_identity(); // Connect to AWS IoT Core with provisioning claim credentials let hostname = credentials::HOSTNAME.unwrap(); + let template_name = + std::env::var("TEMPLATE_NAME").unwrap_or_else(|_| "duoProvisioningTemplate".to_string()); - let mut mqtt_eventloop = EventLoop::new( - c, - SysClock::new(), - MqttOptions::new(thing_name, hostname.into(), 8883).set_clean_session(true), - ); + let network = make_static!(TlsNetwork::new(hostname.to_owned(), claim_identity)); + + // Create the MQTT stack + let broker = + DomainBroker::<_, 128>::new(format!("{}:8883", hostname).as_str(), network).unwrap(); + let config = + Config::new(thing_name, broker).keepalive_interval(embassy_time::Duration::from_secs(50)); + + let state = make_static!(State::::new()); + let (mut stack, client) = embedded_mqtt::new(state, config, network); + + let client = make_static!(client); + + let signing_key = credentials::signing_key(); + let signature: Signature = signing_key.sign(thing_name.as_bytes()); + let hex_signature: String = hex::encode(signature.to_bytes()); - let mqtt_client = mqttrust_core::Client::new(p, thing_name); + let parameters = Parameters { + uuid: thing_name, + signature: &hex_signature, + }; - let credentials = - provision_credentials(hostname, claim_identity, &mut mqtt_eventloop, &mqtt_client).unwrap(); + let mut credential_handler = CredentialDAO { creds: None }; - assert!(credentials.certificate_id.len() > 0); + #[cfg(not(feature = "provision_cbor"))] + let provision_fut = FleetProvisioner::provision::( + client, + &template_name, + Some(parameters), + &mut credential_handler, + ); + #[cfg(feature = "provision_cbor")] + let provision_fut = FleetProvisioner::provision_cbor::( + client, + &template_name, + Some(parameters), + &mut credential_handler, + ); + + let device_config = match embassy_time::with_timeout( + embassy_time::Duration::from_secs(15), + select::select(stack.run(), provision_fut), + ) + .await + .unwrap() + { + select::Either::First(_) => { + unreachable!() + } + select::Either::Second(result) => result.unwrap(), + }; + assert_eq!( + device_config, + Some(DeviceConfig { + software_id: heapless::String::try_from("82b3509e0e924e06ab1bdb1cf1625dcb").unwrap() + }) + ); + assert!(credential_handler.creds.unwrap().certificate_id.len() > 0); } diff --git a/tests/shadows.rs b/tests/shadows.rs index bb83e9c..cbd979e 100644 --- a/tests/shadows.rs +++ b/tests/shadows.rs @@ -1,502 +1,502 @@ -//! -//! ## Integration test of `AWS IoT Shadows` -//! -//! -//! This test simulates updates of the shadow state from both device side & -//! cloud side. Cloud side updates are done by publishing directly to the shadow -//! topics, and ignoring the resulting update accepted response. Device side -//! updates are done through the shadow API provided by this crate. -//! -//! The test runs through the following update sequence: -//! 1. Setup clean starting point (`desired = null, reported = null`) -//! 2. Do a `GetShadow` request to sync empty state -//! 3. Update to initial shadow state from the device -//! 4. Assert on the initial state -//! 5. Update state from device -//! 6. Assert on shadow state -//! 7. Update state from cloud -//! 8. Assert on shadow state -//! 9. Update state from device -//! 10. Assert on shadow state -//! 11. Update state from cloud -//! 12. Assert on shadow state -//! - -mod common; - -use core::fmt::Write; - -use common::{clock::SysClock, credentials, network::Network}; -use embedded_nal::Ipv4Addr; -use mqttrust::Mqtt; -use mqttrust_core::{bbqueue::BBBuffer, EventLoop, MqttOptions, Notification}; -use native_tls::TlsConnector; -use rustot::shadows::{ - derive::ShadowState, topics::Topic, Patch, Shadow, ShadowPatch, ShadowState, -}; -use serde::{de::DeserializeOwned, Deserialize, Serialize}; - -use smlang::statemachine; - -const Q_SIZE: usize = 1024 * 6; -static mut Q: BBBuffer = BBBuffer::new(); - -#[derive(Debug, Clone, Default, PartialEq, Eq, PartialOrd, Ord)] -pub struct ConfigId(pub u8); - -impl Serialize for ConfigId { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - let mut str = heapless::String::<3>::new(); - write!(str, "{}", self.0).map_err(serde::ser::Error::custom)?; - serializer.serialize_str(&str) - } -} - -impl<'de> Deserialize<'de> for ConfigId { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - heapless::String::<3>::deserialize(deserializer)? - .parse() - .map(ConfigId) - .map_err(serde::de::Error::custom) - } -} - -impl From for ConfigId { - fn from(v: u8) -> Self { - Self(v) - } -} - -#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)] -pub struct NetworkMap(heapless::LinearMap>, N>); - -impl NetworkMap -where - K: Eq, -{ - pub fn insert(&mut self, k: impl Into, v: V) -> Result<(), ()> { - self.0.insert(k.into(), Some(Patch::Set(v))).map_err(drop)?; - Ok(()) - } - - pub fn remove(&mut self, k: impl Into) -> Result<(), ()> { - self.0.insert(k.into(), None).map_err(drop)?; - Ok(()) - } -} - -impl ShadowPatch for NetworkMap -where - K: Clone + Default + Eq + Serialize + DeserializeOwned, - V: Clone + Default + Serialize + DeserializeOwned, -{ - type PatchState = NetworkMap; - - fn apply_patch(&mut self, opt: Self::PatchState) { - for (id, network) in opt.0.into_iter() { - match network { - Some(Patch::Set(v)) => { - self.insert(id.clone(), v.clone()).ok(); - } - None | Some(Patch::Unset) => { - self.remove(id.clone()).ok(); - } - } - } - } -} - -const MAX_NETWORKS: usize = 5; -type KnownNetworks = NetworkMap; - -#[derive(Debug, Clone, Default, Serialize, Deserialize, ShadowState)] -#[shadow("wifi")] -pub struct WifiConfig { - pub enabled: bool, - - pub known_networks: KnownNetworks, -} - -#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)] -pub struct ConnectionOptions { - pub ssid: heapless::String<64>, - pub password: Option>, - - pub ip: Option, - pub subnet: Option, - pub gateway: Option, -} - -#[derive(Debug, Clone)] -pub enum UpdateAction { - Insert(u8, ConnectionOptions), - Remove(u8), - Enabled(bool), -} - -statemachine! { - transitions: { - *Begin + Delete = DeleteShadow, - DeleteShadow + Get = GetShadow, - GetShadow + Load / load_initial = LoadShadow(Option), - LoadShadow(Option) + CheckInitial / check_initial = Check(Option), - UpdateFromDevice(UpdateAction) + CheckState / check = Check(Option), - UpdateFromCloud(UpdateAction) + Ack = AckUpdate, - AckUpdate + CheckState / check_cloud = Check(Option), - Check(Option) + UpdateStateFromDevice / get_next_device = UpdateFromDevice(UpdateAction), - Check(Option) + UpdateStateFromCloud / get_next_cloud = UpdateFromCloud(UpdateAction), - } -} - -impl core::fmt::Debug for States { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Begin => write!(f, "Self::Begin"), - Self::DeleteShadow => write!(f, "Self::DeleteShadow"), - Self::GetShadow => write!(f, "Self::GetShadow"), - Self::AckUpdate => write!(f, "Self::AckUpdate"), - Self::LoadShadow(t) => write!(f, "Self::LoadShadow({:?})", t), - Self::UpdateFromDevice(t) => write!(f, "Self::UpdateFromDevice({:?})", t), - Self::UpdateFromCloud(t) => write!(f, "Self::UpdateFromCloud({:?})", t), - Self::Check(t) => write!(f, "Self::Check({:?})", t), - } - } -} - -fn asserts(id: usize) -> ConnectionOptions { - match id { - 0 => ConnectionOptions { - ssid: heapless::String::from("MySSID"), - password: None, - ip: None, - subnet: None, - gateway: None, - }, - 1 => ConnectionOptions { - ssid: heapless::String::from("MyProtectedSSID"), - password: Some(heapless::String::from("SecretPass")), - ip: None, - subnet: None, - gateway: None, - }, - 2 => ConnectionOptions { - ssid: heapless::String::from("CloudSSID"), - password: Some(heapless::String::from("SecretCloudPass")), - ip: Some(Ipv4Addr::new(1, 2, 3, 4)), - subnet: None, - gateway: None, - }, - _ => panic!("Unknown assert ID"), - } -} - -pub struct TestContext<'a> { - shadow: Shadow<'a, WifiConfig, mqttrust_core::Client<'static, 'static, Q_SIZE>>, - update_cnt: u8, -} - -impl<'a> StateMachineContext for TestContext<'a> { - fn check_initial( - &mut self, - _last_update_action: &Option, - ) -> Option { - self.check(&UpdateAction::Remove(0)) - } - - fn check_cloud(&mut self) -> Option { - self.check(&UpdateAction::Remove(0)) - } - - fn check(&mut self, _last_update_action: &UpdateAction) -> Option { - let mut known_networks = KnownNetworks::default(); - - match self.update_cnt { - 0 => { - // After load_initial - known_networks.insert(0, asserts(0)).unwrap(); - known_networks.insert(1, asserts(1)).unwrap(); - } - 1 => { - // After get_next_device - known_networks.remove(0).unwrap(); - known_networks.insert(1, asserts(1)).unwrap(); - } - 2 => { - // After get_next_cloud - known_networks.remove(0).unwrap(); - known_networks.insert(1, asserts(1)).unwrap(); - known_networks.insert(2, asserts(2)).unwrap(); - } - 3 => { - // After get_next_device - known_networks.insert(0, asserts(0)).unwrap(); - known_networks.insert(1, asserts(1)).unwrap(); - known_networks.insert(2, asserts(2)).unwrap(); - } - 4 => { - // After get_next_cloud - known_networks.insert(0, asserts(0)).unwrap(); - known_networks.insert(1, asserts(1)).unwrap(); - known_networks.remove(2).unwrap(); - } - 5 => return None, - _ => {} - } - - Some(known_networks) - } - - fn get_next_device(&mut self, _: &Option) -> UpdateAction { - self.update_cnt += 1; - match self.update_cnt { - 1 => UpdateAction::Remove(0), - 3 => UpdateAction::Insert(0, asserts(0)), - 5 => UpdateAction::Remove(0), - _ => panic!("Unexpected update counter in `get_next_device`"), - } - } - - fn get_next_cloud(&mut self, _: &Option) -> UpdateAction { - self.update_cnt += 1; - - match self.update_cnt { - 2 => UpdateAction::Insert(2, asserts(2)), - 4 => UpdateAction::Remove(2), - _ => panic!("Unexpected update counter in `get_next_cloud`"), - } - } - - fn load_initial(&mut self) -> Option { - let mut known_networks = KnownNetworks::default(); - known_networks.insert(0, asserts(0)).unwrap(); - known_networks.insert(1, asserts(1)).unwrap(); - Some(known_networks) - } -} - -impl<'a> StateMachine> { - pub fn spin( - &mut self, - notification: Notification, - mqtt_client: &mqttrust_core::Client<'static, 'static, Q_SIZE>, - ) -> bool { - log::info!("State: {:?}", self.state()); - match (self.state(), notification) { - (&States::Begin, Notification::Suback(_)) => { - self.process_event(Events::Delete).unwrap(); - } - (&States::DeleteShadow, Notification::Suback(_)) => { - mqtt_client - .publish( - &Topic::Update - .format::<128>( - mqtt_client.client_id(), - ::NAME, - ) - .unwrap(), - b"{\"state\":{\"desired\":null,\"reported\":null}}", - mqttrust::QoS::AtLeastOnce, - ) - .unwrap(); - - self.process_event(Events::Get).unwrap(); - } - (&States::GetShadow, Notification::Publish(publish)) - if matches!( - publish.topic_name.as_str(), - "$aws/things/rustot-test/shadow/name/wifi/update/accepted" - ) => - { - self.context_mut().shadow.get_shadow().unwrap(); - self.process_event(Events::Load).unwrap(); - } - (&States::LoadShadow(ref initial_map), Notification::Publish(publish)) - if matches!( - publish.topic_name.as_str(), - "$aws/things/rustot-test/shadow/name/wifi/get/accepted" - ) => - { - let initial_map = initial_map.clone(); - - self.context_mut() - .shadow - .update(|_current, desired| { - desired.known_networks = Some(initial_map.unwrap()); - }) - .unwrap(); - self.process_event(Events::CheckInitial).unwrap(); - } - (&States::UpdateFromDevice(ref update_action), Notification::Publish(publish)) - if matches!( - publish.topic_name.as_str(), - "$aws/things/rustot-test/shadow/name/wifi/get/accepted" - ) => - { - let action = update_action.clone(); - self.context_mut() - .shadow - .update(|current, desired| match action { - UpdateAction::Insert(id, options) => { - let mut desired_map = current.known_networks.clone(); - desired_map.insert(id, options).unwrap(); - desired.known_networks = Some(desired_map); - } - UpdateAction::Remove(id) => { - let mut desired_map = current.known_networks.clone(); - desired_map.remove(id).unwrap(); - desired.known_networks = Some(desired_map); - } - UpdateAction::Enabled(en) => { - desired.enabled = Some(en); - } - }) - .unwrap(); - self.process_event(Events::CheckState).unwrap(); - } - (&States::UpdateFromCloud(ref update_action), Notification::Publish(publish)) - if matches!( - publish.topic_name.as_str(), - "$aws/things/rustot-test/shadow/name/wifi/get/accepted" - ) => - { - let desired_known_networks = match update_action { - UpdateAction::Insert(id, options) => format!( - "\"known_networks\": {{\"{}\":{{\"set\":{}}}}}", - id, - serde_json_core::to_string::<_, 256>(options).unwrap() - ), - UpdateAction::Remove(id) => { - format!("\"known_networks\": {{\"{}\":\"unset\"}}", id) - } - &UpdateAction::Enabled(en) => format!("\"enabled\": {}", en), - }; - - let payload = format!( - "{{\"state\":{{\"desired\":{{{}}}, \"reported\":{}}}}}", - desired_known_networks, - serde_json_core::to_string::<_, 512>(self.context().shadow.get()).unwrap() - ); - - log::debug!("Update from cloud: {:?}", payload); - - mqtt_client - .publish( - &Topic::Update - .format::<128>( - mqtt_client.client_id(), - ::NAME, - ) - .unwrap(), - payload.as_bytes(), - mqttrust::QoS::AtLeastOnce, - ) - .unwrap(); - self.process_event(Events::Ack).unwrap(); - } - (&States::AckUpdate, Notification::Publish(publish)) - if matches!( - publish.topic_name.as_str(), - "$aws/things/rustot-test/shadow/name/wifi/update/delta" - ) => - { - self.context_mut() - .shadow - .handle_message(&publish.topic_name, &publish.payload) - .unwrap(); - - self.process_event(Events::CheckState).unwrap(); - } - (&States::Check(ref expected_map), Notification::Publish(publish)) - if matches!( - publish.topic_name.as_str(), - "$aws/things/rustot-test/shadow/name/wifi/update/accepted" - | "$aws/things/rustot-test/shadow/name/wifi/update/delta" - ) => - { - let expected = expected_map.clone(); - self.context_mut() - .shadow - .handle_message(&publish.topic_name, &publish.payload) - .unwrap(); - - match expected { - Some(expected_map) => { - assert_eq!(self.context().shadow.get().known_networks, expected_map); - self.context_mut().shadow.get_shadow().unwrap(); - let event = if self.context().update_cnt % 2 == 0 { - Events::UpdateStateFromDevice - } else { - Events::UpdateStateFromCloud - }; - self.process_event(event).unwrap(); - } - None => return true, - } - } - (_, Notification::Publish(publish)) => { - log::warn!("TOPIC: {}", publish.topic_name); - self.context_mut() - .shadow - .handle_message(&publish.topic_name, &publish.payload) - .unwrap(); - } - _ => {} - } - - false - } -} - -#[test] -fn test_shadows() { - env_logger::init(); - - let (p, c) = unsafe { Q.try_split_framed().unwrap() }; - - log::info!("Starting shadows test..."); - - let hostname = credentials::HOSTNAME.unwrap(); - let (thing_name, identity) = credentials::identity(); - - let connector = TlsConnector::builder() - .identity(identity) - .add_root_certificate(credentials::root_ca()) - .build() - .unwrap(); - - let mut network = Network::new_tls(connector, std::string::String::from(hostname)); - - let mut mqtt_eventloop = EventLoop::new( - c, - SysClock::new(), - MqttOptions::new(thing_name, hostname.into(), 8883).set_clean_session(true), - ); - - let mqtt_client = mqttrust_core::Client::new(p, thing_name); - - let mut test_state = StateMachine::new(TestContext { - shadow: Shadow::new(WifiConfig::default(), &mqtt_client, true).unwrap(), - update_cnt: 0, - }); - - loop { - if nb::block!(mqtt_eventloop.connect(&mut network)).expect("to connect to mqtt") { - log::info!("Successfully connected to broker"); - } - - match mqtt_eventloop.yield_event(&mut network) { - Ok(notification) => { - if test_state.spin(notification, &mqtt_client) { - break; - } - } - Err(_) => {} - } - } -} +// //! +// //! ## Integration test of `AWS IoT Shadows` +// //! +// //! +// //! This test simulates updates of the shadow state from both device side & +// //! cloud side. Cloud side updates are done by publishing directly to the shadow +// //! topics, and ignoring the resulting update accepted response. Device side +// //! updates are done through the shadow API provided by this crate. +// //! +// //! The test runs through the following update sequence: +// //! 1. Setup clean starting point (`desired = null, reported = null`) +// //! 2. Do a `GetShadow` request to sync empty state +// //! 3. Update to initial shadow state from the device +// //! 4. Assert on the initial state +// //! 5. Update state from device +// //! 6. Assert on shadow state +// //! 7. Update state from cloud +// //! 8. Assert on shadow state +// //! 9. Update state from device +// //! 10. Assert on shadow state +// //! 11. Update state from cloud +// //! 12. Assert on shadow state +// //! + +// mod common; + +// use core::fmt::Write; + +// use common::{clock::SysClock, credentials, network::Network}; +// use embedded_nal::Ipv4Addr; +// use mqttrust::Mqtt; +// use mqttrust_core::{bbqueue::BBBuffer, EventLoop, MqttOptions, Notification}; +// use native_tls::TlsConnector; +// use rustot::shadows::{ +// derive::ShadowState, topics::Topic, Patch, Shadow, ShadowPatch, ShadowState, +// }; +// use serde::{de::DeserializeOwned, Deserialize, Serialize}; + +// use smlang::statemachine; + +// const Q_SIZE: usize = 1024 * 6; +// static mut Q: BBBuffer = BBBuffer::new(); + +// #[derive(Debug, Clone, Default, PartialEq, Eq, PartialOrd, Ord)] +// pub struct ConfigId(pub u8); + +// impl Serialize for ConfigId { +// fn serialize(&self, serializer: S) -> Result +// where +// S: serde::Serializer, +// { +// let mut str = heapless::String::<3>::new(); +// write!(str, "{}", self.0).map_err(serde::ser::Error::custom)?; +// serializer.serialize_str(&str) +// } +// } + +// impl<'de> Deserialize<'de> for ConfigId { +// fn deserialize(deserializer: D) -> Result +// where +// D: serde::Deserializer<'de>, +// { +// heapless::String::<3>::deserialize(deserializer)? +// .parse() +// .map(ConfigId) +// .map_err(serde::de::Error::custom) +// } +// } + +// impl From for ConfigId { +// fn from(v: u8) -> Self { +// Self(v) +// } +// } + +// #[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)] +// pub struct NetworkMap(heapless::LinearMap>, N>); + +// impl NetworkMap +// where +// K: Eq, +// { +// pub fn insert(&mut self, k: impl Into, v: V) -> Result<(), ()> { +// self.0.insert(k.into(), Some(Patch::Set(v))).map_err(drop)?; +// Ok(()) +// } + +// pub fn remove(&mut self, k: impl Into) -> Result<(), ()> { +// self.0.insert(k.into(), None).map_err(drop)?; +// Ok(()) +// } +// } + +// impl ShadowPatch for NetworkMap +// where +// K: Clone + Default + Eq + Serialize + DeserializeOwned, +// V: Clone + Default + Serialize + DeserializeOwned, +// { +// type PatchState = NetworkMap; + +// fn apply_patch(&mut self, opt: Self::PatchState) { +// for (id, network) in opt.0.into_iter() { +// match network { +// Some(Patch::Set(v)) => { +// self.insert(id.clone(), v.clone()).ok(); +// } +// None | Some(Patch::Unset) => { +// self.remove(id.clone()).ok(); +// } +// } +// } +// } +// } + +// const MAX_NETWORKS: usize = 5; +// type KnownNetworks = NetworkMap; + +// #[derive(Debug, Clone, Default, Serialize, Deserialize, ShadowState)] +// #[shadow("wifi")] +// pub struct WifiConfig { +// pub enabled: bool, + +// pub known_networks: KnownNetworks, +// } + +// #[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)] +// pub struct ConnectionOptions { +// pub ssid: heapless::String<64>, +// pub password: Option>, + +// pub ip: Option, +// pub subnet: Option, +// pub gateway: Option, +// } + +// #[derive(Debug, Clone)] +// pub enum UpdateAction { +// Insert(u8, ConnectionOptions), +// Remove(u8), +// Enabled(bool), +// } + +// statemachine! { +// transitions: { +// *Begin + Delete = DeleteShadow, +// DeleteShadow + Get = GetShadow, +// GetShadow + Load / load_initial = LoadShadow(Option), +// LoadShadow(Option) + CheckInitial / check_initial = Check(Option), +// UpdateFromDevice(UpdateAction) + CheckState / check = Check(Option), +// UpdateFromCloud(UpdateAction) + Ack = AckUpdate, +// AckUpdate + CheckState / check_cloud = Check(Option), +// Check(Option) + UpdateStateFromDevice / get_next_device = UpdateFromDevice(UpdateAction), +// Check(Option) + UpdateStateFromCloud / get_next_cloud = UpdateFromCloud(UpdateAction), +// } +// } + +// impl core::fmt::Debug for States { +// fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +// match self { +// Self::Begin => write!(f, "Self::Begin"), +// Self::DeleteShadow => write!(f, "Self::DeleteShadow"), +// Self::GetShadow => write!(f, "Self::GetShadow"), +// Self::AckUpdate => write!(f, "Self::AckUpdate"), +// Self::LoadShadow(t) => write!(f, "Self::LoadShadow({:?})", t), +// Self::UpdateFromDevice(t) => write!(f, "Self::UpdateFromDevice({:?})", t), +// Self::UpdateFromCloud(t) => write!(f, "Self::UpdateFromCloud({:?})", t), +// Self::Check(t) => write!(f, "Self::Check({:?})", t), +// } +// } +// } + +// fn asserts(id: usize) -> ConnectionOptions { +// match id { +// 0 => ConnectionOptions { +// ssid: heapless::String::from("MySSID"), +// password: None, +// ip: None, +// subnet: None, +// gateway: None, +// }, +// 1 => ConnectionOptions { +// ssid: heapless::String::from("MyProtectedSSID"), +// password: Some(heapless::String::from("SecretPass")), +// ip: None, +// subnet: None, +// gateway: None, +// }, +// 2 => ConnectionOptions { +// ssid: heapless::String::from("CloudSSID"), +// password: Some(heapless::String::from("SecretCloudPass")), +// ip: Some(Ipv4Addr::new(1, 2, 3, 4)), +// subnet: None, +// gateway: None, +// }, +// _ => panic!("Unknown assert ID"), +// } +// } + +// pub struct TestContext<'a> { +// shadow: Shadow<'a, WifiConfig, mqttrust_core::Client<'static, 'static, Q_SIZE>>, +// update_cnt: u8, +// } + +// impl<'a> StateMachineContext for TestContext<'a> { +// fn check_initial( +// &mut self, +// _last_update_action: &Option, +// ) -> Option { +// self.check(&UpdateAction::Remove(0)) +// } + +// fn check_cloud(&mut self) -> Option { +// self.check(&UpdateAction::Remove(0)) +// } + +// fn check(&mut self, _last_update_action: &UpdateAction) -> Option { +// let mut known_networks = KnownNetworks::default(); + +// match self.update_cnt { +// 0 => { +// // After load_initial +// known_networks.insert(0, asserts(0)).unwrap(); +// known_networks.insert(1, asserts(1)).unwrap(); +// } +// 1 => { +// // After get_next_device +// known_networks.remove(0).unwrap(); +// known_networks.insert(1, asserts(1)).unwrap(); +// } +// 2 => { +// // After get_next_cloud +// known_networks.remove(0).unwrap(); +// known_networks.insert(1, asserts(1)).unwrap(); +// known_networks.insert(2, asserts(2)).unwrap(); +// } +// 3 => { +// // After get_next_device +// known_networks.insert(0, asserts(0)).unwrap(); +// known_networks.insert(1, asserts(1)).unwrap(); +// known_networks.insert(2, asserts(2)).unwrap(); +// } +// 4 => { +// // After get_next_cloud +// known_networks.insert(0, asserts(0)).unwrap(); +// known_networks.insert(1, asserts(1)).unwrap(); +// known_networks.remove(2).unwrap(); +// } +// 5 => return None, +// _ => {} +// } + +// Some(known_networks) +// } + +// fn get_next_device(&mut self, _: &Option) -> UpdateAction { +// self.update_cnt += 1; +// match self.update_cnt { +// 1 => UpdateAction::Remove(0), +// 3 => UpdateAction::Insert(0, asserts(0)), +// 5 => UpdateAction::Remove(0), +// _ => panic!("Unexpected update counter in `get_next_device`"), +// } +// } + +// fn get_next_cloud(&mut self, _: &Option) -> UpdateAction { +// self.update_cnt += 1; + +// match self.update_cnt { +// 2 => UpdateAction::Insert(2, asserts(2)), +// 4 => UpdateAction::Remove(2), +// _ => panic!("Unexpected update counter in `get_next_cloud`"), +// } +// } + +// fn load_initial(&mut self) -> Option { +// let mut known_networks = KnownNetworks::default(); +// known_networks.insert(0, asserts(0)).unwrap(); +// known_networks.insert(1, asserts(1)).unwrap(); +// Some(known_networks) +// } +// } + +// impl<'a> StateMachine> { +// pub fn spin( +// &mut self, +// notification: Notification, +// mqtt_client: &mqttrust_core::Client<'static, 'static, Q_SIZE>, +// ) -> bool { +// log::info!("State: {:?}", self.state()); +// match (self.state(), notification) { +// (&States::Begin, Notification::Suback(_)) => { +// self.process_event(Events::Delete).unwrap(); +// } +// (&States::DeleteShadow, Notification::Suback(_)) => { +// mqtt_client +// .publish( +// &Topic::Update +// .format::<128>( +// mqtt_client.client_id(), +// ::NAME, +// ) +// .unwrap(), +// b"{\"state\":{\"desired\":null,\"reported\":null}}", +// mqttrust::QoS::AtLeastOnce, +// ) +// .unwrap(); + +// self.process_event(Events::Get).unwrap(); +// } +// (&States::GetShadow, Notification::Publish(publish)) +// if matches!( +// publish.topic_name.as_str(), +// "$aws/things/rustot-test/shadow/name/wifi/update/accepted" +// ) => +// { +// self.context_mut().shadow.get_shadow().unwrap(); +// self.process_event(Events::Load).unwrap(); +// } +// (&States::LoadShadow(ref initial_map), Notification::Publish(publish)) +// if matches!( +// publish.topic_name.as_str(), +// "$aws/things/rustot-test/shadow/name/wifi/get/accepted" +// ) => +// { +// let initial_map = initial_map.clone(); + +// self.context_mut() +// .shadow +// .update(|_current, desired| { +// desired.known_networks = Some(initial_map.unwrap()); +// }) +// .unwrap(); +// self.process_event(Events::CheckInitial).unwrap(); +// } +// (&States::UpdateFromDevice(ref update_action), Notification::Publish(publish)) +// if matches!( +// publish.topic_name.as_str(), +// "$aws/things/rustot-test/shadow/name/wifi/get/accepted" +// ) => +// { +// let action = update_action.clone(); +// self.context_mut() +// .shadow +// .update(|current, desired| match action { +// UpdateAction::Insert(id, options) => { +// let mut desired_map = current.known_networks.clone(); +// desired_map.insert(id, options).unwrap(); +// desired.known_networks = Some(desired_map); +// } +// UpdateAction::Remove(id) => { +// let mut desired_map = current.known_networks.clone(); +// desired_map.remove(id).unwrap(); +// desired.known_networks = Some(desired_map); +// } +// UpdateAction::Enabled(en) => { +// desired.enabled = Some(en); +// } +// }) +// .unwrap(); +// self.process_event(Events::CheckState).unwrap(); +// } +// (&States::UpdateFromCloud(ref update_action), Notification::Publish(publish)) +// if matches!( +// publish.topic_name.as_str(), +// "$aws/things/rustot-test/shadow/name/wifi/get/accepted" +// ) => +// { +// let desired_known_networks = match update_action { +// UpdateAction::Insert(id, options) => format!( +// "\"known_networks\": {{\"{}\":{{\"set\":{}}}}}", +// id, +// serde_json_core::to_string::<_, 256>(options).unwrap() +// ), +// UpdateAction::Remove(id) => { +// format!("\"known_networks\": {{\"{}\":\"unset\"}}", id) +// } +// &UpdateAction::Enabled(en) => format!("\"enabled\": {}", en), +// }; + +// let payload = format!( +// "{{\"state\":{{\"desired\":{{{}}}, \"reported\":{}}}}}", +// desired_known_networks, +// serde_json_core::to_string::<_, 512>(self.context().shadow.get()).unwrap() +// ); + +// log::debug!("Update from cloud: {:?}", payload); + +// mqtt_client +// .publish( +// &Topic::Update +// .format::<128>( +// mqtt_client.client_id(), +// ::NAME, +// ) +// .unwrap(), +// payload.as_bytes(), +// mqttrust::QoS::AtLeastOnce, +// ) +// .unwrap(); +// self.process_event(Events::Ack).unwrap(); +// } +// (&States::AckUpdate, Notification::Publish(publish)) +// if matches!( +// publish.topic_name.as_str(), +// "$aws/things/rustot-test/shadow/name/wifi/update/delta" +// ) => +// { +// self.context_mut() +// .shadow +// .handle_message(&publish.topic_name, &publish.payload) +// .unwrap(); + +// self.process_event(Events::CheckState).unwrap(); +// } +// (&States::Check(ref expected_map), Notification::Publish(publish)) +// if matches!( +// publish.topic_name.as_str(), +// "$aws/things/rustot-test/shadow/name/wifi/update/accepted" +// | "$aws/things/rustot-test/shadow/name/wifi/update/delta" +// ) => +// { +// let expected = expected_map.clone(); +// self.context_mut() +// .shadow +// .handle_message(&publish.topic_name, &publish.payload) +// .unwrap(); + +// match expected { +// Some(expected_map) => { +// assert_eq!(self.context().shadow.get().known_networks, expected_map); +// self.context_mut().shadow.get_shadow().unwrap(); +// let event = if self.context().update_cnt % 2 == 0 { +// Events::UpdateStateFromDevice +// } else { +// Events::UpdateStateFromCloud +// }; +// self.process_event(event).unwrap(); +// } +// None => return true, +// } +// } +// (_, Notification::Publish(publish)) => { +// log::warn!("TOPIC: {}", publish.topic_name); +// self.context_mut() +// .shadow +// .handle_message(&publish.topic_name, &publish.payload) +// .unwrap(); +// } +// _ => {} +// } + +// false +// } +// } + +// #[test] +// fn test_shadows() { +// env_logger::init(); + +// let (p, c) = unsafe { Q.try_split_framed().unwrap() }; + +// log::info!("Starting shadows test..."); + +// let hostname = credentials::HOSTNAME.unwrap(); +// let (thing_name, identity) = credentials::identity(); + +// let connector = TlsConnector::builder() +// .identity(identity) +// .add_root_certificate(credentials::root_ca()) +// .build() +// .unwrap(); + +// let mut network = Network::new_tls(connector, std::string::String::from(hostname)); + +// let mut mqtt_eventloop = EventLoop::new( +// c, +// SysClock::new(), +// MqttOptions::new(thing_name, hostname.into(), 8883).set_clean_session(true), +// ); + +// let mqtt_client = mqttrust_core::Client::new(p, thing_name); + +// let mut test_state = StateMachine::new(TestContext { +// shadow: Shadow::new(WifiConfig::default(), &mqtt_client, true).unwrap(), +// update_cnt: 0, +// }); + +// loop { +// if nb::block!(mqtt_eventloop.connect(&mut network)).expect("to connect to mqtt") { +// log::info!("Successfully connected to broker"); +// } + +// match mqtt_eventloop.yield_event(&mut network) { +// Ok(notification) => { +// if test_state.spin(notification, &mqtt_client) { +// break; +// } +// } +// Err(_) => {} +// } +// } +// } From cadf3d3997f807a6a315c905c18c050dded9a83b Mon Sep 17 00:00:00 2001 From: Mathias Date: Sat, 6 Jan 2024 14:05:08 +0100 Subject: [PATCH 03/36] Rework topic formatting slightly --- src/provisioning/mod.rs | 105 ++++++------------------------ src/provisioning/topics.rs | 129 +++++++++++-------------------------- 2 files changed, 57 insertions(+), 177 deletions(-) diff --git a/src/provisioning/mod.rs b/src/provisioning/mod.rs index df79353..538db86 100644 --- a/src/provisioning/mod.rs +++ b/src/provisioning/mod.rs @@ -109,37 +109,16 @@ impl FleetProvisioner { where CH: CredentialHandler, { - // FIXME: Changing these to a single topic filter of - // `$aws/certificates/create//+` could be beneficial to - // stack usage - let topic_paths = topics::Subscribe::<2>::new() - .topic( - Topic::CreateKeysAndCertificateAccepted(payload_format), - QoS::AtLeastOnce, - ) - .topic( - Topic::CreateKeysAndCertificateRejected(payload_format), - QoS::AtLeastOnce, - ) - .topics::<38>()?; - - let subscribe_topics = topic_paths - .iter() - .map(|(s, qos)| SubscribeTopic { - topic_path: s.as_str(), - maximum_qos: *qos, + let mut subscription = mqtt + .subscribe::<1>(Subscribe::new(&[SubscribeTopic { + topic_path: Topic::CreateKeysAndCertificateAny(payload_format) + .format::<31>()? + .as_str(), + maximum_qos: QoS::AtLeastOnce, no_local: false, retain_as_published: false, retain_handling: RetainHandling::SendAtSubscribeTime, - }) - .collect::>(); - - let mut subscription = mqtt - .subscribe::<2>(Subscribe { - pid: None, - properties: embedded_mqtt::Properties::Slice(&[]), - topics: subscribe_topics.as_slice(), - }) + }])) .await .map_err(|_| Error::Mqtt)?; @@ -226,37 +205,16 @@ impl FleetProvisioner { where CH: CredentialHandler, { - // FIXME: Changing these to a single topic filter of - // `$aws/certificates/create-from-csr//+` could be beneficial to - // stack usage - let topic_paths = topics::Subscribe::<2>::new() - .topic( - Topic::CreateCertificateFromCsrAccepted(payload_format), - QoS::AtLeastOnce, - ) - .topic( - Topic::CreateCertificateFromCsrRejected(payload_format), - QoS::AtLeastOnce, - ) - .topics::<47>()?; - - let subscribe_topics = topic_paths - .iter() - .map(|(s, qos)| SubscribeTopic { - topic_path: s.as_str(), - maximum_qos: *qos, + let mut subscription = mqtt + .subscribe::<1>(Subscribe::new(&[SubscribeTopic { + topic_path: Topic::CreateCertificateFromCsrAny(payload_format) + .format::<40>()? + .as_str(), + maximum_qos: QoS::AtLeastOnce, no_local: false, retain_as_published: false, retain_handling: RetainHandling::SendAtSubscribeTime, - }) - .collect::>(); - - let mut subscription = mqtt - .subscribe::<2>(Subscribe { - pid: None, - properties: embedded_mqtt::Properties::Slice(&[]), - topics: subscribe_topics.as_slice(), - }) + }])) .await .map_err(|_| Error::Mqtt)?; @@ -343,37 +301,16 @@ impl FleetProvisioner { certificate_ownership_token: &str, parameters: Option

, ) -> Result, Error> { - // FIXME: Changing these to a single topic filter of - // `$aws/provisioning-templates//provision//+` - // could be beneficial to stack usage - let topic_paths = topics::Subscribe::<2>::new() - .topic( - Topic::RegisterThingAccepted(template_name, payload_format), - QoS::AtLeastOnce, - ) - .topic( - Topic::RegisterThingRejected(template_name, payload_format), - QoS::AtLeastOnce, - ) - .topics::<128>()?; - - let subscribe_topics = topic_paths - .iter() - .map(|(s, qos)| SubscribeTopic { - topic_path: s.as_str(), - maximum_qos: *qos, + let mut subscription = mqtt + .subscribe::<1>(Subscribe::new(&[SubscribeTopic { + topic_path: Topic::RegisterThingAny(template_name, payload_format) + .format::<128>()? + .as_str(), + maximum_qos: QoS::AtLeastOnce, no_local: false, retain_as_published: false, retain_handling: RetainHandling::SendAtSubscribeTime, - }) - .collect::>(); - - let mut subscription = mqtt - .subscribe::<2>(Subscribe { - pid: None, - properties: embedded_mqtt::Properties::Slice(&[]), - topics: subscribe_topics.as_slice(), - }) + }])) .await .map_err(|_| Error::Mqtt)?; diff --git a/src/provisioning/topics.rs b/src/provisioning/topics.rs index c94a5d9..bd3e7af 100644 --- a/src/provisioning/topics.rs +++ b/src/provisioning/topics.rs @@ -2,7 +2,6 @@ use core::fmt::Display; use core::fmt::Write; use core::str::FromStr; -use embedded_mqtt::QoS; use heapless::String; use super::Error; @@ -59,18 +58,27 @@ pub enum Topic<'a> { CreateCertificateFromCsr(PayloadFormat), // ---- Incoming Topics + /// `$aws/provisioning-templates//provision//+` + RegisterThingAny(&'a str, PayloadFormat), + /// `$aws/provisioning-templates//provision//accepted` RegisterThingAccepted(&'a str, PayloadFormat), /// `$aws/provisioning-templates//provision//rejected` RegisterThingRejected(&'a str, PayloadFormat), + /// `$aws/certificates/create//+` + CreateKeysAndCertificateAny(PayloadFormat), + /// `$aws/certificates/create//accepted` CreateKeysAndCertificateAccepted(PayloadFormat), /// `$aws/certificates/create//rejected` CreateKeysAndCertificateRejected(PayloadFormat), + /// `$aws/certificates/create-from-csr//+` + CreateCertificateFromCsrAny(PayloadFormat), + /// `$aws/certificates/create-from-csr//accepted` CreateCertificateFromCsrAccepted(PayloadFormat), @@ -169,6 +177,14 @@ impl<'a> Topic<'a> { payload_format, )) } + Topic::RegisterThingAny(template_name, payload_format) => { + topic_path.write_fmt(format_args!( + "{}/{}/provision/{}/+", + Self::PROVISIONING_PREFIX, + template_name, + payload_format, + )) + } Topic::RegisterThingAccepted(template_name, payload_format) => { topic_path.write_fmt(format_args!( "{}/{}/provision/{}/accepted", @@ -192,6 +208,9 @@ impl<'a> Topic<'a> { payload_format, )), + Topic::CreateKeysAndCertificateAny(payload_format) => topic_path.write_fmt( + format_args!("{}/create/{}/+", Self::CERT_PREFIX, payload_format), + ), Topic::CreateKeysAndCertificateAccepted(payload_format) => topic_path.write_fmt( format_args!("{}/create/{}/accepted", Self::CERT_PREFIX, payload_format), ), @@ -204,102 +223,26 @@ impl<'a> Topic<'a> { Self::CERT_PREFIX, payload_format, )), - Topic::CreateCertificateFromCsrAccepted(payload_format) => topic_path.write_fmt( - format_args!("{}/create-from-csr/{}", Self::CERT_PREFIX, payload_format), - ), - Topic::CreateCertificateFromCsrRejected(payload_format) => topic_path.write_fmt( - format_args!("{}/create-from-csr/{}", Self::CERT_PREFIX, payload_format), + Topic::CreateCertificateFromCsrAny(payload_format) => topic_path.write_fmt( + format_args!("{}/create-from-csr/{}/+", Self::CERT_PREFIX, payload_format), ), + Topic::CreateCertificateFromCsrAccepted(payload_format) => { + topic_path.write_fmt(format_args!( + "{}/create-from-csr/{}/accepted", + Self::CERT_PREFIX, + payload_format + )) + } + Topic::CreateCertificateFromCsrRejected(payload_format) => { + topic_path.write_fmt(format_args!( + "{}/create-from-csr/{}/rejected", + Self::CERT_PREFIX, + payload_format + )) + } } .map_err(|_| Error::Overflow)?; Ok(topic_path) } } - -#[derive(Default)] -pub struct Subscribe<'a, const N: usize> { - topics: heapless::Vec<(Topic<'a>, QoS), N>, -} - -impl<'a, const N: usize> Subscribe<'a, N> { - pub fn new() -> Self { - Self::default() - } - - pub fn topic(self, topic: Topic<'a>, qos: QoS) -> Self { - // Ignore attempts to subscribe to outgoing topics - if topic.direction() != Direction::Incoming { - return self; - } - - if self.topics.iter().any(|(t, _)| t == &topic) { - return self; - } - - let mut topics = self.topics; - topics.push((topic, qos)).ok(); - - Self { topics } - } - - pub fn topics( - self, - ) -> Result, QoS), N>, Error> { - self.iter() - .map(|(topic, qos)| Ok((topic.format()?, *qos))) - .collect() - } - - pub fn iter(&self) -> impl Iterator, QoS)> { - self.topics.iter() - } -} - -// #[derive(Default)] -// pub struct Unsubscribe<'a, const N: usize> { -// topics: heapless::Vec, N>, -// } - -// impl<'a, const N: usize> Unsubscribe<'a, N> { -// pub fn new() -> Self { -// Self::default() -// } - -// pub fn topic(self, topic: Topic<'a>) -> Self { -// // Ignore attempts to subscribe to outgoing topics -// if topic.direction() != Direction::Incoming { -// return self; -// } - -// if self.topics.iter().any(|t| t == &topic) { -// return self; -// } - -// let mut topics = self.topics; -// topics.push(topic).ok(); -// Self { topics } -// } - -// pub fn topics(self) -> Result, N>, Error> { -// self.topics -// .iter() -// .map(|topic| topic.clone().format()) -// .collect() -// } - -// // pub fn send(self, mqtt: &M) -> Result<(), Error> { -// // if self.topics.is_empty() { -// // return Ok(()); -// // } - -// // let topic_paths = self.topics()?; -// // let topics: heapless::Vec<_, N> = topic_paths.iter().map(|s| s.as_str()).collect(); - -// // for t in topics.chunks(5) { -// // // mqtt.unsubscribe(t)?; -// // } - -// // Ok(()) -// // } -// } From 548be762ba633db53ae28a1710a59e097d3d5072 Mon Sep 17 00:00:00 2001 From: Mathias Date: Sat, 6 Jan 2024 21:02:47 +0100 Subject: [PATCH 04/36] Get rid of additional stack allocation for ownership token --- src/provisioning/mod.rs | 341 +++++++++++++--------------------------- tests/provisioning.rs | 4 +- 2 files changed, 112 insertions(+), 233 deletions(-) diff --git a/src/provisioning/mod.rs b/src/provisioning/mod.rs index 538db86..507fc02 100644 --- a/src/provisioning/mod.rs +++ b/src/provisioning/mod.rs @@ -4,18 +4,20 @@ pub mod topics; use core::future::Future; -use embassy_sync::blocking_mutex::raw::NoopRawMutex; -use embedded_mqtt::{Publish, QoS, RetainHandling, Subscribe, SubscribeTopic}; +use embassy_sync::blocking_mutex::raw::RawMutex; +use embedded_mqtt::{ + Message, Publish, QoS, RetainHandling, Subscribe, SubscribeTopic, Subscription, +}; use futures::StreamExt; -use serde::de::DeserializeOwned; use serde::Serialize; +use serde::{de::DeserializeOwned, Deserialize}; pub use error::Error; use self::{ data_types::{ - CreateCertificateFromCsrResponse, CreateKeysAndCertificateResponse, ErrorResponse, - RegisterThingRequest, RegisterThingResponse, + CreateKeysAndCertificateResponse, ErrorResponse, RegisterThingRequest, + RegisterThingResponse, }, topics::{PayloadFormat, Topic}, }; @@ -37,8 +39,8 @@ pub struct Credentials<'a> { pub struct FleetProvisioner; impl FleetProvisioner { - pub async fn provision<'a, C>( - mqtt: &'a embedded_mqtt::MqttClient<'a, NoopRawMutex, 2>, + pub async fn provision<'a, C, M: RawMutex, const SUBS: usize>( + mqtt: &'a embedded_mqtt::MqttClient<'a, M, SUBS>, template_name: &str, parameters: Option, credential_handler: &mut impl CredentialHandler, @@ -57,8 +59,8 @@ impl FleetProvisioner { } #[cfg(feature = "provision_cbor")] - pub async fn provision_cbor<'a, C>( - mqtt: &'a embedded_mqtt::MqttClient<'a, NoopRawMutex, 2>, + pub async fn provision_cbor<'a, C, M: RawMutex, const SUBS: usize>( + mqtt: &'a embedded_mqtt::MqttClient<'a, M, SUBS>, template_name: &str, parameters: Option, credential_handler: &mut impl CredentialHandler, @@ -76,88 +78,30 @@ impl FleetProvisioner { .await } - async fn provision_inner<'a, C, P, CH>( - mqtt: &'a embedded_mqtt::MqttClient<'a, NoopRawMutex, 2>, + #[cfg(feature = "provision_cbor")] + async fn provision_inner<'a, C, M: RawMutex, const SUBS: usize>( + mqtt: &'a embedded_mqtt::MqttClient<'a, M, SUBS>, template_name: &str, - parameters: Option

, - credential_handler: &mut CH, + parameters: Option, + credential_handler: &mut impl CredentialHandler, payload_format: PayloadFormat, ) -> Result, Error> where C: DeserializeOwned, - P: Serialize, - CH: CredentialHandler, { - let certificate_ownership_token = - Self::create_keys_and_certificates(mqtt, payload_format, credential_handler).await?; - - Self::register_thing( - mqtt, - template_name, - payload_format, - certificate_ownership_token.as_str(), - parameters, - ) - .await - } + let mut create_subscription = Self::begin(mqtt, payload_format).await?; - pub async fn create_keys_and_certificates( - mqtt: &embedded_mqtt::MqttClient<'_, NoopRawMutex, 2>, - payload_format: PayloadFormat, - credential_handler: &mut CH, - ) -> Result, Error> - where - CH: CredentialHandler, - { - let mut subscription = mqtt - .subscribe::<1>(Subscribe::new(&[SubscribeTopic { - topic_path: Topic::CreateKeysAndCertificateAny(payload_format) - .format::<31>()? - .as_str(), - maximum_qos: QoS::AtLeastOnce, - no_local: false, - retain_as_published: false, - retain_handling: RetainHandling::SendAtSubscribeTime, - }])) + let mut message = create_subscription + .next() .await - .map_err(|_| Error::Mqtt)?; + .ok_or(Error::InvalidState)?; - mqtt.publish(Publish { - dup: false, - qos: QoS::AtLeastOnce, - retain: false, - pid: None, - topic_name: Topic::CreateKeysAndCertificate(payload_format) - .format::<29>()? - .as_str(), - payload: b"", - properties: embedded_mqtt::Properties::Slice(&[]), - }) - .await - .map_err(|_| Error::Mqtt)?; - - let mut message = subscription.next().await.ok_or(Error::InvalidState)?; - - match Topic::from_str(message.topic_name()) { + let ownership_token = match Topic::from_str(message.topic_name()) { Some(Topic::CreateKeysAndCertificateAccepted(format)) => { - trace!( - "Topic::CreateKeysAndCertificateAccepted {:?}. Payload len: {:?}", + let response = Self::deserialize::( format, - message.payload().len() - ); - - let response = match format { - #[cfg(feature = "provision_cbor")] - PayloadFormat::Cbor => serde_cbor::de::from_mut_slice::< - CreateKeysAndCertificateResponse, - >(message.payload_mut())?, - PayloadFormat::Json => { - serde_json_core::from_slice::( - message.payload(), - )? - .0 - } - }; + &mut message, + )?; credential_handler .store_credentials(Credentials { @@ -167,48 +111,49 @@ impl FleetProvisioner { }) .await?; - Ok(heapless::String::try_from(response.certificate_ownership_token).unwrap()) + response.certificate_ownership_token } // Error happened! Some(Topic::CreateKeysAndCertificateRejected(format)) => { - error!(">> {:?}", message.topic_name()); - - let response = match format { - #[cfg(feature = "provision_cbor")] - PayloadFormat::Cbor => { - serde_cbor::de::from_mut_slice::(message.payload_mut())? - } - PayloadFormat::Json => { - serde_json_core::from_slice::(message.payload())?.0 - } - }; - - error!("{:?}", response); - - Err(Error::Response(response.status_code)) + return Err(Self::handle_error(format, message).unwrap_err()); } t => { trace!("{:?}", t); - Err(Error::InvalidState) + return Err(Error::InvalidState); } - } - } + }; - pub async fn create_certificate_from_csr( - mqtt: &embedded_mqtt::MqttClient<'_, NoopRawMutex, 2>, - payload_format: PayloadFormat, - credential_handler: &mut CH, - ) -> Result, Error> - where - CH: CredentialHandler, - { - let mut subscription = mqtt + let register_request = RegisterThingRequest { + certificate_ownership_token: &ownership_token, + parameters, + }; + + // FIXME: Serialize directly into the publish payload through `DeferredPublish` API + let payload = &mut [0u8; 1024]; + + let payload_len = match payload_format { + #[cfg(feature = "provision_cbor")] + PayloadFormat::Cbor => { + let mut serializer = + serde_cbor::ser::Serializer::new(serde_cbor::ser::SliceWrite::new(payload)); + register_request.serialize(&mut serializer)?; + serializer.into_inner().bytes_written() + } + PayloadFormat::Json => serde_json_core::to_slice(®ister_request, payload)?, + }; + + drop(message); + drop(create_subscription); + + debug!("Starting RegisterThing {:?}", payload_len); + + let mut register_subscription = mqtt .subscribe::<1>(Subscribe::new(&[SubscribeTopic { - topic_path: Topic::CreateCertificateFromCsrAny(payload_format) - .format::<40>()? + topic_path: Topic::RegisterThingAny(template_name, payload_format) + .format::<128>()? .as_str(), maximum_qos: QoS::AtLeastOnce, no_local: false, @@ -223,67 +168,33 @@ impl FleetProvisioner { qos: QoS::AtLeastOnce, retain: false, pid: None, - topic_name: Topic::CreateCertificateFromCsr(payload_format) - .format::<38>()? + topic_name: Topic::RegisterThing(template_name, payload_format) + .format::<69>()? .as_str(), - payload: b"", + payload: &payload[..payload_len], properties: embedded_mqtt::Properties::Slice(&[]), }) .await .map_err(|_| Error::Mqtt)?; - let mut message = subscription.next().await.ok_or(Error::InvalidState)?; + let mut message = register_subscription + .next() + .await + .ok_or(Error::InvalidState)?; match Topic::from_str(message.topic_name()) { - Some(Topic::CreateCertificateFromCsrAccepted(format)) => { - trace!( - "Topic::CreateCertificateFromCsrAccepted {:?}. Payload len: {:?}", + Some(Topic::RegisterThingAccepted(_, format)) => { + let response = Self::deserialize::, M, SUBS>( format, - message.payload().len() - ); - - let response = match format { - #[cfg(feature = "provision_cbor")] - PayloadFormat::Cbor => serde_cbor::de::from_mut_slice::< - CreateCertificateFromCsrResponse, - >(message.payload_mut())?, - PayloadFormat::Json => { - serde_json_core::from_slice::( - message.payload(), - )? - .0 - } - }; - - credential_handler - .store_credentials(Credentials { - certificate_id: response.certificate_id, - certificate_pem: response.certificate_pem, - private_key: None, - }) - .await?; + &mut message, + )?; - // FIXME: It should be possible to re-arrange stuff to get rid of the need for this 512 byte stack alloc - Ok(heapless::String::try_from(response.certificate_ownership_token).unwrap()) + Ok(response.device_configuration) } // Error happened! - Some(Topic::CreateCertificateFromCsrRejected(format)) => { - error!(">> {:?}", message.topic_name()); - - let response = match format { - #[cfg(feature = "provision_cbor")] - PayloadFormat::Cbor => { - serde_cbor::de::from_mut_slice::(message.payload_mut())? - } - PayloadFormat::Json => { - serde_json_core::from_slice::(message.payload())?.0 - } - }; - - error!("{:?}", response); - - Err(Error::Response(response.status_code)) + Some(Topic::RegisterThingRejected(_, format)) => { + Err(Self::handle_error(format, message).unwrap_err()) } t => { @@ -294,17 +205,14 @@ impl FleetProvisioner { } } - pub async fn register_thing( - mqtt: &embedded_mqtt::MqttClient<'_, NoopRawMutex, 2>, - template_name: &str, + async fn begin<'a, M: RawMutex, const SUBS: usize>( + mqtt: &'a embedded_mqtt::MqttClient<'a, M, SUBS>, payload_format: PayloadFormat, - certificate_ownership_token: &str, - parameters: Option

, - ) -> Result, Error> { - let mut subscription = mqtt + ) -> Result, Error> { + let subscription = mqtt .subscribe::<1>(Subscribe::new(&[SubscribeTopic { - topic_path: Topic::RegisterThingAny(template_name, payload_format) - .format::<128>()? + topic_path: Topic::CreateKeysAndCertificateAny(payload_format) + .format::<31>()? .as_str(), maximum_qos: QoS::AtLeastOnce, no_local: false, @@ -314,87 +222,58 @@ impl FleetProvisioner { .await .map_err(|_| Error::Mqtt)?; - let register_request = RegisterThingRequest { - certificate_ownership_token: &certificate_ownership_token, - parameters, - }; - - // FIXME: Serialize directly into the publish payload through `DeferredPublish` API - let payload = &mut [0u8; 1024]; - - let payload_len = match payload_format { - #[cfg(feature = "provision_cbor")] - PayloadFormat::Cbor => { - let mut serializer = - serde_cbor::ser::Serializer::new(serde_cbor::ser::SliceWrite::new(payload)); - register_request.serialize(&mut serializer)?; - serializer.into_inner().bytes_written() - } - PayloadFormat::Json => serde_json_core::to_slice(®ister_request, payload)?, - }; - - info!("Starting RegisterThing {:?}", payload_len); - mqtt.publish(Publish { dup: false, qos: QoS::AtLeastOnce, retain: false, pid: None, - topic_name: Topic::RegisterThing(template_name, payload_format) - .format::<69>()? + topic_name: Topic::CreateKeysAndCertificate(payload_format) + .format::<29>()? .as_str(), - payload: &payload[..payload_len], + payload: b"", properties: embedded_mqtt::Properties::Slice(&[]), }) .await .map_err(|_| Error::Mqtt)?; - let mut message = subscription.next().await.ok_or(Error::InvalidState)?; - - match Topic::from_str(message.topic_name()) { - Some(Topic::RegisterThingAccepted(_, format)) => { - trace!("Topic::RegisterThingAccepted {:?}", format); - - let response = match format { - #[cfg(feature = "provision_cbor")] - PayloadFormat::Cbor => serde_cbor::de::from_mut_slice::< - RegisterThingResponse<'_, C>, - >(message.payload_mut())?, - PayloadFormat::Json => { - serde_json_core::from_slice::>( - message.payload(), - )? - .0 - } - }; - - Ok(response.device_configuration) - } + Ok(subscription) + } - // Error happened! - Some(Topic::RegisterThingRejected(_, format)) => { - error!(">> {:?}", message.topic_name()); + fn deserialize<'a, R: Deserialize<'a>, M: RawMutex, const SUBS: usize>( + payload_format: PayloadFormat, + message: &'a mut Message<'_, M, SUBS>, + ) -> Result { + trace!( + "Topic::CreateKeysAndCertificateAccepted {:?}. Payload len: {:?}", + payload_format, + message.payload().len() + ); - let response = match format { - #[cfg(feature = "provision_cbor")] - PayloadFormat::Cbor => { - serde_cbor::de::from_mut_slice::(message.payload_mut())? - } - PayloadFormat::Json => { - serde_json_core::from_slice::(message.payload())?.0 - } - }; + Ok(match payload_format { + #[cfg(feature = "provision_cbor")] + PayloadFormat::Cbor => serde_cbor::de::from_mut_slice::(message.payload_mut())?, + PayloadFormat::Json => serde_json_core::from_slice::(message.payload())?.0, + }) + } - error!("{:?}", response); + fn handle_error( + format: PayloadFormat, + mut message: Message<'_, M, SUBS>, + ) -> Result<(), Error> { + error!(">> {:?}", message.topic_name()); - Err(Error::Response(response.status_code)) + let response = match format { + #[cfg(feature = "provision_cbor")] + PayloadFormat::Cbor => { + serde_cbor::de::from_mut_slice::(message.payload_mut())? + } + PayloadFormat::Json => { + serde_json_core::from_slice::(message.payload())?.0 } + }; - t => { - trace!("{:?}", t); + error!("{:?}", response); - Err(Error::InvalidState) - } - } + Err(Error::Response(response.status_code)) } } diff --git a/tests/provisioning.rs b/tests/provisioning.rs index 949ba7f..bdc6d86 100644 --- a/tests/provisioning.rs +++ b/tests/provisioning.rs @@ -98,14 +98,14 @@ async fn test_provisioning() { let mut credential_handler = CredentialDAO { creds: None }; #[cfg(not(feature = "provision_cbor"))] - let provision_fut = FleetProvisioner::provision::( + let provision_fut = FleetProvisioner::provision::( client, &template_name, Some(parameters), &mut credential_handler, ); #[cfg(feature = "provision_cbor")] - let provision_fut = FleetProvisioner::provision_cbor::( + let provision_fut = FleetProvisioner::provision_cbor::( client, &template_name, Some(parameters), From e81fd6f143db7a160a260083dcdd77c22e4ad278 Mon Sep 17 00:00:00 2001 From: Mathias Date: Sat, 6 Jan 2024 21:15:54 +0100 Subject: [PATCH 05/36] Reenable jobs --- src/jobs/describe.rs | 41 +++++----- src/jobs/get_pending.rs | 39 ++++------ src/jobs/mod.rs | 19 +---- src/jobs/start_next.rs | 38 ++++----- src/jobs/subscribe.rs | 166 +--------------------------------------- src/jobs/unsubscribe.rs | 124 ------------------------------ src/jobs/update.rs | 47 ++++++------ src/lib.rs | 2 +- 8 files changed, 76 insertions(+), 400 deletions(-) delete mode 100644 src/jobs/unsubscribe.rs diff --git a/src/jobs/describe.rs b/src/jobs/describe.rs index 846181f..5c59d5c 100644 --- a/src/jobs/describe.rs +++ b/src/jobs/describe.rs @@ -1,4 +1,3 @@ -use mqttrust::{Mqtt, QoS}; use serde::Serialize; use crate::jobs::JobTopic; @@ -79,18 +78,22 @@ impl<'a> Describe<'a> { pub fn topic_payload( self, client_id: &str, + buf: &mut [u8], ) -> Result< ( heapless::String<{ MAX_THING_NAME_LEN + MAX_JOB_ID_LEN + 22 }>, - heapless::Vec, + usize, ), JobError, > { - let payload = serde_json_core::to_vec(&DescribeJobExecutionRequest { - execution_number: self.execution_number, - include_job_document: self.include_job_document.then(|| true), - client_token: self.client_token, - }) + let payload_len = serde_json_core::to_slice( + &DescribeJobExecutionRequest { + execution_number: self.execution_number, + include_job_document: self.include_job_document.then(|| true), + client_token: self.client_token, + }, + buf, + ) .map_err(|_| JobError::Encoding)?; Ok(( @@ -98,17 +101,9 @@ impl<'a> Describe<'a> { .map(JobTopic::Get) .unwrap_or(JobTopic::GetNext) .format(client_id)?, - payload, + payload_len, )) } - - pub fn send(self, mqtt: &M, qos: QoS) -> Result<(), JobError> { - let (topic, payload) = self.topic_payload(mqtt.client_id())?; - - mqtt.publish(topic.as_str(), &payload, qos)?; - - Ok(()) - } } #[cfg(test)] @@ -131,15 +126,16 @@ mod test { #[test] fn topic_payload() { - let (topic, payload) = Describe::new() + let mut buf = [0u8; 512]; + let (topic, payload_len) = Describe::new() .include_job_document() .execution_number(1) .client_token("test_client:token") - .topic_payload("test_client") + .topic_payload("test_client", &mut buf) .unwrap(); assert_eq!( - payload, + &buf[..payload_len], br#"{"executionNumber":1,"includeJobDocument":true,"clientToken":"test_client:token"}"# ); @@ -148,16 +144,17 @@ mod test { #[test] fn topic_job_id() { - let (topic, payload) = Describe::new() + let mut buf = [0u8; 512]; + let (topic, payload_len) = Describe::new() .include_job_document() .execution_number(1) .job_id("test_job_id") .client_token("test_client:token") - .topic_payload("test_client") + .topic_payload("test_client", &mut buf) .unwrap(); assert_eq!( - payload, + &buf[..payload_len], br#"{"executionNumber":1,"includeJobDocument":true,"clientToken":"test_client:token"}"# ); diff --git a/src/jobs/get_pending.rs b/src/jobs/get_pending.rs index d44f2f0..4c185ca 100644 --- a/src/jobs/get_pending.rs +++ b/src/jobs/get_pending.rs @@ -1,4 +1,3 @@ -use mqttrust::{Mqtt, QoS}; use serde::Serialize; use crate::jobs::JobTopic; @@ -38,27 +37,17 @@ impl<'a> GetPending<'a> { pub fn topic_payload( self, client_id: &str, - ) -> Result< - ( - heapless::String<{ MAX_THING_NAME_LEN + 21 }>, - heapless::Vec, - ), - JobError, - > { - let payload = serde_json_core::to_vec(&&GetPendingJobExecutionsRequest { - client_token: self.client_token, - }) + buf: &mut [u8], + ) -> Result<(heapless::String<{ MAX_THING_NAME_LEN + 21 }>, usize), JobError> { + let payload_len = serde_json_core::to_slice( + &&GetPendingJobExecutionsRequest { + client_token: self.client_token, + }, + buf, + ) .map_err(|_| JobError::Encoding)?; - Ok((JobTopic::GetPending.format(client_id)?, payload)) - } - - pub fn send(self, mqtt: &M, qos: QoS) -> Result<(), JobError> { - let (topic, payload) = self.topic_payload(mqtt.client_id())?; - - mqtt.publish(topic.as_str(), &payload, qos)?; - - Ok(()) + Ok((JobTopic::GetPending.format(client_id)?, payload_len)) } } @@ -80,12 +69,16 @@ mod test { #[test] fn topic_payload() { - let (topic, payload) = GetPending::new() + let mut buf = [0u8; 512]; + let (topic, payload_len) = GetPending::new() .client_token("test_client:token_pending") - .topic_payload("test_client") + .topic_payload("test_client", &mut buf) .unwrap(); - assert_eq!(payload, br#"{"clientToken":"test_client:token_pending"}"#); + assert_eq!( + &buf[..payload_len], + br#"{"clientToken":"test_client:token_pending"}"# + ); assert_eq!(topic.as_str(), "$aws/things/test_client/jobs/get"); } diff --git a/src/jobs/mod.rs b/src/jobs/mod.rs index 1481087..2f63158 100644 --- a/src/jobs/mod.rs +++ b/src/jobs/mod.rs @@ -102,14 +102,13 @@ pub mod describe; pub mod get_pending; pub mod start_next; pub mod subscribe; -pub mod unsubscribe; pub mod update; use core::fmt::Write; use self::{ data_types::JobStatus, describe::Describe, get_pending::GetPending, start_next::StartNext, - subscribe::Subscribe, unsubscribe::Unsubscribe, update::Update, + update::Update, }; pub use subscribe::Topic; @@ -128,13 +127,7 @@ pub type StatusDetailsOwned = heapless::LinearMap, heapless pub enum JobError { Overflow, Encoding, - Mqtt(mqttrust::MqttError), -} - -impl From for JobError { - fn from(e: mqttrust::MqttError) -> Self { - Self::Mqtt(e) - } + Mqtt, } #[derive(Debug, Clone, PartialEq)] @@ -272,12 +265,4 @@ impl Jobs { pub fn update(job_id: &str, status: JobStatus) -> Update { Update::new(job_id, status) } - - pub fn subscribe<'a, const N: usize>() -> Subscribe<'a, N> { - Subscribe::new() - } - - pub fn unsubscribe<'a, const N: usize>() -> Unsubscribe<'a, N> { - Unsubscribe::new() - } } diff --git a/src/jobs/start_next.rs b/src/jobs/start_next.rs index 05e568a..9a4f83f 100644 --- a/src/jobs/start_next.rs +++ b/src/jobs/start_next.rs @@ -1,4 +1,3 @@ -use mqttrust::{Mqtt, QoS}; use serde::Serialize; use crate::jobs::JobTopic; @@ -84,28 +83,18 @@ impl<'a> StartNext<'a> { pub fn topic_payload( self, client_id: &str, - ) -> Result< - ( - heapless::String<{ MAX_THING_NAME_LEN + 28 }>, - heapless::Vec, - ), - JobError, - > { - let payload = serde_json_core::to_vec(&StartNextPendingJobExecutionRequest { - step_timeout_in_minutes: self.step_timeout_in_minutes, - client_token: self.client_token, - }) + buf: &mut [u8], + ) -> Result<(heapless::String<{ MAX_THING_NAME_LEN + 28 }>, usize), JobError> { + let payload_len = serde_json_core::to_slice( + &StartNextPendingJobExecutionRequest { + step_timeout_in_minutes: self.step_timeout_in_minutes, + client_token: self.client_token, + }, + buf, + ) .map_err(|_| JobError::Encoding)?; - Ok((JobTopic::StartNext.format(client_id)?, payload)) - } - - pub fn send(self, mqtt: &M, qos: QoS) -> Result<(), JobError> { - let (topic, payload) = self.topic_payload(mqtt.client_id())?; - - mqtt.publish(topic.as_str(), &payload, qos)?; - - Ok(()) + Ok((JobTopic::StartNext.format(client_id)?, payload_len)) } } @@ -136,14 +125,15 @@ mod test { #[test] fn topic_payload() { - let (topic, payload) = StartNext::new() + let mut buf = [0u8; 512]; + let (topic, payload_len) = StartNext::new() .client_token("test_client:token_next_pending") .step_timeout_in_minutes(43) - .topic_payload("test_client") + .topic_payload("test_client", &mut buf) .unwrap(); assert_eq!( - payload, + &buf[..payload_len], br#"{"stepTimeoutInMinutes":43,"clientToken":"test_client:token_next_pending"}"# ); diff --git a/src/jobs/subscribe.rs b/src/jobs/subscribe.rs index 1f740cc..cf796eb 100644 --- a/src/jobs/subscribe.rs +++ b/src/jobs/subscribe.rs @@ -1,8 +1,4 @@ -use mqttrust::{Mqtt, QoS, SubscribeTopic}; - -use crate::jobs::JobError; - -use super::{JobTopic, MAX_JOB_ID_LEN}; +use super::JobTopic; #[derive(Debug, Clone, PartialEq)] #[cfg_attr(feature = "defmt", derive(defmt::Format))] @@ -72,163 +68,3 @@ impl<'a> From<&Topic<'a>> for JobTopic<'a> { } } } - -#[derive(Default)] -pub struct Subscribe<'a, const N: usize> { - topics: heapless::Vec<(Topic<'a>, QoS), N>, -} - -impl<'a, const N: usize> Subscribe<'a, N> { - pub fn new() -> Self { - Self::default() - } - - pub fn topic(self, topic: Topic<'a>, qos: QoS) -> Self { - match topic { - Topic::DescribeAccepted(job_id) => assert!(job_id.len() <= MAX_JOB_ID_LEN), - Topic::DescribeRejected(job_id) => assert!(job_id.len() <= MAX_JOB_ID_LEN), - Topic::UpdateAccepted(job_id) => assert!(job_id.len() <= MAX_JOB_ID_LEN), - Topic::UpdateRejected(job_id) => assert!(job_id.len() <= MAX_JOB_ID_LEN), - _ => {} - } - - if self.topics.iter().any(|(t, _)| t == &topic) { - return self; - } - - let mut topics = self.topics; - topics.push((topic, qos)).ok(); - - Self { topics } - } - - pub fn topics( - self, - client_id: &str, - ) -> Result, QoS), N>, JobError> { - // assert!(client_id.len() <= super::MAX_THING_NAME_LEN); - self.topics - .iter() - .map(|(topic, qos)| Ok((JobTopic::from(topic).format(client_id)?, *qos))) - .collect() - } - - pub fn send(self, mqtt: &M) -> Result<(), JobError> { - if self.topics.is_empty() { - return Ok(()); - } - let topic_paths = self.topics(mqtt.client_id())?; - - let topics: heapless::Vec<_, N> = topic_paths - .iter() - .map(|(s, qos)| SubscribeTopic { - topic_path: s.as_str(), - qos: *qos, - }) - .collect(); - - debug!("Subscribing!"); - - for t in topics.chunks(5) { - mqtt.subscribe(t)?; - } - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use mqttrust::{encoding::v4::decode_slice, Packet, QoS, SubscribeTopic}; - - use super::*; - - use crate::test::MockMqtt; - - #[test] - fn splits_subscribe_all() { - let mqtt = &MockMqtt::new(); - - Subscribe::<10>::new() - .topic(Topic::Notify, QoS::AtLeastOnce) - .topic(Topic::NotifyNext, QoS::AtLeastOnce) - .topic(Topic::GetAccepted, QoS::AtLeastOnce) - .topic(Topic::GetRejected, QoS::AtLeastOnce) - .topic(Topic::StartNextAccepted, QoS::AtLeastOnce) - .topic(Topic::StartNextRejected, QoS::AtLeastOnce) - .topic(Topic::DescribeAccepted("test_job"), QoS::AtLeastOnce) - .topic(Topic::DescribeRejected("test_job"), QoS::AtLeastOnce) - .topic(Topic::UpdateAccepted("test_job"), QoS::AtLeastOnce) - .topic(Topic::UpdateRejected("test_job"), QoS::AtLeastOnce) - .send(mqtt) - .unwrap(); - - assert_eq!(mqtt.tx.borrow_mut().len(), 2); - let bytes = mqtt.tx.borrow_mut().pop_front().unwrap(); - let packet = decode_slice(bytes.as_slice()).unwrap(); - - let topics = match packet { - Some(Packet::Subscribe(ref s)) => s.topics().collect::>(), - _ => panic!(), - }; - - assert_eq!( - topics, - vec![ - SubscribeTopic { - topic_path: "$aws/things/test_client/jobs/notify", - qos: QoS::AtLeastOnce - }, - SubscribeTopic { - topic_path: "$aws/things/test_client/jobs/notify-next", - qos: QoS::AtLeastOnce - }, - SubscribeTopic { - topic_path: "$aws/things/test_client/jobs/get/accepted", - qos: QoS::AtLeastOnce - }, - SubscribeTopic { - topic_path: "$aws/things/test_client/jobs/get/rejected", - qos: QoS::AtLeastOnce - }, - SubscribeTopic { - topic_path: "$aws/things/test_client/jobs/start-next/accepted", - qos: QoS::AtLeastOnce - } - ] - ); - - let bytes = mqtt.tx.borrow_mut().pop_front().unwrap(); - let packet = decode_slice(bytes.as_slice()).unwrap(); - - let topics = match packet { - Some(Packet::Subscribe(ref s)) => s.topics().collect::>(), - _ => panic!(), - }; - - assert_eq!( - topics, - vec![ - SubscribeTopic { - topic_path: "$aws/things/test_client/jobs/start-next/rejected", - qos: QoS::AtLeastOnce - }, - SubscribeTopic { - topic_path: "$aws/things/test_client/jobs/test_job/get/accepted", - qos: QoS::AtLeastOnce - }, - SubscribeTopic { - topic_path: "$aws/things/test_client/jobs/test_job/get/rejected", - qos: QoS::AtLeastOnce - }, - SubscribeTopic { - topic_path: "$aws/things/test_client/jobs/test_job/update/accepted", - qos: QoS::AtLeastOnce - }, - SubscribeTopic { - topic_path: "$aws/things/test_client/jobs/test_job/update/rejected", - qos: QoS::AtLeastOnce - } - ] - ); - } -} diff --git a/src/jobs/unsubscribe.rs b/src/jobs/unsubscribe.rs deleted file mode 100644 index 79009ac..0000000 --- a/src/jobs/unsubscribe.rs +++ /dev/null @@ -1,124 +0,0 @@ -use mqttrust::Mqtt; - -use crate::jobs::JobTopic; - -use super::{subscribe::Topic, JobError, MAX_JOB_ID_LEN}; - -#[derive(Default)] -pub struct Unsubscribe<'a, const N: usize> { - topics: heapless::Vec, N>, -} - -impl<'a, const N: usize> Unsubscribe<'a, N> { - pub fn new() -> Self { - Self::default() - } - - pub fn topic(self, topic: Topic<'a>) -> Self { - match topic { - Topic::DescribeAccepted(job_id) => assert!(job_id.len() <= MAX_JOB_ID_LEN), - Topic::DescribeRejected(job_id) => assert!(job_id.len() <= MAX_JOB_ID_LEN), - Topic::UpdateAccepted(job_id) => assert!(job_id.len() <= MAX_JOB_ID_LEN), - Topic::UpdateRejected(job_id) => assert!(job_id.len() <= MAX_JOB_ID_LEN), - _ => {} - } - - if self.topics.iter().any(|t| t == &topic) { - return self; - } - - let mut topics = self.topics; - topics.push(topic).ok(); - Self { topics } - } - - pub fn topics( - self, - client_id: &str, - ) -> Result, N>, JobError> { - // assert!(client_id.len() <= super::MAX_THING_NAME_LEN); - - self.topics - .iter() - .map(|topic| JobTopic::from(topic).format(client_id)) - .collect() - } - - pub fn send(self, mqtt: &M) -> Result<(), JobError> { - if self.topics.is_empty() { - return Ok(()); - } - let topic_paths = self.topics(mqtt.client_id())?; - let topics: heapless::Vec<_, N> = topic_paths.iter().map(|s| s.as_str()).collect(); - - for t in topics.chunks(5) { - mqtt.unsubscribe(t)?; - } - - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use mqttrust::{encoding::v4::decode_slice, Packet}; - - use super::*; - use crate::test::MockMqtt; - - #[test] - fn splits_unsubscribe_all() { - let mqtt = &MockMqtt::new(); - - Unsubscribe::<10>::new() - .topic(Topic::Notify) - .topic(Topic::NotifyNext) - .topic(Topic::GetAccepted) - .topic(Topic::GetRejected) - .topic(Topic::StartNextAccepted) - .topic(Topic::StartNextRejected) - .topic(Topic::DescribeAccepted("test_job")) - .topic(Topic::DescribeRejected("test_job")) - .topic(Topic::UpdateAccepted("test_job")) - .topic(Topic::UpdateRejected("test_job")) - .send(mqtt) - .unwrap(); - - assert_eq!(mqtt.tx.borrow_mut().len(), 2); - let bytes = mqtt.tx.borrow_mut().pop_front().unwrap(); - - let packet = decode_slice(bytes.as_slice()).unwrap(); - let topics = match packet { - Some(Packet::Unsubscribe(ref s)) => s.topics().collect::>(), - _ => panic!(), - }; - - assert_eq!( - topics, - vec![ - "$aws/things/test_client/jobs/notify", - "$aws/things/test_client/jobs/notify-next", - "$aws/things/test_client/jobs/get/accepted", - "$aws/things/test_client/jobs/get/rejected", - "$aws/things/test_client/jobs/start-next/accepted", - ] - ); - - let bytes = mqtt.tx.borrow_mut().pop_front().unwrap(); - let packet = decode_slice(bytes.as_slice()).unwrap(); - let topics = match packet { - Some(Packet::Unsubscribe(ref s)) => s.topics().collect::>(), - _ => panic!(), - }; - assert_eq!( - topics, - vec![ - "$aws/things/test_client/jobs/start-next/rejected", - "$aws/things/test_client/jobs/test_job/get/accepted", - "$aws/things/test_client/jobs/test_job/get/rejected", - "$aws/things/test_client/jobs/test_job/update/accepted", - "$aws/things/test_client/jobs/test_job/update/rejected" - ] - ); - } -} diff --git a/src/jobs/update.rs b/src/jobs/update.rs index 5a3903d..e2c83f4 100644 --- a/src/jobs/update.rs +++ b/src/jobs/update.rs @@ -1,4 +1,3 @@ -use mqttrust::{Mqtt, QoS}; use serde::Serialize; use crate::jobs::{ @@ -152,34 +151,33 @@ impl<'a> Update<'a> { pub fn topic_payload( self, client_id: &str, + buf: &mut [u8], ) -> Result< ( heapless::String<{ MAX_THING_NAME_LEN + MAX_JOB_ID_LEN + 25 }>, - heapless::Vec, + usize, ), JobError, > { - let payload = serde_json_core::to_vec(&UpdateJobExecutionRequest { - execution_number: self.execution_number, - include_job_document: self.include_job_document.then(|| true), - expected_version: self.expected_version, - include_job_execution_state: self.include_job_execution_state.then(|| true), - status: self.status, - status_details: self.status_details, - step_timeout_in_minutes: self.step_timeout_in_minutes, - client_token: self.client_token, - }) + let payload_len = serde_json_core::to_slice( + &UpdateJobExecutionRequest { + execution_number: self.execution_number, + include_job_document: self.include_job_document.then(|| true), + expected_version: self.expected_version, + include_job_execution_state: self.include_job_execution_state.then(|| true), + status: self.status, + status_details: self.status_details, + step_timeout_in_minutes: self.step_timeout_in_minutes, + client_token: self.client_token, + }, + buf, + ) .map_err(|_| JobError::Encoding)?; - Ok((JobTopic::Update(self.job_id).format(client_id)?, payload)) - } - - pub fn send(self, mqtt: &M, qos: QoS) -> Result<(), JobError> { - let (topic, payload) = self.topic_payload(mqtt.client_id())?; - - mqtt.publish(topic.as_str(), &payload, qos)?; - - Ok(()) + Ok(( + JobTopic::Update(self.job_id).format(client_id)?, + payload_len, + )) } } @@ -208,17 +206,18 @@ mod test { #[test] fn topic_payload() { - let (topic, payload) = Update::new("test_job_id", JobStatus::Failed) + let mut buf = [0u8; 512]; + let (topic, payload_len) = Update::new("test_job_id", JobStatus::Failed) .client_token("test_client:token_update") .step_timeout_in_minutes(50) .execution_number(5) .expected_version(2) .include_job_document() .include_job_execution_state() - .topic_payload("test_client") + .topic_payload("test_client", &mut buf) .unwrap(); - assert_eq!(payload, br#"{"executionNumber":5,"expectedVersion":2,"includeJobDocument":true,"includeJobExecutionState":true,"status":"FAILED","stepTimeoutInMinutes":50,"clientToken":"test_client:token_update"}"#); + assert_eq!(&ubf[..payload_len], br#"{"executionNumber":5,"expectedVersion":2,"includeJobDocument":true,"includeJobExecutionState":true,"status":"FAILED","stepTimeoutInMinutes":50,"clientToken":"test_client:token_update"}"#); assert_eq!( topic.as_str(), diff --git a/src/lib.rs b/src/lib.rs index 1b9acd2..2ed5a3a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,7 +5,7 @@ // This mod MUST go first, so that the others see its macros. pub(crate) mod fmt; -// pub mod jobs; +pub mod jobs; // #[cfg(any(feature = "ota_mqtt_data", feature = "ota_http_data"))] // pub mod ota; pub mod provisioning; From cd14fc88c342f5b15e8799e4f01572c41c1cab34 Mon Sep 17 00:00:00 2001 From: Mathias Date: Mon, 8 Jan 2024 12:29:02 +0100 Subject: [PATCH 06/36] Rewrite OTA to async --- Cargo.toml | 1 + src/jobs/mod.rs | 4 +- src/lib.rs | 4 +- src/ota/agent.rs | 154 ---- src/ota/builder.rs | 237 ------ src/ota/config.rs | 16 +- src/ota/control_interface/mod.rs | 6 +- src/ota/control_interface/mqtt.rs | 65 +- src/ota/data_interface/mod.rs | 41 +- src/ota/data_interface/mqtt.rs | 57 +- src/ota/encoding/json.rs | 8 +- src/ota/encoding/mod.rs | 40 +- src/ota/error.rs | 16 +- src/ota/mod.rs | 335 +++++++- src/ota/pal.rs | 158 +--- src/ota/state.rs | 1181 ----------------------------- src/ota/test/mock.rs | 93 --- src/ota/test/mod.rs | 523 ------------- tests/common/file_handler.rs | 142 ++-- tests/common/mod.rs | 2 +- tests/ota.rs | 405 +++++----- 21 files changed, 723 insertions(+), 2765 deletions(-) delete mode 100644 src/ota/agent.rs delete mode 100644 src/ota/builder.rs delete mode 100644 src/ota/state.rs delete mode 100644 src/ota/test/mock.rs delete mode 100644 src/ota/test/mod.rs diff --git a/Cargo.toml b/Cargo.toml index 4a4ce18..8b9be4d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,6 +32,7 @@ embedded-storage-async = "0.4" embedded-mqtt = { path = "../embedded-mqtt" } futures = { version = "0.3.28", default-features = false } +embassy-time = { version = "0.2" } embassy-sync = "0.5" log = { version = "^0.4", default-features = false, optional = true } diff --git a/src/jobs/mod.rs b/src/jobs/mod.rs index 2f63158..246d94b 100644 --- a/src/jobs/mod.rs +++ b/src/jobs/mod.rs @@ -123,11 +123,11 @@ pub const MAX_RUNNING_JOBS: usize = 1; pub type StatusDetails<'a> = heapless::LinearMap<&'a str, &'a str, 4>; pub type StatusDetailsOwned = heapless::LinearMap, heapless::String<11>, 4>; -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq)] pub enum JobError { Overflow, Encoding, - Mqtt, + Mqtt(embedded_mqtt::Error), } #[derive(Debug, Clone, PartialEq)] diff --git a/src/lib.rs b/src/lib.rs index 2ed5a3a..2fac696 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,8 +6,8 @@ pub(crate) mod fmt; pub mod jobs; -// #[cfg(any(feature = "ota_mqtt_data", feature = "ota_http_data"))] -// pub mod ota; +#[cfg(any(feature = "ota_mqtt_data", feature = "ota_http_data"))] +pub mod ota; pub mod provisioning; // pub mod shadows; diff --git a/src/ota/agent.rs b/src/ota/agent.rs deleted file mode 100644 index 6bf0060..0000000 --- a/src/ota/agent.rs +++ /dev/null @@ -1,154 +0,0 @@ -use super::{ - builder::{self, NoTimer}, - control_interface::ControlInterface, - data_interface::{DataInterface, NoInterface}, - encoding::json::OtaJob, - pal::OtaPal, - state::{Error, Events, JobEventData, SmContext, StateMachine, States}, -}; -use crate::jobs::StatusDetails; - -// OTA Agent driving the FSM of an OTA update -pub struct OtaAgent<'a, C, DP, DS, T, ST, PAL, const TIMER_HZ: u32> -where - C: ControlInterface, - DP: DataInterface, - DS: DataInterface, - T: fugit_timer::Timer, - ST: fugit_timer::Timer, - PAL: OtaPal, -{ - pub(crate) state: StateMachine>, -} - -// Make sure any active OTA session is cleaned up, and the topics are -// unsubscribed on drop. -impl<'a, C, DP, DS, T, ST, PAL, const TIMER_HZ: u32> Drop - for OtaAgent<'a, C, DP, DS, T, ST, PAL, TIMER_HZ> -where - C: ControlInterface, - DP: DataInterface, - DS: DataInterface, - T: fugit_timer::Timer, - ST: fugit_timer::Timer, - PAL: OtaPal, -{ - fn drop(&mut self) { - let sm_context = self.state.context_mut(); - sm_context.ota_close().ok(); - sm_context.control.cleanup().ok(); - } -} - -impl<'a, C, DP, T, PAL, const TIMER_HZ: u32> - OtaAgent<'a, C, DP, NoInterface, T, NoTimer, PAL, TIMER_HZ> -where - C: ControlInterface, - DP: DataInterface, - T: fugit_timer::Timer, - PAL: OtaPal, -{ - pub fn builder( - control_interface: &'a C, - data_primary: DP, - request_timer: T, - pal: PAL, - ) -> builder::OtaAgentBuilder<'a, C, DP, NoInterface, T, NoTimer, PAL, TIMER_HZ> { - builder::OtaAgentBuilder::new(control_interface, data_primary, request_timer, pal) - } -} - -/// Public interface of the OTA Agent -impl<'a, C, DP, DS, T, ST, PAL, const TIMER_HZ: u32> OtaAgent<'a, C, DP, DS, T, ST, PAL, TIMER_HZ> -where - C: ControlInterface, - DP: DataInterface, - DS: DataInterface, - T: fugit_timer::Timer, - ST: fugit_timer::Timer, - PAL: OtaPal, -{ - pub fn init(&mut self) { - if matches!(self.state(), &States::Ready) { - self.state.process_event(Events::Start).ok(); - } else { - self.state.process_event(Events::Resume).ok(); - } - } - - pub fn job_update( - &mut self, - job_name: &str, - ota_document: &OtaJob, - status_details: Option<&StatusDetails>, - ) -> Result<&States, Error> { - self.state - .process_event(Events::ReceivedJobDocument(JobEventData { - job_name, - ota_document, - status_details, - })) - } - - pub fn timer_callback(&mut self) -> Result<(), Error> { - let ctx = self.state.context_mut(); - if ctx.request_timer.wait().is_ok() { - return self.state.process_event(Events::RequestTimer).map(drop); - } - - if let Some(ref mut self_test_timer) = ctx.self_test_timer { - if self_test_timer.wait().is_ok() { - error!( - "Self test failed to complete within {} ms", - ctx.config.self_test_timeout_ms - ); - ctx.pal.reset_device().ok(); - } - } - Ok(()) - } - - pub fn process_event(&mut self) -> Result<&States, Error> { - if let Some(event) = self.state.context_mut().events.dequeue() { - self.state.process_event(event) - } else { - Ok(self.state()) - } - } - - pub fn handle_message(&mut self, payload: &mut [u8]) -> Result<&States, Error> { - self.state.process_event(Events::ReceivedFileBlock(payload)) - } - - pub fn check_for_update(&mut self) -> Result<&States, Error> { - if matches!( - self.state(), - States::WaitingForJob | States::RequestingJob | States::WaitingForFileBlock - ) { - self.state.process_event(Events::RequestJobDocument) - } else { - Ok(self.state()) - } - } - - pub fn abort(&mut self) -> Result<&States, Error> { - self.state.process_event(Events::UserAbort) - } - - pub fn suspend(&mut self) -> Result<&States, Error> { - // Stop the request timer - self.state.context_mut().request_timer.cancel().ok(); - - // Send event to OTA agent task. - self.state.process_event(Events::Suspend) - } - - pub fn resume(&mut self) -> Result<&States, Error> { - // Send event to OTA agent task - self.state.process_event(Events::Resume) - } - - pub fn state(&self) -> &States { - self.state.state() - } -} diff --git a/src/ota/builder.rs b/src/ota/builder.rs deleted file mode 100644 index 603d568..0000000 --- a/src/ota/builder.rs +++ /dev/null @@ -1,237 +0,0 @@ -use crate::ota::{ - config::Config, - control_interface::ControlInterface, - data_interface::DataInterface, - pal::OtaPal, - state::{SmContext, StateMachine}, -}; - -use super::{agent::OtaAgent, data_interface::NoInterface, pal::ImageState}; - -pub struct NoTimer; - -impl fugit_timer::Timer for NoTimer { - type Error = (); - - fn now(&mut self) -> fugit_timer::TimerInstantU32 { - todo!() - } - - fn start( - &mut self, - _duration: fugit_timer::TimerDurationU32, - ) -> Result<(), Self::Error> { - todo!() - } - - fn cancel(&mut self) -> Result<(), Self::Error> { - todo!() - } - - fn wait(&mut self) -> nb::Result<(), Self::Error> { - todo!() - } -} - -pub struct OtaAgentBuilder<'a, C, DP, DS, T, ST, PAL, const TIMER_HZ: u32> -where - C: ControlInterface, - DP: DataInterface, - DS: DataInterface, - T: fugit_timer::Timer, - ST: fugit_timer::Timer, - PAL: OtaPal, -{ - control: &'a C, - data_primary: DP, - #[cfg(all(feature = "ota_mqtt_data", feature = "ota_http_data"))] - data_secondary: Option, - #[cfg(not(all(feature = "ota_mqtt_data", feature = "ota_http_data")))] - data_secondary: core::marker::PhantomData, - pal: PAL, - request_timer: T, - self_test_timer: Option, - config: Config, -} - -impl<'a, C, DP, T, PAL, const TIMER_HZ: u32> - OtaAgentBuilder<'a, C, DP, NoInterface, T, NoTimer, PAL, TIMER_HZ> -where - C: ControlInterface, - DP: DataInterface, - T: fugit_timer::Timer, - PAL: OtaPal, -{ - pub fn new(control_interface: &'a C, data_primary: DP, request_timer: T, pal: PAL) -> Self { - Self { - control: control_interface, - data_primary, - #[cfg(all(feature = "ota_mqtt_data", feature = "ota_http_data"))] - data_secondary: None, - #[cfg(not(all(feature = "ota_mqtt_data", feature = "ota_http_data")))] - data_secondary: core::marker::PhantomData, - pal, - request_timer, - self_test_timer: None, - config: Config::default(), - } - } -} - -impl<'a, C, DP, DS, T, ST, PAL, const TIMER_HZ: u32> - OtaAgentBuilder<'a, C, DP, DS, T, ST, PAL, TIMER_HZ> -where - C: ControlInterface, - DP: DataInterface, - DS: DataInterface, - T: fugit_timer::Timer, - ST: fugit_timer::Timer, - PAL: OtaPal, -{ - #[cfg(all(feature = "ota_mqtt_data", feature = "ota_http_data"))] - pub fn data_secondary( - self, - interface: D, - ) -> OtaAgentBuilder<'a, C, DP, D, T, ST, PAL, TIMER_HZ> { - OtaAgentBuilder { - control: self.control, - data_primary: self.data_primary, - data_secondary: Some(interface), - pal: self.pal, - request_timer: self.request_timer, - self_test_timer: self.self_test_timer, - config: self.config, - } - } - - pub fn block_size(self, block_size: usize) -> Self { - Self { - config: Config { - block_size, - ..self.config - }, - ..self - } - } - - pub fn max_request_momentum(self, max_request_momentum: u8) -> Self { - Self { - config: Config { - max_request_momentum, - ..self.config - }, - ..self - } - } - - pub fn activate_delay(self, activate_delay: u8) -> Self { - Self { - config: Config { - activate_delay, - ..self.config - }, - ..self - } - } - - pub fn request_wait_ms(self, request_wait_ms: u32) -> Self { - Self { - config: Config { - request_wait_ms, - ..self.config - }, - ..self - } - } - - pub fn status_update_frequency(self, status_update_frequency: u32) -> Self { - Self { - config: Config { - status_update_frequency, - ..self.config - }, - ..self - } - } - - pub fn allow_downgrade(self) -> Self { - Self { - config: Config { - allow_downgrade: true, - ..self.config - }, - ..self - } - } - - pub fn with_self_test_timeout( - self, - timer: NST, - timeout_ms: u32, - ) -> OtaAgentBuilder<'a, C, DP, DS, T, NST, PAL, TIMER_HZ> - where - NST: fugit_timer::Timer, - { - OtaAgentBuilder { - control: self.control, - data_primary: self.data_primary, - data_secondary: self.data_secondary, - pal: self.pal, - request_timer: self.request_timer, - self_test_timer: Some(timer), - config: Config { - self_test_timeout_ms: timeout_ms, - ..self.config - }, - } - } - - pub fn build(self) -> OtaAgent<'a, C, DP, DS, T, ST, PAL, TIMER_HZ> { - OtaAgent { - state: StateMachine::new(SmContext { - events: heapless::spsc::Queue::new(), - control: self.control, - data_secondary: self.data_secondary, - data_primary: self.data_primary, - active_interface: None, - request_momentum: 0, - request_timer: self.request_timer, - self_test_timer: self.self_test_timer, - pal: self.pal, - config: self.config, - image_state: ImageState::Unknown, - }), - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - use crate::{ - ota::test::mock::{MockPal, MockTimer}, - test::MockMqtt, - }; - - #[test] - fn enables_allow_downgrade() { - let mqtt = MockMqtt::new(); - - let request_timer = MockTimer::new(); - let self_test_timer = MockTimer::new(); - let pal = MockPal {}; - - let builder = OtaAgentBuilder::new(&mqtt, &mqtt, request_timer, pal) - .with_self_test_timeout(self_test_timer, 32000) - .allow_downgrade(); - - assert!(builder.config.allow_downgrade); - assert!(builder.self_test_timer.is_some()); - assert_eq!(builder.config.self_test_timeout_ms, 32000); - - let agent = builder.build(); - - assert!(agent.state.context().config.allow_downgrade); - } -} diff --git a/src/ota/config.rs b/src/ota/config.rs index 7862273..0d69854 100644 --- a/src/ota/config.rs +++ b/src/ota/config.rs @@ -1,12 +1,11 @@ +use embassy_time::Duration; + pub struct Config { pub(crate) block_size: usize, pub(crate) max_request_momentum: u8, - pub(crate) activate_delay: u8, - pub(crate) request_wait_ms: u32, + pub(crate) request_wait: Duration, pub(crate) status_update_frequency: u32, - pub(crate) allow_downgrade: bool, - pub(crate) unsubscribe_on_shutdown: bool, - pub(crate) self_test_timeout_ms: u32, + pub(crate) self_test_timeout: Option, } impl Default for Config { @@ -14,12 +13,9 @@ impl Default for Config { Self { block_size: 256, max_request_momentum: 3, - activate_delay: 5, - request_wait_ms: 8000, + request_wait: Duration::from_secs(8), status_update_frequency: 24, - allow_downgrade: false, - unsubscribe_on_shutdown: true, - self_test_timeout_ms: 16000, + self_test_timeout: None, } } } diff --git a/src/ota/control_interface/mod.rs b/src/ota/control_interface/mod.rs index e1b28e3..8962ea2 100644 --- a/src/ota/control_interface/mod.rs +++ b/src/ota/control_interface/mod.rs @@ -10,14 +10,12 @@ pub mod mqtt; // Interfaces required for OTA pub trait ControlInterface { - fn init(&self) -> Result<(), OtaError>; - fn request_job(&self) -> Result<(), OtaError>; - fn update_job_status( + async fn request_job(&self) -> Result<(), OtaError>; + async fn update_job_status( &self, file_ctx: &mut FileContext, config: &Config, status: JobStatus, reason: JobStatusReason, ) -> Result<(), OtaError>; - fn cleanup(&self) -> Result<(), OtaError>; } diff --git a/src/ota/control_interface/mqtt.rs b/src/ota/control_interface/mqtt.rs index f8c34b1..495064a 100644 --- a/src/ota/control_interface/mqtt.rs +++ b/src/ota/control_interface/mqtt.rs @@ -1,37 +1,43 @@ use core::fmt::Write; -use mqttrust::QoS; +use embassy_sync::blocking_mutex::raw::RawMutex; +use embedded_mqtt::{Publish, QoS}; use super::ControlInterface; use crate::jobs::data_types::JobStatus; -use crate::jobs::subscribe::Topic; use crate::jobs::Jobs; use crate::ota::config::Config; use crate::ota::encoding::json::JobStatusReason; use crate::ota::encoding::FileContext; use crate::ota::error::OtaError; -impl ControlInterface for T { - /// Initialize the control interface by subscribing to the OTA job - /// notification topics. - fn init(&self) -> Result<(), OtaError> { - Jobs::subscribe::<1>() - .topic(Topic::NotifyNext, QoS::AtLeastOnce) - .send(self)?; - Ok(()) - } - +impl<'a, M: RawMutex, const SUBS: usize> ControlInterface + for embedded_mqtt::MqttClient<'a, M, SUBS> +{ /// Check for next available OTA job from the job service by publishing a /// "get next job" message to the job service. - fn request_job(&self) -> Result<(), OtaError> { - Jobs::describe().send(self, QoS::AtLeastOnce)?; + async fn request_job(&self) -> Result<(), OtaError> { + // FIXME: Serialize directly into the publish payload through `DeferredPublish` API + let mut buf = [0u8; 512]; + let (topic, payload_len) = Jobs::describe().topic_payload(self.client_id(), &mut buf)?; + + self.publish(Publish { + dup: false, + qos: QoS::AtLeastOnce, + retain: false, + pid: None, + topic_name: &topic, + payload: &buf[..payload_len], + properties: embedded_mqtt::Properties::Slice(&[]), + }) + .await?; Ok(()) } /// Update the job status on the service side with progress or completion /// info - fn update_job_status( + async fn update_job_status( &self, file_ctx: &mut FileContext, config: &Config, @@ -41,8 +47,8 @@ impl ControlInterface for T { file_ctx .status_details .insert( - heapless::String::from("self_test"), - heapless::String::from(reason.as_str()), + heapless::String::try_from("self_test").unwrap(), + heapless::String::try_from(reason.as_str()).unwrap(), ) .map_err(|_| OtaError::Overflow)?; @@ -73,7 +79,7 @@ impl ControlInterface for T { file_ctx .status_details - .insert(heapless::String::from("progress"), progress) + .insert(heapless::String::try_from("progress").unwrap(), progress) .map_err(|_| OtaError::Overflow)?; } @@ -84,18 +90,23 @@ impl ControlInterface for T { } } - Jobs::update(file_ctx.job_name.as_str(), status) + // FIXME: Serialize directly into the publish payload through `DeferredPublish` API + let mut buf = [0u8; 512]; + let (topic, payload_len) = Jobs::update(file_ctx.job_name.as_str(), status) .status_details(&file_ctx.status_details) - .send(self, qos)?; + .topic_payload(self.client_id(), &mut buf)?; - Ok(()) - } + self.publish(Publish { + dup: false, + qos, + retain: false, + pid: None, + topic_name: &topic, + payload: &buf[..payload_len], + properties: embedded_mqtt::Properties::Slice(&[]), + }) + .await?; - /// Perform any cleanup operations required for control plane - fn cleanup(&self) -> Result<(), OtaError> { - Jobs::unsubscribe::<1>() - .topic(Topic::NotifyNext) - .send(self)?; Ok(()) } } diff --git a/src/ota/data_interface/mod.rs b/src/ota/data_interface/mod.rs index ec02550..d80a0c9 100644 --- a/src/ota/data_interface/mod.rs +++ b/src/ota/data_interface/mod.rs @@ -1,5 +1,5 @@ -#[cfg(feature = "ota_http_data")] -pub mod http; +// #[cfg(feature = "ota_http_data")] +// pub mod http; #[cfg(feature = "ota_mqtt_data")] pub mod mqtt; @@ -44,46 +44,15 @@ impl<'a> FileBlock<'a> { pub trait DataInterface { const PROTOCOL: Protocol; - fn init_file_transfer(&self, file_ctx: &mut FileContext) -> Result<(), OtaError>; - fn request_file_block( + async fn init_file_transfer(&self, file_ctx: &mut FileContext) -> Result<(), OtaError>; + async fn request_file_block( &self, file_ctx: &mut FileContext, config: &Config, ) -> Result<(), OtaError>; - fn decode_file_block<'a>( + async fn decode_file_block<'a>( &self, file_ctx: &mut FileContext, payload: &'a mut [u8], ) -> Result, OtaError>; - fn cleanup(&self, file_ctx: &mut FileContext, config: &Config) -> Result<(), OtaError>; -} - -pub struct NoInterface; - -impl DataInterface for NoInterface { - const PROTOCOL: Protocol = Protocol::Mqtt; - - fn init_file_transfer(&self, _file_ctx: &mut FileContext) -> Result<(), OtaError> { - unreachable!() - } - - fn request_file_block( - &self, - _file_ctx: &mut FileContext, - _config: &Config, - ) -> Result<(), OtaError> { - unreachable!() - } - - fn decode_file_block<'a>( - &self, - _file_ctx: &mut FileContext, - _payload: &'a mut [u8], - ) -> Result, OtaError> { - unreachable!() - } - - fn cleanup(&self, _file_ctx: &mut FileContext, _config: &Config) -> Result<(), OtaError> { - unreachable!() - } } diff --git a/src/ota/data_interface/mqtt.rs b/src/ota/data_interface/mqtt.rs index 87e5999..3f4af47 100644 --- a/src/ota/data_interface/mqtt.rs +++ b/src/ota/data_interface/mqtt.rs @@ -1,7 +1,8 @@ use core::fmt::{Display, Write}; use core::str::FromStr; -use mqttrust::{Mqtt, QoS, SubscribeTopic}; +use embassy_sync::blocking_mutex::raw::RawMutex; +use embedded_mqtt::{MqttClient, Properties, Publish, RetainHandling, Subscribe, SubscribeTopic}; use crate::ota::error::OtaError; use crate::{ @@ -116,30 +117,32 @@ impl<'a> OtaTopic<'a> { } } -impl<'a, M> DataInterface for &'a M -where - M: Mqtt, -{ +impl<'a, M: RawMutex, const SUBS: usize> DataInterface for MqttClient<'a, M, SUBS> { const PROTOCOL: Protocol = Protocol::Mqtt; /// Init file transfer by subscribing to the OTA data stream topic - fn init_file_transfer(&self, file_ctx: &mut FileContext) -> Result<(), OtaError> { + async fn init_file_transfer(&self, file_ctx: &mut FileContext) -> Result<(), OtaError> { let topic_path = OtaTopic::Data(Encoding::Cbor, file_ctx.stream_name.as_str()) .format::<256>(self.client_id())?; + let topic = SubscribeTopic { topic_path: topic_path.as_str(), - qos: mqttrust::QoS::AtLeastOnce, + maximum_qos: embedded_mqtt::QoS::AtLeastOnce, + no_local: false, + retain_as_published: false, + retain_handling: RetainHandling::SendAtSubscribeTime, }; debug!("Subscribing to: [{:?}]", &topic_path); - self.subscribe(&[topic])?; + // FIXME: + self.subscribe::<1>(Subscribe::new(&[topic])).await?; Ok(()) } /// Request file block by publishing to the get stream topic - fn request_file_block( + async fn request_file_block( &self, file_ctx: &mut FileContext, config: &Config, @@ -147,6 +150,7 @@ where // Reset number of blocks requested file_ctx.request_block_remaining = file_ctx.bitmap.len() as u32; + // FIXME: Serialize directly into the publish payload through `DeferredPublish` API let buf = &mut [0u8; 32]; let len = cbor::to_slice( &cbor::GetStreamRequest { @@ -163,19 +167,24 @@ where ) .map_err(|_| OtaError::Encoding)?; - self.publish( - OtaTopic::Get(Encoding::Cbor, file_ctx.stream_name.as_str()) + self.publish(Publish { + dup: false, + qos: embedded_mqtt::QoS::AtMostOnce, + retain: false, + pid: None, + topic_name: OtaTopic::Get(Encoding::Cbor, file_ctx.stream_name.as_str()) .format::<{ MAX_STREAM_ID_LEN + MAX_THING_NAME_LEN + 30 }>(self.client_id())? .as_str(), - &buf[..len], - QoS::AtMostOnce, - )?; + payload: &buf[..len], + properties: Properties::Slice(&[]), + }) + .await?; Ok(()) } /// Decode a cbor encoded fileblock received from streaming service - fn decode_file_block<'c>( + async fn decode_file_block<'c>( &self, _file_ctx: &mut FileContext, payload: &'c mut [u8], @@ -186,19 +195,6 @@ where .into(), ) } - - /// Perform any cleanup operations required for data plane - fn cleanup(&self, file_ctx: &mut FileContext, config: &Config) -> Result<(), OtaError> { - if config.unsubscribe_on_shutdown { - // Unsubscribe from data stream topics - self.unsubscribe(&[ - OtaTopic::Data(Encoding::Cbor, file_ctx.stream_name.as_str()) - .format::<256>(self.client_id())? - .as_str(), - ])?; - } - Ok(()) - } } #[cfg(test)] @@ -234,7 +230,10 @@ mod tests { topics, vec![SubscribeTopic { topic_path: "$aws/things/test_client/streams/test_stream/data/cbor", - qos: QoS::AtLeastOnce + maximum_qos: QoS::AtLeastOnce, + no_local: false, + retain_as_published: false, + retain_handling: RetainHandling::SendAtSubscribeTime }] ); } diff --git a/src/ota/encoding/json.rs b/src/ota/encoding/json.rs index ae08e4d..45ea3f2 100644 --- a/src/ota/encoding/json.rs +++ b/src/ota/encoding/json.rs @@ -61,16 +61,16 @@ pub struct FileDescription<'a> { impl<'a> FileDescription<'a> { pub fn signature(&self) -> Signature { if let Some(sig) = self.sha1_rsa { - return Signature::Sha1Rsa(heapless::String::from(sig)); + return Signature::Sha1Rsa(heapless::String::try_from(sig).unwrap()); } if let Some(sig) = self.sha256_rsa { - return Signature::Sha256Rsa(heapless::String::from(sig)); + return Signature::Sha256Rsa(heapless::String::try_from(sig).unwrap()); } if let Some(sig) = self.sha1_ecdsa { - return Signature::Sha1Ecdsa(heapless::String::from(sig)); + return Signature::Sha1Ecdsa(heapless::String::try_from(sig).unwrap()); } if let Some(sig) = self.sha256_ecdsa { - return Signature::Sha256Ecdsa(heapless::String::from(sig)); + return Signature::Sha256Ecdsa(heapless::String::try_from(sig).unwrap()); } unreachable!() } diff --git a/src/ota/encoding/mod.rs b/src/ota/encoding/mod.rs index 257a1ea..bb9f473 100644 --- a/src/ota/encoding/mod.rs +++ b/src/ota/encoding/mod.rs @@ -3,15 +3,14 @@ pub mod cbor; pub mod json; use core::ops::{Deref, DerefMut}; -use core::str::FromStr; use serde::{Serialize, Serializer}; use crate::jobs::StatusDetailsOwned; use self::json::{JobStatusReason, OtaJob, Signature}; +use super::config::Config; use super::error::OtaError; -use super::{config::Config, pal::Version}; #[derive(Clone, PartialEq)] pub struct Bitmap(bitmaps::Bitmap<32>); @@ -81,7 +80,6 @@ impl FileContext { status_details: Option, file_idx: usize, config: &Config, - current_version: Version, ) -> Result { let file_desc = ota_job .files @@ -94,12 +92,12 @@ impl FileContext { details } else { let mut status = StatusDetailsOwned::new(); - status - .insert( - heapless::String::from("updated_by"), - current_version.to_string(), - ) - .map_err(|_| OtaError::Overflow)?; + // status + // .insert( + // heapless::String::try_from("updated_by").unwrap(), + // current_version.to_string(), + // ) + // .map_err(|_| OtaError::Overflow)?; status }; @@ -109,29 +107,33 @@ impl FileContext { let bitmap = Bitmap::new(file_desc.filesize, config.block_size, block_offset); Ok(FileContext { - filepath: heapless::String::from(file_desc.filepath), + filepath: heapless::String::try_from(file_desc.filepath).unwrap(), filesize: file_desc.filesize, fileid: file_desc.fileid, - certfile: heapless::String::from(file_desc.certfile), - update_data_url: file_desc.update_data_url.map(heapless::String::from), - auth_scheme: file_desc.auth_scheme.map(heapless::String::from), + certfile: heapless::String::try_from(file_desc.certfile).unwrap(), + update_data_url: file_desc + .update_data_url + .map(|s| heapless::String::try_from(s).unwrap()), + auth_scheme: file_desc + .auth_scheme + .map(|s| heapless::String::try_from(s).unwrap()), signature, file_type: file_desc.file_type, status_details: status, - job_name: heapless::String::from(job_name), + job_name: heapless::String::try_from(job_name).unwrap(), block_offset, request_block_remaining: bitmap.len() as u32, blocks_remaining: (file_desc.filesize + config.block_size - 1) / config.block_size, - stream_name: heapless::String::from(ota_job.streamname), + stream_name: heapless::String::try_from(ota_job.streamname).unwrap(), bitmap, }) } pub fn self_test(&self) -> bool { self.status_details - .get(&heapless::String::from("self_test")) + .get(&heapless::String::try_from("self_test").unwrap()) .and_then(|f| f.parse().ok()) .map(|reason: JobStatusReason| { reason == JobStatusReason::SigCheckPassed @@ -139,12 +141,6 @@ impl FileContext { }) .unwrap_or(false) } - - pub fn updated_by(&self) -> Option { - self.status_details - .get(&heapless::String::from("updated_by")) - .and_then(|s| Version::from_str(s.as_str()).ok()) - } } #[cfg(test)] diff --git a/src/ota/error.rs b/src/ota/error.rs index ad533c4..3dd7101 100644 --- a/src/ota/error.rs +++ b/src/ota/error.rs @@ -2,7 +2,7 @@ use crate::jobs::JobError; use super::pal::OtaPalError; -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq)] #[cfg_attr(feature = "defmt", derive(defmt::Format))] pub enum OtaError { NoActiveJob, @@ -15,10 +15,10 @@ pub enum OtaError { ZeroFileSize, Overflow, InvalidFile, - Mqtt(mqttrust::MqttError), + Mqtt(embedded_mqtt::Error), Encoding, Pal, - Timer, + Timeout, } impl OtaError { @@ -27,14 +27,14 @@ impl OtaError { } } -impl From for OtaError { - fn from(e: mqttrust::MqttError) -> Self { +impl From for OtaError { + fn from(e: embedded_mqtt::Error) -> Self { Self::Mqtt(e) } } -impl From> for OtaError { - fn from(_e: OtaPalError) -> Self { +impl From for OtaError { + fn from(_e: OtaPalError) -> Self { Self::Pal } } @@ -44,7 +44,7 @@ impl From for OtaError { match e { JobError::Overflow => Self::Overflow, JobError::Encoding => Self::Encoding, - JobError::Mqtt(m) => Self::Mqtt(m), + JobError::Mqtt(e) => Self::Mqtt(e), } } } diff --git a/src/ota/mod.rs b/src/ota/mod.rs index a627e88..a366b6d 100644 --- a/src/ota/mod.rs +++ b/src/ota/mod.rs @@ -30,18 +30,343 @@ //! - Code sign verification //! - CBOR deserializer -pub mod agent; -pub mod builder; pub mod config; pub mod control_interface; pub mod data_interface; pub mod encoding; pub mod error; pub mod pal; -pub mod state; #[cfg(feature = "ota_mqtt_data")] pub use data_interface::mqtt::{Encoding, Topic}; -#[cfg(test)] -pub mod test; +use crate::{jobs::data_types::JobStatus, ota::encoding::json::JobStatusReason}; + +use self::{ + control_interface::ControlInterface, + data_interface::DataInterface, + encoding::FileContext, + pal::{ImageState, ImageStateReason}, +}; + +#[derive(PartialEq)] +pub struct JobEventData<'a> { + pub job_name: &'a str, + pub ota_document: &'a encoding::json::OtaJob<'a>, + pub status_details: Option<&'a crate::jobs::StatusDetails<'a>>, +} + +pub struct Updater; + +impl Updater { + pub async fn perform_ota<'a, C: ControlInterface, D: DataInterface>( + control: &C, + data: &D, + job_data: JobEventData<'a>, + pal: &mut impl pal::OtaPal, + config: config::Config, + ) -> Result<(), error::OtaError> { + let mut request_momentum = 0; + + // TODO: Handle request_momentum? + control.request_job().await?; + + let JobEventData { + job_name, + ota_document, + status_details, + } = job_data; + + let file_idx = 0; + + if ota_document + .files + .get(file_idx) + .map(|f| f.filesize) + .unwrap_or_default() + == 0 + { + return Err(error::OtaError::ZeroFileSize); + } + + let mut file_ctx = FileContext::new_from( + job_name, + ota_document, + status_details.map(|s| { + s.iter() + .map(|(&k, &v)| { + ( + heapless::String::try_from(k).unwrap(), + heapless::String::try_from(v).unwrap(), + ) + }) + .collect() + }), + file_idx, + &config, + )?; + + // If the job is in self test mode, don't start an OTA update but + // instead do the following: + // + // If the firmware that performed the update was older than the + // currently running firmware, set the image state to "Testing." This is + // the success path. + // + // If it's the same or newer, reject the job since either the firmware + // was not accepted during self test or an incorrect image was sent by + // the OTA operator. + let platform_self_test = pal + .get_platform_image_state() + .await + .map_or(false, |i| i == pal::PalImageState::PendingCommit); + + match (file_ctx.self_test(), platform_self_test) { + (true, true) => { + // Run self-test! + Self::set_image_state_with_reason( + control, + pal, + &config, + &mut file_ctx, + ImageState::Testing(ImageStateReason::VersionCheck), + ) + .await?; + + info!("Beginning self-test"); + + let test_fut = pal.complete_callback(pal::OtaEvent::StartTest); + + match config.self_test_timeout { + Some(timeout) => embassy_time::with_timeout(timeout, test_fut) + .await + .map_err(|_| error::OtaError::Timeout)?, + None => test_fut.await, + }?; + + control + .update_job_status( + &mut file_ctx, + &config, + JobStatus::Succeeded, + JobStatusReason::Accepted, + ) + .await?; + + return Ok(()); + } + (false, false) => {} + (false, true) => { + // Received a job that is not in self-test but platform is, so + // reboot the device to allow roll back to previous image. + error!("Rejecting new image and rebooting: The platform is in the self-test state while the job is not."); + pal.reset_device().await?; + } + (true, false) => { + // The job is in self test but the platform image state is not so it + // could be an attack on the platform image state. Reject the update + // (this should also cause the image to be erased), aborting the job + // and reset the device. + error!("Rejecting new image and rebooting: the job is in the self-test state while the platform is not."); + Self::set_image_state_with_reason( + control, + pal, + &config, + &mut file_ctx, + ImageState::Rejected(ImageStateReason::ImageStateMismatch), + ) + .await?; + pal.reset_device().await?; + } + } + + if !ota_document.protocols.contains(&D::PROTOCOL) { + error!("Unable to handle current OTA job with given data interface ({:?}). Supported protocols: {:?}. Aborting current update.", D::PROTOCOL, ota_document.protocols); + Self::set_image_state_with_reason( + control, + pal, + &config, + &mut file_ctx, + ImageState::Aborted(ImageStateReason::InvalidDataProtocol), + ) + .await?; + return Err(error::OtaError::InvalidInterface); + } + + info!("Job document was accepted. Attempting to begin the update"); + + // Create/Open the OTA file on the file system + if let Err(e) = pal.create_file_for_rx(&file_ctx).await { + Self::set_image_state_with_reason( + control, + pal, + &config, + &mut file_ctx, + ImageState::Aborted(ImageStateReason::Pal(e)), + ) + .await?; + + pal.close_file(&file_ctx).await?; + return Err(e.into()); + } + + // Prepare the storage layer on receiving a new file + match data.init_file_transfer(&mut file_ctx).await { + Err(e) => { + return if request_momentum < config.max_request_momentum { + // Start request timer + // self.request_timer + // .start(config.request_wait.millis()) + // .map_err(|_| error::OtaError::Timer)?; + + request_momentum += 1; + Err(e) + } else { + // Stop request timer + // self.request_timer + // .cancel() + // .map_err(|_| error::OtaError::Timer)?; + + // Too many requests have been sent without a response or + // too many failures when trying to publish the request + // message. Abort. + + Err(error::OtaError::MomentumAbort) + }; + } + Ok(_) => { + // Reset the request momentum + request_momentum = 0; + + // TODO: Reset the OTA statistics + + info!("Initialized file handler! Requesting file blocks"); + } + } + + // Request data + if file_ctx.blocks_remaining > 0 { + if request_momentum <= config.max_request_momentum { + // Each request increases the momentum until a response is + // received. Too much momentum is interpreted as a failure to + // communicate and will cause us to abort the OTA. + request_momentum += 1; + + // Request data blocks + data.request_file_block(&mut file_ctx, &config).await?; + } else { + // Stop the request timer + // self.request_timer.cancel().map_err(|_| error::OtaError::Timer)?; + + // Failed to send data request abort and close file. + Self::set_image_state_with_reason( + control, + pal, + &config, + &mut file_ctx, + ImageState::Aborted(ImageStateReason::MomentumAbort), + ) + .await?; + + // Reset the request momentum + request_momentum = 0; + + // Too many requests have been sent without a response or too + // many failures when trying to publish the request message. + // Abort. + return Err(error::OtaError::MomentumAbort); + } + } else { + return Err(error::OtaError::BlockOutOfRange); + } + + Ok(()) + } + + async fn set_image_state_with_reason<'a, C: ControlInterface, PAL: pal::OtaPal>( + control: &C, + pal: &mut PAL, + config: &config::Config, + file_ctx: &mut FileContext, + image_state: ImageState, + ) -> Result<(), error::OtaError> { + // debug!("set_image_state_with_reason {:?}", image_state); + // Call the platform specific code to set the image state + + // FIXME: + let image_state = match pal.set_platform_image_state(image_state).await { + Err(e) if !matches!(image_state, ImageState::Aborted(_)) => { + // If the platform image state couldn't be set correctly, force + // fail the update by setting the image state to "Rejected" + // unless it's already in "Aborted". + + // Capture the failure reason if not already set (and we're not + // already Aborted as checked above). Otherwise Keep the + // original reject reason code since it is possible for the PAL + // to fail to update the image state in some cases (e.g. a reset + // already caused the bundle rollback and we failed to rollback + // again). + + // Intentionally override reason since we failed within this + // function + ImageState::Rejected(ImageStateReason::Pal(e)) + } + _ => image_state, + }; + + // Now update the image state and job status on server side + match image_state { + ImageState::Testing(_) => { + // We discovered we're ready for test mode, put job status + // in self_test active + control + .update_job_status( + file_ctx, + config, + JobStatus::InProgress, + JobStatusReason::SelfTestActive, + ) + .await?; + } + ImageState::Accepted => { + // Now that we have accepted the firmware update, we can + // complete the job + control + .update_job_status( + file_ctx, + config, + JobStatus::Succeeded, + JobStatusReason::Accepted, + ) + .await?; + } + ImageState::Rejected(_) => { + // The firmware update was rejected, complete the job as + // FAILED (Job service will not allow us to set REJECTED + // after the job has been started already). + control + .update_job_status( + file_ctx, + config, + JobStatus::Failed, + JobStatusReason::Rejected, + ) + .await?; + } + _ => { + // The firmware update was aborted, complete the job as + // FAILED (Job service will not allow us to set REJECTED + // after the job has been started already). + control + .update_job_status( + file_ctx, + config, + JobStatus::Failed, + JobStatusReason::Aborted, + ) + .await?; + } + } + Ok(()) + } +} diff --git a/src/ota/pal.rs b/src/ota/pal.rs index 2163a5d..409f5ed 100644 --- a/src/ota/pal.rs +++ b/src/ota/pal.rs @@ -1,23 +1,33 @@ //! Platform abstraction trait for OTA updates - -use core::fmt::Write; -use core::str::FromStr; - use super::encoding::FileContext; -use super::state::ImageStateReason; +#[derive(Debug, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum ImageStateReason { + NewerJob, + FailedIngest, + MomentumAbort, + ImageStateMismatch, + SignatureCheckPassed, + InvalidDataProtocol, + UserAbort, + VersionCheck, + Pal(OtaPalError), +} + +#[derive(Debug, Clone, Copy)] #[cfg_attr(feature = "defmt", derive(defmt::Format))] -pub enum ImageState { +pub enum ImageState { Unknown, - Aborted(ImageStateReason), - Rejected(ImageStateReason), + Aborted(ImageStateReason), + Rejected(ImageStateReason), Accepted, - Testing(ImageStateReason), + Testing(ImageStateReason), } -#[derive(Debug)] +#[derive(Debug, Clone, Copy)] #[cfg_attr(feature = "defmt", derive(defmt::Format))] -pub enum OtaPalError { +pub enum OtaPalError { SignatureCheckFailed, FileWriteFailed, FileTooLarge, @@ -27,13 +37,7 @@ pub enum OtaPalError { BadImageState, CommitFailed, VersionCheck, - Custom(E), -} - -impl From for OtaPalError { - fn from(value: E) -> Self { - Self::Custom(value) - } + Other, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -62,93 +66,8 @@ pub enum OtaEvent { UpdateComplete, } -#[derive(Debug, Clone, Eq)] -pub struct Version { - major: u8, - minor: u8, - patch: u8, -} - -#[cfg(feature = "defmt")] -impl defmt::Format for Version { - fn format(&self, fmt: defmt::Formatter) { - defmt::write!(fmt, "{=u8}.{=u8}.{=u8}", self.major, self.minor, self.patch) - } -} - -impl Default for Version { - fn default() -> Self { - Self::new(0, 0, 0) - } -} - -impl FromStr for Version { - type Err = (); - - fn from_str(s: &str) -> Result { - let mut iter = s.split('.'); - Ok(Self { - major: iter.next().and_then(|v| v.parse().ok()).ok_or(())?, - minor: iter.next().and_then(|v| v.parse().ok()).ok_or(())?, - patch: iter.next().and_then(|v| v.parse().ok()).ok_or(())?, - }) - } -} - -impl Version { - pub fn new(major: u8, minor: u8, patch: u8) -> Self { - Self { - major, - minor, - patch, - } - } - - pub fn to_string(&self) -> heapless::String { - let mut s = heapless::String::new(); - s.write_fmt(format_args!("{}.{}.{}", self.major, self.minor, self.patch)) - .unwrap(); - s - } -} - -impl core::cmp::PartialEq for Version { - #[inline] - fn eq(&self, other: &Self) -> bool { - self.major == other.major && self.minor == other.minor && self.patch == other.patch - } -} - -impl core::cmp::PartialOrd for Version { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl core::cmp::Ord for Version { - fn cmp(&self, other: &Self) -> core::cmp::Ordering { - match self.major.cmp(&other.major) { - core::cmp::Ordering::Equal => {} - r => return r, - } - - match self.minor.cmp(&other.minor) { - core::cmp::Ordering::Equal => {} - r => return r, - } - - match self.patch.cmp(&other.patch) { - core::cmp::Ordering::Equal => {} - r => return r, - } - - core::cmp::Ordering::Equal - } -} /// Platform abstraction layer for OTA jobs pub trait OtaPal { - type Error; - /// OTA abort. /// /// The user may register a callback function when initializing the OTA @@ -156,7 +75,7 @@ pub trait OtaPal { /// aborted. /// /// - `file`: [`FileContext`] File description of the job being aborted - fn abort(&mut self, file: &FileContext) -> Result<(), OtaPalError>; + async fn abort(&mut self, file: &FileContext) -> Result<(), OtaPalError>; /// Activate the newest MCU image received via OTA. /// @@ -168,8 +87,8 @@ pub trait OtaPal { /// /// **return**: The OTA PAL layer error code combined with the MCU specific /// error code. - fn activate_new_image(&mut self) -> Result<(), OtaPalError> { - self.reset_device() + async fn activate_new_image(&mut self) -> Result<(), OtaPalError> { + self.reset_device().await } /// OTA create file to store received data. @@ -179,7 +98,7 @@ pub trait OtaPal { /// is created. /// /// - `file`: [`FileContext`] File description of the job being aborted - fn create_file_for_rx(&mut self, file: &FileContext) -> Result<(), OtaPalError>; + async fn create_file_for_rx(&mut self, file: &FileContext) -> Result<(), OtaPalError>; /// Get the state of the OTA update image. /// @@ -196,7 +115,7 @@ pub trait OtaPal { /// timer is not started. /// /// **return** An [`PalImageState`]. - fn get_platform_image_state(&mut self) -> Result>; + async fn get_platform_image_state(&mut self) -> Result; /// Attempt to set the state of the OTA update image. /// @@ -208,10 +127,10 @@ pub trait OtaPal { /// /// **return** The [`OtaPalError`] error code combined with the MCU specific /// error code. - fn set_platform_image_state( + async fn set_platform_image_state( &mut self, - image_state: ImageState, - ) -> Result<(), OtaPalError>; + image_state: ImageState, + ) -> Result<(), OtaPalError>; /// Reset the device. /// @@ -222,7 +141,7 @@ pub trait OtaPal { /// /// **return** The OTA PAL layer error code combined with the MCU specific /// error code. - fn reset_device(&mut self) -> Result<(), OtaPalError>; + async fn reset_device(&mut self) -> Result<(), OtaPalError>; /// Authenticate and close the underlying receive file in the specified OTA /// context. @@ -234,7 +153,7 @@ pub trait OtaPal { /// /// **return** The OTA PAL layer error code combined with the MCU specific /// error code. - fn close_file(&mut self, file: &FileContext) -> Result<(), OtaPalError>; + async fn close_file(&mut self, file: &FileContext) -> Result<(), OtaPalError>; /// Write a block of data to the specified file at the given offset. /// @@ -245,12 +164,12 @@ pub trait OtaPal { /// /// **return** The number of bytes written on a success, or a negative error /// code from the platform abstraction layer. - fn write_block( + async fn write_block( &mut self, file: &FileContext, block_offset: usize, block_payload: &[u8], - ) -> Result>; + ) -> Result; /// OTA update complete. /// @@ -284,9 +203,9 @@ pub trait OtaPal { /// the OTA update job has failed in some way and should be rejected. /// /// - `event` [`OtaEvent`] An OTA update event from the `OtaEvent` enum. - fn complete_callback(&mut self, event: OtaEvent) -> Result<(), OtaPalError> { + async fn complete_callback(&mut self, event: OtaEvent) -> Result<(), OtaPalError> { match event { - OtaEvent::Activate => self.activate_new_image(), + OtaEvent::Activate => self.activate_new_image().await, OtaEvent::Fail | OtaEvent::UpdateComplete => { // Nothing special to do. The OTA agent handles it Ok(()) @@ -294,7 +213,7 @@ pub trait OtaPal { OtaEvent::StartTest => { // Accept the image since it was a good transfer // and networking and services are all working. - self.set_platform_image_state(ImageState::Accepted)?; + self.set_platform_image_state(ImageState::Accepted).await?; Ok(()) } OtaEvent::SelfTestFailed => { @@ -308,7 +227,4 @@ pub trait OtaPal { } } } - - /// - fn get_active_firmware_version(&self) -> Result>; } diff --git a/src/ota/state.rs b/src/ota/state.rs deleted file mode 100644 index 3f4e576..0000000 --- a/src/ota/state.rs +++ /dev/null @@ -1,1181 +0,0 @@ -use smlang::statemachine; - -use super::config::Config; -use super::control_interface::ControlInterface; -use super::data_interface::{DataInterface, Protocol}; -use super::encoding::json::JobStatusReason; -use super::encoding::json::OtaJob; -use super::encoding::FileContext; -use super::pal::OtaPal; -use super::pal::OtaPalError; - -use crate::jobs::{data_types::JobStatus, StatusDetails}; -use crate::ota::encoding::Bitmap; -use crate::ota::pal::OtaEvent; - -use fugit_timer::ExtU32; - -use super::{ - error::OtaError, - pal::{ImageState, PalImageState}, -}; - -#[cfg_attr(feature = "defmt", derive(defmt::Format))] -pub enum ImageStateReason { - NewerJob, - FailedIngest, - MomentumAbort, - ImageStateMismatch, - SignatureCheckPassed, - InvalidDataProtocol, - UserAbort, - VersionCheck, - Pal(OtaPalError), -} - -#[derive(Debug, Clone, Copy, PartialEq)] -pub enum RestartReason { - Activate(u8), - Restart(u8), -} - -impl RestartReason { - #[must_use] - pub fn inc(self) -> Self { - match self { - Self::Activate(cnt) => Self::Activate(cnt + 1), - Self::Restart(cnt) => Self::Restart(cnt + 1), - } - } -} - -#[derive(PartialEq)] -pub struct JobEventData<'a> { - pub job_name: &'a str, - pub ota_document: &'a OtaJob<'a>, - pub status_details: Option<&'a StatusDetails<'a>>, -} - -statemachine! { - guard_error: OtaError, - transitions: { - *Ready + Start [start_handler] = RequestingJob, - RequestingJob | WaitingForFileBlock + RequestJobDocument [request_job_handler] = WaitingForJob, - RequestingJob + RequestTimer [request_job_handler] = WaitingForJob, - RequestingJob + ContinueJob = WaitingForFileBlock, - RequestingJob + ReplacementJob(JobEventData<'a>) [process_job_handler] = CreatingFile, - WaitingForJob + RequestJobDocument [request_job_handler] = WaitingForJob, - WaitingForJob + ReceivedJobDocument(JobEventData<'a>) [process_job_handler] = CreatingFile, - CreatingFile + StartSelfTest [in_self_test_handler] = WaitingForJob, - CreatingFile + CreateFile [init_file_handler] = RequestingFileBlock, - CreatingFile + RequestTimer [init_file_handler] = RequestingFileBlock, - CreatingFile | WaitingForJob | Restarting + Restart(RestartReason) [restart_handler] = Restarting, - RequestingFileBlock | WaitingForFileBlock + RequestFileBlock [request_data_handler] = WaitingForFileBlock, - RequestingFileBlock | WaitingForFileBlock + RequestTimer [request_data_handler] = WaitingForFileBlock, - WaitingForFileBlock + ReceivedFileBlock(&'a mut [u8]) [process_data_handler] = WaitingForFileBlock, - WaitingForFileBlock + ReceivedJobDocument(JobEventData<'a>) [job_notification_handler] = RequestingJob, - WaitingForFileBlock + CloseFile [close_file_handler] = WaitingForJob, - Suspended | RequestingJob | WaitingForJob | CreatingFile | RequestingFileBlock | WaitingForFileBlock + Resume [resume_job_handler] = RequestingJob, - Ready | RequestingJob | WaitingForJob | CreatingFile | RequestingFileBlock | WaitingForFileBlock + Suspend = Suspended, - Ready | RequestingJob | WaitingForJob | CreatingFile | RequestingFileBlock | WaitingForFileBlock + UserAbort [user_abort_handler] = WaitingForJob, - Ready | RequestingJob | WaitingForJob | CreatingFile | RequestingFileBlock | WaitingForFileBlock + Shutdown [shutdown_handler] = Ready, - } -} - -#[cfg(feature = "defmt")] -impl defmt::Format for Error { - fn format(&self, fmt: defmt::Formatter) { - match self { - Error::InvalidEvent => defmt::write!(fmt, "Error::InvalidEvent"), - Error::GuardFailed(e) => defmt::write!(fmt, "Error::GuardFailed({:?})", e), - } - } -} - -pub(crate) enum Interface { - Primary(FileContext), - #[cfg(all(feature = "ota_mqtt_data", feature = "ota_http_data"))] - Secondary(FileContext), -} - -impl Interface { - pub const fn file_ctx(&self) -> &FileContext { - match self { - Interface::Primary(i) => i, - #[cfg(all(feature = "ota_mqtt_data", feature = "ota_http_data"))] - Interface::Secondary(i) => i, - } - } - - pub fn mut_file_ctx(&mut self) -> &mut FileContext { - match self { - Interface::Primary(i) => i, - #[cfg(all(feature = "ota_mqtt_data", feature = "ota_http_data"))] - Interface::Secondary(i) => i, - } - } -} - -macro_rules! data_interface { - ($self:ident.$func:ident $(,$y:expr),*) => { - match $self.active_interface { - Some(Interface::Primary(ref mut ctx)) => $self.data_primary.$func(ctx, $($y),*), - #[cfg(all(feature = "ota_mqtt_data", feature = "ota_http_data"))] - Some(Interface::Secondary(ref mut ctx)) => $self.data_secondary.as_mut().ok_or(OtaError::InvalidInterface)?.$func(ctx, $($y),*), - _ => Err(OtaError::InvalidInterface) - } - }; -} - -// Context of current OTA Job, keeping state -pub(crate) struct SmContext<'a, C, DP, DS, T, ST, PAL, const L: usize, const TIMER_HZ: u32> -where - C: ControlInterface, - DP: DataInterface, - DS: DataInterface, - T: fugit_timer::Timer, - ST: fugit_timer::Timer, - PAL: OtaPal, -{ - pub(crate) events: heapless::spsc::Queue, L>, - pub(crate) control: &'a C, - pub(crate) data_primary: DP, - #[cfg(all(feature = "ota_mqtt_data", feature = "ota_http_data"))] - pub(crate) data_secondary: Option, - #[cfg(not(all(feature = "ota_mqtt_data", feature = "ota_http_data")))] - pub(crate) data_secondary: core::marker::PhantomData, - pub(crate) active_interface: Option, - pub(crate) pal: PAL, - pub(crate) request_momentum: u8, - pub(crate) request_timer: T, - pub(crate) self_test_timer: Option, - pub(crate) config: Config, - pub(crate) image_state: ImageState, -} - -impl<'a, C, DP, DS, T, ST, PAL, const L: usize, const TIMER_HZ: u32> - SmContext<'a, C, DP, DS, T, ST, PAL, L, TIMER_HZ> -where - C: ControlInterface, - DP: DataInterface, - DS: DataInterface, - T: fugit_timer::Timer, - ST: fugit_timer::Timer, - PAL: OtaPal, -{ - /// Called to update the filecontext structure from the job - fn get_file_context_from_job( - &mut self, - job_name: &str, - ota_document: &OtaJob, - status_details: Option, - ) -> Result { - let file_idx = 0; - - if ota_document - .files - .get(file_idx) - .map(|f| f.filesize) - .unwrap_or_default() - == 0 - { - return Err(OtaError::ZeroFileSize); - } - - // If there's an active job, verify that it's the same as what's being - // reported now - let cur_file_ctx = self.active_interface.as_mut().map(|i| i.mut_file_ctx()); - let file_ctx = if let Some(file_ctx) = cur_file_ctx { - if file_ctx.stream_name != ota_document.streamname { - info!("New job document received, aborting current job"); - - // Abort the current job - // TODO:?? - self.pal - .set_platform_image_state(ImageState::Aborted(ImageStateReason::NewerJob))?; - - // Abort any active file access and release the file resource, - // if needed - self.pal.abort(file_ctx)?; - - // Cleanup related to selected protocol - data_interface!(self.cleanup, &self.config)?; - - // Set new active job - Ok(FileContext::new_from( - job_name, - ota_document, - status_details.map(|s| { - s.iter() - .map(|(&k, &v)| (heapless::String::from(k), heapless::String::from(v))) - .collect() - }), - file_idx, - &self.config, - self.pal.get_active_firmware_version()?, - )?) - } else { - // The same job is being reported so update the url - info!("New job document ID is identical to the current job: Updating the URL based on the new job document"); - file_ctx.update_data_url = ota_document - .files - .get(0) - .map(|f| f.update_data_url.map(heapless::String::from)) - .ok_or(OtaError::InvalidFile)?; - - Err(file_ctx.clone()) - } - } else { - Ok(FileContext::new_from( - job_name, - ota_document, - status_details.map(|s| { - s.iter() - .map(|(&k, &v)| (heapless::String::from(k), heapless::String::from(v))) - .collect() - }), - file_idx, - &self.config, - self.pal.get_active_firmware_version()?, - )?) - }; - - // If the job is in self test mode, don't start an OTA update but - // instead do the following: - // - // If the firmware that performed the update was older than the - // currently running firmware, set the image state to "Testing." This is - // the success path. - // - // If it's the same or newer, reject the job since either the firmware - // was not accepted during self test or an incorrect image was sent by - // the OTA operator. - let mut file_ctx = match file_ctx { - Ok(mut file_ctx) if file_ctx.self_test() => { - self.handle_self_test_job(&mut file_ctx)?; - return Ok(file_ctx); - } - Ok(file_ctx) => { - info!("Job document was accepted. Attempting to begin the update"); - file_ctx - } - Err(file_ctx) => { - info!("Job document for receiving an update received"); - // Don't create file again on update. - return Ok(file_ctx); - } - }; - - // Create/Open the OTA file on the file system - if let Err(e) = self.pal.create_file_for_rx(&file_ctx) { - self.image_state = Self::set_image_state_with_reason( - self.control, - &mut self.pal, - &self.config, - &mut file_ctx, - ImageState::Aborted(ImageStateReason::Pal(e)), - )?; - - self.ota_close()?; - // FIXME: - return Err(OtaError::Pal); - // return Err(e.into()); - } - - Ok(file_ctx) - } - - fn select_interface( - &self, - file_ctx: FileContext, - protocols: &[Protocol], - ) -> Result { - if protocols.contains(&DP::PROTOCOL) { - Ok(Interface::Primary(file_ctx)) - } else { - #[cfg(all(feature = "ota_mqtt_data", feature = "ota_http_data"))] - if protocols.contains(&DS::PROTOCOL) && self.data_secondary.is_some() { - Ok(Interface::Secondary(file_ctx)) - } else { - Err(file_ctx) - } - - #[cfg(not(all(feature = "ota_mqtt_data", feature = "ota_http_data")))] - Err(file_ctx) - } - } - - /// Check if the current image is `PendingCommit` and thus is in selftest - fn platform_in_selftest(&mut self) -> bool { - // Get the platform state from the OTA pal layer - self.pal - .get_platform_image_state() - .map_or(false, |i| i == PalImageState::PendingCommit) - } - - /// Validate update version when receiving job doc in self test state - fn handle_self_test_job(&mut self, file_ctx: &mut FileContext) -> Result<(), OtaError> { - info!("In self test mode"); - - let active_version = self.pal.get_active_firmware_version().unwrap_or_default(); - - let version_check = if file_ctx.fileid == 0 && file_ctx.file_type == Some(0) { - // Only check for versions if the target is self & always allow - // updates if updated_by is not present. - file_ctx - .updated_by() - .map_or(true, |updated_by| active_version > updated_by) - } else { - true - }; - - info!("Version check: {:?}", version_check); - - if self.config.allow_downgrade || version_check { - // The running firmware version is newer than the firmware that - // performed the update or downgrade is allowed so this means we're - // ready to start the self test phase. - // - // Set image state accordingly and update job status with self test - // identifier. - self.image_state = Self::set_image_state_with_reason( - self.control, - &mut self.pal, - &self.config, - file_ctx, - ImageState::Testing(ImageStateReason::VersionCheck), - )?; - - Ok(()) - } else { - self.image_state = Self::set_image_state_with_reason( - self.control, - &mut self.pal, - &self.config, - file_ctx, - ImageState::Rejected(ImageStateReason::VersionCheck), - )?; - - self.pal.complete_callback(OtaEvent::SelfTestFailed)?; - - // Handle self-test failure in the platform specific implementation, - // example, reset the device in case of firmware upgrade. - self.events - .enqueue(Events::Restart(RestartReason::Restart(0))) - .map_err(|_| OtaError::SignalEventFailed)?; - Ok(()) - } - } - - fn set_image_state_with_reason( - control: &C, - _pal: &mut PAL, - config: &Config, - file_ctx: &mut FileContext, - image_state: ImageState, - ) -> Result, OtaError> { - // debug!("set_image_state_with_reason {:?}", image_state); - // Call the platform specific code to set the image state - - // FIXME: - // let image_state = match pal.set_platform_image_state(image_state) { - // Err(e) if !matches!(image_state, ImageState::Aborted(_)) => { - // If the platform image state couldn't be set correctly, force - // fail the update by setting the image state to "Rejected" - // unless it's already in "Aborted". - - // Capture the failure reason if not already set (and we're not - // already Aborted as checked above). Otherwise Keep the - // original reject reason code since it is possible for the PAL - // to fail to update the image state in some cases (e.g. a reset - // already caused the bundle rollback and we failed to rollback - // again). - // - // Intentionally override reason since we failed within this - // function - // ImageState::Rejected(ImageStateReason::Pal(e)) - // } - // _ => image_state, - // }; - - // Now update the image state and job status on server side - match image_state { - ImageState::Testing(_) => { - // We discovered we're ready for test mode, put job status - // in self_test active - control.update_job_status( - file_ctx, - config, - JobStatus::InProgress, - JobStatusReason::SelfTestActive, - )?; - } - ImageState::Accepted => { - // Now that we have accepted the firmware update, we can - // complete the job - control.update_job_status( - file_ctx, - config, - JobStatus::Succeeded, - JobStatusReason::Accepted, - )?; - } - ImageState::Rejected(_) => { - // The firmware update was rejected, complete the job as - // FAILED (Job service will not allow us to set REJECTED - // after the job has been started already). - control.update_job_status( - file_ctx, - config, - JobStatus::Failed, - JobStatusReason::Rejected, - )?; - } - _ => { - // The firmware update was aborted, complete the job as - // FAILED (Job service will not allow us to set REJECTED - // after the job has been started already). - control.update_job_status( - file_ctx, - config, - JobStatus::Failed, - JobStatusReason::Aborted, - )?; - } - } - Ok(image_state) - } - - pub fn ota_close(&mut self) -> Result<(), OtaError> { - // Cleanup related to selected protocol. - data_interface!(self.cleanup, &self.config)?; - - // Abort any active file access and release the file resource, if needed - let file_ctx = self - .active_interface - .as_ref() - .ok_or(OtaError::InvalidInterface)? - .file_ctx(); - - self.pal.abort(file_ctx)?; - - self.active_interface = None; - Ok(()) - } - - fn ingest_data_block(&mut self, payload: &mut [u8]) -> Result { - let block = data_interface!(self.decode_file_block, payload)?; - - let file_ctx = self - .active_interface - .as_mut() - .ok_or(OtaError::InvalidInterface)? - .mut_file_ctx(); - - if block.validate(self.config.block_size, file_ctx.filesize) { - if block.block_id < file_ctx.block_offset as usize - || !file_ctx - .bitmap - .get(block.block_id - file_ctx.block_offset as usize) - { - info!( - "Block {:?} is a DUPLICATE. {:?} blocks remaining.", - block.block_id, file_ctx.blocks_remaining - ); - - // Just return same progress as before - return Ok(false); - } - - info!( - "Received block {}. {:?} blocks remaining.", - block.block_id, file_ctx.blocks_remaining - ); - - self.pal.write_block( - file_ctx, - block.block_id * self.config.block_size, - block.block_payload, - )?; - - file_ctx - .bitmap - .set(block.block_id - file_ctx.block_offset as usize, false); - - file_ctx.blocks_remaining -= 1; - - if file_ctx.blocks_remaining == 0 { - info!("Received final expected block of file."); - - // Stop the request timer - self.request_timer.cancel().map_err(|_| OtaError::Timer)?; - - self.pal.close_file(file_ctx)?; - - // Return true to indicate end of file. - Ok(true) - } else { - if file_ctx.bitmap.is_empty() { - file_ctx.block_offset += 31; - file_ctx.bitmap = Bitmap::new( - file_ctx.filesize, - self.config.block_size, - file_ctx.block_offset, - ); - } - - Ok(false) - } - } else { - error!( - "Error! Block {:?} out of expected range! Size {:?}", - block.block_id, block.block_size - ); - - Err(OtaError::BlockOutOfRange) - } - } -} - -impl<'a, C, DP, DS, T, ST, PAL, const L: usize, const TIMER_HZ: u32> StateMachineContext - for SmContext<'a, C, DP, DS, T, ST, PAL, L, TIMER_HZ> -where - C: ControlInterface, - DP: DataInterface, - DS: DataInterface, - T: fugit_timer::Timer, - ST: fugit_timer::Timer, - PAL: OtaPal, -{ - fn restart_handler(&mut self, reason: &RestartReason) -> Result<(), OtaError> { - debug!("restart_handler"); - match reason { - RestartReason::Activate(cnt) if *cnt > self.config.activate_delay => { - info!("Application callback! OtaEvent::Activate"); - self.pal.complete_callback(OtaEvent::Activate)?; - } - RestartReason::Restart(cnt) if *cnt > self.config.activate_delay => { - self.pal.reset_device()?; - } - r => { - self.events - .enqueue(Events::Restart(r.inc())) - .map_err(|_| OtaError::SignalEventFailed)?; - } - } - Ok(()) - } - - /// Start timers and initiate request for job document - fn start_handler(&mut self) -> Result<(), OtaError> { - debug!("start_handler"); - // Start self-test timer, if platform is in self-test. - if self.platform_in_selftest() { - // Start self-test timer - if let Some(ref mut self_test_timer) = self.self_test_timer { - self_test_timer - .start(self.config.self_test_timeout_ms.millis()) - .map_err(|_| OtaError::Timer)?; - } - } - - // Initialize the control interface - self.control.init()?; - - // Send event to OTA task to get job document - self.events - .enqueue(Events::RequestJobDocument) - .map_err(|_| OtaError::SignalEventFailed) - } - - fn resume_job_handler(&mut self) -> Result<(), OtaError> { - debug!("resume_job_handler"); - - // Initialize the control interface - self.control.init()?; - - // Send signal to request job document - self.events - .enqueue(Events::RequestJobDocument) - .map_err(|_| OtaError::SignalEventFailed) - } - - /// Initiate a request for a job - fn request_job_handler(&mut self) -> Result<(), OtaError> { - debug!("request_job_handler"); - match self.control.request_job() { - Err(e) => { - if self.request_momentum < self.config.max_request_momentum { - // Start request timer - self.request_timer - .start(self.config.request_wait_ms.millis()) - .map_err(|_| OtaError::Timer)?; - - self.request_momentum += 1; - Err(e) - } else { - // Stop request timer - self.request_timer.cancel().map_err(|_| OtaError::Timer)?; - - // Send shutdown event to the OTA Agent task - self.events - .enqueue(Events::Shutdown) - .map_err(|_| OtaError::SignalEventFailed)?; - - // Too many requests have been sent without a response or - // too many failures when trying to publish the request - // message. Abort. - Err(OtaError::MomentumAbort) - } - } - Ok(_) => { - // Stop request timer - self.request_timer.cancel().map_err(|_| OtaError::Timer)?; - - // Reset the request momentum - self.request_momentum = 0; - Ok(()) - } - } - } - - /// Initialize and handle file transfer - fn init_file_handler(&mut self) -> Result<(), OtaError> { - debug!("init_file_handler"); - match data_interface!(self.init_file_transfer) { - Err(e) => { - if self.request_momentum < self.config.max_request_momentum { - // Start request timer - self.request_timer - .start(self.config.request_wait_ms.millis()) - .map_err(|_| OtaError::Timer)?; - - self.request_momentum += 1; - Err(e) - } else { - // Stop request timer - self.request_timer.cancel().map_err(|_| OtaError::Timer)?; - - // Send shutdown event to the OTA Agent task - self.events - .enqueue(Events::Shutdown) - .map_err(|_| OtaError::SignalEventFailed)?; - - // Too many requests have been sent without a response or - // too many failures when trying to publish the request - // message. Abort. - - Err(OtaError::MomentumAbort) - } - } - Ok(_) => { - // Reset the request momentum - self.request_momentum = 0; - - // TODO: Reset the OTA statistics - - info!("Initialized file handler! Requesting file blocks"); - - self.events - .enqueue(Events::RequestFileBlock) - .map_err(|_| OtaError::SignalEventFailed)?; - - Ok(()) - } - } - } - - /// Handle self test - fn in_self_test_handler(&mut self) -> Result<(), OtaError> { - info!("Beginning self-test"); - // Check the platform's OTA update image state. It should also be in - // self test - let in_self_test = self.platform_in_selftest(); - // Clear self-test flag - let file_ctx = self - .active_interface - .as_mut() - .ok_or(OtaError::InvalidInterface)? - .mut_file_ctx(); - - if in_self_test { - self.pal.complete_callback(OtaEvent::StartTest)?; - info!("Application callback! OtaEvent::StartTest"); - - self.image_state = ImageState::Accepted; - self.control.update_job_status( - file_ctx, - &self.config, - JobStatus::Succeeded, - JobStatusReason::Accepted, - )?; - - file_ctx - .status_details - .insert( - heapless::String::from("self_test"), - heapless::String::from(JobStatusReason::Accepted.as_str()), - ) - .map_err(|_| OtaError::Overflow)?; - - // Stop the self test timer as it is no longer required - if let Some(ref mut self_test_timer) = self.self_test_timer { - self_test_timer.cancel().map_err(|_| OtaError::Timer)?; - } - } else { - // The job is in self test but the platform image state is not so it - // could be an attack on the platform image state. Reject the update - // (this should also cause the image to be erased), aborting the job - // and reset the device. - error!("Rejecting new image and rebooting: the job is in the self-test state while the platform is not."); - self.image_state = Self::set_image_state_with_reason( - self.control, - &mut self.pal, - &self.config, - file_ctx, - ImageState::Rejected(ImageStateReason::ImageStateMismatch), - )?; - - self.events - .enqueue(Events::Restart(RestartReason::Restart(0))) - .map_err(|_| OtaError::SignalEventFailed)?; - } - Ok(()) - } - - /// Update file context from job document - fn process_job_handler(&mut self, data: &JobEventData<'_>) -> Result<(), OtaError> { - let JobEventData { - job_name, - ota_document, - status_details, - } = data; - - let file_ctx = self.get_file_context_from_job( - job_name, - ota_document, - status_details.map(Clone::clone), - )?; - - match self.select_interface(file_ctx, &ota_document.protocols) { - Ok(interface) => { - info!("Setting OTA data interface"); - self.active_interface = Some(interface); - } - Err(mut file_ctx) => { - // Failed to set the data interface so abort the OTA. If there - // is a valid job id, then a job status update will be sent. - - error!("Failed to set OTA data interface. Aborting current update."); - - self.image_state = Self::set_image_state_with_reason( - self.control, - &mut self.pal, - &self.config, - &mut file_ctx, - ImageState::Aborted(ImageStateReason::InvalidDataProtocol), - )?; - return Err(OtaError::InvalidInterface); - } - } - - if self - .active_interface - .as_mut() - .ok_or(OtaError::InvalidInterface)? - .file_ctx() - .self_test() - { - // If the OTA job is in the self_test state, alert the application layer. - if matches!(self.image_state, ImageState::Testing(_)) { - self.events - .enqueue(Events::StartSelfTest) - .map_err(|_| OtaError::SignalEventFailed)?; - - Ok(()) - } else { - Err(OtaError::InvalidFile) - } - } else { - if !self.platform_in_selftest() { - // Received a valid context so send event to request file blocks - self.events - .enqueue(Events::CreateFile) - .map_err(|_| OtaError::SignalEventFailed)?; - } else { - // Received a job that is not in self-test but platform is, so - // reboot the device to allow roll back to previous image. - error!("Rejecting new image and rebooting: The platform is in the self-test state while the job is not."); - self.events - .enqueue(Events::Restart(RestartReason::Restart(0))) - .map_err(|_| OtaError::SignalEventFailed)?; - } - Ok(()) - } - } - - /// Request for data blocks - fn request_data_handler(&mut self) -> Result<(), OtaError> { - debug!("request_data_handler"); - let file_ctx = self - .active_interface - .as_mut() - .ok_or(OtaError::InvalidInterface)? - .mut_file_ctx(); - if file_ctx.blocks_remaining > 0 { - // Start the request timer - self.request_timer - .start(self.config.request_wait_ms.millis()) - .map_err(|_| OtaError::Timer)?; - - if self.request_momentum <= self.config.max_request_momentum { - // Each request increases the momentum until a response is - // received. Too much momentum is interpreted as a failure to - // communicate and will cause us to abort the OTA. - self.request_momentum += 1; - - // Request data blocks - data_interface!(self.request_file_block, &self.config) - } else { - // Stop the request timer - self.request_timer.cancel().map_err(|_| OtaError::Timer)?; - - // Failed to send data request abort and close file. - self.image_state = Self::set_image_state_with_reason( - self.control, - &mut self.pal, - &self.config, - file_ctx, - ImageState::Aborted(ImageStateReason::MomentumAbort), - )?; - - warn!("Shutdown [request_data_handler]"); - self.events - .enqueue(Events::Shutdown) - .map_err(|_| OtaError::SignalEventFailed)?; - - // Reset the request momentum - self.request_momentum = 0; - - // Too many requests have been sent without a response or too - // many failures when trying to publish the request message. - // Abort. - Err(OtaError::MomentumAbort) - } - } else { - Err(OtaError::BlockOutOfRange) - } - } - - /// Upon receiving a new job document cancel current job if present and - /// initiate new download - fn job_notification_handler(&mut self, data: &JobEventData<'_>) -> Result<(), OtaError> { - if let Some(ref mut interface) = self.active_interface { - if interface.file_ctx().job_name.as_str() == data.job_name { - self.events - .enqueue(Events::ContinueJob) - .map_err(|_| OtaError::SignalEventFailed)?; - return Ok(()); - } else { - // Stop the request timer - self.request_timer.cancel().map_err(|_| OtaError::Timer)?; - - // Abort the current job - // TODO: This should never write to current image flags?! - self.pal - .set_platform_image_state(ImageState::Aborted(ImageStateReason::NewerJob))?; - self.ota_close()?; - } - } - - // Start the new job! - Ok(()) - } - - /// Process incoming data blocks - fn process_data_handler(&mut self, payload: &mut [u8]) -> Result<(), OtaError> { - debug!("process_data_handler"); - // Decode the file block received - match self.ingest_data_block(payload) { - Ok(true) => { - let file_ctx = self - .active_interface - .as_mut() - .ok_or(OtaError::InvalidInterface)? - .mut_file_ctx(); - - // File is completed! Update progress accordingly. - let (status, reason, event) = if let Some(0) = file_ctx.file_type { - ( - JobStatus::InProgress, - JobStatusReason::SigCheckPassed, - OtaEvent::Activate, - ) - } else { - ( - JobStatus::Succeeded, - JobStatusReason::Accepted, - OtaEvent::UpdateComplete, - ) - }; - - self.control - .update_job_status(file_ctx, &self.config, status, reason)?; - - // Send event to close file. - self.events - .enqueue(Events::CloseFile) - .map_err(|_| OtaError::SignalEventFailed)?; - - // TODO: Last file block processed, increment the statistics - // otaAgent.statistics.otaPacketsProcessed++; - - match event { - OtaEvent::Activate => { - self.events - .enqueue(Events::Restart(RestartReason::Activate(0))) - .map_err(|_| OtaError::SignalEventFailed)?; - } - event => self.pal.complete_callback(event)?, - }; - } - Ok(false) => { - let file_ctx = self - .active_interface - .as_mut() - .ok_or(OtaError::InvalidInterface)? - .mut_file_ctx(); - - // File block processed, increment the statistics. - // otaAgent.statistics.otaPacketsProcessed++; - - // Reset the momentum counter since we received a good block - self.request_momentum = 0; - - // We're actively receiving a file so update the job status as - // needed - self.control.update_job_status( - file_ctx, - &self.config, - JobStatus::InProgress, - JobStatusReason::Receiving, - )?; - - if file_ctx.request_block_remaining > 1 { - file_ctx.request_block_remaining -= 1; - } else { - // Start the request timer. - self.request_timer - .start(self.config.request_wait_ms.millis()) - .map_err(|_| OtaError::Timer)?; - - self.events - .enqueue(Events::RequestFileBlock) - .map_err(|_| OtaError::SignalEventFailed)?; - } - } - Err(e) if e.is_retryable() => { - warn!("Failed to ingest data block, Error is retryable! ingest_data_block returned error {:?}", e); - } - Err(e) => { - let file_ctx = self - .active_interface - .as_mut() - .ok_or(OtaError::InvalidInterface)? - .mut_file_ctx(); - - error!("Failed to ingest data block, rejecting image: ingest_data_block returned error {:?}", e); - - // Call the platform specific code to reject the image - // TODO: This should never write to current image flags?! - self.pal.set_platform_image_state(ImageState::Rejected( - ImageStateReason::FailedIngest, - ))?; - - // TODO: Pal reason - self.control.update_job_status( - file_ctx, - &self.config, - JobStatus::Failed, - JobStatusReason::Pal(0), - )?; - - // Stop the request timer. - self.request_timer.cancel().map_err(|_| OtaError::Timer)?; - - // Send event to close file. - self.events - .enqueue(Events::CloseFile) - .map_err(|_| OtaError::SignalEventFailed)?; - - self.pal.complete_callback(OtaEvent::Fail)?; - info!("Application callback! OtaEvent::Fail"); - return Err(e); - } - } - - // TODO: Application callback for event processed. - // otaAgent.OtaAppCallback( OtaJobEventProcessed, ( const void * ) pEventData ); - Ok(()) - } - - /// Close file opened for download - fn close_file_handler(&mut self) -> Result<(), OtaError> { - self.ota_close() - } - - /// Handle user interrupt to abort task - fn user_abort_handler(&mut self) -> Result<(), OtaError> { - warn!("User abort OTA!"); - if let Some(ref mut interface) = self.active_interface { - self.image_state = Self::set_image_state_with_reason( - self.control, - &mut self.pal, - &self.config, - interface.mut_file_ctx(), - ImageState::Aborted(ImageStateReason::UserAbort), - )?; - self.ota_close() - } else { - Err(OtaError::NoActiveJob) - } - } - - /// Handle user interrupt to abort task - fn shutdown_handler(&mut self) -> Result<(), OtaError> { - warn!("Shutting down OTA!"); - if let Some(ref mut interface) = self.active_interface { - self.image_state = Self::set_image_state_with_reason( - self.control, - &mut self.pal, - &self.config, - interface.mut_file_ctx(), - ImageState::Aborted(ImageStateReason::UserAbort), - )?; - self.ota_close()?; - } - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use crate::{ - ota::{ - agent::OtaAgent, - pal::Version, - test::{ - mock::{MockPal, MockTimer}, - test_job_doc, - }, - }, - test::MockMqtt, - }; - - use super::*; - - #[test] - fn version_check_success() { - // The version check is run after swapping & rebooting, so the PAL will - // return the version of the newly flashed firmware, and `FileContext` - // will contain the `updated_by` version, which is the old firmware - // version. - - let mqtt = MockMqtt::new(); - - let request_timer = MockTimer::new(); - let self_test_timer = MockTimer::new(); - let pal = MockPal {}; - - let mut agent = OtaAgent::builder(&mqtt, &mqtt, request_timer, pal) - .with_self_test_timeout(self_test_timer, 32000) - .build(); - - let ota_job = test_job_doc(); - let mut file_ctx = FileContext::new_from( - "Job-name", - &ota_job, - None, - 0, - &Config::default(), - Version::new(0, 1, 0), - ) - .unwrap(); - - let context = agent.state.context_mut(); - - assert_eq!(context.handle_self_test_job(&mut file_ctx), Ok(())); - - assert!( - matches!(context.image_state, ImageState::Testing(_)), - "Unexpected image state" - ); - } - - #[test] - fn version_check_rejected() { - let mqtt = MockMqtt::new(); - - let request_timer = MockTimer::new(); - let self_test_timer = MockTimer::new(); - let pal = MockPal {}; - - let mut agent = OtaAgent::builder(&mqtt, &mqtt, request_timer, pal) - .with_self_test_timeout(self_test_timer, 32000) - .build(); - - let ota_job = test_job_doc(); - let mut file_ctx = FileContext::new_from( - "Job-name", - &ota_job, - None, - 0, - &Config::default(), - Version::new(1, 1, 0), - ) - .unwrap(); - - let context = agent.state.context_mut(); - - assert_eq!(context.handle_self_test_job(&mut file_ctx), Ok(())); - - assert!( - matches!(context.image_state, ImageState::Rejected(_)), - "Unexpected image state" - ); - } - - #[test] - fn version_check_allow_donwgrade() { - let mqtt = MockMqtt::new(); - - let request_timer = MockTimer::new(); - let self_test_timer = MockTimer::new(); - let pal = MockPal {}; - - let mut agent = OtaAgent::builder(&mqtt, &mqtt, request_timer, pal) - .with_self_test_timeout(self_test_timer, 32000) - .allow_downgrade() - .build(); - - let ota_job = test_job_doc(); - let mut file_ctx = FileContext::new_from( - "Job-name", - &ota_job, - None, - 0, - &Config::default(), - Version::new(1, 1, 0), - ) - .unwrap(); - - let context = agent.state.context_mut(); - - assert_eq!(context.handle_self_test_job(&mut file_ctx), Ok(())); - - assert!( - matches!(context.image_state, ImageState::Testing(_)), - "Unexpected image state" - ); - } -} diff --git a/src/ota/test/mock.rs b/src/ota/test/mock.rs deleted file mode 100644 index 42e4ca3..0000000 --- a/src/ota/test/mock.rs +++ /dev/null @@ -1,93 +0,0 @@ -use crate::ota::{ - encoding::FileContext, - pal::{ImageState, OtaPal, OtaPalError, PalImageState, Version}, -}; - -use super::TEST_TIMER_HZ; - -/// -/// Mock timer used for unit tests. Implements `fugit_timer::Timer` trait. -/// -pub struct MockTimer { - pub is_started: bool, -} -impl MockTimer { - pub fn new() -> Self { - Self { is_started: false } - } -} - -impl fugit_timer::Timer for MockTimer { - type Error = (); - - fn now(&mut self) -> fugit_timer::TimerInstantU32 { - todo!() - } - - fn start( - &mut self, - _duration: fugit_timer::TimerDurationU32, - ) -> Result<(), Self::Error> { - self.is_started = true; - Ok(()) - } - - fn cancel(&mut self) -> Result<(), Self::Error> { - self.is_started = false; - Ok(()) - } - - fn wait(&mut self) -> nb::Result<(), Self::Error> { - Ok(()) - } -} - -/// -/// Mock Platform abstration layer used for unit tests. Implements `OtaPal` -/// trait. -/// -pub struct MockPal {} - -impl OtaPal for MockPal { - type Error = (); - - fn abort(&mut self, _file: &FileContext) -> Result<(), OtaPalError> { - Ok(()) - } - - fn create_file_for_rx(&mut self, _file: &FileContext) -> Result<(), OtaPalError> { - Ok(()) - } - - fn get_platform_image_state(&mut self) -> Result> { - Ok(PalImageState::Valid) - } - - fn set_platform_image_state( - &mut self, - _image_state: ImageState, - ) -> Result<(), OtaPalError> { - Ok(()) - } - - fn reset_device(&mut self) -> Result<(), OtaPalError> { - Ok(()) - } - - fn close_file(&mut self, _file: &FileContext) -> Result<(), OtaPalError> { - Ok(()) - } - - fn write_block( - &mut self, - _file: &FileContext, - _block_offset: usize, - block_payload: &[u8], - ) -> Result> { - Ok(block_payload.len()) - } - - fn get_active_firmware_version(&self) -> Result> { - Ok(Version::new(1, 0, 0)) - } -} diff --git a/src/ota/test/mod.rs b/src/ota/test/mod.rs deleted file mode 100644 index c535337..0000000 --- a/src/ota/test/mod.rs +++ /dev/null @@ -1,523 +0,0 @@ -use super::{ - config::Config, - data_interface::Protocol, - encoding::{ - json::{FileDescription, OtaJob}, - FileContext, - }, - pal::Version, -}; - -pub mod mock; - -pub const TEST_TIMER_HZ: u32 = 8_000_000; - -pub fn test_job_doc() -> OtaJob<'static> { - OtaJob { - protocols: heapless::Vec::from_slice(&[Protocol::Mqtt]).unwrap(), - streamname: "test_stream", - files: heapless::Vec::from_slice(&[FileDescription { - filepath: "", - filesize: 123456, - fileid: 0, - certfile: "cert", - update_data_url: None, - auth_scheme: None, - sha1_rsa: Some(""), - file_type: Some(0), - sha256_rsa: None, - sha1_ecdsa: None, - sha256_ecdsa: None, - }]) - .unwrap(), - } -} - -pub fn test_file_ctx(config: &Config) -> FileContext { - let ota_job = test_job_doc(); - FileContext::new_from("Job-name", &ota_job, None, 0, config, Version::default()).unwrap() -} - -pub mod ota_tests { - use crate::jobs::data_types::{DescribeJobExecutionResponse, JobExecution, JobStatus}; - use crate::ota::data_interface::Protocol; - use crate::ota::encoding::json::{FileDescription, OtaJob}; - use crate::ota::error::OtaError; - use crate::ota::state::{Error, Events, States}; - use crate::ota::test::test_job_doc; - use crate::ota::{ - agent::OtaAgent, - control_interface::ControlInterface, - data_interface::{DataInterface, NoInterface}, - pal::OtaPal, - test::mock::{MockPal, MockTimer}, - }; - use crate::test::MockMqtt; - use mqttrust::encoding::v4::{decode_slice, utils::Pid, PacketType}; - use mqttrust::{MqttError, Packet, QoS, SubscribeTopic}; - use serde::Deserialize; - use serde_json_core::from_slice; - - use super::TEST_TIMER_HZ; - - /// All known job document that the device knows how to process. - #[derive(Debug, PartialEq, Deserialize)] - pub enum JobDetails<'a> { - #[serde(rename = "afr_ota")] - #[serde(borrow)] - Ota(OtaJob<'a>), - - #[serde(other)] - Unknown, - } - - fn new_agent( - mqtt: &MockMqtt, - ) -> OtaAgent<'_, MockMqtt, &MockMqtt, NoInterface, MockTimer, MockTimer, MockPal, TEST_TIMER_HZ> - { - let request_timer = MockTimer::new(); - let self_test_timer = MockTimer::new(); - let pal = MockPal {}; - - OtaAgent::builder(mqtt, mqtt, request_timer, pal) - .with_self_test_timeout(self_test_timer, 16000) - .build() - } - - fn run_to_state<'a, C, DP, DS, T, ST, PAL, const TIMER_HZ: u32>( - agent: &mut OtaAgent<'a, C, DP, DS, T, ST, PAL, TIMER_HZ>, - state: States, - ) where - C: ControlInterface, - DP: DataInterface, - DS: DataInterface, - T: fugit_timer::Timer, - ST: fugit_timer::Timer, - PAL: OtaPal, - { - if agent.state.state() == &state { - return; - } - - match state { - States::Ready => { - println!( - "Running to 'States::Ready', events: {}", - agent.state.context().events.len() - ); - agent.state.process_event(Events::Shutdown).unwrap(); - } - States::CreatingFile => { - println!( - "Running to 'States::CreatingFile', events: {}", - agent.state.context().events.len() - ); - run_to_state(agent, States::WaitingForJob); - - let job_doc = test_job_doc(); - agent.job_update("Test-job", &job_doc, None).unwrap(); - agent.state.context_mut().events.dequeue(); - } - States::RequestingFileBlock => { - println!( - "Running to 'States::RequestingFileBlock', events: {}", - agent.state.context().events.len() - ); - run_to_state(agent, States::CreatingFile); - agent.state.process_event(Events::CreateFile).unwrap(); - agent.state.context_mut().events.dequeue(); - } - States::RequestingJob => { - println!( - "Running to 'States::RequestingJob', events: {}", - agent.state.context().events.len() - ); - run_to_state(agent, States::Ready); - agent.state.process_event(Events::Start).unwrap(); - agent.state.context_mut().events.dequeue(); - } - States::Suspended => { - println!( - "Running to 'States::Suspended', events: {}", - agent.state.context().events.len() - ); - run_to_state(agent, States::Ready); - agent.suspend().unwrap(); - } - States::WaitingForFileBlock => { - println!( - "Running to 'States::Suspended', events: {}", - agent.state.context().events.len() - ); - run_to_state(agent, States::RequestingFileBlock); - agent.state.process_event(Events::RequestFileBlock).unwrap(); - agent.state.context_mut().events.dequeue(); - } - States::WaitingForJob => { - println!( - "Running to 'States::WaitingForJob', events: {}", - agent.state.context().events.len() - ); - run_to_state(agent, States::RequestingJob); - agent.check_for_update().unwrap(); - } - States::Restarting => {} - } - } - - pub fn set_pid(buf: &mut [u8], pid: Pid) -> Result<(), ()> { - let mut offset = 0; - let (header, _) = mqttrust::encoding::v4::decoder::read_header(buf, &mut offset) - .map_err(|_| ())? - .ok_or(())?; - - match (header.typ, header.qos) { - (PacketType::Publish, QoS::AtLeastOnce | QoS::ExactlyOnce) => { - if buf[offset..].len() < 2 { - return Err(()); - } - let len = ((buf[offset] as usize) << 8) | buf[offset + 1] as usize; - - offset += 2; - if len > buf[offset..].len() { - return Err(()); - } else { - offset += len; - } - } - (PacketType::Subscribe | PacketType::Unsubscribe | PacketType::Suback, _) => {} - ( - PacketType::Puback - | PacketType::Pubrec - | PacketType::Pubrel - | PacketType::Pubcomp - | PacketType::Unsuback, - _, - ) => {} - _ => return Ok(()), - } - - pid.to_buffer(buf, &mut offset).map_err(|_| ()) - } - - #[test] - fn ready_when_stopped() { - let mqtt = MockMqtt::new(); - let mut ota_agent = new_agent(&mqtt); - - assert!(matches!(ota_agent.state.state(), &States::Ready)); - run_to_state(&mut ota_agent, States::Ready); - assert!(matches!(ota_agent.state.state(), &States::Ready)); - assert_eq!(ota_agent.state.context().events.len(), 0); - assert_eq!(mqtt.tx.borrow_mut().len(), 0); - } - - #[test] - fn abort_when_stopped() { - let mqtt = MockMqtt::new(); - let mut ota_agent = new_agent(&mqtt); - - run_to_state(&mut ota_agent, States::Ready); - assert_eq!(ota_agent.state.context().events.len(), 0); - - assert_eq!( - ota_agent.abort().err(), - Some(Error::GuardFailed(OtaError::NoActiveJob)) - ); - ota_agent.process_event().unwrap(); - assert!(matches!(ota_agent.state.state(), &States::Ready)); - assert_eq!(mqtt.tx.borrow_mut().len(), 0); - } - - #[test] - fn resume_when_stopped() { - let mqtt = MockMqtt::new(); - let mut ota_agent = new_agent(&mqtt); - - run_to_state(&mut ota_agent, States::Ready); - assert_eq!(ota_agent.state.context().events.len(), 0); - - assert!(matches!( - ota_agent.resume().err().unwrap(), - Error::InvalidEvent - )); - ota_agent.process_event().unwrap(); - assert!(matches!(ota_agent.state.state(), &States::Ready)); - assert_eq!(mqtt.tx.borrow_mut().len(), 0); - } - - #[test] - fn resume_when_suspended() { - let mqtt = MockMqtt::new(); - let mut ota_agent = new_agent(&mqtt); - - run_to_state(&mut ota_agent, States::Suspended); - assert_eq!(ota_agent.state.context().events.len(), 0); - - assert!(matches!( - ota_agent.resume().unwrap(), - &States::RequestingJob - )); - assert_eq!(mqtt.tx.borrow_mut().len(), 1); - } - - #[test] - fn check_for_update() { - let mqtt = MockMqtt::new(); - let mut ota_agent = new_agent(&mqtt); - - run_to_state(&mut ota_agent, States::RequestingJob); - assert!(matches!(ota_agent.state.state(), &States::RequestingJob)); - - assert_eq!(ota_agent.state.context().events.len(), 0); - - assert!(matches!( - ota_agent.check_for_update().unwrap(), - &States::WaitingForJob - )); - - let bytes = mqtt.tx.borrow_mut().pop_front().unwrap(); - - let packet = decode_slice(bytes.as_slice()).unwrap(); - let topics = match packet { - Some(Packet::Subscribe(ref s)) => s.topics().collect::>(), - _ => panic!(), - }; - - assert_eq!( - topics, - vec![SubscribeTopic { - topic_path: "$aws/things/test_client/jobs/notify-next", - qos: QoS::AtLeastOnce - }] - ); - - let mut bytes = mqtt.tx.borrow_mut().pop_front().unwrap(); - set_pid(bytes.as_mut_slice(), Pid::new()).expect("Failed to set valid PID"); - let packet = decode_slice(bytes.as_slice()).unwrap(); - - let publish = match packet { - Some(Packet::Publish(p)) => p, - _ => panic!(), - }; - - assert_eq!( - publish, - mqttrust::encoding::v4::publish::Publish { - dup: false, - qos: QoS::AtLeastOnce, - retain: false, - topic_name: "$aws/things/test_client/jobs/$next/get", - payload: &[123, 125], - pid: Some(Pid::new()), - } - ); - assert_eq!(mqtt.tx.borrow_mut().len(), 0); - } - - #[test] - #[ignore] - fn request_job_retry_fail() { - let mut mqtt = MockMqtt::new(); - - // Let MQTT publish fail so request job will also fail - mqtt.publish_fail(); - - let mut ota_agent = new_agent(&mqtt); - - // Place the OTA Agent into the state for requesting a job - run_to_state(&mut ota_agent, States::RequestingJob); - assert!(matches!(ota_agent.state.state(), &States::RequestingJob)); - assert_eq!(ota_agent.state.context().events.len(), 0); - - assert_eq!( - ota_agent.check_for_update().err(), - Some(Error::GuardFailed(OtaError::Mqtt(MqttError::Full))) - ); - - // Fail the maximum number of attempts to request a job document - for _ in 0..ota_agent.state.context().config.max_request_momentum { - ota_agent.process_event().unwrap(); - assert!(ota_agent.state.context().request_timer.is_started); - ota_agent.timer_callback().ok(); - assert!(matches!(ota_agent.state.state(), &States::RequestingJob)); - } - - // Attempt to request another job document after failing the maximum - // number of times, triggering a shutdown event. - ota_agent.process_event().unwrap(); - assert!(matches!(ota_agent.state.state(), &States::Ready)); - assert_eq!(mqtt.tx.borrow_mut().len(), 4); - } - - #[test] - fn init_file_transfer_mqtt() { - let mqtt = MockMqtt::new(); - - let mut ota_agent = new_agent(&mqtt); - - // Place the OTA Agent into the state for creating file - run_to_state(&mut ota_agent, States::CreatingFile); - assert!(matches!(ota_agent.state.state(), &States::CreatingFile)); - assert_eq!(ota_agent.state.context().events.len(), 0); - - ota_agent.process_event().unwrap(); - assert!(matches!(ota_agent.state.state(), &States::CreatingFile)); - ota_agent.process_event().unwrap(); - - ota_agent.state.process_event(Events::CreateFile).unwrap(); - - // Above will automatically enqueue `RequestFileBlock` - assert!(matches!( - ota_agent.state.state(), - &States::RequestingFileBlock - )); - - // Check the latest MQTT message - let bytes = mqtt.tx.borrow_mut().pop_back().unwrap(); - - let packet = decode_slice(bytes.as_slice()).unwrap(); - let topics = match packet { - Some(Packet::Subscribe(ref s)) => s.topics().collect::>(), - _ => panic!(), - }; - - assert_eq!( - topics, - vec![SubscribeTopic { - topic_path: "$aws/things/test_client/streams/test_stream/data/cbor", - qos: QoS::AtLeastOnce - }] - ); - - // Should still contain: - // - subscription to `$aws/things/test_client/jobs/notify-next` - // - publish to `$aws/things/test_client/jobs/$next/get` - assert_eq!(mqtt.tx.borrow_mut().len(), 2); - } - - #[test] - fn request_file_block_mqtt() { - let mqtt = MockMqtt::new(); - - let mut ota_agent = new_agent(&mqtt); - - // Place the OTA Agent into the state for requesting file block - run_to_state(&mut ota_agent, States::RequestingFileBlock); - assert!(matches!( - ota_agent.state.state(), - &States::RequestingFileBlock - )); - assert_eq!(ota_agent.state.context().events.len(), 0); - - ota_agent - .state - .process_event(Events::RequestFileBlock) - .unwrap(); - - assert!(matches!( - ota_agent.state.state(), - &States::WaitingForFileBlock - )); - - let bytes = mqtt.tx.borrow_mut().pop_back().unwrap(); - - let publish = match decode_slice(bytes.as_slice()).unwrap() { - Some(Packet::Publish(p)) => p, - _ => panic!(), - }; - - // Check the latest MQTT message - assert_eq!( - publish, - mqttrust::encoding::v4::publish::Publish { - dup: false, - qos: QoS::AtMostOnce, - retain: false, - topic_name: "$aws/things/test_client/streams/test_stream/get/cbor", - payload: &[ - 164, 97, 102, 0, 97, 108, 25, 1, 0, 97, 111, 0, 97, 98, 68, 255, 255, 255, 127 - ], - pid: None - } - ); - - // Should still contain: - // - subscription to `$aws/things/test_client/jobs/notify-next` - // - publish to `$aws/things/test_client/jobs/$next/get` - // - subscription to - // `$aws/things/test_client/streams/test_stream/data/cbor` - assert_eq!(mqtt.tx.borrow_mut().len(), 3); - } - - #[test] - fn deserialize_describe_job_execution_response_ota() { - let payload = br#"{ - "clientToken":"0:rustot-test", - "timestamp":1624445100, - "execution":{ - "jobId":"AFR_OTA-rustot_test_1", - "status":"QUEUED", - "queuedAt":1624440618, - "lastUpdatedAt":1624440618, - "versionNumber":1, - "executionNumber":1, - "jobDocument":{ - "afr_ota":{ - "protocols":["MQTT"], - "streamname":"AFR_OTA-0ba01295-9417-4ba7-9a99-4b31fb03d252", - "files":[{ - "filepath":"IMG_test.jpg", - "filesize":2674792, - "fileid":0, - "certfile":"nope", - "fileType":0, - "sig-sha256-ecdsa":"This is my signature! Better believe it!" - }] - } - } - } - }"#; - - let (response, _) = - from_slice::>(payload).unwrap(); - - assert_eq!( - response, - DescribeJobExecutionResponse { - execution: Some(JobExecution { - execution_number: Some(1), - job_document: Some(JobDetails::Ota(OtaJob { - protocols: heapless::Vec::from_slice(&[Protocol::Mqtt]).unwrap(), - streamname: "AFR_OTA-0ba01295-9417-4ba7-9a99-4b31fb03d252", - files: heapless::Vec::from_slice(&[FileDescription { - filepath: "IMG_test.jpg", - filesize: 2674792, - fileid: 0, - certfile: "nope", - update_data_url: None, - auth_scheme: None, - sha1_rsa: None, - sha256_rsa: None, - sha1_ecdsa: None, - sha256_ecdsa: Some("This is my signature! Better believe it!"), - file_type: Some(0), - }]) - .unwrap(), - })), - job_id: "AFR_OTA-rustot_test_1", - last_updated_at: 1624440618, - queued_at: 1624440618, - status_details: None, - status: JobStatus::Queued, - version_number: 1, - approximate_seconds_before_timed_out: None, - started_at: None, - thing_name: None, - }), - timestamp: 1624445100, - client_token: Some("0:rustot-test"), - } - ); - } -} diff --git a/tests/common/file_handler.rs b/tests/common/file_handler.rs index 9cae2e6..335e1af 100644 --- a/tests/common/file_handler.rs +++ b/tests/common/file_handler.rs @@ -1,6 +1,6 @@ -use rustot::ota::pal::{OtaPal, OtaPalError, PalImageState}; -use std::fs::File; -use std::io::{Cursor, Write}; +// use rustot::ota::pal::{OtaPal, OtaPalError, PalImageState}; +use tokio::fs::File; +use tokio::io::{Cursor, Write}; pub struct FileHandler { filebuf: Option>>, @@ -12,80 +12,80 @@ impl FileHandler { } } -impl OtaPal for FileHandler { - type Error = (); +// impl OtaPal for FileHandler { +// type Error = (); - fn abort( - &mut self, - _file: &rustot::ota::encoding::FileContext, - ) -> Result<(), OtaPalError> { - Ok(()) - } +// fn abort( +// &mut self, +// _file: &rustot::ota::encoding::FileContext, +// ) -> Result<(), OtaPalError> { +// Ok(()) +// } - fn create_file_for_rx( - &mut self, - file: &rustot::ota::encoding::FileContext, - ) -> Result<(), OtaPalError> { - self.filebuf = Some(Cursor::new(Vec::with_capacity(file.filesize))); - Ok(()) - } +// fn create_file_for_rx( +// &mut self, +// file: &rustot::ota::encoding::FileContext, +// ) -> Result<(), OtaPalError> { +// self.filebuf = Some(Cursor::new(Vec::with_capacity(file.filesize))); +// Ok(()) +// } - fn get_platform_image_state(&mut self) -> Result> { - Ok(PalImageState::Valid) - } +// fn get_platform_image_state(&mut self) -> Result> { +// Ok(PalImageState::Valid) +// } - fn set_platform_image_state( - &mut self, - _image_state: rustot::ota::pal::ImageState<()>, - ) -> Result<(), OtaPalError> { - Ok(()) - } +// fn set_platform_image_state( +// &mut self, +// _image_state: rustot::ota::pal::ImageState<()>, +// ) -> Result<(), OtaPalError> { +// Ok(()) +// } - fn reset_device(&mut self) -> Result<(), OtaPalError> { - Ok(()) - } +// fn reset_device(&mut self) -> Result<(), OtaPalError> { +// Ok(()) +// } - fn close_file( - &mut self, - file: &rustot::ota::encoding::FileContext, - ) -> Result<(), OtaPalError> { - if let Some(ref mut buf) = &mut self.filebuf { - log::debug!( - "Closing completed file. Len: {}/{} -> {}", - buf.get_ref().len(), - file.filesize, - file.filepath.as_str() - ); - let mut file = - File::create(file.filepath.as_str()).map_err(|_| OtaPalError::FileWriteFailed)?; - file.write_all(buf.get_ref()) - .map_err(|_| OtaPalError::FileWriteFailed)?; +// fn close_file( +// &mut self, +// file: &rustot::ota::encoding::FileContext, +// ) -> Result<(), OtaPalError> { +// if let Some(ref mut buf) = &mut self.filebuf { +// log::debug!( +// "Closing completed file. Len: {}/{} -> {}", +// buf.get_ref().len(), +// file.filesize, +// file.filepath.as_str() +// ); +// let mut file = +// File::create(file.filepath.as_str()).map_err(|_| OtaPalError::FileWriteFailed)?; +// file.write_all(buf.get_ref()) +// .map_err(|_| OtaPalError::FileWriteFailed)?; - Ok(()) - } else { - Err(OtaPalError::BadFileHandle) - } - } +// Ok(()) +// } else { +// Err(OtaPalError::BadFileHandle) +// } +// } - fn write_block( - &mut self, - _file: &rustot::ota::encoding::FileContext, - block_offset: usize, - block_payload: &[u8], - ) -> Result> { - if let Some(ref mut buf) = &mut self.filebuf { - buf.set_position(block_offset as u64); - buf.write(block_payload) - .map_err(|_e| OtaPalError::FileWriteFailed)?; - Ok(block_payload.len()) - } else { - Err(OtaPalError::BadFileHandle) - } - } +// fn write_block( +// &mut self, +// _file: &rustot::ota::encoding::FileContext, +// block_offset: usize, +// block_payload: &[u8], +// ) -> Result> { +// if let Some(ref mut buf) = &mut self.filebuf { +// buf.set_position(block_offset as u64); +// buf.write(block_payload) +// .map_err(|_e| OtaPalError::FileWriteFailed)?; +// Ok(block_payload.len()) +// } else { +// Err(OtaPalError::BadFileHandle) +// } +// } - fn get_active_firmware_version( - &self, - ) -> Result> { - Ok(rustot::ota::pal::Version::new(0, 1, 0)) - } -} +// fn get_active_firmware_version( +// &self, +// ) -> Result> { +// Ok(rustot::ota::pal::Version::new(0, 1, 0)) +// } +// } diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 594f1b0..fc87c3f 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -1,3 +1,3 @@ pub mod credentials; -// pub mod file_handler; +pub mod file_handler; pub mod network; diff --git a/tests/ota.rs b/tests/ota.rs index e645b99..5992196 100644 --- a/tests/ota.rs +++ b/tests/ota.rs @@ -1,235 +1,170 @@ -// mod common; - -// use mqttrust_core::{bbqueue::BBBuffer, EventLoop, MqttOptions, Notification, PublishNotification}; -// use native_tls::TlsConnector; -// use rustot::ota::state::States; -// use serde::Deserialize; -// use sha2::{Digest, Sha256}; -// use std::{fs::File, io::Read, ops::Deref}; - -// use common::{clock::SysClock, credentials, file_handler::FileHandler, network::Network}; -// use rustot::{ -// jobs::{ -// self, -// data_types::{DescribeJobExecutionResponse, NextJobExecutionChanged}, -// StatusDetails, -// }, -// ota::{self, agent::OtaAgent, encoding::json::OtaJob}, -// }; - -// static mut Q: BBBuffer<{ 1024 * 10 }> = BBBuffer::new(); - -// #[derive(Debug, Deserialize)] -// pub enum Jobs<'a> { -// #[serde(rename = "afr_ota")] -// #[serde(borrow)] -// Ota(OtaJob<'a>), -// } - -// impl<'a> Jobs<'a> { -// pub fn ota_job(self) -> Option> { -// match self { -// Jobs::Ota(ota_job) => Some(ota_job), -// } -// } -// } - -// enum OtaUpdate<'a> { -// JobUpdate(&'a str, OtaJob<'a>, Option>), -// Data(&'a mut [u8]), -// } - -// fn handle_ota<'a>(publish: &'a mut PublishNotification) -> Result, ()> { -// match jobs::Topic::from_str(publish.topic_name.as_str()) { -// Some(jobs::Topic::NotifyNext) => { -// let (execution_changed, _) = -// serde_json_core::from_slice::>(&publish.payload) -// .map_err(drop)?; -// let job = execution_changed.execution.ok_or(())?; -// let ota_job = job.job_document.ok_or(())?.ota_job().ok_or(())?; -// return Ok(OtaUpdate::JobUpdate( -// job.job_id, -// ota_job, -// job.status_details, -// )); -// } -// Some(jobs::Topic::DescribeAccepted(_)) => { -// let (execution_changed, _) = -// serde_json_core::from_slice::>(&publish.payload) -// .map_err(drop)?; -// let job = execution_changed.execution.ok_or(())?; -// let ota_job = job.job_document.ok_or(())?.ota_job().ok_or(())?; -// return Ok(OtaUpdate::JobUpdate( -// job.job_id, -// ota_job, -// job.status_details, -// )); -// } -// _ => {} -// } - -// match ota::Topic::from_str(publish.topic_name.as_str()) { -// Some(ota::Topic::Data(_, _)) => { -// return Ok(OtaUpdate::Data(&mut publish.payload)); -// } -// _ => {} -// } -// Err(()) -// } - -// pub struct FileInfo { -// pub file_path: String, -// pub filesize: usize, -// pub signature: ota::encoding::json::Signature, -// } - -// #[test] -// fn test_mqtt_ota() { -// // Make sure this times out in case something went wrong setting up the OTA -// // job in AWS IoT before starting. -// timebomb::timeout_ms(test_mqtt_ota_inner, 100_000) -// } - -// fn test_mqtt_ota_inner() { -// env_logger::init(); - -// let (p, c) = unsafe { Q.try_split_framed().unwrap() }; - -// log::info!("Starting OTA test..."); - -// let hostname = credentials::HOSTNAME.unwrap(); -// let (thing_name, identity) = credentials::identity(); - -// let connector = TlsConnector::builder() -// .identity(identity) -// .add_root_certificate(credentials::root_ca()) -// .build() -// .unwrap(); - -// let mut network = Network::new_tls(connector, String::from(hostname)); - -// let mut mqtt_eventloop = EventLoop::new( -// c, -// SysClock::new(), -// MqttOptions::new(thing_name, hostname.into(), 8883).set_clean_session(true), -// ); - -// let mqtt_client = mqttrust_core::Client::new(p, thing_name); - -// let file_handler = FileHandler::new(); - -// let mut ota_agent = -// OtaAgent::builder(&mqtt_client, &mqtt_client, SysClock::new(), file_handler) -// .request_wait_ms(3000) -// .block_size(256) -// .build(); - -// let mut file_info = None; - -// loop { -// match mqtt_eventloop.connect(&mut network) { -// Ok(true) => { -// log::info!("Successfully connected to broker"); -// ota_agent.init(); -// } -// Ok(false) => {} -// Err(nb::Error::WouldBlock) => continue, -// Err(e) => panic!("{:?}", e), -// } - -// match mqtt_eventloop.yield_event(&mut network) { -// Ok(Notification::Publish(mut publish)) => { -// // Check if the received file is a jobs topic, that we -// // want to react to. -// match handle_ota(&mut publish) { -// Ok(OtaUpdate::JobUpdate(job_id, job_doc, status_details)) => { -// log::debug!("Received job! Starting OTA! {:?}", job_doc.streamname); - -// let file = &job_doc.files[0]; -// file_info.replace(FileInfo { -// file_path: file.filepath.to_string(), -// filesize: file.filesize, -// signature: file.signature(), -// }); -// ota_agent -// .job_update(job_id, &job_doc, status_details.as_ref()) -// .expect("Failed to start OTA job"); -// } -// Ok(OtaUpdate::Data(payload)) => { -// if ota_agent.handle_message(payload).is_err() { -// match ota_agent.state() { -// States::CreatingFile => log::info!("State: CreatingFile"), -// States::Ready => log::info!("State: Ready"), -// States::RequestingFileBlock => { -// log::info!("State: RequestingFileBlock") -// } -// States::RequestingJob => log::info!("State: RequestingJob"), -// States::Restarting => log::info!("State: Restarting"), -// States::Suspended => log::info!("State: Suspended"), -// States::WaitingForFileBlock => { -// log::info!("State: WaitingForFileBlock") -// } -// States::WaitingForJob => log::info!("State: WaitingForJob"), -// } -// } -// } -// Err(_) => {} -// } -// } -// Ok(n) => { -// log::trace!("{:?}", n); -// } -// _ => {} -// } - -// ota_agent.timer_callback().expect("Failed timer callback!"); - -// match ota_agent.process_event() { -// // Use the restarting state to indicate finished -// Ok(States::Restarting) => break, -// _ => {} -// } -// } - -// let mut expected_file = File::open("tests/assets/ota_file").unwrap(); -// let mut expected_data = Vec::new(); -// expected_file.read_to_end(&mut expected_data).unwrap(); -// let mut expected_hasher = Sha256::new(); -// expected_hasher.update(&expected_data); -// let expected_hash = expected_hasher.finalize(); - -// let file_info = file_info.unwrap(); - -// log::info!( -// "Comparing {:?} with {:?}", -// "tests/assets/ota_file", -// file_info.file_path -// ); -// let mut file = File::open(file_info.file_path.clone()).unwrap(); -// let mut data = Vec::new(); -// file.read_to_end(&mut data).unwrap(); -// drop(file); -// std::fs::remove_file(file_info.file_path).unwrap(); - -// assert_eq!(data.len(), file_info.filesize); - -// let mut hasher = Sha256::new(); -// hasher.update(&data); -// assert_eq!(hasher.finalize().deref(), expected_hash.deref()); - -// // Check file signature -// match file_info.signature { -// ota::encoding::json::Signature::Sha1Rsa(_) => { -// panic!("Unexpected signature format: Sha1Rsa. Expected Sha256Ecdsa") -// } -// ota::encoding::json::Signature::Sha256Rsa(_) => { -// panic!("Unexpected signature format: Sha256Rsa. Expected Sha256Ecdsa") -// } -// ota::encoding::json::Signature::Sha1Ecdsa(_) => { -// panic!("Unexpected signature format: Sha1Ecdsa. Expected Sha256Ecdsa") -// } -// ota::encoding::json::Signature::Sha256Ecdsa(sig) => { -// assert_eq!(&sig, "This is my custom signature\\n") -// } -// } -// } +#![allow(async_fn_in_trait)] +#![feature(type_alias_impl_trait)] + +mod common; + +use std::{net::ToSocketAddrs, process}; + +use common::credentials; +use common::file_handler::FileHandler; +use common::network::TlsNetwork; +use embassy_futures::select; +use embassy_sync::blocking_mutex::raw::{NoopRawMutex, RawMutex}; +use embassy_time::Duration; +use embedded_mqtt::{ + Config, DomainBroker, IpBroker, Message, Publish, QoS, RetainHandling, State, Subscribe, + SubscribeTopic, +}; +use futures::StreamExt; +use serde::{Deserialize, Serialize}; +use static_cell::make_static; + +use rustot::{ + jobs::{ + self, + data_types::{DescribeJobExecutionResponse, NextJobExecutionChanged}, + JobTopic, StatusDetails, + }, + ota::{self, encoding::json::OtaJob, JobEventData, Updater}, +}; + +#[derive(Debug, Deserialize)] +pub enum Jobs<'a> { + #[serde(rename = "afr_ota")] + #[serde(borrow)] + Ota(OtaJob<'a>), +} + +fn handle_job<'a, M: RawMutex, const SUBS: usize>( + message: &'a Message<'_, M, SUBS>, +) -> Option> { + match jobs::Topic::from_str(message.topic_name()) { + Some(jobs::Topic::NotifyNext) => { + let (execution_changed, _) = + serde_json_core::from_slice::>(&message.payload()) + .map_err(drop)?; + let job = execution_changed.execution.ok_or(())?; + let ota_job = job.job_document.ok_or(())?.ota_job().ok_or(())?; + Some(JobEventData { + job_name: job.job_id, + ota_document: ota_job, + status_details: job.status_details, + }) + } + Some(jobs::Topic::DescribeAccepted(_)) => { + let (execution_changed, _) = serde_json_core::from_slice::< + DescribeJobExecutionResponse, + >(&message.payload()) + .map_err(drop)?; + let job = execution_changed.execution.ok_or(())?; + let ota_job = job.job_document.ok_or(())?.ota_job().ok_or(())?; + Some(JobEventData { + job_name: job.job_id, + ota_document: ota_job, + status_details: job.status_details, + }) + } + _ => None, + } +} + +pub struct FileInfo { + pub file_path: String, + pub filesize: usize, + pub signature: ota::encoding::json::Signature, +} + +#[tokio::test(flavor = "current_thread")] +async fn test_mqtt_ota() { + env_logger::init(); + + log::info!("Starting OTA test..."); + + let (thing_name, identity) = credentials::identity(); + + let hostname = credentials::HOSTNAME.unwrap(); + let network = make_static!(TlsNetwork::new(hostname.to_owned(), identity)); + + // Create the MQTT stack + let broker = + DomainBroker::<_, 128>::new(format!("{}:8883", hostname).as_str(), network).unwrap(); + let config = + Config::new(thing_name, broker).keepalive_interval(embassy_time::Duration::from_secs(50)); + + let state = make_static!(State::::new()); + let (mut stack, client) = embedded_mqtt::new(state, config, network); + + let client = make_static!(client); + + let ota_fut = async { + let jobs_subscription = client + .subscribe(Subscribe::new(&[SubscribeTopic { + topic_path: jobs::JobTopic::NotifyNext + .format::<64>(thing_name)? + .as_str(), + maximum_qos: QoS::AtLeastOnce, + no_local: false, + retain_as_published: false, + retain_handling: RetainHandling::SendAtSubscribeTime, + }])) + .await?; + + while let Some(message) = jobs_subscription.next().await { + if let Some(job_details) = handle_job(&message) { + // We have an OTA job, leeeets go! + let config = ota::config::Config::default(); + let mut file_handler = FileHandler::new(); + Updater::perform_ota(client, client, job_details, &mut file_hander, config).await; + } + } + }; + + match select::select(stack.run(), ota_fut).await { + select::Either::First(_) => { + unreachable!() + } + select::Either::Second(result) => result.unwrap(), + }; + + // let mut expected_file = File::open("tests/assets/ota_file").unwrap(); + // let mut expected_data = Vec::new(); + // expected_file.read_to_end(&mut expected_data).unwrap(); + // let mut expected_hasher = Sha256::new(); + // expected_hasher.update(&expected_data); + // let expected_hash = expected_hasher.finalize(); + + // let file_info = file_info.unwrap(); + + // log::info!( + // "Comparing {:?} with {:?}", + // "tests/assets/ota_file", + // file_info.file_path + // ); + // let mut file = File::open(file_info.file_path.clone()).unwrap(); + // let mut data = Vec::new(); + // file.read_to_end(&mut data).unwrap(); + // drop(file); + // std::fs::remove_file(file_info.file_path).unwrap(); + + // assert_eq!(data.len(), file_info.filesize); + + // let mut hasher = Sha256::new(); + // hasher.update(&data); + // assert_eq!(hasher.finalize().deref(), expected_hash.deref()); + + // // Check file signature + // match file_info.signature { + // ota::encoding::json::Signature::Sha1Rsa(_) => { + // panic!("Unexpected signature format: Sha1Rsa. Expected Sha256Ecdsa") + // } + // ota::encoding::json::Signature::Sha256Rsa(_) => { + // panic!("Unexpected signature format: Sha256Rsa. Expected Sha256Ecdsa") + // } + // ota::encoding::json::Signature::Sha1Ecdsa(_) => { + // panic!("Unexpected signature format: Sha1Ecdsa. Expected Sha256Ecdsa") + // } + // ota::encoding::json::Signature::Sha256Ecdsa(sig) => { + // assert_eq!(&sig, "This is my custom signature\\n") + // } + // } +} From 4a09047fb06cffe2f5effe0fe2a29a6dbe8dba47 Mon Sep 17 00:00:00 2001 From: Mathias Date: Tue, 9 Jan 2024 14:50:58 +0100 Subject: [PATCH 07/36] Working OTA --- Cargo.toml | 1 + src/jobs/data_types.rs | 2 +- src/jobs/update.rs | 2 +- src/lib.rs | 1 + src/ota/data_interface/mod.rs | 16 +- src/ota/data_interface/mqtt.rs | 242 +++-------------------- src/ota/encoding/mod.rs | 58 +++--- src/ota/mod.rs | 337 ++++++++++++++++++++++----------- src/ota/pal.rs | 4 +- tests/common/file_handler.rs | 174 ++++++++++------- tests/ota.rs | 162 +++++++++------- 11 files changed, 501 insertions(+), 498 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 8b9be4d..a2518d2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,6 +34,7 @@ futures = { version = "0.3.28", default-features = false } embassy-time = { version = "0.2" } embassy-sync = "0.5" +embassy-futures = "0.1" log = { version = "^0.4", default-features = false, optional = true } defmt = { version = "^0.3", optional = true } diff --git a/src/jobs/data_types.rs b/src/jobs/data_types.rs index 23101eb..5910469 100644 --- a/src/jobs/data_types.rs +++ b/src/jobs/data_types.rs @@ -417,7 +417,7 @@ mod test { queued_jobs .push(JobExecutionSummary { execution_number: Some(1), - job_id: Some(String::from("test")), + job_id: Some(String::try_from("test").unwrap()), last_updated_at: Some(1587036256), queued_at: Some(1587036256), started_at: None, diff --git a/src/jobs/update.rs b/src/jobs/update.rs index e2c83f4..db34d90 100644 --- a/src/jobs/update.rs +++ b/src/jobs/update.rs @@ -217,7 +217,7 @@ mod test { .topic_payload("test_client", &mut buf) .unwrap(); - assert_eq!(&ubf[..payload_len], br#"{"executionNumber":5,"expectedVersion":2,"includeJobDocument":true,"includeJobExecutionState":true,"status":"FAILED","stepTimeoutInMinutes":50,"clientToken":"test_client:token_update"}"#); + assert_eq!(&buf[..payload_len], br#"{"executionNumber":5,"expectedVersion":2,"includeJobDocument":true,"includeJobExecutionState":true,"status":"FAILED","stepTimeoutInMinutes":50,"clientToken":"test_client:token_update"}"#); assert_eq!( topic.as_str(), diff --git a/src/lib.rs b/src/lib.rs index 2fac696..eb9c203 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,6 @@ #![cfg_attr(not(any(test, feature = "std")), no_std)] #![allow(incomplete_features)] +#![allow(async_fn_in_trait)] #![feature(generic_const_exprs)] // This mod MUST go first, so that the others see its macros. diff --git a/src/ota/data_interface/mod.rs b/src/ota/data_interface/mod.rs index d80a0c9..bfdff2a 100644 --- a/src/ota/data_interface/mod.rs +++ b/src/ota/data_interface/mod.rs @@ -3,6 +3,8 @@ #[cfg(feature = "ota_mqtt_data")] pub mod mqtt; +use core::ops::DerefMut; + use serde::Deserialize; use crate::ota::config::Config; @@ -41,18 +43,26 @@ impl<'a> FileBlock<'a> { } } +pub trait BlockTransfer { + async fn next_block(&mut self) -> Result, OtaError>; +} + pub trait DataInterface { const PROTOCOL: Protocol; - async fn init_file_transfer(&self, file_ctx: &mut FileContext) -> Result<(), OtaError>; + type ActiveTransfer<'t>: BlockTransfer where Self: 't; + + async fn init_file_transfer(&self, file_ctx: &FileContext) -> Result, OtaError>; + async fn request_file_block( &self, file_ctx: &mut FileContext, config: &Config, ) -> Result<(), OtaError>; - async fn decode_file_block<'a>( + + fn decode_file_block<'a>( &self, - file_ctx: &mut FileContext, + file_ctx: &FileContext, payload: &'a mut [u8], ) -> Result, OtaError>; } diff --git a/src/ota/data_interface/mqtt.rs b/src/ota/data_interface/mqtt.rs index 3f4af47..077cb59 100644 --- a/src/ota/data_interface/mqtt.rs +++ b/src/ota/data_interface/mqtt.rs @@ -1,8 +1,12 @@ use core::fmt::{Display, Write}; +use core::ops::DerefMut; use core::str::FromStr; use embassy_sync::blocking_mutex::raw::RawMutex; -use embedded_mqtt::{MqttClient, Properties, Publish, RetainHandling, Subscribe, SubscribeTopic}; +use embedded_mqtt::{ + MqttClient, Properties, Publish, RetainHandling, Subscribe, SubscribeTopic, Subscription, +}; +use futures::StreamExt; use crate::ota::error::OtaError; use crate::{ @@ -14,6 +18,8 @@ use crate::{ }, }; +use super::BlockTransfer; + #[derive(Debug, Clone, Copy, PartialEq)] pub enum Encoding { Cbor, @@ -117,11 +123,22 @@ impl<'a> OtaTopic<'a> { } } +impl<'a, 'b, M: RawMutex, const SUBS: usize> BlockTransfer for Subscription<'a, 'b, M, SUBS, 1> { + async fn next_block(&mut self) -> Result, OtaError> { + Ok(self.next().await.ok_or(OtaError::Encoding)?) + } +} + impl<'a, M: RawMutex, const SUBS: usize> DataInterface for MqttClient<'a, M, SUBS> { const PROTOCOL: Protocol = Protocol::Mqtt; + type ActiveTransfer<'t> = Subscription<'a, 't, M, SUBS, 1> where Self: 't; + /// Init file transfer by subscribing to the OTA data stream topic - async fn init_file_transfer(&self, file_ctx: &mut FileContext) -> Result<(), OtaError> { + async fn init_file_transfer( + &self, + file_ctx: &FileContext, + ) -> Result, OtaError> { let topic_path = OtaTopic::Data(Encoding::Cbor, file_ctx.stream_name.as_str()) .format::<256>(self.client_id())?; @@ -135,10 +152,7 @@ impl<'a, M: RawMutex, const SUBS: usize> DataInterface for MqttClient<'a, M, SUB debug!("Subscribing to: [{:?}]", &topic_path); - // FIXME: - self.subscribe::<1>(Subscribe::new(&[topic])).await?; - - Ok(()) + Ok(self.subscribe::<1>(Subscribe::new(&[topic])).await?) } /// Request file block by publishing to the get stream topic @@ -147,7 +161,6 @@ impl<'a, M: RawMutex, const SUBS: usize> DataInterface for MqttClient<'a, M, SUB file_ctx: &mut FileContext, config: &Config, ) -> Result<(), OtaError> { - // Reset number of blocks requested file_ctx.request_block_remaining = file_ctx.bitmap.len() as u32; // FIXME: Serialize directly into the publish payload through `DeferredPublish` API @@ -184,9 +197,9 @@ impl<'a, M: RawMutex, const SUBS: usize> DataInterface for MqttClient<'a, M, SUB } /// Decode a cbor encoded fileblock received from streaming service - async fn decode_file_block<'c>( + fn decode_file_block<'c>( &self, - _file_ctx: &mut FileContext, + _file_ctx: &FileContext, payload: &'c mut [u8], ) -> Result, OtaError> { Ok( @@ -196,214 +209,3 @@ impl<'a, M: RawMutex, const SUBS: usize> DataInterface for MqttClient<'a, M, SUB ) } } - -#[cfg(test)] -mod tests { - use mqttrust::{encoding::v4::decode_slice, Packet, SubscribeTopic}; - - use super::*; - use crate::{ota::test::test_file_ctx, test::MockMqtt}; - - #[test] - fn protocol_fits() { - assert_eq!(<&MockMqtt as DataInterface>::PROTOCOL, Protocol::Mqtt); - } - - #[test] - fn init_file_transfer_subscribes() { - let mqtt = &MockMqtt::new(); - - let mut file_ctx = test_file_ctx(&Config::default()); - - mqtt.init_file_transfer(&mut file_ctx).unwrap(); - - assert_eq!(mqtt.tx.borrow_mut().len(), 1); - - let bytes = mqtt.tx.borrow_mut().pop_front().unwrap(); - - let packet = decode_slice(bytes.as_slice()).unwrap(); - let topics = match packet { - Some(Packet::Subscribe(ref s)) => s.topics().collect::>(), - _ => panic!(), - }; - assert_eq!( - topics, - vec![SubscribeTopic { - topic_path: "$aws/things/test_client/streams/test_stream/data/cbor", - maximum_qos: QoS::AtLeastOnce, - no_local: false, - retain_as_published: false, - retain_handling: RetainHandling::SendAtSubscribeTime - }] - ); - } - - #[test] - fn request_file_block_publish() { - let mqtt = &MockMqtt::new(); - - let config = Config::default(); - let mut file_ctx = test_file_ctx(&config); - - mqtt.request_file_block(&mut file_ctx, &config).unwrap(); - - assert_eq!(mqtt.tx.borrow_mut().len(), 1); - - let bytes = mqtt.tx.borrow_mut().pop_front().unwrap(); - - let publish = match decode_slice(bytes.as_slice()).unwrap() { - Some(Packet::Publish(s)) => s, - _ => panic!(), - }; - - assert_eq!( - publish, - mqttrust::encoding::v4::publish::Publish { - dup: false, - qos: QoS::AtMostOnce, - retain: false, - topic_name: "$aws/things/test_client/streams/test_stream/get/cbor", - payload: &[ - 164, 97, 102, 0, 97, 108, 25, 1, 0, 97, 111, 0, 97, 98, 68, 255, 255, 255, 127 - ], - pid: None - } - ); - } - - #[test] - fn decode_file_block_cbor() { - let mqtt = &MockMqtt::new(); - - let mut file_ctx = test_file_ctx(&Config::default()); - - let payload = &mut [ - 191, 97, 102, 0, 97, 105, 0, 97, 108, 25, 4, 0, 97, 112, 89, 4, 0, 141, 62, 28, 246, - 80, 193, 2, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 255, - ]; - - let file_blk = mqtt.decode_file_block(&mut file_ctx, payload).unwrap(); - - assert_eq!(mqtt.tx.borrow_mut().len(), 0); - assert_eq!(file_blk.file_id, 0); - assert_eq!(file_blk.block_id, 0); - assert_eq!( - file_blk.block_payload, - &[ - 141, 62, 28, 246, 80, 193, 2, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 - ] - ); - assert_eq!(file_blk.block_size, 1024); - assert_eq!(file_blk.client_token, None); - } - - #[test] - fn cleanup_unsubscribe() { - let mqtt = &MockMqtt::new(); - - let config = Config::default(); - - let mut file_ctx = test_file_ctx(&config); - - mqtt.cleanup(&mut file_ctx, &config).unwrap(); - - assert_eq!(mqtt.tx.borrow_mut().len(), 1); - let bytes = mqtt.tx.borrow_mut().pop_front().unwrap(); - - let packet = decode_slice(bytes.as_slice()).unwrap(); - let topics = match packet { - Some(Packet::Unsubscribe(ref s)) => s.topics().collect::>(), - _ => panic!(), - }; - - assert_eq!( - topics, - vec!["$aws/things/test_client/streams/test_stream/data/cbor"] - ); - } - - #[test] - fn cleanup_no_unsubscribe() { - let mqtt = &MockMqtt::new(); - - let mut config = Config::default(); - config.unsubscribe_on_shutdown = false; - - let mut file_ctx = test_file_ctx(&config); - - mqtt.cleanup(&mut file_ctx, &config).unwrap(); - - assert_eq!(mqtt.tx.borrow_mut().len(), 0); - } -} diff --git a/src/ota/encoding/mod.rs b/src/ota/encoding/mod.rs index bb9f473..bc68c67 100644 --- a/src/ota/encoding/mod.rs +++ b/src/ota/encoding/mod.rs @@ -7,10 +7,12 @@ use serde::{Serialize, Serializer}; use crate::jobs::StatusDetailsOwned; -use self::json::{JobStatusReason, OtaJob, Signature}; +use self::json::{JobStatusReason, Signature}; use super::config::Config; +use super::data_interface::Protocol; use super::error::OtaError; +use super::JobEventData; #[derive(Clone, PartialEq)] pub struct Bitmap(bitmaps::Bitmap<32>); @@ -63,6 +65,7 @@ pub struct FileContext { pub auth_scheme: Option>, pub signature: Signature, pub file_type: Option, + pub protocols: heapless::Vec, pub status_details: StatusDetailsOwned, pub block_offset: u32, @@ -75,32 +78,28 @@ pub struct FileContext { impl FileContext { pub fn new_from( - job_name: &str, - ota_job: &OtaJob, - status_details: Option, + job_data: JobEventData<'_>, file_idx: usize, config: &Config, ) -> Result { - let file_desc = ota_job + if job_data + .ota_document + .files + .get(file_idx) + .map(|f| f.filesize) + .unwrap_or_default() + == 0 + { + return Err(OtaError::ZeroFileSize); + } + + let file_desc = job_data + .ota_document .files .get(file_idx) .ok_or(OtaError::InvalidFile)? .clone(); - // Initialize new `status_details' if not already present - let status = if let Some(details) = status_details { - details - } else { - let mut status = StatusDetailsOwned::new(); - // status - // .insert( - // heapless::String::try_from("updated_by").unwrap(), - // current_version.to_string(), - // ) - // .map_err(|_| OtaError::Overflow)?; - status - }; - let signature = file_desc.signature(); let block_offset = 0; @@ -109,6 +108,7 @@ impl FileContext { Ok(FileContext { filepath: heapless::String::try_from(file_desc.filepath).unwrap(), filesize: file_desc.filesize, + protocols: job_data.ota_document.protocols, fileid: file_desc.fileid, certfile: heapless::String::try_from(file_desc.certfile).unwrap(), update_data_url: file_desc @@ -120,13 +120,25 @@ impl FileContext { signature, file_type: file_desc.file_type, - status_details: status, - - job_name: heapless::String::try_from(job_name).unwrap(), + status_details: job_data + .status_details + .map(|s| { + s.iter() + .map(|(&k, &v)| { + ( + heapless::String::try_from(k).unwrap(), + heapless::String::try_from(v).unwrap(), + ) + }) + .collect() + }) + .unwrap_or_else(|| StatusDetailsOwned::new()), + + job_name: heapless::String::try_from(job_data.job_name).unwrap(), block_offset, request_block_remaining: bitmap.len() as u32, blocks_remaining: (file_desc.filesize + config.block_size - 1) / config.block_size, - stream_name: heapless::String::try_from(ota_job.streamname).unwrap(), + stream_name: heapless::String::try_from(job_data.ota_document.streamname).unwrap(), bitmap, }) } diff --git a/src/ota/mod.rs b/src/ota/mod.rs index a366b6d..793f27a 100644 --- a/src/ota/mod.rs +++ b/src/ota/mod.rs @@ -37,10 +37,19 @@ pub mod encoding; pub mod error; pub mod pal; +use core::{ + ops::DerefMut, + sync::atomic::{AtomicU8, Ordering}, +}; + #[cfg(feature = "ota_mqtt_data")] pub use data_interface::mqtt::{Encoding, Topic}; +use embassy_sync::{blocking_mutex::raw::NoopRawMutex, mutex::Mutex}; -use crate::{jobs::data_types::JobStatus, ota::encoding::json::JobStatusReason}; +use crate::{ + jobs::data_types::JobStatus, + ota::{data_interface::BlockTransfer, encoding::json::JobStatusReason}, +}; use self::{ control_interface::ControlInterface, @@ -52,59 +61,27 @@ use self::{ #[derive(PartialEq)] pub struct JobEventData<'a> { pub job_name: &'a str, - pub ota_document: &'a encoding::json::OtaJob<'a>, - pub status_details: Option<&'a crate::jobs::StatusDetails<'a>>, + pub ota_document: encoding::json::OtaJob<'a>, + pub status_details: Option>, } pub struct Updater; impl Updater { - pub async fn perform_ota<'a, C: ControlInterface, D: DataInterface>( + pub async fn check_for_job<'a, C: ControlInterface>( + control: &C, + ) -> Result<(), error::OtaError> { + control.request_job().await?; + Ok(()) + } + + pub async fn perform_ota<'a, 'b, C: ControlInterface, D: DataInterface>( control: &C, data: &D, - job_data: JobEventData<'a>, + mut file_ctx: FileContext, pal: &mut impl pal::OtaPal, config: config::Config, ) -> Result<(), error::OtaError> { - let mut request_momentum = 0; - - // TODO: Handle request_momentum? - control.request_job().await?; - - let JobEventData { - job_name, - ota_document, - status_details, - } = job_data; - - let file_idx = 0; - - if ota_document - .files - .get(file_idx) - .map(|f| f.filesize) - .unwrap_or_default() - == 0 - { - return Err(error::OtaError::ZeroFileSize); - } - - let mut file_ctx = FileContext::new_from( - job_name, - ota_document, - status_details.map(|s| { - s.iter() - .map(|(&k, &v)| { - ( - heapless::String::try_from(k).unwrap(), - heapless::String::try_from(v).unwrap(), - ) - }) - .collect() - }), - file_idx, - &config, - )?; // If the job is in self test mode, don't start an OTA update but // instead do the following: @@ -180,8 +157,8 @@ impl Updater { } } - if !ota_document.protocols.contains(&D::PROTOCOL) { - error!("Unable to handle current OTA job with given data interface ({:?}). Supported protocols: {:?}. Aborting current update.", D::PROTOCOL, ota_document.protocols); + if !file_ctx.protocols.contains(&D::PROTOCOL) { + error!("Unable to handle current OTA job with given data interface ({:?}). Supported protocols: {:?}. Aborting current update.", D::PROTOCOL, file_ctx.protocols); Self::set_image_state_with_reason( control, pal, @@ -211,76 +188,226 @@ impl Updater { } // Prepare the storage layer on receiving a new file - match data.init_file_transfer(&mut file_ctx).await { - Err(e) => { - return if request_momentum < config.max_request_momentum { - // Start request timer - // self.request_timer - // .start(config.request_wait.millis()) - // .map_err(|_| error::OtaError::Timer)?; - - request_momentum += 1; - Err(e) - } else { - // Stop request timer - // self.request_timer - // .cancel() - // .map_err(|_| error::OtaError::Timer)?; - - // Too many requests have been sent without a response or - // too many failures when trying to publish the request - // message. Abort. - - Err(error::OtaError::MomentumAbort) - }; + let mut subscription = data.init_file_transfer(&mut file_ctx).await?; + + info!("Initialized file handler! Requesting file blocks"); + + let request_momentum = AtomicU8::new(0); + + // let momentum_fut = async { + // while file_ctx.lock().await.blocks_remaining > 0 { + // if request_momentum.load(Ordering::Relaxed) <= config.max_request_momentum { + // // Each request increases the momentum until a response is + // // received. Too much momentum is interpreted as a failure to + // // communicate and will cause us to abort the OTA. + // request_momentum.fetch_add(1, Ordering::Relaxed); + + // // Reset number of blocks requested + // let mut ctx = file_ctx.lock().await; + // ctx.request_block_remaining = ctx.bitmap.len() as u32; + + // // Request data blocks + // data.request_file_block(&ctx, &config).await?; + // } else { + // // Too many requests have been sent without a response or too + // // many failures when trying to publish the request message. + // // Abort. + // return Err(error::OtaError::MomentumAbort); + // } + + // embassy_time::Timer::after(config.request_wait).await; + // } + + // Ok(()) + // }; + + let data_fut = async { + data.request_file_block(&mut file_ctx, &config).await?; + + while let Ok(mut payload) = subscription.next_block().await { + debug!("process_data_handler"); + // Decode the file block received + match Self::ingest_data_block(data, pal, &config, &mut file_ctx, payload.deref_mut()) + .await + { + Ok(true) => { + // File is completed! Update progress accordingly. + match pal.close_file(&file_ctx).await { + Err(e) => { + control + .update_job_status( + &mut file_ctx, + &config, + JobStatus::Failed, + JobStatusReason::Pal(0), + ) + .await?; + + return Err(e.into()); + } + Ok(_) => { + let (status, reason, event) = if let Some(0) = file_ctx.file_type { + ( + JobStatus::InProgress, + JobStatusReason::SigCheckPassed, + pal::OtaEvent::Activate, + ) + } else { + ( + JobStatus::Succeeded, + JobStatusReason::Accepted, + pal::OtaEvent::UpdateComplete, + ) + }; + + control + .update_job_status(&mut file_ctx, &config, status, reason) + .await?; + + return Ok(event); + } + } + } + Ok(false) => { + debug!("Ingested one block!"); + // Reset the momentum counter since we received a good block + request_momentum.store(0, Ordering::Relaxed); + + // We're actively receiving a file so update the job status as + // needed + control + .update_job_status( + &mut file_ctx, + &config, + JobStatus::InProgress, + JobStatusReason::Receiving, + ) + .await?; + + if file_ctx.request_block_remaining > 1 { + file_ctx.request_block_remaining -= 1; + } else { + data.request_file_block(&mut file_ctx, &config).await?; + } + } + Err(e) if e.is_retryable() => { + warn!("Failed to ingest data block, Error is retryable! ingest_data_block returned error {:?}", e); + } + Err(e) => { + error!("Failed to ingest data block, rejecting image: ingest_data_block returned error {:?}", e); + + // Call the platform specific code to reject the image + // TODO: This should never write to current image flags?! + // pal.set_platform_image_state(ImageState::Rejected( + // ImageStateReason::FailedIngest, + // )) + // .await?; + + // TODO: Pal reason + control + .update_job_status( + &mut file_ctx, + &config, + JobStatus::Failed, + JobStatusReason::Pal(0), + ) + .await?; + + pal.complete_callback(pal::OtaEvent::Fail).await?; + info!("Application callback! OtaEvent::Fail"); + return Err(e); + } + } } - Ok(_) => { - // Reset the request momentum - request_momentum = 0; - // TODO: Reset the OTA statistics + Err(error::OtaError::Mqtt(embedded_mqtt::Error::EOF)) + }; + + // let (momentum_res, data_res) = embassy_futures::join::join(momentum_fut, data_fut).await; + + let data_res = data_fut.await; + + // if let Err(e) = momentum_res { + // // Failed to send data request abort and close file. + // Self::set_image_state_with_reason( + // control, + // pal, + // &config, + // &mut file_ctx, + // ImageState::Aborted(ImageStateReason::MomentumAbort), + // ) + // .await?; + + // return Err(e); + // }; + + pal.complete_callback(data_res?).await?; + + Ok(()) + } - info!("Initialized file handler! Requesting file blocks"); + async fn ingest_data_block<'a, D: DataInterface, PAL: pal::OtaPal>( + data: &D, + pal: &mut PAL, + config: &config::Config, + file_ctx: &mut FileContext, + payload: &mut [u8], + ) -> Result { + let block = data.decode_file_block(&file_ctx, payload)?; + if block.validate(config.block_size, file_ctx.filesize) { + if block.block_id < file_ctx.block_offset as usize + || !file_ctx.bitmap.get(block.block_id - file_ctx.block_offset as usize) + { + info!( + "Block {:?} is a DUPLICATE. {:?} blocks remaining.", + block.block_id, file_ctx.blocks_remaining + ); + + // Just return same progress as before + return Ok(false); } - } - // Request data - if file_ctx.blocks_remaining > 0 { - if request_momentum <= config.max_request_momentum { - // Each request increases the momentum until a response is - // received. Too much momentum is interpreted as a failure to - // communicate and will cause us to abort the OTA. - request_momentum += 1; + info!( + "Received block {}. {:?} blocks remaining.", + block.block_id, file_ctx.blocks_remaining + ); - // Request data blocks - data.request_file_block(&mut file_ctx, &config).await?; - } else { - // Stop the request timer - // self.request_timer.cancel().map_err(|_| error::OtaError::Timer)?; + pal.write_block( + file_ctx, + block.block_id * config.block_size, + block.block_payload, + ) + .await?; - // Failed to send data request abort and close file. - Self::set_image_state_with_reason( - control, - pal, - &config, - &mut file_ctx, - ImageState::Aborted(ImageStateReason::MomentumAbort), - ) - .await?; - // Reset the request momentum - request_momentum = 0; + let block_offset = file_ctx.block_offset; + file_ctx.bitmap + .set(block.block_id - block_offset as usize, false); + + file_ctx.blocks_remaining -= 1; + + if file_ctx.blocks_remaining == 0 { + info!("Received final expected block of file."); - // Too many requests have been sent without a response or too - // many failures when trying to publish the request message. - // Abort. - return Err(error::OtaError::MomentumAbort); + // Return true to indicate end of file. + Ok(true) + } else { + if file_ctx.bitmap.is_empty() { + file_ctx.block_offset += 31; + file_ctx.bitmap = + encoding::Bitmap::new(file_ctx.filesize, config.block_size, file_ctx.block_offset); + } + + Ok(false) } } else { - return Err(error::OtaError::BlockOutOfRange); - } + error!( + "Error! Block {:?} out of expected range! Size {:?}", + block.block_id, block.block_size + ); - Ok(()) + Err(error::OtaError::BlockOutOfRange) + } } async fn set_image_state_with_reason<'a, C: ControlInterface, PAL: pal::OtaPal>( @@ -290,10 +417,8 @@ impl Updater { file_ctx: &mut FileContext, image_state: ImageState, ) -> Result<(), error::OtaError> { - // debug!("set_image_state_with_reason {:?}", image_state); // Call the platform specific code to set the image state - // FIXME: let image_state = match pal.set_platform_image_state(image_state).await { Err(e) if !matches!(image_state, ImageState::Aborted(_)) => { // If the platform image state couldn't be set correctly, force diff --git a/src/ota/pal.rs b/src/ota/pal.rs index 409f5ed..063bb0b 100644 --- a/src/ota/pal.rs +++ b/src/ota/pal.rs @@ -1,4 +1,6 @@ //! Platform abstraction trait for OTA updates +use embassy_sync::{blocking_mutex::raw::NoopRawMutex, mutex::Mutex}; + use super::encoding::FileContext; #[derive(Debug, Clone, Copy)] @@ -166,7 +168,7 @@ pub trait OtaPal { /// code from the platform abstraction layer. async fn write_block( &mut self, - file: &FileContext, + file: &mut FileContext, block_offset: usize, block_payload: &[u8], ) -> Result; diff --git a/tests/common/file_handler.rs b/tests/common/file_handler.rs index 335e1af..a23b796 100644 --- a/tests/common/file_handler.rs +++ b/tests/common/file_handler.rs @@ -1,91 +1,119 @@ -// use rustot::ota::pal::{OtaPal, OtaPalError, PalImageState}; -use tokio::fs::File; -use tokio::io::{Cursor, Write}; +use embassy_sync::blocking_mutex::raw::NoopRawMutex; +use embassy_sync::mutex::Mutex; +use rustot::ota::{pal::{OtaPal, OtaPalError, PalImageState}, self}; +use sha2::{Sha256, Digest}; +use std::{io::{Cursor, Write, Read}, fs::File}; +use core::ops::Deref; pub struct FileHandler { filebuf: Option>>, + compare_file_path: String, } impl FileHandler { - pub fn new() -> Self { - FileHandler { filebuf: None } + pub fn new(compare_file_path: String) -> Self { + FileHandler { + filebuf: None, + compare_file_path, + } } } -// impl OtaPal for FileHandler { -// type Error = (); +impl OtaPal for FileHandler { + async fn abort( + &mut self, + _file: &rustot::ota::encoding::FileContext, + ) -> Result<(), OtaPalError> { + Ok(()) + } + + async fn create_file_for_rx( + &mut self, + file: &rustot::ota::encoding::FileContext, + ) -> Result<(), OtaPalError> { + self.filebuf = Some(Cursor::new(Vec::with_capacity(file.filesize))); + Ok(()) + } + + async fn get_platform_image_state(&mut self) -> Result { + Ok(PalImageState::Valid) + } -// fn abort( -// &mut self, -// _file: &rustot::ota::encoding::FileContext, -// ) -> Result<(), OtaPalError> { -// Ok(()) -// } + async fn set_platform_image_state( + &mut self, + _image_state: rustot::ota::pal::ImageState, + ) -> Result<(), OtaPalError> { + Ok(()) + } -// fn create_file_for_rx( -// &mut self, -// file: &rustot::ota::encoding::FileContext, -// ) -> Result<(), OtaPalError> { -// self.filebuf = Some(Cursor::new(Vec::with_capacity(file.filesize))); -// Ok(()) -// } + async fn reset_device(&mut self) -> Result<(), OtaPalError> { + Ok(()) + } -// fn get_platform_image_state(&mut self) -> Result> { -// Ok(PalImageState::Valid) -// } + async fn close_file( + &mut self, + file: &rustot::ota::encoding::FileContext, + ) -> Result<(), OtaPalError> { + if let Some(ref mut buf) = &mut self.filebuf { + log::debug!( + "Closing completed file. Len: {}/{} -> {}", + buf.get_ref().len(), + file.filesize, + file.filepath.as_str() + ); -// fn set_platform_image_state( -// &mut self, -// _image_state: rustot::ota::pal::ImageState<()>, -// ) -> Result<(), OtaPalError> { -// Ok(()) -// } + + let mut expected_data = std::fs::read(self.compare_file_path.as_str()).unwrap(); + let mut expected_hasher = ::new(); + expected_hasher.update(&expected_data); + let expected_hash = expected_hasher.finalize(); -// fn reset_device(&mut self) -> Result<(), OtaPalError> { -// Ok(()) -// } + log::info!( + "Comparing {:?} with {:?}", + self.compare_file_path, + file.filepath.as_str() + ); + assert_eq!(buf.get_ref().len(), file.filesize); -// fn close_file( -// &mut self, -// file: &rustot::ota::encoding::FileContext, -// ) -> Result<(), OtaPalError> { -// if let Some(ref mut buf) = &mut self.filebuf { -// log::debug!( -// "Closing completed file. Len: {}/{} -> {}", -// buf.get_ref().len(), -// file.filesize, -// file.filepath.as_str() -// ); -// let mut file = -// File::create(file.filepath.as_str()).map_err(|_| OtaPalError::FileWriteFailed)?; -// file.write_all(buf.get_ref()) -// .map_err(|_| OtaPalError::FileWriteFailed)?; + let mut hasher = ::new(); + hasher.update(&buf.get_ref()); + assert_eq!(hasher.finalize().deref(), expected_hash.deref()); -// Ok(()) -// } else { -// Err(OtaPalError::BadFileHandle) -// } -// } + // Check file signature + match &file.signature { + ota::encoding::json::Signature::Sha1Rsa(_) => { + panic!("Unexpected signature format: Sha1Rsa. Expected Sha256Ecdsa") + } + ota::encoding::json::Signature::Sha256Rsa(_) => { + panic!("Unexpected signature format: Sha256Rsa. Expected Sha256Ecdsa") + } + ota::encoding::json::Signature::Sha1Ecdsa(_) => { + panic!("Unexpected signature format: Sha1Ecdsa. Expected Sha256Ecdsa") + } + ota::encoding::json::Signature::Sha256Ecdsa(sig) => { + assert_eq!(sig.as_str(), "This is my custom signature\\n") + } + } -// fn write_block( -// &mut self, -// _file: &rustot::ota::encoding::FileContext, -// block_offset: usize, -// block_payload: &[u8], -// ) -> Result> { -// if let Some(ref mut buf) = &mut self.filebuf { -// buf.set_position(block_offset as u64); -// buf.write(block_payload) -// .map_err(|_e| OtaPalError::FileWriteFailed)?; -// Ok(block_payload.len()) -// } else { -// Err(OtaPalError::BadFileHandle) -// } -// } + Ok(()) + } else { + Err(OtaPalError::BadFileHandle) + } + } -// fn get_active_firmware_version( -// &self, -// ) -> Result> { -// Ok(rustot::ota::pal::Version::new(0, 1, 0)) -// } -// } + async fn write_block( + &mut self, + _file: &mut rustot::ota::encoding::FileContext, + block_offset: usize, + block_payload: &[u8], + ) -> Result { + if let Some(ref mut buf) = &mut self.filebuf { + buf.set_position(block_offset as u64); + buf.write(block_payload) + .map_err(|_e| OtaPalError::FileWriteFailed)?; + Ok(block_payload.len()) + } else { + Err(OtaPalError::BadFileHandle) + } + } +} diff --git a/tests/ota.rs b/tests/ota.rs index 5992196..3363b0c 100644 --- a/tests/ota.rs +++ b/tests/ota.rs @@ -25,7 +25,11 @@ use rustot::{ data_types::{DescribeJobExecutionResponse, NextJobExecutionChanged}, JobTopic, StatusDetails, }, - ota::{self, encoding::json::OtaJob, JobEventData, Updater}, + ota::{ + self, + encoding::{json::OtaJob, FileContext}, + JobEventData, Updater, + }, }; #[derive(Debug, Deserialize)] @@ -35,6 +39,14 @@ pub enum Jobs<'a> { Ota(OtaJob<'a>), } +impl<'a> Jobs<'a> { + pub fn ota_job(self) -> Option> { + match self { + Jobs::Ota(ota_job) => Some(ota_job), + } + } +} + fn handle_job<'a, M: RawMutex, const SUBS: usize>( message: &'a Message<'_, M, SUBS>, ) -> Option> { @@ -42,9 +54,9 @@ fn handle_job<'a, M: RawMutex, const SUBS: usize>( Some(jobs::Topic::NotifyNext) => { let (execution_changed, _) = serde_json_core::from_slice::>(&message.payload()) - .map_err(drop)?; - let job = execution_changed.execution.ok_or(())?; - let ota_job = job.job_document.ok_or(())?.ota_job().ok_or(())?; + .ok()?; + let job = execution_changed.execution?; + let ota_job = job.job_document?.ota_job()?; Some(JobEventData { job_name: job.job_id, ota_document: ota_job, @@ -55,9 +67,9 @@ fn handle_job<'a, M: RawMutex, const SUBS: usize>( let (execution_changed, _) = serde_json_core::from_slice::< DescribeJobExecutionResponse, >(&message.payload()) - .map_err(drop)?; - let job = execution_changed.execution.ok_or(())?; - let ota_job = job.job_document.ok_or(())?.ota_job().ok_or(())?; + .ok()?; + let job = execution_changed.execution?; + let ota_job = job.job_document?.ota_job()?; Some(JobEventData { job_name: job.job_id, ota_document: ota_job, @@ -68,12 +80,6 @@ fn handle_job<'a, M: RawMutex, const SUBS: usize>( } } -pub struct FileInfo { - pub file_path: String, - pub filesize: usize, - pub signature: ota::encoding::json::Signature, -} - #[tokio::test(flavor = "current_thread")] async fn test_mqtt_ota() { env_logger::init(); @@ -91,32 +97,90 @@ async fn test_mqtt_ota() { let config = Config::new(thing_name, broker).keepalive_interval(embassy_time::Duration::from_secs(50)); - let state = make_static!(State::::new()); + let state = make_static!(State::::new()); let (mut stack, client) = embedded_mqtt::new(state, config, network); let client = make_static!(client); let ota_fut = async { - let jobs_subscription = client - .subscribe(Subscribe::new(&[SubscribeTopic { - topic_path: jobs::JobTopic::NotifyNext - .format::<64>(thing_name)? - .as_str(), - maximum_qos: QoS::AtLeastOnce, - no_local: false, - retain_as_published: false, - retain_handling: RetainHandling::SendAtSubscribeTime, - }])) + let mut jobs_subscription = client + .subscribe::<2>(Subscribe::new(&[ + SubscribeTopic { + topic_path: jobs::JobTopic::NotifyNext + .format::<64>(thing_name)? + .as_str(), + maximum_qos: QoS::AtLeastOnce, + no_local: false, + retain_as_published: false, + retain_handling: RetainHandling::SendAtSubscribeTime, + }, + SubscribeTopic { + topic_path: jobs::JobTopic::DescribeAccepted("$next") + .format::<64>(thing_name)? + .as_str(), + maximum_qos: QoS::AtLeastOnce, + no_local: false, + retain_as_published: false, + retain_handling: RetainHandling::SendAtSubscribeTime, + }, + ])) .await?; + Updater::check_for_job(client).await?; + while let Some(message) = jobs_subscription.next().await { - if let Some(job_details) = handle_job(&message) { + let config = ota::config::Config::default(); + if let Some(mut file_ctx) = match jobs::Topic::from_str(message.topic_name()) { + Some(jobs::Topic::NotifyNext) => { + let (execution_changed, _) = serde_json_core::from_slice::< + NextJobExecutionChanged, + >(&message.payload()) + .ok() + .unwrap(); + let job = execution_changed.execution.unwrap(); + let ota_job = job.job_document.unwrap().ota_job().unwrap(); + FileContext::new_from( + JobEventData { + job_name: job.job_id, + ota_document: ota_job, + status_details: job.status_details, + }, + 0, + &config, + ) + .ok() + } + Some(jobs::Topic::DescribeAccepted(_)) => { + let (execution_changed, _) = serde_json_core::from_slice::< + DescribeJobExecutionResponse, + >(&message.payload()) + .ok() + .unwrap(); + let job = execution_changed.execution.unwrap(); + let ota_job = job.job_document.unwrap().ota_job().unwrap(); + FileContext::new_from( + JobEventData { + job_name: job.job_id, + ota_document: ota_job, + status_details: job.status_details, + }, + 0, + &config, + ) + .ok() + } + _ => None, + } { + drop(message); // We have an OTA job, leeeets go! - let config = ota::config::Config::default(); - let mut file_handler = FileHandler::new(); - Updater::perform_ota(client, client, job_details, &mut file_hander, config).await; + let mut file_handler = FileHandler::new("tests/assets/ota_file".to_owned()); + Updater::perform_ota(client, client, file_ctx, &mut file_handler, config).await?; + + return Ok(()); } } + + Ok::<_, ota::error::OtaError>(()) }; match select::select(stack.run(), ota_fut).await { @@ -125,46 +189,4 @@ async fn test_mqtt_ota() { } select::Either::Second(result) => result.unwrap(), }; - - // let mut expected_file = File::open("tests/assets/ota_file").unwrap(); - // let mut expected_data = Vec::new(); - // expected_file.read_to_end(&mut expected_data).unwrap(); - // let mut expected_hasher = Sha256::new(); - // expected_hasher.update(&expected_data); - // let expected_hash = expected_hasher.finalize(); - - // let file_info = file_info.unwrap(); - - // log::info!( - // "Comparing {:?} with {:?}", - // "tests/assets/ota_file", - // file_info.file_path - // ); - // let mut file = File::open(file_info.file_path.clone()).unwrap(); - // let mut data = Vec::new(); - // file.read_to_end(&mut data).unwrap(); - // drop(file); - // std::fs::remove_file(file_info.file_path).unwrap(); - - // assert_eq!(data.len(), file_info.filesize); - - // let mut hasher = Sha256::new(); - // hasher.update(&data); - // assert_eq!(hasher.finalize().deref(), expected_hash.deref()); - - // // Check file signature - // match file_info.signature { - // ota::encoding::json::Signature::Sha1Rsa(_) => { - // panic!("Unexpected signature format: Sha1Rsa. Expected Sha256Ecdsa") - // } - // ota::encoding::json::Signature::Sha256Rsa(_) => { - // panic!("Unexpected signature format: Sha256Rsa. Expected Sha256Ecdsa") - // } - // ota::encoding::json::Signature::Sha1Ecdsa(_) => { - // panic!("Unexpected signature format: Sha1Ecdsa. Expected Sha256Ecdsa") - // } - // ota::encoding::json::Signature::Sha256Ecdsa(sig) => { - // assert_eq!(&sig, "This is my custom signature\\n") - // } - // } } From 531d6f1069af16530c6b8c641b3ede66152b6c3d Mon Sep 17 00:00:00 2001 From: Mathias Date: Tue, 9 Jan 2024 14:51:22 +0100 Subject: [PATCH 08/36] Formatting --- src/ota/data_interface/mod.rs | 9 +++++++-- src/ota/mod.rs | 26 ++++++++++++++++++-------- tests/common/file_handler.rs | 15 ++++++++++----- 3 files changed, 35 insertions(+), 15 deletions(-) diff --git a/src/ota/data_interface/mod.rs b/src/ota/data_interface/mod.rs index bfdff2a..3c65af7 100644 --- a/src/ota/data_interface/mod.rs +++ b/src/ota/data_interface/mod.rs @@ -50,9 +50,14 @@ pub trait BlockTransfer { pub trait DataInterface { const PROTOCOL: Protocol; - type ActiveTransfer<'t>: BlockTransfer where Self: 't; + type ActiveTransfer<'t>: BlockTransfer + where + Self: 't; - async fn init_file_transfer(&self, file_ctx: &FileContext) -> Result, OtaError>; + async fn init_file_transfer( + &self, + file_ctx: &FileContext, + ) -> Result, OtaError>; async fn request_file_block( &self, diff --git a/src/ota/mod.rs b/src/ota/mod.rs index 793f27a..2b9c3e4 100644 --- a/src/ota/mod.rs +++ b/src/ota/mod.rs @@ -82,7 +82,6 @@ impl Updater { pal: &mut impl pal::OtaPal, config: config::Config, ) -> Result<(), error::OtaError> { - // If the job is in self test mode, don't start an OTA update but // instead do the following: // @@ -227,8 +226,14 @@ impl Updater { while let Ok(mut payload) = subscription.next_block().await { debug!("process_data_handler"); // Decode the file block received - match Self::ingest_data_block(data, pal, &config, &mut file_ctx, payload.deref_mut()) - .await + match Self::ingest_data_block( + data, + pal, + &config, + &mut file_ctx, + payload.deref_mut(), + ) + .await { Ok(true) => { // File is completed! Update progress accordingly. @@ -356,7 +361,9 @@ impl Updater { let block = data.decode_file_block(&file_ctx, payload)?; if block.validate(config.block_size, file_ctx.filesize) { if block.block_id < file_ctx.block_offset as usize - || !file_ctx.bitmap.get(block.block_id - file_ctx.block_offset as usize) + || !file_ctx + .bitmap + .get(block.block_id - file_ctx.block_offset as usize) { info!( "Block {:?} is a DUPLICATE. {:?} blocks remaining.", @@ -379,9 +386,9 @@ impl Updater { ) .await?; - let block_offset = file_ctx.block_offset; - file_ctx.bitmap + file_ctx + .bitmap .set(block.block_id - block_offset as usize, false); file_ctx.blocks_remaining -= 1; @@ -394,8 +401,11 @@ impl Updater { } else { if file_ctx.bitmap.is_empty() { file_ctx.block_offset += 31; - file_ctx.bitmap = - encoding::Bitmap::new(file_ctx.filesize, config.block_size, file_ctx.block_offset); + file_ctx.bitmap = encoding::Bitmap::new( + file_ctx.filesize, + config.block_size, + file_ctx.block_offset, + ); } Ok(false) diff --git a/tests/common/file_handler.rs b/tests/common/file_handler.rs index a23b796..4b5c316 100644 --- a/tests/common/file_handler.rs +++ b/tests/common/file_handler.rs @@ -1,9 +1,15 @@ +use core::ops::Deref; use embassy_sync::blocking_mutex::raw::NoopRawMutex; use embassy_sync::mutex::Mutex; -use rustot::ota::{pal::{OtaPal, OtaPalError, PalImageState}, self}; -use sha2::{Sha256, Digest}; -use std::{io::{Cursor, Write, Read}, fs::File}; -use core::ops::Deref; +use rustot::ota::{ + self, + pal::{OtaPal, OtaPalError, PalImageState}, +}; +use sha2::{Digest, Sha256}; +use std::{ + fs::File, + io::{Cursor, Read, Write}, +}; pub struct FileHandler { filebuf: Option>>, @@ -62,7 +68,6 @@ impl OtaPal for FileHandler { file.filepath.as_str() ); - let mut expected_data = std::fs::read(self.compare_file_path.as_str()).unwrap(); let mut expected_hasher = ::new(); expected_hasher.update(&expected_data); From eaa82b833109ddd9ef231c2d08f458a2966fe3b7 Mon Sep 17 00:00:00 2001 From: Mathias Date: Thu, 11 Jan 2024 19:59:11 +0100 Subject: [PATCH 09/36] Fully working OTA integration test --- tests/ota.rs | 34 ---------------------------------- 1 file changed, 34 deletions(-) diff --git a/tests/ota.rs b/tests/ota.rs index 3363b0c..9b40edd 100644 --- a/tests/ota.rs +++ b/tests/ota.rs @@ -47,39 +47,6 @@ impl<'a> Jobs<'a> { } } -fn handle_job<'a, M: RawMutex, const SUBS: usize>( - message: &'a Message<'_, M, SUBS>, -) -> Option> { - match jobs::Topic::from_str(message.topic_name()) { - Some(jobs::Topic::NotifyNext) => { - let (execution_changed, _) = - serde_json_core::from_slice::>(&message.payload()) - .ok()?; - let job = execution_changed.execution?; - let ota_job = job.job_document?.ota_job()?; - Some(JobEventData { - job_name: job.job_id, - ota_document: ota_job, - status_details: job.status_details, - }) - } - Some(jobs::Topic::DescribeAccepted(_)) => { - let (execution_changed, _) = serde_json_core::from_slice::< - DescribeJobExecutionResponse, - >(&message.payload()) - .ok()?; - let job = execution_changed.execution?; - let ota_job = job.job_document?.ota_job()?; - Some(JobEventData { - job_name: job.job_id, - ota_document: ota_job, - status_details: job.status_details, - }) - } - _ => None, - } -} - #[tokio::test(flavor = "current_thread")] async fn test_mqtt_ota() { env_logger::init(); @@ -154,7 +121,6 @@ async fn test_mqtt_ota() { let (execution_changed, _) = serde_json_core::from_slice::< DescribeJobExecutionResponse, >(&message.payload()) - .ok() .unwrap(); let job = execution_changed.execution.unwrap(); let ota_job = job.job_document.unwrap().ota_job().unwrap(); From 20f9db983873f28f6184b1598b990a46a5e6bc81 Mon Sep 17 00:00:00 2001 From: Mathias Date: Thu, 11 Jan 2024 21:03:38 +0100 Subject: [PATCH 10/36] Clean up OTA integration test, make sure it reports success and add more assertions --- Cargo.toml | 15 +++-- rust-toolchain.toml | 3 +- src/ota/mod.rs | 6 +- src/ota/pal.rs | 2 - tests/common/file_handler.rs | 21 ++++++- tests/ota.rs | 106 ++++++++++++++++++++--------------- 6 files changed, 93 insertions(+), 60 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index a2518d2..8ceb109 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,7 @@ members = ["shadow_derive"] [package] name = "rustot" -version = "0.4.1" +version = "0.5.0" authors = ["Mathias Koch "] description = "AWS IoT" readme = "README.md" @@ -22,10 +22,10 @@ name = "rustot" maintenance = { status = "actively-developed" } [dependencies] -bitmaps = { version = "^3.1", default-features = false } +bitmaps = { version = "3.1", default-features = false } heapless = { version = "0.8", features = ["serde"] } -serde = { version = "1.0.126", default-features = false, features = ["derive"] } -serde_cbor = { version = "^0.11", default-features = false, optional = true } +serde = { version = "1.0", default-features = false, features = ["derive"] } +serde_cbor = { version = "0.11", default-features = false, optional = true } serde-json-core = { version = "0.5" } shadow-derive = { path = "shadow_derive", version = "0.2.1" } embedded-storage-async = "0.4" @@ -36,11 +36,11 @@ embassy-time = { version = "0.2" } embassy-sync = "0.5" embassy-futures = "0.1" -log = { version = "^0.4", default-features = false, optional = true } -defmt = { version = "^0.3", optional = true } +log = { version = "0.4", default-features = false, optional = true } +defmt = { version = "0.3", optional = true } [dev-dependencies] -native-tls = { version = "^0.2" } +native-tls = { version = "0.2" } embedded-nal-async = "0.7" env_logger = "0.10" sha2 = "0.10.1" @@ -54,7 +54,6 @@ embedded-io-adapters = { version = "0.6.0", features = ["tokio-1"] } ecdsa = { version = "0.16", features = ["pkcs8", "pem"] } p256 = "0.13" pkcs8 = { version = "0.10", features = ["encryption", "pem"] } -timebomb = "0.1.2" hex = { version = "0.4.3", features = ["alloc"] } [features] diff --git a/rust-toolchain.toml b/rust-toolchain.toml index b79d547..ff76255 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -3,5 +3,6 @@ channel = "nightly-2023-12-24" components = [ "rust-src", "rustfmt", "llvm-tools", "clippy" ] targets = [ "x86_64-unknown-linux-gnu", - "thumbv7em-none-eabihf" + "thumbv7em-none-eabihf", + "thumbv6m-none-eabi" ] diff --git a/src/ota/mod.rs b/src/ota/mod.rs index 2b9c3e4..2e381fc 100644 --- a/src/ota/mod.rs +++ b/src/ota/mod.rs @@ -44,7 +44,6 @@ use core::{ #[cfg(feature = "ota_mqtt_data")] pub use data_interface::mqtt::{Encoding, Topic}; -use embassy_sync::{blocking_mutex::raw::NoopRawMutex, mutex::Mutex}; use crate::{ jobs::data_types::JobStatus, @@ -80,7 +79,7 @@ impl Updater { data: &D, mut file_ctx: FileContext, pal: &mut impl pal::OtaPal, - config: config::Config, + config: &config::Config, ) -> Result<(), error::OtaError> { // If the job is in self test mode, don't start an OTA update but // instead do the following: @@ -137,6 +136,7 @@ impl Updater { // reboot the device to allow roll back to previous image. error!("Rejecting new image and rebooting: The platform is in the self-test state while the job is not."); pal.reset_device().await?; + return Err(error::OtaError::ResetFailed); } (true, false) => { // The job is in self test but the platform image state is not so it @@ -153,6 +153,7 @@ impl Updater { ) .await?; pal.reset_device().await?; + return Err(error::OtaError::ResetFailed); } } @@ -193,6 +194,7 @@ impl Updater { let request_momentum = AtomicU8::new(0); + // FIXME: // let momentum_fut = async { // while file_ctx.lock().await.blocks_remaining > 0 { // if request_momentum.load(Ordering::Relaxed) <= config.max_request_momentum { diff --git a/src/ota/pal.rs b/src/ota/pal.rs index 063bb0b..06c6d64 100644 --- a/src/ota/pal.rs +++ b/src/ota/pal.rs @@ -1,6 +1,4 @@ //! Platform abstraction trait for OTA updates -use embassy_sync::{blocking_mutex::raw::NoopRawMutex, mutex::Mutex}; - use super::encoding::FileContext; #[derive(Debug, Clone, Copy)] diff --git a/tests/common/file_handler.rs b/tests/common/file_handler.rs index 4b5c316..942a082 100644 --- a/tests/common/file_handler.rs +++ b/tests/common/file_handler.rs @@ -11,9 +11,16 @@ use std::{ io::{Cursor, Read, Write}, }; +#[derive(Debug, PartialEq, Eq)] +pub enum State { + Swap, + Boot, +} + pub struct FileHandler { filebuf: Option>>, compare_file_path: String, + pub plateform_state: State, } impl FileHandler { @@ -21,6 +28,7 @@ impl FileHandler { FileHandler { filebuf: None, compare_file_path, + plateform_state: State::Boot, } } } @@ -42,13 +50,20 @@ impl OtaPal for FileHandler { } async fn get_platform_image_state(&mut self) -> Result { - Ok(PalImageState::Valid) + Ok(match self.plateform_state { + State::Swap => PalImageState::PendingCommit, + State::Boot => PalImageState::Valid, + }) } async fn set_platform_image_state( &mut self, - _image_state: rustot::ota::pal::ImageState, + image_state: rustot::ota::pal::ImageState, ) -> Result<(), OtaPalError> { + if matches!(image_state, rustot::ota::pal::ImageState::Accepted) { + self.plateform_state = State::Boot; + } + Ok(()) } @@ -100,6 +115,8 @@ impl OtaPal for FileHandler { } } + self.plateform_state = State::Swap; + Ok(()) } else { Err(OtaPalError::BadFileHandle) diff --git a/tests/ota.rs b/tests/ota.rs index 9b40edd..60bdb3b 100644 --- a/tests/ota.rs +++ b/tests/ota.rs @@ -6,7 +6,7 @@ mod common; use std::{net::ToSocketAddrs, process}; use common::credentials; -use common::file_handler::FileHandler; +use common::file_handler::{FileHandler, State as FileHandlerState}; use common::network::TlsNetwork; use embassy_futures::select; use embassy_sync::blocking_mutex::raw::{NoopRawMutex, RawMutex}; @@ -47,6 +47,41 @@ impl<'a> Jobs<'a> { } } +fn handle_ota<'a, const SUBS: usize>( + message: Message<'a, NoopRawMutex, SUBS>, + config: &ota::config::Config, +) -> Option { + let job = match jobs::Topic::from_str(message.topic_name()) { + Some(jobs::Topic::NotifyNext) => { + let (execution_changed, _) = + serde_json_core::from_slice::>(&message.payload()) + .ok()?; + execution_changed.execution? + } + Some(jobs::Topic::DescribeAccepted(_)) => { + let (execution_changed, _) = serde_json_core::from_slice::< + DescribeJobExecutionResponse, + >(&message.payload()) + .ok()?; + execution_changed.execution? + } + _ => return None, + }; + + let ota_job = job.job_document?.ota_job()?; + + FileContext::new_from( + JobEventData { + job_name: job.job_id, + ota_document: ota_job, + status_details: job.status_details, + }, + 0, + config, + ) + .ok() +} + #[tokio::test(flavor = "current_thread")] async fn test_mqtt_ota() { env_logger::init(); @@ -68,6 +103,7 @@ async fn test_mqtt_ota() { let (mut stack, client) = embedded_mqtt::new(state, config, network); let client = make_static!(client); + let mut file_handler = FileHandler::new("tests/assets/ota_file".to_owned()); let ota_fut = async { let mut jobs_subscription = client @@ -95,52 +131,24 @@ async fn test_mqtt_ota() { Updater::check_for_job(client).await?; + let config = ota::config::Config::default(); while let Some(message) = jobs_subscription.next().await { - let config = ota::config::Config::default(); - if let Some(mut file_ctx) = match jobs::Topic::from_str(message.topic_name()) { - Some(jobs::Topic::NotifyNext) => { - let (execution_changed, _) = serde_json_core::from_slice::< - NextJobExecutionChanged, - >(&message.payload()) - .ok() - .unwrap(); - let job = execution_changed.execution.unwrap(); - let ota_job = job.job_document.unwrap().ota_job().unwrap(); - FileContext::new_from( - JobEventData { - job_name: job.job_id, - ota_document: ota_job, - status_details: job.status_details, - }, - 0, - &config, + if let Some(mut file_ctx) = handle_ota(message, &config) { + // We have an OTA job, leeeets go! + Updater::perform_ota(client, client, file_ctx.clone(), &mut file_handler, &config) + .await?; + + assert_eq!(file_handler.plateform_state, FileHandlerState::Swap); + + // Run it twice in this particular integration test, in order to simulate image commit after bootloader swap + file_ctx + .status_details + .insert( + heapless::String::try_from("self_test").unwrap(), + heapless::String::try_from("active").unwrap(), ) - .ok() - } - Some(jobs::Topic::DescribeAccepted(_)) => { - let (execution_changed, _) = serde_json_core::from_slice::< - DescribeJobExecutionResponse, - >(&message.payload()) .unwrap(); - let job = execution_changed.execution.unwrap(); - let ota_job = job.job_document.unwrap().ota_job().unwrap(); - FileContext::new_from( - JobEventData { - job_name: job.job_id, - ota_document: ota_job, - status_details: job.status_details, - }, - 0, - &config, - ) - .ok() - } - _ => None, - } { - drop(message); - // We have an OTA job, leeeets go! - let mut file_handler = FileHandler::new("tests/assets/ota_file".to_owned()); - Updater::perform_ota(client, client, file_ctx, &mut file_handler, config).await?; + Updater::perform_ota(client, client, file_ctx, &mut file_handler, &config).await?; return Ok(()); } @@ -149,10 +157,18 @@ async fn test_mqtt_ota() { Ok::<_, ota::error::OtaError>(()) }; - match select::select(stack.run(), ota_fut).await { + match embassy_time::with_timeout( + embassy_time::Duration::from_secs(25), + select::select(stack.run(), ota_fut), + ) + .await + .unwrap() + { select::Either::First(_) => { unreachable!() } select::Either::Second(result) => result.unwrap(), }; + + assert_eq!(file_handler.plateform_state, FileHandlerState::Boot); } From f3742a46d7b2df22715816128382174a4db76b46 Mon Sep 17 00:00:00 2001 From: Mathias Date: Thu, 18 Jan 2024 12:53:30 +0100 Subject: [PATCH 11/36] Use stable 1.75 --- Cargo.toml | 4 ++-- rust-toolchain.toml | 2 +- src/lib.rs | 2 -- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 8ceb109..9149e4f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,7 +32,7 @@ embedded-storage-async = "0.4" embedded-mqtt = { path = "../embedded-mqtt" } futures = { version = "0.3.28", default-features = false } -embassy-time = { version = "0.2" } +embassy-time = { version = "0.3" } embassy-sync = "0.5" embassy-futures = "0.1" @@ -48,7 +48,7 @@ static_cell = { version = "2", features = ["nightly"]} tokio = { version = "1.33", default-features = false, features = ["macros", "rt", "net", "time", "io-std"] } tokio-native-tls = { version = "0.3.1" } embassy-futures = { version = "0.1.0" } -embassy-time = { version = "0.2", features = ["log", "std", "generic-queue"] } +embassy-time = { version = "0.3", features = ["log", "std", "generic-queue"] } embedded-io-adapters = { version = "0.6.0", features = ["tokio-1"] } ecdsa = { version = "0.16", features = ["pkcs8", "pem"] } diff --git a/rust-toolchain.toml b/rust-toolchain.toml index ff76255..c2f08bb 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,5 +1,5 @@ [toolchain] -channel = "nightly-2023-12-24" +channel = "1.75" components = [ "rust-src", "rustfmt", "llvm-tools", "clippy" ] targets = [ "x86_64-unknown-linux-gnu", diff --git a/src/lib.rs b/src/lib.rs index eb9c203..b121586 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,5 @@ #![cfg_attr(not(any(test, feature = "std")), no_std)] -#![allow(incomplete_features)] #![allow(async_fn_in_trait)] -#![feature(generic_const_exprs)] // This mod MUST go first, so that the others see its macros. pub(crate) mod fmt; From 505bb2cb8ac3cb02a63e835eaffbaf9dac364373 Mon Sep 17 00:00:00 2001 From: Mathias Date: Thu, 18 Jan 2024 12:55:43 +0100 Subject: [PATCH 12/36] use git embedded-mqtt --- Cargo.toml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 9149e4f..a017314 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,7 +29,9 @@ serde_cbor = { version = "0.11", default-features = false, optional = true } serde-json-core = { version = "0.5" } shadow-derive = { path = "shadow_derive", version = "0.2.1" } embedded-storage-async = "0.4" -embedded-mqtt = { path = "../embedded-mqtt" } +embedded-mqtt = { git = "ssh://git@github.com/BlackbirdHQ/embedded-mqtt/", rev = "1396855", features = ["defmt"] } +# embedded-mqtt = { path = "../embedded-mqtt" } + futures = { version = "0.3.28", default-features = false } embassy-time = { version = "0.3" } From f225c96150614ca8a3b08a8e4480888c9abc2d9a Mon Sep 17 00:00:00 2001 From: Mathias Date: Thu, 18 Jan 2024 13:00:59 +0100 Subject: [PATCH 13/36] Fix defmt feature --- Cargo.toml | 2 +- src/ota/data_interface/mod.rs | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index a017314..17c16f3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,7 +29,7 @@ serde_cbor = { version = "0.11", default-features = false, optional = true } serde-json-core = { version = "0.5" } shadow-derive = { path = "shadow_derive", version = "0.2.1" } embedded-storage-async = "0.4" -embedded-mqtt = { git = "ssh://git@github.com/BlackbirdHQ/embedded-mqtt/", rev = "1396855", features = ["defmt"] } +embedded-mqtt = { git = "ssh://git@github.com/BlackbirdHQ/embedded-mqtt/", rev = "6a67789", features = ["defmt"] } # embedded-mqtt = { path = "../embedded-mqtt" } futures = { version = "0.3.28", default-features = false } diff --git a/src/ota/data_interface/mod.rs b/src/ota/data_interface/mod.rs index 3c65af7..3723c69 100644 --- a/src/ota/data_interface/mod.rs +++ b/src/ota/data_interface/mod.rs @@ -12,6 +12,7 @@ use crate::ota::config::Config; use super::{encoding::FileContext, error::OtaError}; #[derive(Debug, Clone, PartialEq, Deserialize)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] pub enum Protocol { #[serde(rename = "MQTT")] Mqtt, From 7edb08cb601bc0414dfcebed54ebd52ccccfd300 Mon Sep 17 00:00:00 2001 From: Mathias Date: Fri, 15 Mar 2024 15:13:57 +0100 Subject: [PATCH 14/36] Add support for provisioning by CSR to the Fleet Provisioner --- src/provisioning/mod.rs | 175 +++++++++++++++++++++++++++++++++------- 1 file changed, 147 insertions(+), 28 deletions(-) diff --git a/src/provisioning/mod.rs b/src/provisioning/mod.rs index 507fc02..a4faf53 100644 --- a/src/provisioning/mod.rs +++ b/src/provisioning/mod.rs @@ -14,6 +14,7 @@ use serde::{de::DeserializeOwned, Deserialize}; pub use error::Error; +use self::data_types::CreateCertificateFromCsrRequest; use self::{ data_types::{ CreateKeysAndCertificateResponse, ErrorResponse, RegisterThingRequest, @@ -52,6 +53,28 @@ impl FleetProvisioner { mqtt, template_name, parameters, + None, + credential_handler, + PayloadFormat::Json, + ) + .await + } + + pub async fn provision_csr<'a, C, M: RawMutex, const SUBS: usize>( + mqtt: &'a embedded_mqtt::MqttClient<'a, M, SUBS>, + template_name: &str, + parameters: Option, + csr: &str, + credential_handler: &mut impl CredentialHandler, + ) -> Result, Error> + where + C: DeserializeOwned, + { + Self::provision_inner( + mqtt, + template_name, + parameters, + Some(csr), credential_handler, PayloadFormat::Json, ) @@ -72,6 +95,29 @@ impl FleetProvisioner { mqtt, template_name, parameters, + None, + credential_handler, + PayloadFormat::Cbor, + ) + .await + } + + #[cfg(feature = "provision_cbor")] + pub async fn provision_csr_cbor<'a, C, M: RawMutex, const SUBS: usize>( + mqtt: &'a embedded_mqtt::MqttClient<'a, M, SUBS>, + template_name: &str, + parameters: Option, + csr: &str, + credential_handler: &mut impl CredentialHandler, + ) -> Result, Error> + where + C: DeserializeOwned, + { + Self::provision_inner( + mqtt, + template_name, + parameters, + Some(csr), credential_handler, PayloadFormat::Cbor, ) @@ -83,13 +129,16 @@ impl FleetProvisioner { mqtt: &'a embedded_mqtt::MqttClient<'a, M, SUBS>, template_name: &str, parameters: Option, + csr: Option<&str>, credential_handler: &mut impl CredentialHandler, payload_format: PayloadFormat, ) -> Result, Error> where C: DeserializeOwned, { - let mut create_subscription = Self::begin(mqtt, payload_format).await?; + use crate::provisioning::data_types::CreateCertificateFromCsrResponse; + + let mut create_subscription = Self::begin(mqtt, csr, payload_format).await?; let mut message = create_subscription .next() @@ -114,13 +163,33 @@ impl FleetProvisioner { response.certificate_ownership_token } + Some(Topic::CreateCertificateFromCsrAccepted(format)) => { + let response = Self::deserialize::( + format, + &mut message, + )?; + + credential_handler + .store_credentials(Credentials { + certificate_id: response.certificate_id, + certificate_pem: response.certificate_pem, + private_key: None, + }) + .await?; + + response.certificate_ownership_token + } + // Error happened! - Some(Topic::CreateKeysAndCertificateRejected(format)) => { + Some( + Topic::CreateKeysAndCertificateRejected(format) + | Topic::CreateCertificateFromCsrRejected(format), + ) => { return Err(Self::handle_error(format, message).unwrap_err()); } t => { - trace!("{:?}", t); + warn!("Got unexpected packet on topic {:?}", t); return Err(Error::InvalidState); } @@ -207,36 +276,86 @@ impl FleetProvisioner { async fn begin<'a, M: RawMutex, const SUBS: usize>( mqtt: &'a embedded_mqtt::MqttClient<'a, M, SUBS>, + csr: Option<&str>, payload_format: PayloadFormat, - ) -> Result, Error> { - let subscription = mqtt - .subscribe::<1>(Subscribe::new(&[SubscribeTopic { - topic_path: Topic::CreateKeysAndCertificateAny(payload_format) - .format::<31>()? + ) -> Result, Error> { + if let Some(csr) = csr { + let request = CreateCertificateFromCsrRequest { + certificate_signing_request: csr, + }; + + // FIXME: Serialize directly into the publish payload through `DeferredPublish` API + let payload = &mut [0u8; 1024]; + + let payload_len = match payload_format { + #[cfg(feature = "provision_cbor")] + PayloadFormat::Cbor => { + let mut serializer = + serde_cbor::ser::Serializer::new(serde_cbor::ser::SliceWrite::new(payload)); + request.serialize(&mut serializer)?; + serializer.into_inner().bytes_written() + } + PayloadFormat::Json => serde_json_core::to_slice(&request, payload)?, + }; + + let subscription = mqtt + .subscribe::<1>(Subscribe::new(&[SubscribeTopic { + topic_path: Topic::CreateCertificateFromCsrAny(payload_format) + .format::<40>()? + .as_str(), + maximum_qos: QoS::AtLeastOnce, + no_local: false, + retain_as_published: false, + retain_handling: RetainHandling::SendAtSubscribeTime, + }])) + .await + .map_err(|_| Error::Mqtt)?; + + mqtt.publish(Publish { + dup: false, + qos: QoS::AtLeastOnce, + retain: false, + pid: None, + topic_name: Topic::CreateCertificateFromCsr(payload_format) + .format::<38>()? .as_str(), - maximum_qos: QoS::AtLeastOnce, - no_local: false, - retain_as_published: false, - retain_handling: RetainHandling::SendAtSubscribeTime, - }])) + payload: &payload[..payload_len], + properties: embedded_mqtt::Properties::Slice(&[]), + }) .await .map_err(|_| Error::Mqtt)?; - mqtt.publish(Publish { - dup: false, - qos: QoS::AtLeastOnce, - retain: false, - pid: None, - topic_name: Topic::CreateKeysAndCertificate(payload_format) - .format::<29>()? - .as_str(), - payload: b"", - properties: embedded_mqtt::Properties::Slice(&[]), - }) - .await - .map_err(|_| Error::Mqtt)?; + Ok(subscription) + } else { + let subscription = mqtt + .subscribe::<1>(Subscribe::new(&[SubscribeTopic { + topic_path: Topic::CreateKeysAndCertificateAny(payload_format) + .format::<31>()? + .as_str(), + maximum_qos: QoS::AtLeastOnce, + no_local: false, + retain_as_published: false, + retain_handling: RetainHandling::SendAtSubscribeTime, + }])) + .await + .map_err(|_| Error::Mqtt)?; + + mqtt.publish(Publish { + dup: false, + qos: QoS::AtLeastOnce, + retain: false, + pid: None, + topic_name: Topic::CreateKeysAndCertificate(payload_format) + .format::<29>()? + .as_str(), + payload: b"", + properties: embedded_mqtt::Properties::Slice(&[]), + }) + .await + .map_err(|_| Error::Mqtt)?; - Ok(subscription) + Ok(subscription) + } } fn deserialize<'a, R: Deserialize<'a>, M: RawMutex, const SUBS: usize>( @@ -244,7 +363,7 @@ impl FleetProvisioner { message: &'a mut Message<'_, M, SUBS>, ) -> Result { trace!( - "Topic::CreateKeysAndCertificateAccepted {:?}. Payload len: {:?}", + "Accepted Topic {:?}. Payload len: {:?}", payload_format, message.payload().len() ); From e225b1f795cd063bddeb1ffb9ed62b998a6a166e Mon Sep 17 00:00:00 2001 From: Mathias Date: Fri, 31 May 2024 14:50:06 +0200 Subject: [PATCH 15/36] Update embedded-mqtt dependency --- Cargo.toml | 18 ++++++++++----- rust-toolchain.toml | 2 +- src/ota/config.rs | 10 ++++----- src/ota/data_interface/mqtt.rs | 9 ++++++-- src/ota/mod.rs | 2 ++ src/provisioning/error.rs | 7 +----- src/provisioning/mod.rs | 40 +++++++++++++++++++--------------- tests/ota.rs | 30 +++++++++++++++++-------- tests/provisioning.rs | 23 +++++++++++-------- 9 files changed, 86 insertions(+), 55 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 17c16f3..6d3d5f0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,7 +29,9 @@ serde_cbor = { version = "0.11", default-features = false, optional = true } serde-json-core = { version = "0.5" } shadow-derive = { path = "shadow_derive", version = "0.2.1" } embedded-storage-async = "0.4" -embedded-mqtt = { git = "ssh://git@github.com/BlackbirdHQ/embedded-mqtt/", rev = "6a67789", features = ["defmt"] } +embedded-mqtt = { git = "ssh://git@github.com/BlackbirdHQ/embedded-mqtt/", rev = "f8d4ee8", features = [ + "defmt", +] } # embedded-mqtt = { path = "../embedded-mqtt" } futures = { version = "0.3.28", default-features = false } @@ -44,10 +46,16 @@ defmt = { version = "0.3", optional = true } [dev-dependencies] native-tls = { version = "0.2" } embedded-nal-async = "0.7" -env_logger = "0.10" +env_logger = "0.11" sha2 = "0.10.1" -static_cell = { version = "2", features = ["nightly"]} -tokio = { version = "1.33", default-features = false, features = ["macros", "rt", "net", "time", "io-std"] } +static_cell = { version = "2", features = ["nightly"] } +tokio = { version = "1.33", default-features = false, features = [ + "macros", + "rt", + "net", + "time", + "io-std", +] } tokio-native-tls = { version = "0.3.1" } embassy-futures = { version = "0.1.0" } embassy-time = { version = "0.3", features = ["log", "std", "generic-queue"] } @@ -69,4 +77,4 @@ ota_http_data = [] std = ["serde/std", "serde_cbor?/std"] defmt = ["dep:defmt", "heapless/defmt-03", "embedded-mqtt/defmt"] -log = ["dep:log", "embedded-mqtt/log", ] +log = ["dep:log", "embedded-mqtt/log"] diff --git a/rust-toolchain.toml b/rust-toolchain.toml index c2f08bb..17dc494 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,5 +1,5 @@ [toolchain] -channel = "1.75" +channel = "1.78" components = [ "rust-src", "rustfmt", "llvm-tools", "clippy" ] targets = [ "x86_64-unknown-linux-gnu", diff --git a/src/ota/config.rs b/src/ota/config.rs index 0d69854..39fe685 100644 --- a/src/ota/config.rs +++ b/src/ota/config.rs @@ -1,11 +1,11 @@ use embassy_time::Duration; pub struct Config { - pub(crate) block_size: usize, - pub(crate) max_request_momentum: u8, - pub(crate) request_wait: Duration, - pub(crate) status_update_frequency: u32, - pub(crate) self_test_timeout: Option, + pub block_size: usize, + pub max_request_momentum: u8, + pub request_wait: Duration, + pub status_update_frequency: u32, + pub self_test_timeout: Option, } impl Default for Config { diff --git a/src/ota/data_interface/mqtt.rs b/src/ota/data_interface/mqtt.rs index 077cb59..78caed4 100644 --- a/src/ota/data_interface/mqtt.rs +++ b/src/ota/data_interface/mqtt.rs @@ -144,7 +144,7 @@ impl<'a, M: RawMutex, const SUBS: usize> DataInterface for MqttClient<'a, M, SUB let topic = SubscribeTopic { topic_path: topic_path.as_str(), - maximum_qos: embedded_mqtt::QoS::AtLeastOnce, + maximum_qos: embedded_mqtt::QoS::AtMostOnce, no_local: false, retain_as_published: false, retain_handling: RetainHandling::SendAtSubscribeTime, @@ -180,9 +180,14 @@ impl<'a, M: RawMutex, const SUBS: usize> DataInterface for MqttClient<'a, M, SUB ) .map_err(|_| OtaError::Encoding)?; + debug!( + "Requesting more file blocks. Remaining: {}", + file_ctx.request_block_remaining + ); + self.publish(Publish { dup: false, - qos: embedded_mqtt::QoS::AtMostOnce, + qos: embedded_mqtt::QoS::AtLeastOnce, retain: false, pid: None, topic_name: OtaTopic::Get(Encoding::Cbor, file_ctx.stream_name.as_str()) diff --git a/src/ota/mod.rs b/src/ota/mod.rs index 2e381fc..04fe2d3 100644 --- a/src/ota/mod.rs +++ b/src/ota/mod.rs @@ -225,6 +225,8 @@ impl Updater { let data_fut = async { data.request_file_block(&mut file_ctx, &config).await?; + info!("Awaiting file blocks!"); + while let Ok(mut payload) = subscription.next_block().await { debug!("process_data_handler"); // Decode the file block received diff --git a/src/provisioning/error.rs b/src/provisioning/error.rs index d0a128c..deb5725 100644 --- a/src/provisioning/error.rs +++ b/src/provisioning/error.rs @@ -6,15 +6,10 @@ pub enum Error { Mqtt, DeserializeJson(serde_json_core::de::Error), DeserializeCbor, + CertificateStorage, Response(u16), } -// impl From for Error { -// fn from(e: MqttError) -> Self { -// Self::Mqtt(e) -// } -// } - impl From for Error { fn from(_: serde_json_core::ser::Error) -> Self { Self::Overflow diff --git a/src/provisioning/mod.rs b/src/provisioning/mod.rs index a4faf53..2ed8349 100644 --- a/src/provisioning/mod.rs +++ b/src/provisioning/mod.rs @@ -27,7 +27,7 @@ pub trait CredentialHandler { fn store_credentials( &mut self, credentials: Credentials<'_>, - ) -> impl Future> + Send; + ) -> impl Future>; } #[derive(Debug)] @@ -104,7 +104,7 @@ impl FleetProvisioner { #[cfg(feature = "provision_cbor")] pub async fn provision_csr_cbor<'a, C, M: RawMutex, const SUBS: usize>( - mqtt: &'a embedded_mqtt::MqttClient<'a, M, SUBS>, + mqtt: &embedded_mqtt::MqttClient<'a, M, SUBS>, template_name: &str, parameters: Option, csr: &str, @@ -126,7 +126,7 @@ impl FleetProvisioner { #[cfg(feature = "provision_cbor")] async fn provision_inner<'a, C, M: RawMutex, const SUBS: usize>( - mqtt: &'a embedded_mqtt::MqttClient<'a, M, SUBS>, + mqtt: &embedded_mqtt::MqttClient<'a, M, SUBS>, template_name: &str, parameters: Option, csr: Option<&str>, @@ -147,7 +147,7 @@ impl FleetProvisioner { let ownership_token = match Topic::from_str(message.topic_name()) { Some(Topic::CreateKeysAndCertificateAccepted(format)) => { - let response = Self::deserialize::( + let response = Self::deserialize::( format, &mut message, )?; @@ -164,7 +164,7 @@ impl FleetProvisioner { } Some(Topic::CreateCertificateFromCsrAccepted(format)) => { - let response = Self::deserialize::( + let response = Self::deserialize::( format, &mut message, )?; @@ -230,7 +230,10 @@ impl FleetProvisioner { retain_handling: RetainHandling::SendAtSubscribeTime, }])) .await - .map_err(|_| Error::Mqtt)?; + .map_err(|e| { + error!("Failed subscription to RegisterThingAny! {}", e); + Error::Mqtt + })?; mqtt.publish(Publish { dup: false, @@ -244,7 +247,10 @@ impl FleetProvisioner { properties: embedded_mqtt::Properties::Slice(&[]), }) .await - .map_err(|_| Error::Mqtt)?; + .map_err(|e| { + error!("Failed publish to RegisterThing! {}", e); + Error::Mqtt + })?; let mut message = register_subscription .next() @@ -253,10 +259,8 @@ impl FleetProvisioner { match Topic::from_str(message.topic_name()) { Some(Topic::RegisterThingAccepted(_, format)) => { - let response = Self::deserialize::, M, SUBS>( - format, - &mut message, - )?; + let response = + Self::deserialize::, SUBS>(format, &mut message)?; Ok(response.device_configuration) } @@ -274,11 +278,11 @@ impl FleetProvisioner { } } - async fn begin<'a, M: RawMutex, const SUBS: usize>( - mqtt: &'a embedded_mqtt::MqttClient<'a, M, SUBS>, + async fn begin<'a, 'b, M: RawMutex, const SUBS: usize>( + mqtt: &'b embedded_mqtt::MqttClient<'a, M, SUBS>, csr: Option<&str>, payload_format: PayloadFormat, - ) -> Result, Error> { + ) -> Result, Error> { if let Some(csr) = csr { let request = CreateCertificateFromCsrRequest { certificate_signing_request: csr, @@ -358,9 +362,9 @@ impl FleetProvisioner { } } - fn deserialize<'a, R: Deserialize<'a>, M: RawMutex, const SUBS: usize>( + fn deserialize<'a, R: Deserialize<'a>, const SUBS: usize>( payload_format: PayloadFormat, - message: &'a mut Message<'_, M, SUBS>, + message: &'a mut Message<'_, SUBS>, ) -> Result { trace!( "Accepted Topic {:?}. Payload len: {:?}", @@ -375,9 +379,9 @@ impl FleetProvisioner { }) } - fn handle_error( + fn handle_error( format: PayloadFormat, - mut message: Message<'_, M, SUBS>, + mut message: Message<'_, SUBS>, ) -> Result<(), Error> { error!(">> {:?}", message.topic_name()); diff --git a/tests/ota.rs b/tests/ota.rs index 60bdb3b..ac82f18 100644 --- a/tests/ota.rs +++ b/tests/ota.rs @@ -11,6 +11,7 @@ use common::network::TlsNetwork; use embassy_futures::select; use embassy_sync::blocking_mutex::raw::{NoopRawMutex, RawMutex}; use embassy_time::Duration; +use embedded_mqtt::transport::embedded_nal::NalTransport; use embedded_mqtt::{ Config, DomainBroker, IpBroker, Message, Publish, QoS, RetainHandling, State, Subscribe, SubscribeTopic, @@ -91,7 +92,9 @@ async fn test_mqtt_ota() { let (thing_name, identity) = credentials::identity(); let hostname = credentials::HOSTNAME.unwrap(); - let network = make_static!(TlsNetwork::new(hostname.to_owned(), identity)); + + static NETWORK: StaticCell = StaticCell::new(); + let network = NETWORK.init(TlsNetwork::new(hostname.to_owned(), identity)); // Create the MQTT stack let broker = @@ -99,10 +102,10 @@ async fn test_mqtt_ota() { let config = Config::new(thing_name, broker).keepalive_interval(embassy_time::Duration::from_secs(50)); - let state = make_static!(State::::new()); - let (mut stack, client) = embedded_mqtt::new(state, config, network); + static STATE: StaticCell> = StaticCell::new(); + let state = STATE.init(State::::new()); + let (mut stack, client) = embedded_mqtt::new(state, config); - let client = make_static!(client); let mut file_handler = FileHandler::new("tests/assets/ota_file".to_owned()); let ota_fut = async { @@ -129,14 +132,20 @@ async fn test_mqtt_ota() { ])) .await?; - Updater::check_for_job(client).await?; + Updater::check_for_job(&client).await?; let config = ota::config::Config::default(); while let Some(message) = jobs_subscription.next().await { if let Some(mut file_ctx) = handle_ota(message, &config) { // We have an OTA job, leeeets go! - Updater::perform_ota(client, client, file_ctx.clone(), &mut file_handler, &config) - .await?; + Updater::perform_ota( + &client, + &client, + file_ctx.clone(), + &mut file_handler, + &config, + ) + .await?; assert_eq!(file_handler.plateform_state, FileHandlerState::Swap); @@ -148,7 +157,8 @@ async fn test_mqtt_ota() { heapless::String::try_from("active").unwrap(), ) .unwrap(); - Updater::perform_ota(client, client, file_ctx, &mut file_handler, &config).await?; + Updater::perform_ota(&client, &client, file_ctx, &mut file_handler, &config) + .await?; return Ok(()); } @@ -157,9 +167,11 @@ async fn test_mqtt_ota() { Ok::<_, ota::error::OtaError>(()) }; + let mut transport = NalTransport::new(network); + match embassy_time::with_timeout( embassy_time::Duration::from_secs(25), - select::select(stack.run(), ota_fut), + select::select(stack.run(&mut transport), ota_fut), ) .await .unwrap() diff --git a/tests/provisioning.rs b/tests/provisioning.rs index bdc6d86..804d2c5 100644 --- a/tests/provisioning.rs +++ b/tests/provisioning.rs @@ -10,7 +10,10 @@ use common::network::TlsNetwork; use ecdsa::Signature; use embassy_futures::select; use embassy_sync::blocking_mutex::raw::NoopRawMutex; -use embedded_mqtt::{Config, DomainBroker, IpBroker, Publish, State, Subscribe, SubscribeTopic}; +use embedded_mqtt::{ + transport::embedded_nal::NalTransport, Config, DomainBroker, IpBroker, Publish, State, + Subscribe, SubscribeTopic, +}; use p256::{ecdsa::signature::Signer, NistP256}; use rustot::provisioning::{ topics::Topic, CredentialHandler, Credentials, Error, FleetProvisioner, @@ -73,7 +76,8 @@ async fn test_provisioning() { let template_name = std::env::var("TEMPLATE_NAME").unwrap_or_else(|_| "duoProvisioningTemplate".to_string()); - let network = make_static!(TlsNetwork::new(hostname.to_owned(), claim_identity)); + static NETWORK: StaticCell = StaticCell::new(); + let network = NETWORK.init(TlsNetwork::new(hostname.to_owned(), claim_identity)); // Create the MQTT stack let broker = @@ -81,10 +85,9 @@ async fn test_provisioning() { let config = Config::new(thing_name, broker).keepalive_interval(embassy_time::Duration::from_secs(50)); - let state = make_static!(State::::new()); - let (mut stack, client) = embedded_mqtt::new(state, config, network); - - let client = make_static!(client); + static STATE: StaticCell> = StaticCell::new(); + let state = STATE.init(State::::new()); + let (mut stack, client) = embedded_mqtt::new(state, config); let signing_key = credentials::signing_key(); let signature: Signature = signing_key.sign(thing_name.as_bytes()); @@ -99,22 +102,24 @@ async fn test_provisioning() { #[cfg(not(feature = "provision_cbor"))] let provision_fut = FleetProvisioner::provision::( - client, + &client, &template_name, Some(parameters), &mut credential_handler, ); #[cfg(feature = "provision_cbor")] let provision_fut = FleetProvisioner::provision_cbor::( - client, + &client, &template_name, Some(parameters), &mut credential_handler, ); + let mut transport = NalTransport::new(network); + let device_config = match embassy_time::with_timeout( embassy_time::Duration::from_secs(15), - select::select(stack.run(), provision_fut), + select::select(stack.run(&mut transport), provision_fut), ) .await .unwrap() From 20ece31feb2804310f89d05831fcb248dbf0a07c Mon Sep 17 00:00:00 2001 From: Mathias Date: Tue, 11 Jun 2024 10:13:12 +0200 Subject: [PATCH 16/36] Bump embedded-mqtt, and use DeferredPayload --- Cargo.toml | 5 +- src/ota/control_interface/mqtt.rs | 7 +- src/ota/data_interface/mqtt.rs | 38 +++++---- src/ota/error.rs | 4 + src/ota/mod.rs | 137 ++++++++++++++++-------------- src/ota/pal.rs | 25 ++---- src/provisioning/mod.rs | 77 ++++++++++------- 7 files changed, 156 insertions(+), 137 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 6d3d5f0..7619b16 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,15 +29,14 @@ serde_cbor = { version = "0.11", default-features = false, optional = true } serde-json-core = { version = "0.5" } shadow-derive = { path = "shadow_derive", version = "0.2.1" } embedded-storage-async = "0.4" -embedded-mqtt = { git = "ssh://git@github.com/BlackbirdHQ/embedded-mqtt/", rev = "f8d4ee8", features = [ +embedded-mqtt = { git = "ssh://git@github.com/BlackbirdHQ/embedded-mqtt/", rev = "dbf8af0", features = [ "defmt", ] } -# embedded-mqtt = { path = "../embedded-mqtt" } futures = { version = "0.3.28", default-features = false } embassy-time = { version = "0.3" } -embassy-sync = "0.5" +embassy-sync = "0.6" embassy-futures = "0.1" log = { version = "0.4", default-features = false, optional = true } diff --git a/src/ota/control_interface/mqtt.rs b/src/ota/control_interface/mqtt.rs index 495064a..4259db2 100644 --- a/src/ota/control_interface/mqtt.rs +++ b/src/ota/control_interface/mqtt.rs @@ -84,8 +84,11 @@ impl<'a, M: RawMutex, const SUBS: usize> ControlInterface } // Downgrade progress updates to QOS 0 to avoid overloading MQTT - // buffers during active streaming - if status == JobStatus::InProgress { + // buffers during active streaming. But make sure to always send and await ack for first update and last update + if status == JobStatus::InProgress + && file_ctx.blocks_remaining != 0 + && received_blocks != 0 + { qos = QoS::AtMostOnce; } } diff --git a/src/ota/data_interface/mqtt.rs b/src/ota/data_interface/mqtt.rs index 78caed4..1f32995 100644 --- a/src/ota/data_interface/mqtt.rs +++ b/src/ota/data_interface/mqtt.rs @@ -4,7 +4,8 @@ use core::str::FromStr; use embassy_sync::blocking_mutex::raw::RawMutex; use embedded_mqtt::{ - MqttClient, Properties, Publish, RetainHandling, Subscribe, SubscribeTopic, Subscription, + DeferredPayload, EncodingError, MqttClient, Properties, Publish, RetainHandling, Subscribe, + SubscribeTopic, Subscription, }; use futures::StreamExt; @@ -163,22 +164,25 @@ impl<'a, M: RawMutex, const SUBS: usize> DataInterface for MqttClient<'a, M, SUB ) -> Result<(), OtaError> { file_ctx.request_block_remaining = file_ctx.bitmap.len() as u32; - // FIXME: Serialize directly into the publish payload through `DeferredPublish` API - let buf = &mut [0u8; 32]; - let len = cbor::to_slice( - &cbor::GetStreamRequest { - // Arbitrary client token sent in the stream "GET" message - client_token: None, - stream_version: None, - file_id: file_ctx.fileid, - block_size: config.block_size, - block_offset: Some(file_ctx.block_offset), - block_bitmap: Some(&file_ctx.bitmap), - number_of_blocks: None, + let payload = DeferredPayload::new( + |buf| { + cbor::to_slice( + &cbor::GetStreamRequest { + // Arbitrary client token sent in the stream "GET" message + client_token: None, + stream_version: None, + file_id: file_ctx.fileid, + block_size: config.block_size, + block_offset: Some(file_ctx.block_offset), + block_bitmap: Some(&file_ctx.bitmap), + number_of_blocks: None, + }, + buf, + ) + .map_err(|e| EncodingError::BufferSize) }, - buf, - ) - .map_err(|_| OtaError::Encoding)?; + 32, + ); debug!( "Requesting more file blocks. Remaining: {}", @@ -193,7 +197,7 @@ impl<'a, M: RawMutex, const SUBS: usize> DataInterface for MqttClient<'a, M, SUB topic_name: OtaTopic::Get(Encoding::Cbor, file_ctx.stream_name.as_str()) .format::<{ MAX_STREAM_ID_LEN + MAX_THING_NAME_LEN + 30 }>(self.client_id())? .as_str(), - payload: &buf[..len], + payload, properties: Properties::Slice(&[]), }) .await?; diff --git a/src/ota/error.rs b/src/ota/error.rs index 3dd7101..8fdf6e6 100644 --- a/src/ota/error.rs +++ b/src/ota/error.rs @@ -15,6 +15,10 @@ pub enum OtaError { ZeroFileSize, Overflow, InvalidFile, + Write( + #[cfg_attr(feature = "defmt", defmt(Debug2Format))] + embedded_storage_async::nor_flash::NorFlashErrorKind, + ), Mqtt(embedded_mqtt::Error), Encoding, Pal, diff --git a/src/ota/mod.rs b/src/ota/mod.rs index 04fe2d3..c962b54 100644 --- a/src/ota/mod.rs +++ b/src/ota/mod.rs @@ -44,6 +44,7 @@ use core::{ #[cfg(feature = "ota_mqtt_data")] pub use data_interface::mqtt::{Encoding, Topic}; +use embedded_storage_async::nor_flash::{NorFlash, NorFlashError as _}; use crate::{ jobs::data_types::JobStatus, @@ -144,6 +145,9 @@ impl Updater { // (this should also cause the image to be erased), aborting the job // and reset the device. error!("Rejecting new image and rebooting: the job is in the self-test state while the platform is not."); + // loop { + // embassy_time::Timer::after_secs(1).await; + // } Self::set_image_state_with_reason( control, pal, @@ -172,26 +176,6 @@ impl Updater { info!("Job document was accepted. Attempting to begin the update"); - // Create/Open the OTA file on the file system - if let Err(e) = pal.create_file_for_rx(&file_ctx).await { - Self::set_image_state_with_reason( - control, - pal, - &config, - &mut file_ctx, - ImageState::Aborted(ImageStateReason::Pal(e)), - ) - .await?; - - pal.close_file(&file_ctx).await?; - return Err(e.into()); - } - - // Prepare the storage layer on receiving a new file - let mut subscription = data.init_file_transfer(&mut file_ctx).await?; - - info!("Initialized file handler! Requesting file blocks"); - let request_momentum = AtomicU8::new(0); // FIXME: @@ -223,6 +207,29 @@ impl Updater { // }; let data_fut = async { + // Create/Open the OTA file on the file system + let block_writer = match pal.create_file_for_rx(&file_ctx).await { + Ok(block_writer) => block_writer, + Err(e) => { + Self::set_image_state_with_reason( + control, + pal, + &config, + &mut file_ctx, + ImageState::Aborted(ImageStateReason::Pal(e)), + ) + .await?; + + pal.close_file(&file_ctx).await?; + return Err(e.into()); + } + }; + + info!("Initialized file handler! Requesting file blocks"); + + // Prepare the storage layer on receiving a new file + let mut subscription = data.init_file_transfer(&mut file_ctx).await?; + data.request_file_block(&mut file_ctx, &config).await?; info!("Awaiting file blocks!"); @@ -232,51 +239,48 @@ impl Updater { // Decode the file block received match Self::ingest_data_block( data, - pal, + block_writer, &config, &mut file_ctx, payload.deref_mut(), ) .await { - Ok(true) => { - // File is completed! Update progress accordingly. - match pal.close_file(&file_ctx).await { - Err(e) => { - control - .update_job_status( - &mut file_ctx, - &config, - JobStatus::Failed, - JobStatusReason::Pal(0), - ) - .await?; - - return Err(e.into()); - } - Ok(_) => { - let (status, reason, event) = if let Some(0) = file_ctx.file_type { - ( - JobStatus::InProgress, - JobStatusReason::SigCheckPassed, - pal::OtaEvent::Activate, - ) - } else { - ( - JobStatus::Succeeded, - JobStatusReason::Accepted, - pal::OtaEvent::UpdateComplete, - ) - }; - - control - .update_job_status(&mut file_ctx, &config, status, reason) - .await?; - - return Ok(event); - } + Ok(true) => match pal.close_file(&file_ctx).await { + Err(e) => { + control + .update_job_status( + &mut file_ctx, + &config, + JobStatus::Failed, + JobStatusReason::Pal(0), + ) + .await?; + + return Err(e.into()); } - } + Ok(_) => { + let (status, reason, event) = if let Some(0) = file_ctx.file_type { + ( + JobStatus::InProgress, + JobStatusReason::SigCheckPassed, + pal::OtaEvent::Activate, + ) + } else { + ( + JobStatus::Succeeded, + JobStatusReason::Accepted, + pal::OtaEvent::UpdateComplete, + ) + }; + + control + .update_job_status(&mut file_ctx, &config, status, reason) + .await?; + + return Ok(event); + } + }, Ok(false) => { debug!("Ingested one block!"); // Reset the momentum counter since we received a good block @@ -355,9 +359,9 @@ impl Updater { Ok(()) } - async fn ingest_data_block<'a, D: DataInterface, PAL: pal::OtaPal>( + async fn ingest_data_block<'a, D: DataInterface>( data: &D, - pal: &mut PAL, + block_writer: &mut impl NorFlash, config: &config::Config, file_ctx: &mut FileContext, payload: &mut [u8], @@ -383,12 +387,13 @@ impl Updater { block.block_id, file_ctx.blocks_remaining ); - pal.write_block( - file_ctx, - block.block_id * config.block_size, - block.block_payload, - ) - .await?; + block_writer + .write( + (block.block_id * config.block_size) as u32, + block.block_payload, + ) + .await + .map_err(|e| error::OtaError::Write(e.kind()))?; let block_offset = file_ctx.block_offset; file_ctx diff --git a/src/ota/pal.rs b/src/ota/pal.rs index 06c6d64..7764e50 100644 --- a/src/ota/pal.rs +++ b/src/ota/pal.rs @@ -1,4 +1,6 @@ //! Platform abstraction trait for OTA updates +use embedded_storage_async::nor_flash::NorFlash; + use super::encoding::FileContext; #[derive(Debug, Clone, Copy)] @@ -68,6 +70,8 @@ pub enum OtaEvent { /// Platform abstraction layer for OTA jobs pub trait OtaPal { + type BlockWriter: NorFlash; + /// OTA abort. /// /// The user may register a callback function when initializing the OTA @@ -98,7 +102,10 @@ pub trait OtaPal { /// is created. /// /// - `file`: [`FileContext`] File description of the job being aborted - async fn create_file_for_rx(&mut self, file: &FileContext) -> Result<(), OtaPalError>; + async fn create_file_for_rx( + &mut self, + file: &FileContext, + ) -> Result<&mut Self::BlockWriter, OtaPalError>; /// Get the state of the OTA update image. /// @@ -155,22 +162,6 @@ pub trait OtaPal { /// error code. async fn close_file(&mut self, file: &FileContext) -> Result<(), OtaPalError>; - /// Write a block of data to the specified file at the given offset. - /// - /// - `file`: [`FileContext`] File description of the job being aborted. - /// - `block_offset`: Byte offset to write to from the beginning of the - /// file. - /// - `block_payload`: Byte array of data to write. - /// - /// **return** The number of bytes written on a success, or a negative error - /// code from the platform abstraction layer. - async fn write_block( - &mut self, - file: &mut FileContext, - block_offset: usize, - block_payload: &[u8], - ) -> Result; - /// OTA update complete. /// /// The user may register a callback function when initializing the OTA diff --git a/src/provisioning/mod.rs b/src/provisioning/mod.rs index 2ed8349..5a209a0 100644 --- a/src/provisioning/mod.rs +++ b/src/provisioning/mod.rs @@ -6,7 +6,8 @@ use core::future::Future; use embassy_sync::blocking_mutex::raw::RawMutex; use embedded_mqtt::{ - Message, Publish, QoS, RetainHandling, Subscribe, SubscribeTopic, Subscription, + DeferredPayload, EncodingError, Message, Publish, QoS, RetainHandling, Subscribe, + SubscribeTopic, Subscription, }; use futures::StreamExt; use serde::Serialize; @@ -200,24 +201,26 @@ impl FleetProvisioner { parameters, }; - // FIXME: Serialize directly into the publish payload through `DeferredPublish` API - let payload = &mut [0u8; 1024]; - - let payload_len = match payload_format { - #[cfg(feature = "provision_cbor")] - PayloadFormat::Cbor => { - let mut serializer = - serde_cbor::ser::Serializer::new(serde_cbor::ser::SliceWrite::new(payload)); - register_request.serialize(&mut serializer)?; - serializer.into_inner().bytes_written() - } - PayloadFormat::Json => serde_json_core::to_slice(®ister_request, payload)?, - }; - - drop(message); - drop(create_subscription); + let payload = DeferredPayload::new( + |buf| { + Ok(match payload_format { + #[cfg(feature = "provision_cbor")] + PayloadFormat::Cbor => { + let mut serializer = + serde_cbor::ser::Serializer::new(serde_cbor::ser::SliceWrite::new(buf)); + register_request + .serialize(&mut serializer) + .map_err(|_| EncodingError::BufferSize)?; + serializer.into_inner().bytes_written() + } + PayloadFormat::Json => serde_json_core::to_slice(®ister_request, buf) + .map_err(|_| EncodingError::BufferSize)?, + }) + }, + 1024, + ); - debug!("Starting RegisterThing {:?}", payload_len); + debug!("Starting RegisterThing"); let mut register_subscription = mqtt .subscribe::<1>(Subscribe::new(&[SubscribeTopic { @@ -243,7 +246,7 @@ impl FleetProvisioner { topic_name: Topic::RegisterThing(template_name, payload_format) .format::<69>()? .as_str(), - payload: &payload[..payload_len], + payload, properties: embedded_mqtt::Properties::Slice(&[]), }) .await @@ -252,6 +255,9 @@ impl FleetProvisioner { Error::Mqtt })?; + drop(message); + drop(create_subscription); + let mut message = register_subscription .next() .await @@ -289,18 +295,25 @@ impl FleetProvisioner { }; // FIXME: Serialize directly into the publish payload through `DeferredPublish` API - let payload = &mut [0u8; 1024]; - - let payload_len = match payload_format { - #[cfg(feature = "provision_cbor")] - PayloadFormat::Cbor => { - let mut serializer = - serde_cbor::ser::Serializer::new(serde_cbor::ser::SliceWrite::new(payload)); - request.serialize(&mut serializer)?; - serializer.into_inner().bytes_written() - } - PayloadFormat::Json => serde_json_core::to_slice(&request, payload)?, - }; + let payload = DeferredPayload::new( + |buf| { + Ok(match payload_format { + #[cfg(feature = "provision_cbor")] + PayloadFormat::Cbor => { + let mut serializer = serde_cbor::ser::Serializer::new( + serde_cbor::ser::SliceWrite::new(buf), + ); + request + .serialize(&mut serializer) + .map_err(|_| EncodingError::BufferSize)?; + serializer.into_inner().bytes_written() + } + PayloadFormat::Json => serde_json_core::to_slice(&request, buf) + .map_err(|_| EncodingError::BufferSize)?, + }) + }, + 1024, + ); let subscription = mqtt .subscribe::<1>(Subscribe::new(&[SubscribeTopic { @@ -323,7 +336,7 @@ impl FleetProvisioner { topic_name: Topic::CreateCertificateFromCsr(payload_format) .format::<38>()? .as_str(), - payload: &payload[..payload_len], + payload, properties: embedded_mqtt::Properties::Slice(&[]), }) .await From 34dc52f8cbf88c872194362e10da61bef0d088f0 Mon Sep 17 00:00:00 2001 From: Mathias Date: Thu, 20 Jun 2024 11:00:21 +0200 Subject: [PATCH 17/36] Temporarilly subscribe to only accepted topic --- src/provisioning/mod.rs | 33 ++++++++++++++++----------------- src/provisioning/topics.rs | 4 ++-- 2 files changed, 18 insertions(+), 19 deletions(-) diff --git a/src/provisioning/mod.rs b/src/provisioning/mod.rs index 5a209a0..b650a1c 100644 --- a/src/provisioning/mod.rs +++ b/src/provisioning/mod.rs @@ -62,7 +62,7 @@ impl FleetProvisioner { } pub async fn provision_csr<'a, C, M: RawMutex, const SUBS: usize>( - mqtt: &'a embedded_mqtt::MqttClient<'a, M, SUBS>, + mqtt: &embedded_mqtt::MqttClient<'a, M, SUBS>, template_name: &str, parameters: Option, csr: &str, @@ -224,8 +224,8 @@ impl FleetProvisioner { let mut register_subscription = mqtt .subscribe::<1>(Subscribe::new(&[SubscribeTopic { - topic_path: Topic::RegisterThingAny(template_name, payload_format) - .format::<128>()? + topic_path: Topic::RegisterThingAccepted(template_name, payload_format) + .format::<150>()? .as_str(), maximum_qos: QoS::AtLeastOnce, no_local: false, @@ -290,11 +290,23 @@ impl FleetProvisioner { payload_format: PayloadFormat, ) -> Result, Error> { if let Some(csr) = csr { + let subscription = mqtt + .subscribe::<1>(Subscribe::new(&[SubscribeTopic { + topic_path: Topic::CreateCertificateFromCsrAccepted(payload_format) + .format::<47>()? + .as_str(), + maximum_qos: QoS::AtLeastOnce, + no_local: false, + retain_as_published: false, + retain_handling: RetainHandling::SendAtSubscribeTime, + }])) + .await + .map_err(|_| Error::Mqtt)?; + let request = CreateCertificateFromCsrRequest { certificate_signing_request: csr, }; - // FIXME: Serialize directly into the publish payload through `DeferredPublish` API let payload = DeferredPayload::new( |buf| { Ok(match payload_format { @@ -315,19 +327,6 @@ impl FleetProvisioner { 1024, ); - let subscription = mqtt - .subscribe::<1>(Subscribe::new(&[SubscribeTopic { - topic_path: Topic::CreateCertificateFromCsrAny(payload_format) - .format::<40>()? - .as_str(), - maximum_qos: QoS::AtLeastOnce, - no_local: false, - retain_as_published: false, - retain_handling: RetainHandling::SendAtSubscribeTime, - }])) - .await - .map_err(|_| Error::Mqtt)?; - mqtt.publish(Publish { dup: false, qos: QoS::AtLeastOnce, diff --git a/src/provisioning/topics.rs b/src/provisioning/topics.rs index bd3e7af..9256c55 100644 --- a/src/provisioning/topics.rs +++ b/src/provisioning/topics.rs @@ -179,7 +179,7 @@ impl<'a> Topic<'a> { } Topic::RegisterThingAny(template_name, payload_format) => { topic_path.write_fmt(format_args!( - "{}/{}/provision/{}/+", + "{}/{}/provision/{}/#", Self::PROVISIONING_PREFIX, template_name, payload_format, @@ -209,7 +209,7 @@ impl<'a> Topic<'a> { )), Topic::CreateKeysAndCertificateAny(payload_format) => topic_path.write_fmt( - format_args!("{}/create/{}/+", Self::CERT_PREFIX, payload_format), + format_args!("{}/create/{}/#", Self::CERT_PREFIX, payload_format), ), Topic::CreateKeysAndCertificateAccepted(payload_format) => topic_path.write_fmt( format_args!("{}/create/{}/accepted", Self::CERT_PREFIX, payload_format), From 604ca46ba6c1de7d33d75f09cfe611a90815cb21 Mon Sep 17 00:00:00 2001 From: Mathias Date: Thu, 18 Jul 2024 10:15:26 +0200 Subject: [PATCH 18/36] Bump embedded-mqtt --- Cargo.toml | 4 +- rust-toolchain.toml | 6 +- src/fmt.rs | 85 ++++++++++----- src/jobs/mod.rs | 4 +- src/jobs/update.rs | 36 ++---- src/ota/control_interface/mqtt.rs | 22 ++-- src/ota/data_interface/mod.rs | 2 +- src/ota/data_interface/mqtt.rs | 6 +- src/ota/error.rs | 2 +- src/ota/mod.rs | 176 +++++++++++++++--------------- src/provisioning/mod.rs | 4 +- src/shadows/error.rs | 122 ++++++++++----------- 12 files changed, 243 insertions(+), 226 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 7619b16..c27aacc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,9 +29,7 @@ serde_cbor = { version = "0.11", default-features = false, optional = true } serde-json-core = { version = "0.5" } shadow-derive = { path = "shadow_derive", version = "0.2.1" } embedded-storage-async = "0.4" -embedded-mqtt = { git = "ssh://git@github.com/BlackbirdHQ/embedded-mqtt/", rev = "dbf8af0", features = [ - "defmt", -] } +embedded-mqtt = { git = "ssh://git@github.com/BlackbirdHQ/embedded-mqtt/", rev = "d766137" } futures = { version = "0.3.28", default-features = false } diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 17dc494..1368141 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,8 +1,8 @@ [toolchain] -channel = "1.78" -components = [ "rust-src", "rustfmt", "llvm-tools", "clippy" ] +channel = "1.79" +components = ["rust-src", "rustfmt", "llvm-tools"] targets = [ "x86_64-unknown-linux-gnu", + "thumbv6m-none-eabi", "thumbv7em-none-eabihf", - "thumbv6m-none-eabi" ] diff --git a/src/fmt.rs b/src/fmt.rs index c06793e..35b929f 100644 --- a/src/fmt.rs +++ b/src/fmt.rs @@ -1,31 +1,12 @@ -// MIT License - -// Copyright (c) 2020 Dario Nieuwenhuis - -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: - -// The above copyright notice and this permission notice shall be included in all -// copies or substantial portions of the Software. - -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -// SOFTWARE. - #![macro_use] -#![allow(unused_macros)] +#![allow(unused)] + +use core::fmt::{Debug, Display, LowerHex}; #[cfg(all(feature = "defmt", feature = "log"))] compile_error!("You may not enable both `defmt` and `log` features."); +#[collapse_debuginfo(yes)] macro_rules! assert { ($($x:tt)*) => { { @@ -37,6 +18,7 @@ macro_rules! assert { }; } +#[collapse_debuginfo(yes)] macro_rules! assert_eq { ($($x:tt)*) => { { @@ -48,6 +30,7 @@ macro_rules! assert_eq { }; } +#[collapse_debuginfo(yes)] macro_rules! assert_ne { ($($x:tt)*) => { { @@ -59,6 +42,7 @@ macro_rules! assert_ne { }; } +#[collapse_debuginfo(yes)] macro_rules! debug_assert { ($($x:tt)*) => { { @@ -70,6 +54,7 @@ macro_rules! debug_assert { }; } +#[collapse_debuginfo(yes)] macro_rules! debug_assert_eq { ($($x:tt)*) => { { @@ -81,6 +66,7 @@ macro_rules! debug_assert_eq { }; } +#[collapse_debuginfo(yes)] macro_rules! debug_assert_ne { ($($x:tt)*) => { { @@ -92,6 +78,7 @@ macro_rules! debug_assert_ne { }; } +#[collapse_debuginfo(yes)] macro_rules! todo { ($($x:tt)*) => { { @@ -103,17 +90,23 @@ macro_rules! todo { }; } +#[cfg(not(feature = "defmt"))] +#[collapse_debuginfo(yes)] macro_rules! unreachable { ($($x:tt)*) => { - { - #[cfg(not(feature = "defmt"))] - ::core::unreachable!($($x)*); - #[cfg(feature = "defmt")] - ::defmt::unreachable!($($x)*); - } + ::core::unreachable!($($x)*) }; } +#[cfg(feature = "defmt")] +#[collapse_debuginfo(yes)] +macro_rules! unreachable { + ($($x:tt)*) => { + ::defmt::unreachable!($($x)*) + }; +} + +#[collapse_debuginfo(yes)] macro_rules! panic { ($($x:tt)*) => { { @@ -125,6 +118,7 @@ macro_rules! panic { }; } +#[collapse_debuginfo(yes)] macro_rules! trace { ($s:literal $(, $x:expr)* $(,)?) => { { @@ -138,6 +132,7 @@ macro_rules! trace { }; } +#[collapse_debuginfo(yes)] macro_rules! debug { ($s:literal $(, $x:expr)* $(,)?) => { { @@ -151,6 +146,7 @@ macro_rules! debug { }; } +#[collapse_debuginfo(yes)] macro_rules! info { ($s:literal $(, $x:expr)* $(,)?) => { { @@ -164,6 +160,7 @@ macro_rules! info { }; } +#[collapse_debuginfo(yes)] macro_rules! warn { ($s:literal $(, $x:expr)* $(,)?) => { { @@ -177,6 +174,7 @@ macro_rules! warn { }; } +#[collapse_debuginfo(yes)] macro_rules! error { ($s:literal $(, $x:expr)* $(,)?) => { { @@ -191,6 +189,7 @@ macro_rules! error { } #[cfg(feature = "defmt")] +#[collapse_debuginfo(yes)] macro_rules! unwrap { ($($x:tt)*) => { ::defmt::unwrap!($($x)*) @@ -198,6 +197,7 @@ macro_rules! unwrap { } #[cfg(not(feature = "defmt"))] +#[collapse_debuginfo(yes)] macro_rules! unwrap { ($arg:expr) => { match $crate::fmt::Try::into_result($arg) { @@ -245,3 +245,30 @@ impl Try for Result { self } } + +pub(crate) struct Bytes<'a>(pub &'a [u8]); + +impl<'a> Debug for Bytes<'a> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(f, "{:#02x?}", self.0) + } +} + +impl<'a> Display for Bytes<'a> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(f, "{:#02x?}", self.0) + } +} + +impl<'a> LowerHex for Bytes<'a> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(f, "{:#02x?}", self.0) + } +} + +#[cfg(feature = "defmt")] +impl<'a> defmt::Format for Bytes<'a> { + fn format(&self, fmt: defmt::Formatter) { + defmt::write!(fmt, "{:02x}", self.0) + } +} diff --git a/src/jobs/mod.rs b/src/jobs/mod.rs index 246d94b..8f35c71 100644 --- a/src/jobs/mod.rs +++ b/src/jobs/mod.rs @@ -262,7 +262,7 @@ impl Jobs { Describe::new() } - pub fn update(job_id: &str, status: JobStatus) -> Update { - Update::new(job_id, status) + pub fn update<'a>(status: JobStatus) -> Update<'a> { + Update::new(status) } } diff --git a/src/jobs/update.rs b/src/jobs/update.rs index db34d90..3875bd0 100644 --- a/src/jobs/update.rs +++ b/src/jobs/update.rs @@ -1,8 +1,6 @@ use serde::Serialize; -use crate::jobs::{ - data_types::JobStatus, JobTopic, MAX_CLIENT_TOKEN_LEN, MAX_JOB_ID_LEN, MAX_THING_NAME_LEN, -}; +use crate::jobs::{data_types::JobStatus, MAX_CLIENT_TOKEN_LEN}; use super::{JobError, StatusDetailsOwned}; @@ -69,7 +67,6 @@ pub struct UpdateJobExecutionRequest<'a> { } pub struct Update<'a> { - job_id: &'a str, status: JobStatus, client_token: Option<&'a str>, status_details: Option<&'a StatusDetailsOwned>, @@ -81,11 +78,8 @@ pub struct Update<'a> { } impl<'a> Update<'a> { - pub fn new(job_id: &'a str, status: JobStatus) -> Self { - assert!(job_id.len() < MAX_JOB_ID_LEN); - + pub fn new(status: JobStatus) -> Self { Self { - job_id, status, status_details: None, include_job_document: false, @@ -148,17 +142,7 @@ impl<'a> Update<'a> { } } - pub fn topic_payload( - self, - client_id: &str, - buf: &mut [u8], - ) -> Result< - ( - heapless::String<{ MAX_THING_NAME_LEN + MAX_JOB_ID_LEN + 25 }>, - usize, - ), - JobError, - > { + pub fn payload(self, buf: &mut [u8]) -> Result { let payload_len = serde_json_core::to_slice( &UpdateJobExecutionRequest { execution_number: self.execution_number, @@ -174,15 +158,14 @@ impl<'a> Update<'a> { ) .map_err(|_| JobError::Encoding)?; - Ok(( - JobTopic::Update(self.job_id).format(client_id)?, - payload_len, - )) + Ok(payload_len) } } #[cfg(test)] mod test { + use crate::jobs::JobTopic; + use super::*; use serde_json_core::to_string; @@ -207,14 +190,17 @@ mod test { #[test] fn topic_payload() { let mut buf = [0u8; 512]; - let (topic, payload_len) = Update::new("test_job_id", JobStatus::Failed) + let topic = JobTopic::Update("test_job_id") + .format::<64>("test_client") + .unwrap(); + let payload_len = Update::new(JobStatus::Failed) .client_token("test_client:token_update") .step_timeout_in_minutes(50) .execution_number(5) .expected_version(2) .include_job_document() .include_job_execution_state() - .topic_payload("test_client", &mut buf) + .payload(&mut buf) .unwrap(); assert_eq!(&buf[..payload_len], br#"{"executionNumber":5,"expectedVersion":2,"includeJobDocument":true,"includeJobExecutionState":true,"status":"FAILED","stepTimeoutInMinutes":50,"clientToken":"test_client:token_update"}"#); diff --git a/src/ota/control_interface/mqtt.rs b/src/ota/control_interface/mqtt.rs index 4259db2..094ce16 100644 --- a/src/ota/control_interface/mqtt.rs +++ b/src/ota/control_interface/mqtt.rs @@ -1,11 +1,11 @@ use core::fmt::Write; use embassy_sync::blocking_mutex::raw::RawMutex; -use embedded_mqtt::{Publish, QoS}; +use embedded_mqtt::{DeferredPayload, EncodingError, Publish, QoS}; use super::ControlInterface; use crate::jobs::data_types::JobStatus; -use crate::jobs::Jobs; +use crate::jobs::{JobTopic, Jobs, MAX_JOB_ID_LEN, MAX_THING_NAME_LEN}; use crate::ota::config::Config; use crate::ota::encoding::json::JobStatusReason; use crate::ota::encoding::FileContext; @@ -93,11 +93,17 @@ impl<'a, M: RawMutex, const SUBS: usize> ControlInterface } } - // FIXME: Serialize directly into the publish payload through `DeferredPublish` API - let mut buf = [0u8; 512]; - let (topic, payload_len) = Jobs::update(file_ctx.job_name.as_str(), status) - .status_details(&file_ctx.status_details) - .topic_payload(self.client_id(), &mut buf)?; + let topic = JobTopic::Update(file_ctx.job_name.as_str()) + .format::<{ MAX_THING_NAME_LEN + MAX_JOB_ID_LEN + 25 }>(self.client_id())?; + let payload = DeferredPayload::new( + |buf| { + Jobs::update(status) + .status_details(&file_ctx.status_details) + .payload(buf) + .map_err(|_| EncodingError::BufferSize) + }, + 512, + ); self.publish(Publish { dup: false, @@ -105,7 +111,7 @@ impl<'a, M: RawMutex, const SUBS: usize> ControlInterface retain: false, pid: None, topic_name: &topic, - payload: &buf[..payload_len], + payload, properties: embedded_mqtt::Properties::Slice(&[]), }) .await?; diff --git a/src/ota/data_interface/mod.rs b/src/ota/data_interface/mod.rs index 3723c69..cfb9cee 100644 --- a/src/ota/data_interface/mod.rs +++ b/src/ota/data_interface/mod.rs @@ -45,7 +45,7 @@ impl<'a> FileBlock<'a> { } pub trait BlockTransfer { - async fn next_block(&mut self) -> Result, OtaError>; + async fn next_block(&mut self) -> Result>, OtaError>; } pub trait DataInterface { diff --git a/src/ota/data_interface/mqtt.rs b/src/ota/data_interface/mqtt.rs index 1f32995..f40f31a 100644 --- a/src/ota/data_interface/mqtt.rs +++ b/src/ota/data_interface/mqtt.rs @@ -125,8 +125,8 @@ impl<'a> OtaTopic<'a> { } impl<'a, 'b, M: RawMutex, const SUBS: usize> BlockTransfer for Subscription<'a, 'b, M, SUBS, 1> { - async fn next_block(&mut self) -> Result, OtaError> { - Ok(self.next().await.ok_or(OtaError::Encoding)?) + async fn next_block(&mut self) -> Result>, OtaError> { + Ok(self.next().await) } } @@ -179,7 +179,7 @@ impl<'a, M: RawMutex, const SUBS: usize> DataInterface for MqttClient<'a, M, SUB }, buf, ) - .map_err(|e| EncodingError::BufferSize) + .map_err(|_| EncodingError::BufferSize) }, 32, ); diff --git a/src/ota/error.rs b/src/ota/error.rs index 8fdf6e6..8c5744b 100644 --- a/src/ota/error.rs +++ b/src/ota/error.rs @@ -27,7 +27,7 @@ pub enum OtaError { impl OtaError { pub fn is_retryable(&self) -> bool { - matches!(self, Self::Encoding) + matches!(self, Self::Encoding | Self::Timeout) } } diff --git a/src/ota/mod.rs b/src/ota/mod.rs index c962b54..338320c 100644 --- a/src/ota/mod.rs +++ b/src/ota/mod.rs @@ -227,113 +227,113 @@ impl Updater { info!("Initialized file handler! Requesting file blocks"); - // Prepare the storage layer on receiving a new file - let mut subscription = data.init_file_transfer(&mut file_ctx).await?; + loop { + // Prepare the storage layer on receiving a new file + let mut subscription = data.init_file_transfer(&mut file_ctx).await?; - data.request_file_block(&mut file_ctx, &config).await?; + data.request_file_block(&mut file_ctx, &config).await?; - info!("Awaiting file blocks!"); + info!("Awaiting file blocks!"); - while let Ok(mut payload) = subscription.next_block().await { - debug!("process_data_handler"); - // Decode the file block received - match Self::ingest_data_block( - data, - block_writer, - &config, - &mut file_ctx, - payload.deref_mut(), - ) - .await - { - Ok(true) => match pal.close_file(&file_ctx).await { - Err(e) => { + while let Some(mut payload) = subscription.next_block().await? { + debug!("process_data_handler"); + // Decode the file block received + match Self::ingest_data_block( + data, + block_writer, + &config, + &mut file_ctx, + payload.deref_mut(), + ) + .await + { + Ok(true) => match pal.close_file(&file_ctx).await { + Err(e) => { + control + .update_job_status( + &mut file_ctx, + &config, + JobStatus::Failed, + JobStatusReason::Pal(0), + ) + .await?; + + return Err(e.into()); + } + Ok(_) => { + let (status, reason, event) = if let Some(0) = file_ctx.file_type { + ( + JobStatus::InProgress, + JobStatusReason::SigCheckPassed, + pal::OtaEvent::Activate, + ) + } else { + ( + JobStatus::Succeeded, + JobStatusReason::Accepted, + pal::OtaEvent::UpdateComplete, + ) + }; + + control + .update_job_status(&mut file_ctx, &config, status, reason) + .await?; + + return Ok(event); + } + }, + Ok(false) => { + debug!("Ingested one block!"); + // Reset the momentum counter since we received a good block + request_momentum.store(0, Ordering::Relaxed); + + // We're actively receiving a file so update the job status as + // needed control .update_job_status( &mut file_ctx, &config, - JobStatus::Failed, - JobStatusReason::Pal(0), + JobStatus::InProgress, + JobStatusReason::Receiving, ) .await?; - return Err(e.into()); - } - Ok(_) => { - let (status, reason, event) = if let Some(0) = file_ctx.file_type { - ( - JobStatus::InProgress, - JobStatusReason::SigCheckPassed, - pal::OtaEvent::Activate, - ) + if file_ctx.request_block_remaining > 1 { + file_ctx.request_block_remaining -= 1; } else { - ( - JobStatus::Succeeded, - JobStatusReason::Accepted, - pal::OtaEvent::UpdateComplete, - ) - }; + data.request_file_block(&mut file_ctx, &config).await?; + } + } + Err(e) if e.is_retryable() => { + warn!("Failed to ingest data block, Error is retryable! ingest_data_block returned error {:?}", e); + } + Err(e) => { + error!("Failed to ingest data block, rejecting image: ingest_data_block returned error {:?}", e); + + // Call the platform specific code to reject the image + // TODO: This should never write to current image flags?! + // pal.set_platform_image_state(ImageState::Rejected( + // ImageStateReason::FailedIngest, + // )) + // .await?; + // TODO: Pal reason control - .update_job_status(&mut file_ctx, &config, status, reason) + .update_job_status( + &mut file_ctx, + &config, + JobStatus::Failed, + JobStatusReason::Pal(0), + ) .await?; - return Ok(event); - } - }, - Ok(false) => { - debug!("Ingested one block!"); - // Reset the momentum counter since we received a good block - request_momentum.store(0, Ordering::Relaxed); - - // We're actively receiving a file so update the job status as - // needed - control - .update_job_status( - &mut file_ctx, - &config, - JobStatus::InProgress, - JobStatusReason::Receiving, - ) - .await?; - - if file_ctx.request_block_remaining > 1 { - file_ctx.request_block_remaining -= 1; - } else { - data.request_file_block(&mut file_ctx, &config).await?; + pal.complete_callback(pal::OtaEvent::Fail).await?; + info!("Application callback! OtaEvent::Fail"); + return Err(e); } } - Err(e) if e.is_retryable() => { - warn!("Failed to ingest data block, Error is retryable! ingest_data_block returned error {:?}", e); - } - Err(e) => { - error!("Failed to ingest data block, rejecting image: ingest_data_block returned error {:?}", e); - - // Call the platform specific code to reject the image - // TODO: This should never write to current image flags?! - // pal.set_platform_image_state(ImageState::Rejected( - // ImageStateReason::FailedIngest, - // )) - // .await?; - - // TODO: Pal reason - control - .update_job_status( - &mut file_ctx, - &config, - JobStatus::Failed, - JobStatusReason::Pal(0), - ) - .await?; - - pal.complete_callback(pal::OtaEvent::Fail).await?; - info!("Application callback! OtaEvent::Fail"); - return Err(e); - } } } - - Err(error::OtaError::Mqtt(embedded_mqtt::Error::EOF)) }; // let (momentum_res, data_res) = embassy_futures::join::join(momentum_fut, data_fut).await; diff --git a/src/provisioning/mod.rs b/src/provisioning/mod.rs index b650a1c..57e1ef5 100644 --- a/src/provisioning/mod.rs +++ b/src/provisioning/mod.rs @@ -234,7 +234,7 @@ impl FleetProvisioner { }])) .await .map_err(|e| { - error!("Failed subscription to RegisterThingAny! {}", e); + error!("Failed subscription to RegisterThingAny! {:?}", e); Error::Mqtt })?; @@ -251,7 +251,7 @@ impl FleetProvisioner { }) .await .map_err(|e| { - error!("Failed publish to RegisterThing! {}", e); + error!("Failed publish to RegisterThing! {:?}", e); Error::Mqtt })?; diff --git a/src/shadows/error.rs b/src/shadows/error.rs index f7cd84b..f08ff4d 100644 --- a/src/shadows/error.rs +++ b/src/shadows/error.rs @@ -98,66 +98,66 @@ impl<'a> TryFrom> for ShadowError { } } -impl Display for ShadowError { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - match self { - Self::InvalidJson => write!(f, "Invalid JSON"), - Self::MissingState => write!(f, "Missing required node: state"), - Self::MalformedState => write!(f, "State node must be an object"), - Self::MalformedDesired => write!(f, "Desired node must be an object"), - Self::MalformedReported => write!(f, "Reported node must be an object"), - Self::InvalidVersion => write!(f, "Invalid version"), - Self::InvalidClientToken => write!(f, "Invalid clientToken"), - Self::JsonTooDeep => { - write!(f, "JSON contains too many levels of nesting; maximum is 6") - } - Self::InvalidStateNode => write!(f, "State contains an invalid node"), - Self::Unauthorized => write!(f, "Unauthorized"), - Self::Forbidden => write!(f, "Forbidden"), - Self::NotFound => write!(f, "Thing not found"), - Self::NoNamedShadow(shadow_name) => { - write!(f, "No shadow exists with name: {}", shadow_name) - } - Self::VersionConflict => write!(f, "Version conflict"), - Self::PayloadTooLarge => write!(f, "The payload exceeds the maximum size allowed"), - Self::UnsupportedEncoding => write!( - f, - "Unsupported documented encoding; supported encoding is UTF-8" - ), - Self::TooManyRequests => write!(f, "The Device Shadow service will generate this error message when there are more than 10 in-flight requests on a single connection"), - Self::InternalServerError => write!(f, "Internal service failure"), - } - } -} +// impl Display for ShadowError { +// fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { +// match self { +// Self::InvalidJson => write!(f, "Invalid JSON"), +// Self::MissingState => write!(f, "Missing required node: state"), +// Self::MalformedState => write!(f, "State node must be an object"), +// Self::MalformedDesired => write!(f, "Desired node must be an object"), +// Self::MalformedReported => write!(f, "Reported node must be an object"), +// Self::InvalidVersion => write!(f, "Invalid version"), +// Self::InvalidClientToken => write!(f, "Invalid clientToken"), +// Self::JsonTooDeep => { +// write!(f, "JSON contains too many levels of nesting; maximum is 6") +// } +// Self::InvalidStateNode => write!(f, "State contains an invalid node"), +// Self::Unauthorized => write!(f, "Unauthorized"), +// Self::Forbidden => write!(f, "Forbidden"), +// Self::NotFound => write!(f, "Thing not found"), +// Self::NoNamedShadow(shadow_name) => { +// write!(f, "No shadow exists with name: {}", shadow_name) +// } +// Self::VersionConflict => write!(f, "Version conflict"), +// Self::PayloadTooLarge => write!(f, "The payload exceeds the maximum size allowed"), +// Self::UnsupportedEncoding => write!( +// f, +// "Unsupported documented encoding; supported encoding is UTF-8" +// ), +// Self::TooManyRequests => write!(f, "The Device Shadow service will generate this error message when there are more than 10 in-flight requests on a single connection"), +// Self::InternalServerError => write!(f, "Internal service failure"), +// } +// } +// } -// TODO: This seems like an extremely brittle way of doing this??! -impl FromStr for ShadowError { - type Err = (); +// // TODO: This seems like an extremely brittle way of doing this??! +// impl FromStr for ShadowError { +// type Err = (); - fn from_str(s: &str) -> Result { - Ok(match s.trim() { - "Invalid JSON" => Self::InvalidJson, - "Missing required node: state" => Self::MissingState, - "State node must be an object" => Self::MalformedState, - "Desired node must be an object" => Self::MalformedDesired, - "Reported node must be an object" => Self::MalformedReported, - "Invalid version" => Self::InvalidVersion, - "Invalid clientToken" => Self::InvalidClientToken, - "JSON contains too many levels of nesting; maximum is 6" => Self::JsonTooDeep, - "State contains an invalid node" => Self::InvalidStateNode, - "Unauthorized" => Self::Unauthorized, - "Forbidden" => Self::Forbidden, - "Thing not found" => Self::NotFound, - // TODO: - "No shadow exists with name: " => Self::NoNamedShadow(String::new()), - "Version conflict" => Self::VersionConflict, - "The payload exceeds the maximum size allowed" => Self::PayloadTooLarge, - "Unsupported documented encoding; supported encoding is UTF-8" => { - Self::UnsupportedEncoding - } - "The Device Shadow service will generate this error message when there are more than 10 in-flight requests on a single connection" => Self::TooManyRequests, - "Internal service failure" => Self::InternalServerError, - _ => return Err(()), - }) - } -} +// fn from_str(s: &str) -> Result { +// Ok(match s.trim() { +// "Invalid JSON" => Self::InvalidJson, +// "Missing required node: state" => Self::MissingState, +// "State node must be an object" => Self::MalformedState, +// "Desired node must be an object" => Self::MalformedDesired, +// "Reported node must be an object" => Self::MalformedReported, +// "Invalid version" => Self::InvalidVersion, +// "Invalid clientToken" => Self::InvalidClientToken, +// "JSON contains too many levels of nesting; maximum is 6" => Self::JsonTooDeep, +// "State contains an invalid node" => Self::InvalidStateNode, +// "Unauthorized" => Self::Unauthorized, +// "Forbidden" => Self::Forbidden, +// "Thing not found" => Self::NotFound, +// // TODO: +// "No shadow exists with name: " => Self::NoNamedShadow(String::new()), +// "Version conflict" => Self::VersionConflict, +// "The payload exceeds the maximum size allowed" => Self::PayloadTooLarge, +// "Unsupported documented encoding; supported encoding is UTF-8" => { +// Self::UnsupportedEncoding +// } +// "The Device Shadow service will generate this error message when there are more than 10 in-flight requests on a single connection" => Self::TooManyRequests, +// "Internal service failure" => Self::InternalServerError, +// _ => return Err(()), +// }) +// } +// } From d6d4af3256a6638f30ff564545be8b0ec3999ad7 Mon Sep 17 00:00:00 2001 From: Mathias Date: Fri, 19 Jul 2024 11:08:21 +0200 Subject: [PATCH 19/36] Fix provisioning topic resulting in status code 143, by subscribing to individual accepted and rejected topic for now --- src/provisioning/error.rs | 12 +++++- src/provisioning/mod.rs | 86 ++++++++++++++++++++++----------------- 2 files changed, 58 insertions(+), 40 deletions(-) diff --git a/src/provisioning/error.rs b/src/provisioning/error.rs index deb5725..67f4696 100644 --- a/src/provisioning/error.rs +++ b/src/provisioning/error.rs @@ -1,10 +1,12 @@ #[derive(Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] + pub enum Error { Overflow, InvalidPayload, InvalidState, - Mqtt, - DeserializeJson(serde_json_core::de::Error), + Mqtt(embedded_mqtt::Error), + DeserializeJson(#[cfg_attr(feature = "defmt", defmt(Debug2Format))] serde_json_core::de::Error), DeserializeCbor, CertificateStorage, Response(u16), @@ -27,3 +29,9 @@ impl From for Error { Self::DeserializeCbor } } + +impl From for Error { + fn from(e: embedded_mqtt::Error) -> Self { + Self::Mqtt(e) + } +} diff --git a/src/provisioning/mod.rs b/src/provisioning/mod.rs index 57e1ef5..b334fca 100644 --- a/src/provisioning/mod.rs +++ b/src/provisioning/mod.rs @@ -232,11 +232,7 @@ impl FleetProvisioner { retain_as_published: false, retain_handling: RetainHandling::SendAtSubscribeTime, }])) - .await - .map_err(|e| { - error!("Failed subscription to RegisterThingAny! {:?}", e); - Error::Mqtt - })?; + .await?; mqtt.publish(Publish { dup: false, @@ -249,11 +245,7 @@ impl FleetProvisioner { payload, properties: embedded_mqtt::Properties::Slice(&[]), }) - .await - .map_err(|e| { - error!("Failed publish to RegisterThing! {:?}", e); - Error::Mqtt - })?; + .await?; drop(message); drop(create_subscription); @@ -288,20 +280,30 @@ impl FleetProvisioner { mqtt: &'b embedded_mqtt::MqttClient<'a, M, SUBS>, csr: Option<&str>, payload_format: PayloadFormat, - ) -> Result, Error> { + ) -> Result, Error> { if let Some(csr) = csr { let subscription = mqtt - .subscribe::<1>(Subscribe::new(&[SubscribeTopic { - topic_path: Topic::CreateCertificateFromCsrAccepted(payload_format) - .format::<47>()? - .as_str(), - maximum_qos: QoS::AtLeastOnce, - no_local: false, - retain_as_published: false, - retain_handling: RetainHandling::SendAtSubscribeTime, - }])) - .await - .map_err(|_| Error::Mqtt)?; + .subscribe(Subscribe::new(&[ + SubscribeTopic { + topic_path: Topic::CreateCertificateFromCsrRejected(payload_format) + .format::<47>()? + .as_str(), + maximum_qos: QoS::AtLeastOnce, + no_local: false, + retain_as_published: false, + retain_handling: RetainHandling::SendAtSubscribeTime, + }, + SubscribeTopic { + topic_path: Topic::CreateCertificateFromCsrAccepted(payload_format) + .format::<47>()? + .as_str(), + maximum_qos: QoS::AtLeastOnce, + no_local: false, + retain_as_published: false, + retain_handling: RetainHandling::SendAtSubscribeTime, + }, + ])) + .await?; let request = CreateCertificateFromCsrRequest { certificate_signing_request: csr, @@ -333,28 +335,37 @@ impl FleetProvisioner { retain: false, pid: None, topic_name: Topic::CreateCertificateFromCsr(payload_format) - .format::<38>()? + .format::<40>()? .as_str(), payload, properties: embedded_mqtt::Properties::Slice(&[]), }) - .await - .map_err(|_| Error::Mqtt)?; + .await?; Ok(subscription) } else { let subscription = mqtt - .subscribe::<1>(Subscribe::new(&[SubscribeTopic { - topic_path: Topic::CreateKeysAndCertificateAny(payload_format) - .format::<31>()? - .as_str(), - maximum_qos: QoS::AtLeastOnce, - no_local: false, - retain_as_published: false, - retain_handling: RetainHandling::SendAtSubscribeTime, - }])) - .await - .map_err(|_| Error::Mqtt)?; + .subscribe(Subscribe::new(&[ + SubscribeTopic { + topic_path: Topic::CreateKeysAndCertificateAccepted(payload_format) + .format::<38>()? + .as_str(), + maximum_qos: QoS::AtLeastOnce, + no_local: false, + retain_as_published: false, + retain_handling: RetainHandling::SendAtSubscribeTime, + }, + SubscribeTopic { + topic_path: Topic::CreateKeysAndCertificateRejected(payload_format) + .format::<38>()? + .as_str(), + maximum_qos: QoS::AtLeastOnce, + no_local: false, + retain_as_published: false, + retain_handling: RetainHandling::SendAtSubscribeTime, + }, + ])) + .await?; mqtt.publish(Publish { dup: false, @@ -367,8 +378,7 @@ impl FleetProvisioner { payload: b"", properties: embedded_mqtt::Properties::Slice(&[]), }) - .await - .map_err(|_| Error::Mqtt)?; + .await?; Ok(subscription) } From f42460845bfbab7c610174413251560edc18cd8a Mon Sep 17 00:00:00 2001 From: Mathias Date: Fri, 19 Jul 2024 12:57:08 +0200 Subject: [PATCH 20/36] Reduce the overhead of of CreateCertificateFromCsr buffer size --- src/provisioning/mod.rs | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/provisioning/mod.rs b/src/provisioning/mod.rs index b334fca..dbeefe5 100644 --- a/src/provisioning/mod.rs +++ b/src/provisioning/mod.rs @@ -10,14 +10,13 @@ use embedded_mqtt::{ SubscribeTopic, Subscription, }; use futures::StreamExt; -use serde::Serialize; -use serde::{de::DeserializeOwned, Deserialize}; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; pub use error::Error; -use self::data_types::CreateCertificateFromCsrRequest; use self::{ data_types::{ + CreateCertificateFromCsrRequest, CreateCertificateFromCsrResponse, CreateKeysAndCertificateResponse, ErrorResponse, RegisterThingRequest, RegisterThingResponse, }, @@ -137,8 +136,6 @@ impl FleetProvisioner { where C: DeserializeOwned, { - use crate::provisioning::data_types::CreateCertificateFromCsrResponse; - let mut create_subscription = Self::begin(mqtt, csr, payload_format).await?; let mut message = create_subscription @@ -326,7 +323,7 @@ impl FleetProvisioner { .map_err(|_| EncodingError::BufferSize)?, }) }, - 1024, + csr.len() + 32, ); mqtt.publish(Publish { From f924e1e92b67e9edae102b48cc38b27d1cc86700 Mon Sep 17 00:00:00 2001 From: Mathias Koch Date: Wed, 25 Sep 2024 15:52:34 +0200 Subject: [PATCH 21/36] Feature/async shadows (#57) * Wip on rewriting shadows to async * Further work on async shadows. Still working on compile errors * Fix: Async shadow (#60) * fix asyunc shadow * renaming of handle message and some linting * shadows error fix and handle delta should wait for connected * fmt * Add const generic SUBS to shadows * Fix/async shadow (#61) * fix asyunc shadow * renaming of handle message and some linting * shadows error fix and handle delta should wait for connected * fmt * subscribe to get shadow and do not overwrite desired state * Get shadow should deserialize patchState * wait for accepted and rejected for delete and update as well * Make sure OTA job documents can be deserialized with no codesigning properties in the document (#62) * Dont blindly copy serde attrs in ShadowPatch derive, but rather introduce patch attr that specifies attrs to copy * Add skip_serializing_if none to all patchstate fields * Shadows: Check client token on all request/response pairs * Create initial shadow state, if dao read fails during getShadow operation * remove some client token checks * Fix not holding delta message across report call * handle delta on get shadow * Bump embedded-mqtt * Fix all tests * Allow reporting non-persisted shadows directly, through a report fn * Bump embedded-mqtt * Enhancement(async): Mutex shadow to borrow as immutable (#63) * Use mutex to borrow shadow as immutable * remove .git in embedded-mqtt dependency --------- Co-authored-by: Kenneth Knudsen Co-authored-by: Kenneth Knudsen <98805797+KennethKnudsen97@users.noreply.github.com> --- Cargo.toml | 17 +- documentation/stack.drawio | 8 +- rust-toolchain.toml | 2 +- shadow_derive/src/lib.rs | 26 +- src/jobs/data_types.rs | 15 +- src/lib.rs | 4 +- src/ota/control_interface/mqtt.rs | 116 +++- src/ota/data_interface/mqtt.rs | 51 +- src/ota/encoding/json.rs | 45 +- src/ota/encoding/mod.rs | 8 +- src/ota/error.rs | 5 +- src/provisioning/mod.rs | 228 ++++---- src/shadows/README.md | 2 +- src/shadows/dao.rs | 60 +-- src/shadows/data_types.rs | 12 +- src/shadows/error.rs | 84 +-- src/shadows/mod.rs | 845 +++++++++++++++--------------- src/shadows/topics.rs | 49 +- tests/common/file_handler.rs | 100 ++-- tests/common/network.rs | 4 +- tests/ota.rs | 62 +-- tests/provisioning.rs | 18 +- tests/shadows.rs | 135 +++-- 23 files changed, 979 insertions(+), 917 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index c27aacc..42f4d3a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,13 +4,13 @@ members = ["shadow_derive"] [package] name = "rustot" version = "0.5.0" -authors = ["Mathias Koch "] +authors = ["Factbird team "] description = "AWS IoT" readme = "README.md" keywords = ["iot", "no-std"] categories = ["embedded", "no-std"] license = "MIT OR Apache-2.0" -repository = "https://github.com/BlackbirdHQ/rustot" +repository = "https://github.com/FactbirdHQ/rustot" edition = "2021" documentation = "https://docs.rs/rustot" exclude = ["/documentation"] @@ -29,7 +29,7 @@ serde_cbor = { version = "0.11", default-features = false, optional = true } serde-json-core = { version = "0.5" } shadow-derive = { path = "shadow_derive", version = "0.2.1" } embedded-storage-async = "0.4" -embedded-mqtt = { git = "ssh://git@github.com/BlackbirdHQ/embedded-mqtt/", rev = "d766137" } +embedded-mqtt = { git = "ssh://git@github.com/FactbirdHQ/embedded-mqtt", rev = "d2b7c02" } futures = { version = "0.3.28", default-features = false } @@ -46,6 +46,7 @@ embedded-nal-async = "0.7" env_logger = "0.11" sha2 = "0.10.1" static_cell = { version = "2", features = ["nightly"] } + tokio = { version = "1.33", default-features = false, features = [ "macros", "rt", @@ -73,5 +74,13 @@ ota_http_data = [] std = ["serde/std", "serde_cbor?/std"] -defmt = ["dep:defmt", "heapless/defmt-03", "embedded-mqtt/defmt"] +defmt = [ + "dep:defmt", + "heapless/defmt-03", + "embedded-mqtt/defmt", + "embassy-time/defmt", +] log = ["dep:log", "embedded-mqtt/log"] + +# [patch."ssh://git@github.com/FactbirdHQ/embedded-mqtt"] +# embedded-mqtt = { path = "../embedded-mqtt" } diff --git a/documentation/stack.drawio b/documentation/stack.drawio index cdf864a..fd6bb68 100644 --- a/documentation/stack.drawio +++ b/documentation/stack.drawio @@ -1,13 +1,13 @@ - + - + - + @@ -22,7 +22,7 @@ - + diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 1368141..b6369b9 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,5 +1,5 @@ [toolchain] -channel = "1.79" +channel = "nightly-2024-07-17" components = ["rust-src", "rustfmt", "llvm-tools"] targets = [ "x86_64-unknown-linux-gnu", diff --git a/shadow_derive/src/lib.rs b/shadow_derive/src/lib.rs index 7838e87..09cdef0 100644 --- a/shadow_derive/src/lib.rs +++ b/shadow_derive/src/lib.rs @@ -11,9 +11,9 @@ use syn::DeriveInput; use syn::Generics; use syn::Ident; use syn::Result; -use syn::{parenthesized, Attribute, Error, Field, LitStr}; +use syn::{parenthesized, Error, Field, LitStr}; -#[proc_macro_derive(ShadowState, attributes(shadow, static_shadow_field))] +#[proc_macro_derive(ShadowState, attributes(shadow, static_shadow_field, patch))] pub fn shadow_state(input: TokenStream) -> TokenStream { match parse_macro_input!(input as ParseInput) { ParseInput::Struct(input) => { @@ -32,7 +32,7 @@ pub fn shadow_state(input: TokenStream) -> TokenStream { } } -#[proc_macro_derive(ShadowPatch, attributes(static_shadow_field, serde))] +#[proc_macro_derive(ShadowPatch, attributes(static_shadow_field, patch))] pub fn shadow_patch(input: TokenStream) -> TokenStream { TokenStream::from(match parse_macro_input!(input as ParseInput) { ParseInput::Struct(input) => generate_shadow_patch_struct(&input), @@ -56,7 +56,7 @@ struct StructParseInput { pub ident: Ident, pub generics: Generics, pub shadow_fields: Vec, - pub copy_attrs: Vec, + pub copy_attrs: Vec, pub shadow_name: Option, } @@ -67,8 +67,6 @@ impl Parse for ParseInput { let mut shadow_name = None; let mut copy_attrs = vec![]; - let attrs_to_copy = ["serde"]; - // Parse valid container attributes for attr in derive_input.attrs { if attr.path.is_ident("shadow") { @@ -78,12 +76,14 @@ impl Parse for ParseInput { content.parse() } shadow_name = Some(shadow_arg.parse2(attr.tokens)?); - } else if attrs_to_copy - .iter() - .find(|a| attr.path.is_ident(a)) - .is_some() - { - copy_attrs.push(attr); + } else if attr.path.is_ident("patch") { + fn patch_arg(input: ParseStream) -> Result { + let content; + parenthesized!(content in input); + content.parse() + } + let args = patch_arg.parse2(attr.tokens)?; + copy_attrs.push(quote! { #[ #args ]}) } } @@ -161,7 +161,7 @@ fn create_optional_fields(fields: &Vec) -> Vec Some(if type_name_string.starts_with("Option<") { quote! { #(#attrs)* pub #field_name: Option::PatchState>> } } else { - quote! { #(#attrs)* pub #field_name: Option<<#type_name as rustot::shadows::ShadowPatch>::PatchState> } + quote! { #(#attrs)* #[serde(skip_serializing_if = "Option::is_none")] pub #field_name: Option<<#type_name as rustot::shadows::ShadowPatch>::PatchState> } }) } }) diff --git a/src/jobs/data_types.rs b/src/jobs/data_types.rs index 5910469..36449dd 100644 --- a/src/jobs/data_types.rs +++ b/src/jobs/data_types.rs @@ -22,7 +22,8 @@ pub enum JobStatus { Removed, } -#[derive(Debug, Clone, PartialEq, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Deserialize)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] pub enum ErrorCode { /// The request was sent to a topic in the AWS IoT Jobs namespace that does /// not map to any API. @@ -89,7 +90,7 @@ pub struct GetPendingJobExecutionsResponse<'a> { /// A client token used to correlate requests and responses. Enter an /// arbitrary value here and it is reflected in the response. #[serde(rename = "clientToken")] - pub client_token: &'a str, + pub client_token: Option<&'a str>, } /// Contains data about a job execution. @@ -211,7 +212,7 @@ pub struct StartNextPendingJobExecutionResponse<'a, J> { /// A client token used to correlate requests and responses. Enter an /// arbitrary value here and it is reflected in the response. #[serde(rename = "clientToken")] - pub client_token: &'a str, + pub client_token: Option<&'a str>, } /// Topic (accepted): $aws/things/{thingName}/jobs/{jobId}/update/accepted \ @@ -232,7 +233,7 @@ pub struct UpdateJobExecutionResponse<'a, J> { /// A client token used to correlate requests and responses. Enter an /// arbitrary value here and it is reflected in the response. #[serde(rename = "clientToken")] - pub client_token: &'a str, + pub client_token: Option<&'a str>, } /// Sent whenever a job execution is added to or removed from the list of @@ -289,7 +290,7 @@ pub struct Jobs { /// service operation. #[derive(Debug, PartialEq, Deserialize)] pub struct ErrorResponse<'a> { - code: ErrorCode, + pub code: ErrorCode, /// An error message string. message: &'a str, /// A client token used to correlate requests and responses. Enter an @@ -394,7 +395,7 @@ mod test { in_progress_jobs: Some(Vec::::new()), queued_jobs: None, timestamp: 1587381778, - client_token: "0:client_name", + client_token: Some("0:client_name"), } ); @@ -433,7 +434,7 @@ mod test { in_progress_jobs: Some(Vec::::new()), queued_jobs: Some(queued_jobs), timestamp: 1587381778, - client_token: "0:client_name", + client_token: Some("0:client_name"), } ); } diff --git a/src/lib.rs b/src/lib.rs index b121586..ca160a6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,7 @@ #![cfg_attr(not(any(test, feature = "std")), no_std)] #![allow(async_fn_in_trait)] +#![allow(incomplete_features)] +#![feature(generic_const_exprs)] // This mod MUST go first, so that the others see its macros. pub(crate) mod fmt; @@ -8,6 +10,6 @@ pub mod jobs; #[cfg(any(feature = "ota_mqtt_data", feature = "ota_http_data"))] pub mod ota; pub mod provisioning; -// pub mod shadows; +pub mod shadows; pub use serde_cbor; diff --git a/src/ota/control_interface/mqtt.rs b/src/ota/control_interface/mqtt.rs index 094ce16..4456428 100644 --- a/src/ota/control_interface/mqtt.rs +++ b/src/ota/control_interface/mqtt.rs @@ -1,18 +1,21 @@ use core::fmt::Write; +use bitmaps::{Bits, BitsImpl}; use embassy_sync::blocking_mutex::raw::RawMutex; -use embedded_mqtt::{DeferredPayload, EncodingError, Publish, QoS}; +use embedded_mqtt::{DeferredPayload, EncodingError, Publish, QoS, Subscribe, SubscribeTopic}; +use futures::StreamExt as _; use super::ControlInterface; -use crate::jobs::data_types::JobStatus; -use crate::jobs::{JobTopic, Jobs, MAX_JOB_ID_LEN, MAX_THING_NAME_LEN}; +use crate::jobs::data_types::{ErrorResponse, JobStatus, UpdateJobExecutionResponse}; +use crate::jobs::{JobError, JobTopic, Jobs, MAX_JOB_ID_LEN, MAX_THING_NAME_LEN}; use crate::ota::config::Config; use crate::ota::encoding::json::JobStatusReason; -use crate::ota::encoding::FileContext; +use crate::ota::encoding::{self, FileContext}; use crate::ota::error::OtaError; -impl<'a, M: RawMutex, const SUBS: usize> ControlInterface - for embedded_mqtt::MqttClient<'a, M, SUBS> +impl<'a, M: RawMutex, const SUBS: usize> ControlInterface for embedded_mqtt::MqttClient<'a, M, SUBS> +where + BitsImpl<{ SUBS }>: Bits, { /// Check for next available OTA job from the job service by publishing a /// "get next job" message to the job service. @@ -21,15 +24,12 @@ impl<'a, M: RawMutex, const SUBS: usize> ControlInterface let mut buf = [0u8; 512]; let (topic, payload_len) = Jobs::describe().topic_payload(self.client_id(), &mut buf)?; - self.publish(Publish { - dup: false, - qos: QoS::AtLeastOnce, - retain: false, - pid: None, - topic_name: &topic, - payload: &buf[..payload_len], - properties: embedded_mqtt::Properties::Slice(&[]), - }) + self.publish( + Publish::builder() + .topic_name(&topic) + .payload(&buf[..payload_len]) + .build(), + ) .await?; Ok(()) @@ -69,7 +69,7 @@ impl<'a, M: RawMutex, const SUBS: usize> ControlInterface } // Don't override the progress on succeeded, nor on self-test - // active. (Cases where progess counter is lost due to device + // active. (Cases where progress counter is lost due to device // restarts) if status != JobStatus::Succeeded && reason != JobStatusReason::SelfTestActive { let mut progress = heapless::String::new(); @@ -93,11 +93,39 @@ impl<'a, M: RawMutex, const SUBS: usize> ControlInterface } } + let mut sub = self + .subscribe::<2>( + Subscribe::builder() + .topics(&[ + SubscribeTopic::builder() + .topic_path( + JobTopic::UpdateAccepted(file_ctx.job_name.as_str()) + .format::<{ MAX_THING_NAME_LEN + MAX_JOB_ID_LEN + 34 }>( + self.client_id(), + )? + .as_str(), + ) + .build(), + SubscribeTopic::builder() + .topic_path( + JobTopic::UpdateRejected(file_ctx.job_name.as_str()) + .format::<{ MAX_THING_NAME_LEN + MAX_JOB_ID_LEN + 34 }>( + self.client_id(), + )? + .as_str(), + ) + .build(), + ]) + .build(), + ) + .await?; + let topic = JobTopic::Update(file_ctx.job_name.as_str()) .format::<{ MAX_THING_NAME_LEN + MAX_JOB_ID_LEN + 25 }>(self.client_id())?; let payload = DeferredPayload::new( |buf| { Jobs::update(status) + .client_token(self.client_id()) .status_details(&file_ctx.status_details) .payload(buf) .map_err(|_| EncodingError::BufferSize) @@ -105,17 +133,53 @@ impl<'a, M: RawMutex, const SUBS: usize> ControlInterface 512, ); - self.publish(Publish { - dup: false, - qos, - retain: false, - pid: None, - topic_name: &topic, - payload, - properties: embedded_mqtt::Properties::Slice(&[]), - }) + self.publish( + Publish::builder() + .qos(qos) + .topic_name(&topic) + .payload(payload) + .build(), + ) .await?; - Ok(()) + loop { + let message = sub.next().await.ok_or(JobError::Encoding)?; + + // Check if topic is GetAccepted + match crate::jobs::Topic::from_str(message.topic_name()) { + Some(crate::jobs::Topic::UpdateAccepted(_)) => { + // Check client token + let (response, _) = serde_json_core::from_slice::< + UpdateJobExecutionResponse>, + >(message.payload()) + .map_err(|_| JobError::Encoding)?; + + if response.client_token != Some(self.client_id()) { + error!( + "Unexpected client token received: {}, expected: {}", + response.client_token.unwrap_or("None"), + self.client_id() + ); + continue; + } + + return Ok(()); + } + Some(crate::jobs::Topic::UpdateRejected(_)) => { + let (error_response, _) = + serde_json_core::from_slice::(message.payload()) + .map_err(|_| JobError::Encoding)?; + + if error_response.client_token != Some(self.client_id()) { + continue; + } + + return Err(OtaError::UpdateRejected(error_response.code)); + } + _ => { + error!("Expected Topic name GetRejected or GetAccepted but got something else"); + } + } + } } } diff --git a/src/ota/data_interface/mqtt.rs b/src/ota/data_interface/mqtt.rs index f40f31a..17bbf48 100644 --- a/src/ota/data_interface/mqtt.rs +++ b/src/ota/data_interface/mqtt.rs @@ -2,10 +2,10 @@ use core::fmt::{Display, Write}; use core::ops::DerefMut; use core::str::FromStr; +use bitmaps::{Bits, BitsImpl}; use embassy_sync::blocking_mutex::raw::RawMutex; use embedded_mqtt::{ - DeferredPayload, EncodingError, MqttClient, Properties, Publish, RetainHandling, Subscribe, - SubscribeTopic, Subscription, + DeferredPayload, EncodingError, MqttClient, Publish, Subscribe, SubscribeTopic, Subscription, }; use futures::StreamExt; @@ -124,13 +124,19 @@ impl<'a> OtaTopic<'a> { } } -impl<'a, 'b, M: RawMutex, const SUBS: usize> BlockTransfer for Subscription<'a, 'b, M, SUBS, 1> { +impl<'a, 'b, M: RawMutex, const SUBS: usize> BlockTransfer for Subscription<'a, 'b, M, SUBS, 1> +where + BitsImpl<{ SUBS }>: Bits, +{ async fn next_block(&mut self) -> Result>, OtaError> { Ok(self.next().await) } } -impl<'a, M: RawMutex, const SUBS: usize> DataInterface for MqttClient<'a, M, SUBS> { +impl<'a, M: RawMutex, const SUBS: usize> DataInterface for MqttClient<'a, M, SUBS> +where + BitsImpl<{ SUBS }>: Bits, +{ const PROTOCOL: Protocol = Protocol::Mqtt; type ActiveTransfer<'t> = Subscription<'a, 't, M, SUBS, 1> where Self: 't; @@ -143,17 +149,15 @@ impl<'a, M: RawMutex, const SUBS: usize> DataInterface for MqttClient<'a, M, SUB let topic_path = OtaTopic::Data(Encoding::Cbor, file_ctx.stream_name.as_str()) .format::<256>(self.client_id())?; - let topic = SubscribeTopic { - topic_path: topic_path.as_str(), - maximum_qos: embedded_mqtt::QoS::AtMostOnce, - no_local: false, - retain_as_published: false, - retain_handling: RetainHandling::SendAtSubscribeTime, - }; + let topics = [SubscribeTopic::builder() + .topic_path(topic_path.as_str()) + .build()]; debug!("Subscribing to: [{:?}]", &topic_path); - Ok(self.subscribe::<1>(Subscribe::new(&[topic])).await?) + Ok(self + .subscribe::<1>(Subscribe::builder().topics(&topics).build()) + .await?) } /// Request file block by publishing to the get stream topic @@ -189,17 +193,18 @@ impl<'a, M: RawMutex, const SUBS: usize> DataInterface for MqttClient<'a, M, SUB file_ctx.request_block_remaining ); - self.publish(Publish { - dup: false, - qos: embedded_mqtt::QoS::AtLeastOnce, - retain: false, - pid: None, - topic_name: OtaTopic::Get(Encoding::Cbor, file_ctx.stream_name.as_str()) - .format::<{ MAX_STREAM_ID_LEN + MAX_THING_NAME_LEN + 30 }>(self.client_id())? - .as_str(), - payload, - properties: Properties::Slice(&[]), - }) + self.publish( + Publish::builder() + .topic_name( + OtaTopic::Get(Encoding::Cbor, file_ctx.stream_name.as_str()) + .format::<{ MAX_STREAM_ID_LEN + MAX_THING_NAME_LEN + 30 }>( + self.client_id(), + )? + .as_str(), + ) + .payload(payload) + .build(), + ) .await?; Ok(()) diff --git a/src/ota/encoding/json.rs b/src/ota/encoding/json.rs index 45ea3f2..c258942 100644 --- a/src/ota/encoding/json.rs +++ b/src/ota/encoding/json.rs @@ -32,7 +32,8 @@ pub struct FileDescription<'a> { #[serde(rename = "fileid")] pub fileid: u8, #[serde(rename = "certfile")] - pub certfile: &'a str, + #[serde(skip_serializing_if = "Option::is_none")] + pub certfile: Option<&'a str>, #[serde(rename = "update_data_url")] #[serde(skip_serializing_if = "Option::is_none")] pub update_data_url: Option<&'a str>, @@ -59,20 +60,26 @@ pub struct FileDescription<'a> { } impl<'a> FileDescription<'a> { - pub fn signature(&self) -> Signature { + pub fn signature(&self) -> Option { if let Some(sig) = self.sha1_rsa { - return Signature::Sha1Rsa(heapless::String::try_from(sig).unwrap()); + return Some(Signature::Sha1Rsa(heapless::String::try_from(sig).unwrap())); } if let Some(sig) = self.sha256_rsa { - return Signature::Sha256Rsa(heapless::String::try_from(sig).unwrap()); + return Some(Signature::Sha256Rsa( + heapless::String::try_from(sig).unwrap(), + )); } if let Some(sig) = self.sha1_ecdsa { - return Signature::Sha1Ecdsa(heapless::String::try_from(sig).unwrap()); + return Some(Signature::Sha1Ecdsa( + heapless::String::try_from(sig).unwrap(), + )); } if let Some(sig) = self.sha256_ecdsa { - return Signature::Sha256Ecdsa(heapless::String::try_from(sig).unwrap()); + return Some(Signature::Sha256Ecdsa( + heapless::String::try_from(sig).unwrap(), + )); } - unreachable!() + None } } @@ -147,4 +154,28 @@ mod tests { ); } } + + #[test] + fn deserializ() { + let data = r#"{ + "protocols": [ + "MQTT" + ], + "streamname": "AFR_OTA-d11032e9-38d5-4dca-8c7c-1e6f24533ede", + "files": [ + { + "filepath": "3.8.4", + "filesize": 537600, + "fileid": 0, + "certfile": null, + "fileType": 0, + "update_data_url": null, + "auth_scheme": null, + "sig--": null + } + ] + }"#; + + serde_json_core::from_str::(&data).unwrap(); + } } diff --git a/src/ota/encoding/mod.rs b/src/ota/encoding/mod.rs index bc68c67..7c88700 100644 --- a/src/ota/encoding/mod.rs +++ b/src/ota/encoding/mod.rs @@ -60,10 +60,10 @@ pub struct FileContext { pub filepath: heapless::String<64>, pub filesize: usize, pub fileid: u8, - pub certfile: heapless::String<64>, + pub certfile: Option>, pub update_data_url: Option>, pub auth_scheme: Option>, - pub signature: Signature, + pub signature: Option, pub file_type: Option, pub protocols: heapless::Vec, @@ -110,7 +110,9 @@ impl FileContext { filesize: file_desc.filesize, protocols: job_data.ota_document.protocols, fileid: file_desc.fileid, - certfile: heapless::String::try_from(file_desc.certfile).unwrap(), + certfile: file_desc + .certfile + .map(|cert| heapless::String::try_from(cert).unwrap()), update_data_url: file_desc .update_data_url .map(|s| heapless::String::try_from(s).unwrap()), diff --git a/src/ota/error.rs b/src/ota/error.rs index 8c5744b..119cc20 100644 --- a/src/ota/error.rs +++ b/src/ota/error.rs @@ -1,4 +1,4 @@ -use crate::jobs::JobError; +use crate::jobs::{data_types::ErrorCode, JobError}; use super::pal::OtaPalError; @@ -6,7 +6,6 @@ use super::pal::OtaPalError; #[cfg_attr(feature = "defmt", derive(defmt::Format))] pub enum OtaError { NoActiveJob, - SignalEventFailed, Momentum, MomentumAbort, InvalidInterface, @@ -14,7 +13,9 @@ pub enum OtaError { BlockOutOfRange, ZeroFileSize, Overflow, + UnexpectedTopic, InvalidFile, + UpdateRejected(ErrorCode), Write( #[cfg_attr(feature = "defmt", defmt(Debug2Format))] embedded_storage_async::nor_flash::NorFlashErrorKind, diff --git a/src/provisioning/mod.rs b/src/provisioning/mod.rs index dbeefe5..d86e2a7 100644 --- a/src/provisioning/mod.rs +++ b/src/provisioning/mod.rs @@ -4,10 +4,11 @@ pub mod topics; use core::future::Future; +use bitmaps::{Bits, BitsImpl}; use embassy_sync::blocking_mutex::raw::RawMutex; use embedded_mqtt::{ - DeferredPayload, EncodingError, Message, Publish, QoS, RetainHandling, Subscribe, - SubscribeTopic, Subscription, + BufferProvider, DeferredPayload, EncodingError, Message, Publish, Subscribe, SubscribeTopic, + Subscription, }; use futures::StreamExt; use serde::{de::DeserializeOwned, Deserialize, Serialize}; @@ -41,12 +42,13 @@ pub struct FleetProvisioner; impl FleetProvisioner { pub async fn provision<'a, C, M: RawMutex, const SUBS: usize>( - mqtt: &'a embedded_mqtt::MqttClient<'a, M, SUBS>, + mqtt: &embedded_mqtt::MqttClient<'a, M, SUBS>, template_name: &str, parameters: Option, credential_handler: &mut impl CredentialHandler, ) -> Result, Error> where + BitsImpl<{ SUBS }>: Bits, C: DeserializeOwned, { Self::provision_inner( @@ -68,6 +70,7 @@ impl FleetProvisioner { credential_handler: &mut impl CredentialHandler, ) -> Result, Error> where + BitsImpl<{ SUBS }>: Bits, C: DeserializeOwned, { Self::provision_inner( @@ -83,12 +86,13 @@ impl FleetProvisioner { #[cfg(feature = "provision_cbor")] pub async fn provision_cbor<'a, C, M: RawMutex, const SUBS: usize>( - mqtt: &'a embedded_mqtt::MqttClient<'a, M, SUBS>, + mqtt: &embedded_mqtt::MqttClient<'a, M, SUBS>, template_name: &str, parameters: Option, credential_handler: &mut impl CredentialHandler, ) -> Result, Error> where + BitsImpl<{ SUBS }>: Bits, C: DeserializeOwned, { Self::provision_inner( @@ -111,6 +115,7 @@ impl FleetProvisioner { credential_handler: &mut impl CredentialHandler, ) -> Result, Error> where + BitsImpl<{ SUBS }>: Bits, C: DeserializeOwned, { Self::provision_inner( @@ -134,10 +139,12 @@ impl FleetProvisioner { payload_format: PayloadFormat, ) -> Result, Error> where + BitsImpl<{ SUBS }>: Bits, C: DeserializeOwned, { - let mut create_subscription = Self::begin(mqtt, csr, payload_format).await?; + use embedded_mqtt::SliceBufferProvider; + let mut create_subscription = Self::begin(mqtt, csr, payload_format).await?; let mut message = create_subscription .next() .await @@ -145,10 +152,11 @@ impl FleetProvisioner { let ownership_token = match Topic::from_str(message.topic_name()) { Some(Topic::CreateKeysAndCertificateAccepted(format)) => { - let response = Self::deserialize::( - format, - &mut message, - )?; + let response = Self::deserialize::< + CreateKeysAndCertificateResponse, + SliceBufferProvider<'a>, + SUBS, + >(format, &mut message)?; credential_handler .store_credentials(Credentials { @@ -162,10 +170,11 @@ impl FleetProvisioner { } Some(Topic::CreateCertificateFromCsrAccepted(format)) => { - let response = Self::deserialize::( - format, - &mut message, - )?; + let response = Self::deserialize::< + CreateCertificateFromCsrResponse, + SliceBufferProvider<'a>, + SUBS, + >(format, &mut message)?; credential_handler .store_credentials(Credentials { @@ -220,28 +229,29 @@ impl FleetProvisioner { debug!("Starting RegisterThing"); let mut register_subscription = mqtt - .subscribe::<1>(Subscribe::new(&[SubscribeTopic { - topic_path: Topic::RegisterThingAccepted(template_name, payload_format) - .format::<150>()? - .as_str(), - maximum_qos: QoS::AtLeastOnce, - no_local: false, - retain_as_published: false, - retain_handling: RetainHandling::SendAtSubscribeTime, - }])) + .subscribe::<1>( + Subscribe::builder() + .topics(&[SubscribeTopic::builder() + .topic_path( + Topic::RegisterThingAccepted(template_name, payload_format) + .format::<150>()? + .as_str(), + ) + .build()]) + .build(), + ) .await?; - mqtt.publish(Publish { - dup: false, - qos: QoS::AtLeastOnce, - retain: false, - pid: None, - topic_name: Topic::RegisterThing(template_name, payload_format) - .format::<69>()? - .as_str(), - payload, - properties: embedded_mqtt::Properties::Slice(&[]), - }) + mqtt.publish( + Publish::builder() + .topic_name( + Topic::RegisterThing(template_name, payload_format) + .format::<69>()? + .as_str(), + ) + .payload(payload) + .build(), + ) .await?; drop(message); @@ -254,8 +264,11 @@ impl FleetProvisioner { match Topic::from_str(message.topic_name()) { Some(Topic::RegisterThingAccepted(_, format)) => { - let response = - Self::deserialize::, SUBS>(format, &mut message)?; + let response = Self::deserialize::< + RegisterThingResponse<'_, C>, + SliceBufferProvider<'a>, + SUBS, + >(format, &mut message)?; Ok(response.device_configuration) } @@ -277,29 +290,32 @@ impl FleetProvisioner { mqtt: &'b embedded_mqtt::MqttClient<'a, M, SUBS>, csr: Option<&str>, payload_format: PayloadFormat, - ) -> Result, Error> { + ) -> Result, Error> + where + BitsImpl<{ SUBS }>: Bits, + { if let Some(csr) = csr { let subscription = mqtt - .subscribe(Subscribe::new(&[ - SubscribeTopic { - topic_path: Topic::CreateCertificateFromCsrRejected(payload_format) - .format::<47>()? - .as_str(), - maximum_qos: QoS::AtLeastOnce, - no_local: false, - retain_as_published: false, - retain_handling: RetainHandling::SendAtSubscribeTime, - }, - SubscribeTopic { - topic_path: Topic::CreateCertificateFromCsrAccepted(payload_format) - .format::<47>()? - .as_str(), - maximum_qos: QoS::AtLeastOnce, - no_local: false, - retain_as_published: false, - retain_handling: RetainHandling::SendAtSubscribeTime, - }, - ])) + .subscribe( + Subscribe::builder() + .topics(&[ + SubscribeTopic::builder() + .topic_path( + Topic::CreateCertificateFromCsrRejected(payload_format) + .format::<47>()? + .as_str(), + ) + .build(), + SubscribeTopic::builder() + .topic_path( + Topic::CreateCertificateFromCsrAccepted(payload_format) + .format::<47>()? + .as_str(), + ) + .build(), + ]) + .build(), + ) .await?; let request = CreateCertificateFromCsrRequest { @@ -326,65 +342,66 @@ impl FleetProvisioner { csr.len() + 32, ); - mqtt.publish(Publish { - dup: false, - qos: QoS::AtLeastOnce, - retain: false, - pid: None, - topic_name: Topic::CreateCertificateFromCsr(payload_format) - .format::<40>()? - .as_str(), - payload, - properties: embedded_mqtt::Properties::Slice(&[]), - }) + mqtt.publish( + Publish::builder() + .topic_name( + Topic::CreateCertificateFromCsr(payload_format) + .format::<40>()? + .as_str(), + ) + .payload(payload) + .build(), + ) .await?; Ok(subscription) } else { let subscription = mqtt - .subscribe(Subscribe::new(&[ - SubscribeTopic { - topic_path: Topic::CreateKeysAndCertificateAccepted(payload_format) - .format::<38>()? - .as_str(), - maximum_qos: QoS::AtLeastOnce, - no_local: false, - retain_as_published: false, - retain_handling: RetainHandling::SendAtSubscribeTime, - }, - SubscribeTopic { - topic_path: Topic::CreateKeysAndCertificateRejected(payload_format) - .format::<38>()? - .as_str(), - maximum_qos: QoS::AtLeastOnce, - no_local: false, - retain_as_published: false, - retain_handling: RetainHandling::SendAtSubscribeTime, - }, - ])) + .subscribe( + Subscribe::builder() + .topics(&[ + SubscribeTopic::builder() + .topic_path( + Topic::CreateKeysAndCertificateAccepted(payload_format) + .format::<38>()? + .as_str(), + ) + .build(), + SubscribeTopic::builder() + .topic_path( + Topic::CreateKeysAndCertificateRejected(payload_format) + .format::<38>()? + .as_str(), + ) + .build(), + ]) + .build(), + ) .await?; - mqtt.publish(Publish { - dup: false, - qos: QoS::AtLeastOnce, - retain: false, - pid: None, - topic_name: Topic::CreateKeysAndCertificate(payload_format) - .format::<29>()? - .as_str(), - payload: b"", - properties: embedded_mqtt::Properties::Slice(&[]), - }) + mqtt.publish( + Publish::builder() + .topic_name( + Topic::CreateKeysAndCertificate(payload_format) + .format::<29>()? + .as_str(), + ) + .payload(b"") + .build(), + ) .await?; Ok(subscription) } } - fn deserialize<'a, R: Deserialize<'a>, const SUBS: usize>( + fn deserialize<'a, R: Deserialize<'a>, B: BufferProvider, const SUBS: usize>( payload_format: PayloadFormat, - message: &'a mut Message<'_, SUBS>, - ) -> Result { + message: &'a mut Message<'_, B, SUBS>, + ) -> Result + where + BitsImpl<{ SUBS }>: Bits, + { trace!( "Accepted Topic {:?}. Payload len: {:?}", payload_format, @@ -398,10 +415,13 @@ impl FleetProvisioner { }) } - fn handle_error( + fn handle_error( format: PayloadFormat, - mut message: Message<'_, SUBS>, - ) -> Result<(), Error> { + mut message: Message<'_, B, SUBS>, + ) -> Result<(), Error> + where + BitsImpl<{ SUBS }>: Bits, + { error!(">> {:?}", message.topic_name()); let response = match format { diff --git a/src/shadows/README.md b/src/shadows/README.md index a1ec0b0..9ea3132 100644 --- a/src/shadows/README.md +++ b/src/shadows/README.md @@ -8,4 +8,4 @@ You can find an example of how to use this crate for iot shadow states in the `t pfx identity files can be created from a set of device certificate and private key using OpenSSL as: `openssl pkcs12 -export -out identity.pfx -inkey private.pem.key -in certificate.pem.crt -certfile root-ca.pem` -The example functions as a CI integration test, that is run against `Blackbirds` integration account on every PR. This test will run through a statemachine of shadow delete, updates and gets from both device & cloud side with assertions in between. +The example functions as a CI integration test, that is run against Factbirds integration account on every PR. This test will run through a statemachine of shadow delete, updates and gets from both device & cloud side with assertions in between. diff --git a/src/shadows/dao.rs b/src/shadows/dao.rs index 875c5c2..1435cd6 100644 --- a/src/shadows/dao.rs +++ b/src/shadows/dao.rs @@ -2,53 +2,23 @@ use serde::{de::DeserializeOwned, Serialize}; use super::{Error, ShadowState}; -pub trait ShadowDAO { - fn read(&mut self) -> Result; - fn write(&mut self, state: &S) -> Result<(), Error>; +pub trait ShadowDAO { + async fn read(&mut self) -> Result; + async fn write(&mut self, state: &S) -> Result<(), Error>; } -impl ShadowDAO for () { - fn read(&mut self) -> Result { - Err(Error::NoPersistance) - } - - fn write(&mut self, _state: &S) -> Result<(), Error> { - Err(Error::NoPersistance) - } -} - -pub struct EmbeddedStorageDAO(T); - -impl From for EmbeddedStorageDAO -where - T: embedded_storage::Storage, -{ - fn from(v: T) -> Self { - Self::new(v) - } -} - -impl EmbeddedStorageDAO -where - T: embedded_storage::Storage, -{ - pub fn new(storage: T) -> Self { - Self(storage) - } -} - -const U32_SIZE: usize = core::mem::size_of::(); +const U32_SIZE: usize = 4; -impl ShadowDAO for EmbeddedStorageDAO +impl ShadowDAO for T where S: ShadowState + DeserializeOwned, - T: embedded_storage::Storage, + T: embedded_storage_async::nor_flash::NorFlash, [(); S::MAX_PAYLOAD_SIZE + U32_SIZE]:, { - fn read(&mut self) -> Result { + async fn read(&mut self) -> Result { let buf = &mut [0u8; S::MAX_PAYLOAD_SIZE + U32_SIZE]; - self.0.read(OFFSET, buf).map_err(|_| Error::DaoRead)?; + self.read(0, buf).await.map_err(|_| Error::DaoRead)?; match buf[..U32_SIZE].try_into() { Ok(len_bytes) => { @@ -68,8 +38,8 @@ where } } - fn write(&mut self, state: &S) -> Result<(), Error> { - assert!(S::MAX_PAYLOAD_SIZE <= self.0.capacity() - OFFSET as usize); + async fn write(&mut self, state: &S) -> Result<(), Error> { + assert!(S::MAX_PAYLOAD_SIZE <= self.capacity()); let buf = &mut [0u8; S::MAX_PAYLOAD_SIZE + U32_SIZE]; @@ -88,11 +58,11 @@ where buf[..U32_SIZE].copy_from_slice(&(len as u32).to_le_bytes()); - self.0 - .write(OFFSET, &buf[..len + U32_SIZE]) + self.write(0, &buf[..len + U32_SIZE]) + .await .map_err(|_| Error::DaoWrite)?; - debug!("Wrote {} bytes to DAO @ {}", len + U32_SIZE, OFFSET); + debug!("Wrote {} bytes to DAO", len + U32_SIZE); Ok(()) } @@ -128,7 +98,7 @@ where T: std::io::Write + std::io::Read, [(); S::MAX_PAYLOAD_SIZE]:, { - fn read(&mut self) -> Result { + async fn read(&mut self) -> Result { let bytes = &mut [0u8; S::MAX_PAYLOAD_SIZE]; self.0.read(bytes).map_err(|_| Error::DaoRead)?; @@ -136,7 +106,7 @@ where Ok(shadow) } - fn write(&mut self, state: &S) -> Result<(), Error> { + async fn write(&mut self, state: &S) -> Result<(), Error> { let bytes = serde_json_core::to_vec::<_, { S::MAX_PAYLOAD_SIZE }>(state) .map_err(|_| Error::Overflow)?; diff --git a/src/shadows/data_types.rs b/src/shadows/data_types.rs index 7a25453..5449726 100644 --- a/src/shadows/data_types.rs +++ b/src/shadows/data_types.rs @@ -34,16 +34,22 @@ impl From for Patch { #[derive(Debug, Serialize, Deserialize)] pub struct State { + #[serde(skip_serializing_if = "Option::is_none")] #[serde(rename = "desired")] pub desired: Option, + + #[serde(skip_serializing_if = "Option::is_none")] #[serde(rename = "reported")] pub reported: Option, } #[derive(Debug, Serialize, Deserialize)] pub struct DeltaState { + #[serde(skip_serializing_if = "Option::is_none")] #[serde(rename = "desired")] pub desired: Option, + + #[serde(skip_serializing_if = "Option::is_none")] #[serde(rename = "reported")] pub reported: Option, #[serde(rename = "delta")] @@ -172,7 +178,7 @@ mod tests { exp_map .0 .insert( - heapless::String::from("1"), + heapless::String::try_from("1").unwrap(), Patch::Set(Test { field: true }), ) .unwrap(); @@ -189,7 +195,7 @@ mod tests { exp_map .0 .insert( - heapless::String::from("1"), + heapless::String::try_from("1").unwrap(), Patch::Set(Test { field: true }), ) .unwrap(); @@ -215,7 +221,7 @@ mod tests { let mut exp_map = TestMap(heapless::LinearMap::default()); exp_map .0 - .insert(heapless::String::from("1"), Patch::Unset) + .insert(heapless::String::try_from("1").unwrap(), Patch::Unset) .unwrap(); let (patch, _) = serde_json_core::from_str::(payload).unwrap(); diff --git a/src/shadows/error.rs b/src/shadows/error.rs index f08ff4d..54bd0b1 100644 --- a/src/shadows/error.rs +++ b/src/shadows/error.rs @@ -1,9 +1,4 @@ use core::convert::TryFrom; -use core::fmt::Display; -use core::str::FromStr; - -use heapless::String; -use mqttrust::MqttError; use super::data_types::ErrorResponse; @@ -11,21 +6,15 @@ use super::data_types::ErrorResponse; #[cfg_attr(feature = "defmt", derive(defmt::Format))] pub enum Error { Overflow, - NoPersistance, + NoPersistence, DaoRead, DaoWrite, InvalidPayload, WrongShadowName, - Mqtt(MqttError), + MqttError(embedded_mqtt::Error), ShadowError(ShadowError), } -impl From for Error { - fn from(e: MqttError) -> Self { - Self::Mqtt(e) - } -} - impl From for Error { fn from(e: ShadowError) -> Self { Self::ShadowError(e) @@ -47,7 +36,6 @@ pub enum ShadowError { Unauthorized, Forbidden, NotFound, - NoNamedShadow(String<64>), VersionConflict, PayloadTooLarge, UnsupportedEncoding, @@ -70,7 +58,7 @@ impl ShadowError { ShadowError::Unauthorized => 401, ShadowError::Forbidden => 403, - ShadowError::NotFound | ShadowError::NoNamedShadow(_) => 404, + ShadowError::NotFound => 404, ShadowError::VersionConflict => 409, ShadowError::PayloadTooLarge => 413, ShadowError::UnsupportedEncoding => 415, @@ -85,7 +73,7 @@ impl<'a> TryFrom> for ShadowError { fn try_from(e: ErrorResponse<'a>) -> Result { Ok(match e.code { - 400 | 404 => Self::from_str(e.message)?, + 400 | 404 => ShadowError::NotFound, 401 => ShadowError::Unauthorized, 403 => ShadowError::Forbidden, 409 => ShadowError::VersionConflict, @@ -97,67 +85,3 @@ impl<'a> TryFrom> for ShadowError { }) } } - -// impl Display for ShadowError { -// fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { -// match self { -// Self::InvalidJson => write!(f, "Invalid JSON"), -// Self::MissingState => write!(f, "Missing required node: state"), -// Self::MalformedState => write!(f, "State node must be an object"), -// Self::MalformedDesired => write!(f, "Desired node must be an object"), -// Self::MalformedReported => write!(f, "Reported node must be an object"), -// Self::InvalidVersion => write!(f, "Invalid version"), -// Self::InvalidClientToken => write!(f, "Invalid clientToken"), -// Self::JsonTooDeep => { -// write!(f, "JSON contains too many levels of nesting; maximum is 6") -// } -// Self::InvalidStateNode => write!(f, "State contains an invalid node"), -// Self::Unauthorized => write!(f, "Unauthorized"), -// Self::Forbidden => write!(f, "Forbidden"), -// Self::NotFound => write!(f, "Thing not found"), -// Self::NoNamedShadow(shadow_name) => { -// write!(f, "No shadow exists with name: {}", shadow_name) -// } -// Self::VersionConflict => write!(f, "Version conflict"), -// Self::PayloadTooLarge => write!(f, "The payload exceeds the maximum size allowed"), -// Self::UnsupportedEncoding => write!( -// f, -// "Unsupported documented encoding; supported encoding is UTF-8" -// ), -// Self::TooManyRequests => write!(f, "The Device Shadow service will generate this error message when there are more than 10 in-flight requests on a single connection"), -// Self::InternalServerError => write!(f, "Internal service failure"), -// } -// } -// } - -// // TODO: This seems like an extremely brittle way of doing this??! -// impl FromStr for ShadowError { -// type Err = (); - -// fn from_str(s: &str) -> Result { -// Ok(match s.trim() { -// "Invalid JSON" => Self::InvalidJson, -// "Missing required node: state" => Self::MissingState, -// "State node must be an object" => Self::MalformedState, -// "Desired node must be an object" => Self::MalformedDesired, -// "Reported node must be an object" => Self::MalformedReported, -// "Invalid version" => Self::InvalidVersion, -// "Invalid clientToken" => Self::InvalidClientToken, -// "JSON contains too many levels of nesting; maximum is 6" => Self::JsonTooDeep, -// "State contains an invalid node" => Self::InvalidStateNode, -// "Unauthorized" => Self::Unauthorized, -// "Forbidden" => Self::Forbidden, -// "Thing not found" => Self::NotFound, -// // TODO: -// "No shadow exists with name: " => Self::NoNamedShadow(String::new()), -// "Version conflict" => Self::VersionConflict, -// "The payload exceeds the maximum size allowed" => Self::PayloadTooLarge, -// "Unsupported documented encoding; supported encoding is UTF-8" => { -// Self::UnsupportedEncoding -// } -// "The Device Shadow service will generate this error message when there are more than 10 in-flight requests on a single connection" => Self::TooManyRequests, -// "Internal service failure" => Self::InternalServerError, -// _ => return Err(()), -// }) -// } -// } diff --git a/src/shadows/mod.rs b/src/shadows/mod.rs index 31ce599..2825fd5 100644 --- a/src/shadows/mod.rs +++ b/src/shadows/mod.rs @@ -4,18 +4,23 @@ mod error; mod shadow_diff; pub mod topics; -use core::marker::PhantomData; - -use mqttrust::{Mqtt, QoS}; +use core::{marker::PhantomData, ops::DerefMut}; +use bitmaps::{Bits, BitsImpl}; pub use data_types::Patch; +use embassy_sync::{ + blocking_mutex::raw::{NoopRawMutex, RawMutex}, + mutex::Mutex, +}; +use embedded_mqtt::{DeferredPayload, Publish, Subscribe, SubscribeTopic, ToPayload}; pub use error::Error; -use serde::de::DeserializeOwned; +use futures::StreamExt; +use serde::Serialize; pub use shadow_derive as derive; pub use shadow_diff::ShadowPatch; -use data_types::{AcceptedResponse, DeltaResponse, ErrorResponse}; -use topics::{Direction, Subscribe, Topic, Unsubscribe}; +use data_types::{AcceptedResponse, DeltaResponse, DeltaState, ErrorResponse}; +use topics::Topic; use self::dao::ShadowDAO; @@ -23,315 +28,441 @@ const MAX_TOPIC_LEN: usize = 128; const PARTIAL_REQUEST_OVERHEAD: usize = 64; const CLASSIC_SHADOW: &str = "Classic"; -pub trait ShadowState: ShadowPatch { +pub trait ShadowState: ShadowPatch + Default { const NAME: Option<&'static str>; const MAX_PAYLOAD_SIZE: usize = 512; } -struct ShadowHandler<'a, M: Mqtt, S: ShadowState> +struct ShadowHandler<'a, 'm, M: RawMutex, S: ShadowState, const SUBS: usize> where + BitsImpl<{ SUBS }>: Bits, [(); S::MAX_PAYLOAD_SIZE + PARTIAL_REQUEST_OVERHEAD]:, { - mqtt: &'a M, + mqtt: &'m embedded_mqtt::MqttClient<'a, M, SUBS>, + subscription: Mutex>>, _shadow: PhantomData, } -impl<'a, M: Mqtt, S: ShadowState> ShadowHandler<'a, M, S> +impl<'a, 'm, M: RawMutex, S: ShadowState, const SUBS: usize> ShadowHandler<'a, 'm, M, S, SUBS> where + BitsImpl<{ SUBS }>: Bits, [(); S::MAX_PAYLOAD_SIZE + PARTIAL_REQUEST_OVERHEAD]:, { - /// Subscribes to all the topics required for keeping a shadow in sync - pub fn subscribe(&self) -> Result<(), Error> { - Subscribe::<7>::new() - .topic(Topic::GetAccepted, QoS::AtLeastOnce) - .topic(Topic::GetRejected, QoS::AtLeastOnce) - .topic(Topic::DeleteAccepted, QoS::AtLeastOnce) - .topic(Topic::DeleteRejected, QoS::AtLeastOnce) - .topic(Topic::UpdateAccepted, QoS::AtLeastOnce) - .topic(Topic::UpdateRejected, QoS::AtLeastOnce) - .topic(Topic::UpdateDelta, QoS::AtLeastOnce) - .send(self.mqtt, S::NAME)?; + async fn handle_delta(&self) -> Result, Error> { + let mut sub_ref = self.subscription.lock().await; + + let delta_subscription = match sub_ref.deref_mut() { + Some(sub) => sub, + None => { + self.mqtt.wait_connected().await; + + let sub = self + .mqtt + .subscribe::<2>( + Subscribe::builder() + .topics(&[SubscribeTopic::builder() + .topic_path( + topics::Topic::UpdateDelta + .format::<64>(self.mqtt.client_id(), S::NAME)? + .as_str(), + ) + .build()]) + .build(), + ) + .await + .map_err(Error::MqttError)?; + + sub_ref.insert(sub) + } + }; - Ok(()) - } + let delta_message = delta_subscription + .next() + .await + .ok_or(Error::InvalidPayload)?; - /// Unsubscribes from all the topics required for keeping a shadow in sync - pub fn unsubscribe(&self) -> Result<(), Error> { - Unsubscribe::<7>::new() - .topic(Topic::GetAccepted) - .topic(Topic::GetRejected) - .topic(Topic::DeleteAccepted) - .topic(Topic::DeleteRejected) - .topic(Topic::UpdateAccepted) - .topic(Topic::UpdateRejected) - .topic(Topic::UpdateDelta) - .send(self.mqtt, S::NAME)?; + // Update the device's state to match the desired state in the + // message body. + debug!( + "[{:?}] Received shadow delta event.", + S::NAME.unwrap_or(CLASSIC_SHADOW), + ); - Ok(()) - } + let (delta, _) = + serde_json_core::from_slice::>(delta_message.payload()) + .map_err(|_| Error::InvalidPayload)?; - /// Helper function to check whether a topic name is relevant for this - /// particular shadow. - pub fn should_handle_topic(&mut self, topic: &str) -> bool { - if let Some((_, thing_name, shadow_name)) = Topic::from_str(topic) { - return thing_name == self.mqtt.client_id() && shadow_name == S::NAME; + if let Some(client) = delta.client_token { + if client.eq(self.mqtt.client_id()) { + return Ok(None); + } } - false + + Ok(delta.state) } /// Internal helper function for applying a delta state to the actual shadow /// state, and update the cloud shadow. - fn change_shadow_value( - &mut self, - state: &mut S, - delta: Option, - update_desired: Option, - ) -> Result<(), Error> { - if let Some(ref delta) = delta { - state.apply_patch(delta.clone()); - } - + async fn report(&self, reported: &R) -> Result<(), Error> { debug!( - "[{:?}] Updating reported shadow value. Update_desired: {:?}", - S::NAME.unwrap_or_else(|| CLASSIC_SHADOW), - update_desired + "[{:?}] Updating reported shadow value.", + S::NAME.unwrap_or(CLASSIC_SHADOW), ); - if let Some(update_desired) = update_desired { - let desired = if update_desired { Some(&state) } else { None }; - - let request = data_types::Request { - state: data_types::State { - reported: Some(&state), - desired, - }, - client_token: None, - version: None, - }; - - let payload = serde_json_core::to_vec::< - _, - { S::MAX_PAYLOAD_SIZE + PARTIAL_REQUEST_OVERHEAD }, - >(&request) - .map_err(|_| Error::Overflow)?; - - let update_topic = - Topic::Update.format::(self.mqtt.client_id(), S::NAME)?; - self.mqtt - .publish(update_topic.as_str(), &payload, QoS::AtLeastOnce)?; - } + let request = data_types::Request { + state: data_types::State { + reported: Some(reported), + desired: None, + }, + client_token: Some(self.mqtt.client_id()), + version: None, + }; - Ok(()) + let payload = DeferredPayload::new( + |buf| { + serde_json_core::to_slice(&request, buf) + .map_err(|_| embedded_mqtt::EncodingError::BufferSize) + }, + S::MAX_PAYLOAD_SIZE + PARTIAL_REQUEST_OVERHEAD, + ); + + let mut sub = self.publish_and_subscribe(Topic::Update, payload).await?; + + //*** WAIT RESPONSE ***/ + debug!("Wait for Accepted or Rejected"); + loop { + let message = sub.next().await.ok_or(Error::InvalidPayload)?; + + // Check if topic is GetAccepted + match Topic::from_str(message.topic_name()) { + Some((Topic::UpdateAccepted, _, _)) => { + // Check client token + let (response, _) = serde_json_core::from_slice::< + AcceptedResponse, + >(message.payload()) + .map_err(|_| Error::InvalidPayload)?; + + if response.client_token != Some(self.mqtt.client_id()) { + error!( + "Unexpected client token received: {}, expected: {}", + response.client_token.unwrap_or("None"), + self.mqtt.client_id() + ); + continue; + } + + return Ok(()); + } + Some((Topic::UpdateRejected, _, _)) => { + let (error_response, _) = + serde_json_core::from_slice::(message.payload()) + .map_err(|_| Error::ShadowError(error::ShadowError::NotFound))?; + + if error_response.client_token != Some(self.mqtt.client_id()) { + continue; + } + + return Err(Error::ShadowError( + error_response + .try_into() + .unwrap_or(error::ShadowError::NotFound), + )); + } + _ => { + error!("Expected Topic name GetRejected or GetAccepted but got something else"); + return Err(Error::WrongShadowName); + } + } + } } /// Initiate a `GetShadow` request, updating the local state from the cloud. - pub fn get_shadow(&self) -> Result<(), Error> { - let get_topic = Topic::Get.format::(self.mqtt.client_id(), S::NAME)?; - self.mqtt - .publish(get_topic.as_str(), b"", QoS::AtLeastOnce)?; - Ok(()) + async fn get_shadow(&self) -> Result, Error> { + //Wait for mqtt to connect + self.mqtt.wait_connected().await; + + let mut sub = self.publish_and_subscribe(Topic::Get, b"").await?; + + let get_message = sub.next().await.ok_or(Error::InvalidPayload)?; + + //Check if topic is GetAccepted + //Deserialize message + //Persist shadow and return new shadow + match Topic::from_str(get_message.topic_name()) { + Some((Topic::GetAccepted, _, _)) => { + let (response, _) = serde_json_core::from_slice::>( + get_message.payload(), + ) + .map_err(|_| Error::InvalidPayload)?; + + Ok(response.state) + } + Some((Topic::GetRejected, _, _)) => { + let (error_response, _) = + serde_json_core::from_slice::(get_message.payload()) + .map_err(|_| Error::ShadowError(error::ShadowError::NotFound))?; + + if error_response.code == 404 { + debug!( + "[{:?}] Thing has no shadow document. Creating with defaults...", + S::NAME.unwrap_or(CLASSIC_SHADOW) + ); + self.create_shadow().await?; + } + + Err(Error::ShadowError( + error_response + .try_into() + .unwrap_or(error::ShadowError::NotFound), + )) + } + _ => { + error!("Expected Topic name GetRejected or GetAccepted but got something else"); + Err(Error::WrongShadowName) + } + } + } + + pub async fn delete_shadow(&mut self) -> Result<(), Error> { + // Wait for mqtt to connect + self.mqtt.wait_connected().await; + + let mut sub = self + .publish_and_subscribe(topics::Topic::Delete, b"") + .await?; + + let message = sub.next().await.ok_or(Error::InvalidPayload)?; + + // Check if topic is DeleteAccepted + match Topic::from_str(message.topic_name()) { + Some((Topic::DeleteAccepted, _, _)) => Ok(()), + Some((Topic::DeleteRejected, _, _)) => { + let (error_response, _) = + serde_json_core::from_slice::(message.payload()) + .map_err(|_| Error::ShadowError(error::ShadowError::NotFound))?; + + Err(Error::ShadowError( + error_response + .try_into() + .unwrap_or(error::ShadowError::NotFound), + )) + } + _ => { + error!("Expected Topic name GetRejected or GetAccepted but got something else"); + Err(Error::WrongShadowName) + } + } + } + + pub async fn create_shadow(&self) -> Result, Error> { + debug!( + "[{:?}] Creating initial shadow value.", + S::NAME.unwrap_or(CLASSIC_SHADOW), + ); + + let state = S::default(); + + let request = data_types::Request { + state: data_types::State { + reported: Some(&state), + desired: Some(&state), + }, + client_token: Some(self.mqtt.client_id()), + version: None, + }; + + // FIXME: Serialize directly into the publish payload through `DeferredPublish` API + let payload = serde_json_core::to_vec::< + _, + { S::MAX_PAYLOAD_SIZE + PARTIAL_REQUEST_OVERHEAD }, + >(&request) + .map_err(|_| Error::Overflow)?; + + let mut sub = self + .publish_and_subscribe(Topic::Update, payload.as_slice()) + .await?; + loop { + let message = sub.next().await.ok_or(Error::InvalidPayload)?; + + match Topic::from_str(message.topic_name()) { + Some((Topic::UpdateAccepted, _, _)) => { + let (response, _) = serde_json_core::from_slice::< + AcceptedResponse, + >(message.payload()) + .map_err(|_| Error::InvalidPayload)?; + + if response.client_token != Some(self.mqtt.client_id()) { + continue; + } + + return Ok(response.state); + } + Some((Topic::UpdateRejected, _, _)) => { + let (error_response, _) = + serde_json_core::from_slice::(message.payload()) + .map_err(|_| Error::ShadowError(error::ShadowError::NotFound))?; + + if error_response.client_token != Some(self.mqtt.client_id()) { + continue; + } + + return Err(Error::ShadowError( + error_response + .try_into() + .unwrap_or(error::ShadowError::NotFound), + )); + } + _ => { + error!("Expected Topic name GetRejected or GetAccepted but got something else"); + return Err(Error::WrongShadowName); + } + } + } } - pub fn delete_shadow(&mut self) -> Result<(), Error> { - let delete_topic = Topic::Delete.format::(self.mqtt.client_id(), S::NAME)?; + ///This function will subscribe to accepted and rejected topics and then do a publish. + ///It will only return when something is accepted or rejected + ///Topic is the topic you want to publish to + ///The function will automatically subscribe to the accepted and rejected topic related to the publish topic + async fn publish_and_subscribe( + &self, + topic: topics::Topic, + payload: impl ToPayload, + ) -> Result, Error> { + let (accepted, rejected) = match topic { + Topic::Get => (Topic::GetAccepted, Topic::GetRejected), + Topic::Update => (Topic::UpdateAccepted, Topic::UpdateRejected), + Topic::Delete => (Topic::DeleteAccepted, Topic::DeleteRejected), + _ => return Err(Error::ShadowError(error::ShadowError::Forbidden)), + }; + + //*** SUBSCRIBE ***/ + let sub = self + .mqtt + .subscribe::<2>( + Subscribe::builder() + .topics(&[ + SubscribeTopic::builder() + .topic_path( + accepted + .format::<64>(self.mqtt.client_id(), S::NAME)? + .as_str(), + ) + .build(), + SubscribeTopic::builder() + .topic_path( + rejected + .format::<64>(self.mqtt.client_id(), S::NAME)? + .as_str(), + ) + .build(), + ]) + .build(), + ) + .await + .map_err(Error::MqttError)?; + + //*** PUBLISH REQUEST ***/ + let topic_name = topic.format::(self.mqtt.client_id(), S::NAME)?; self.mqtt - .publish(delete_topic.as_str(), b"", QoS::AtLeastOnce)?; - Ok(()) + .publish( + Publish::builder() + .topic_name(topic_name.as_str()) + .payload(payload) + .build(), + ) + .await + .map_err(Error::MqttError)?; + + Ok(sub) } } -pub struct PersistedShadow<'a, S: ShadowState + DeserializeOwned, M: Mqtt, D: ShadowDAO> +pub struct PersistedShadow<'a, 'm, S: ShadowState, M: RawMutex, D: ShadowDAO, const SUBS: usize> where + BitsImpl<{ SUBS }>: Bits, [(); S::MAX_PAYLOAD_SIZE + PARTIAL_REQUEST_OVERHEAD]:, { - handler: ShadowHandler<'a, M, S>, - pub(crate) dao: D, + handler: ShadowHandler<'a, 'm, M, S, SUBS>, + pub(crate) dao: Mutex, } -impl<'a, S, M, D> PersistedShadow<'a, S, M, D> +impl<'a, 'm, S, M, D, const SUBS: usize> PersistedShadow<'a, 'm, S, M, D, SUBS> where - S: ShadowState + DeserializeOwned, - M: Mqtt, + BitsImpl<{ SUBS }>: Bits, + S: ShadowState + Default, + M: RawMutex, D: ShadowDAO, [(); S::MAX_PAYLOAD_SIZE + PARTIAL_REQUEST_OVERHEAD]:, { /// Instantiate a new shadow that will be automatically persisted to NVM /// based on the passed `DAO`. - pub fn new( - initial_state: S, - mqtt: &'a M, - mut dao: D, - auto_subscribe: bool, - ) -> Result { - if dao.read().is_err() { - dao.write(&initial_state)?; - } - + pub fn new(mqtt: &'m embedded_mqtt::MqttClient<'a, M, SUBS>, dao: D) -> Self { let handler = ShadowHandler { mqtt, + subscription: Mutex::new(None), _shadow: PhantomData, }; - if auto_subscribe { - handler.subscribe()?; - } - Ok(Self { handler, dao }) - } - - /// Subscribes to all the topics required for keeping a shadow in sync - pub fn subscribe(&self) -> Result<(), Error> { - self.handler.subscribe() - } - /// Unsubscribes from all the topics required for keeping a shadow in sync - pub fn unsubscribe(&self) -> Result<(), Error> { - self.handler.unsubscribe() - } - - /// Helper function to check whether a topic name is relevant for this - /// particular shadow. - pub fn should_handle_topic(&mut self, topic: &str) -> bool { - self.handler.should_handle_topic(topic) + Self { + handler, + dao: Mutex::new(dao), + } } - /// Handle incomming publish messages from the cloud on any topics relevant - /// for this particular shadow. + /// Wait delta will subscribe if not already to Updatedelta and wait for changes /// - /// This function needs to be fed all relevant incoming MQTT payloads in - /// order for the shadow manager to work. - #[must_use] - pub fn handle_message( - &mut self, - topic: &str, - payload: &[u8], - ) -> Result<(S, Option), Error> { - let (topic, thing_name, shadow_name) = - Topic::from_str(topic).ok_or(Error::WrongShadowName)?; - - assert_eq!(thing_name, self.handler.mqtt.client_id()); - assert_eq!(topic.direction(), Direction::Incoming); - - if shadow_name != S::NAME { - return Err(Error::WrongShadowName); - } - - let mut state = self.dao.read()?; - - let delta = match topic { - Topic::GetAccepted => { - // The actions necessary to process the state document in the - // message body. - serde_json_core::from_slice::>(payload) - .map_err(|_| Error::InvalidPayload) - .and_then(|(response, _)| { - if let Some(_) = response.state.delta { - debug!( - "[{:?}] Received delta state", - S::NAME.unwrap_or_else(|| CLASSIC_SHADOW) - ); - self.handler.change_shadow_value( - &mut state, - response.state.delta.clone(), - Some(false), - )?; - } else if let Some(_) = response.state.reported { - self.handler.change_shadow_value( - &mut state, - response.state.reported, - None, - )?; - } - Ok(response.state.delta) - })? + pub async fn wait_delta(&self) -> Result<(S, Option), Error> { + let mut state = match self.dao.lock().await.read().await { + Ok(state) => state, + Err(_) => { + error!("Could not read state from flash writing default"); + self.dao.lock().await.write(&S::default()).await?; + S::default() } - Topic::GetRejected | Topic::UpdateRejected => { - // Respond to the error message in the message body. - if let Ok((error, _)) = serde_json_core::from_slice::(payload) { - if error.code == 404 && matches!(topic, Topic::GetRejected) { - debug!( - "[{:?}] Thing has no shadow document. Creating with defaults...", - S::NAME.unwrap_or_else(|| CLASSIC_SHADOW) - ); - self.report_shadow()?; - } else { - error!( - "{:?} request was rejected. code: {:?} message:'{:?}'", - if matches!(topic, Topic::GetRejected) { - "Get" - } else { - "Update" - }, - error.code, - error.message - ); - } - } - None - } - Topic::UpdateDelta => { - // Update the device's state to match the desired state in the - // message body. - debug!( - "[{:?}] Received shadow delta event.", - S::NAME.unwrap_or_else(|| CLASSIC_SHADOW), - ); - - serde_json_core::from_slice::>(payload) - .map_err(|_| Error::InvalidPayload) - .and_then(|(delta, _)| { - if let Some(_) = delta.state { - debug!( - "[{:?}] Delta reports new desired value. Changing local value...", - S::NAME.unwrap_or_else(|| CLASSIC_SHADOW), - ); - } - self.handler.change_shadow_value( - &mut state, - delta.state.clone(), - Some(false), - )?; - Ok(delta.state) - })? - } - Topic::UpdateAccepted => { - // Confirm the updated data in the message body matches the - // device state. - - debug!( - "[{:?}] Finished updating reported shadow value.", - S::NAME.unwrap_or_else(|| CLASSIC_SHADOW) - ); - - None - } - _ => None, }; + let delta = self.handler.handle_delta().await?; + // Something has changed as part of handling a message. Persist it // to NVM storage. - if delta.is_some() { - self.dao.write(&state)?; + if let Some(delta) = &delta { + debug!( + "[{:?}] Delta reports new desired value. Changing local value...", + S::NAME.unwrap_or(CLASSIC_SHADOW), + ); + + state.apply_patch(delta.clone()); + + self.handler.report(&state).await?; + + self.dao.lock().await.write(&state).await?; } Ok((state, delta)) } /// Get an immutable reference to the internal local state. - pub fn try_get(&mut self) -> Result { - self.dao.read() + pub async fn try_get(&mut self) -> Result { + self.dao.lock().await.read().await } /// Initiate a `GetShadow` request, updating the local state from the cloud. - pub fn get_shadow(&self) -> Result<(), Error> { - self.handler.get_shadow() - } + pub async fn get_shadow(&self) -> Result { + let delta_state = self.handler.get_shadow().await?; + + debug!("Persisting new state after get shadow request"); + let mut state = self.dao.lock().await.read().await.unwrap_or_default(); + if let Some(desired) = delta_state.desired { + state.apply_patch(desired); + self.dao.lock().await.write(&state).await?; + if delta_state.delta.is_some() { + self.handler.report(&state).await?; + } + } - /// Initiate an `UpdateShadow` request, reporting the local state to the cloud. - pub fn report_shadow(&mut self) -> Result<(), Error> { - let mut state = self.dao.read()?; - self.handler - .change_shadow_value(&mut state, None, Some(false))?; - Ok(()) + Ok(state) } /// Update the state of the shadow. @@ -340,179 +471,75 @@ where /// and depending on whether the state update is rejected or accepted, it /// will automatically update the local version after response /// - /// The returned `bool` from the update closure will determine wether the + /// The returned `bool` from the update closure will determine whether the /// update is persisted using the `DAO`, or just updated in the cloud. This /// can be handy for activity or status field updates that are not relevant - /// to store persistant on the device, but are required to be part of the + /// to store persistent on the device, but are required to be part of the /// same cloud shadow. - pub fn update bool>(&mut self, f: F) -> Result<(), Error> { + pub async fn update(&self, f: F) -> Result<(), Error> { let mut desired = S::PatchState::default(); - let mut state = self.dao.read()?; - let should_persist = f(&state, &mut desired); + let mut state = self.dao.lock().await.read().await?; + f(&state, &mut desired); - self.handler - .change_shadow_value(&mut state, Some(desired), Some(false))?; + self.handler.report(&desired).await?; - if should_persist { - self.dao.write(&state)?; - } + state.apply_patch(desired); + + // Always persist + self.dao.lock().await.write(&state).await?; Ok(()) } - pub fn delete_shadow(&mut self) -> Result<(), Error> { - self.handler.delete_shadow() + pub async fn delete_shadow(&mut self) -> Result<(), Error> { + self.handler.delete_shadow().await?; + self.dao.lock().await.write(&S::default()).await?; + Ok(()) } } -pub struct Shadow<'a, S: ShadowState, M: Mqtt> +pub struct Shadow<'a, 'm, S: ShadowState, M: RawMutex, const SUBS: usize> where + BitsImpl<{ SUBS }>: Bits, [(); S::MAX_PAYLOAD_SIZE + PARTIAL_REQUEST_OVERHEAD]:, { state: S, - handler: ShadowHandler<'a, M, S>, + handler: ShadowHandler<'a, 'm, M, S, SUBS>, } -impl<'a, S, M> Shadow<'a, S, M> +impl<'a, 'm, S, M, const SUBS: usize> Shadow<'a, 'm, S, M, SUBS> where + BitsImpl<{ SUBS }>: Bits, S: ShadowState, - M: Mqtt, + M: RawMutex, [(); S::MAX_PAYLOAD_SIZE + PARTIAL_REQUEST_OVERHEAD]:, { /// Instantiate a new non-persisted shadow - pub fn new(state: S, mqtt: &'a M, auto_subscribe: bool) -> Result { + pub fn new(state: S, mqtt: &'m embedded_mqtt::MqttClient<'a, M, SUBS>) -> Self { let handler = ShadowHandler { mqtt, + subscription: Mutex::new(None), _shadow: PhantomData, }; - if auto_subscribe { - handler.subscribe()?; - } - Ok(Self { handler, state }) - } - - /// Subscribes to all the topics required for keeping a shadow in sync - pub fn subscribe(&self) -> Result<(), Error> { - self.handler.subscribe() - } - - /// Unsubscribes from all the topics required for keeping a shadow in sync - pub fn unsubscribe(&self) -> Result<(), Error> { - self.handler.unsubscribe() + Self { handler, state } } - /// Handle incomming publish messages from the cloud on any topics relevant + /// Handle incoming publish messages from the cloud on any topics relevant /// for this particular shadow. /// /// This function needs to be fed all relevant incoming MQTT payloads in /// order for the shadow manager to work. - #[must_use] - pub fn handle_message( - &mut self, - topic: &str, - payload: &[u8], - ) -> Result<(&S, Option), Error> { - let (topic, thing_name, shadow_name) = - Topic::from_str(topic).ok_or(Error::WrongShadowName)?; - - assert_eq!(thing_name, self.handler.mqtt.client_id()); - assert_eq!(topic.direction(), Direction::Incoming); - - if shadow_name != S::NAME { - return Err(Error::WrongShadowName); + pub async fn wait_delta(&mut self) -> Result<(&S, Option), Error> { + let delta = self.handler.handle_delta().await?; + if let Some(delta) = &delta { + debug!( + "[{:?}] Delta reports new desired value. Changing local value...", + S::NAME.unwrap_or(CLASSIC_SHADOW), + ); + self.handler.report(delta).await?; } - let delta = match topic { - Topic::GetAccepted => { - // The actions necessary to process the state document in the - // message body. - serde_json_core::from_slice::>(payload) - .map_err(|_| Error::InvalidPayload) - .and_then(|(response, _)| { - if let Some(_) = response.state.delta { - debug!( - "[{:?}] Received delta state", - S::NAME.unwrap_or_else(|| CLASSIC_SHADOW) - ); - self.handler.change_shadow_value( - &mut self.state, - response.state.delta.clone(), - Some(false), - )?; - } else if let Some(_) = response.state.reported { - self.handler.change_shadow_value( - &mut self.state, - response.state.reported, - None, - )?; - } - Ok(response.state.delta) - })? - } - Topic::GetRejected | Topic::UpdateRejected => { - // Respond to the error message in the message body. - if let Ok((error, _)) = serde_json_core::from_slice::(payload) { - if error.code == 404 && matches!(topic, Topic::GetRejected) { - debug!( - "[{:?}] Thing has no shadow document. Creating with defaults...", - S::NAME.unwrap_or_else(|| CLASSIC_SHADOW) - ); - self.report_shadow()?; - } else { - error!( - "{:?} request was rejected. code: {:?} message:'{:?}'", - if matches!(topic, Topic::GetRejected) { - "Get" - } else { - "Update" - }, - error.code, - error.message - ); - } - } - None - } - Topic::UpdateDelta => { - // Update the device's state to match the desired state in the - // message body. - debug!( - "[{:?}] Received shadow delta event.", - S::NAME.unwrap_or_else(|| CLASSIC_SHADOW), - ); - - serde_json_core::from_slice::>(payload) - .map_err(|_| Error::InvalidPayload) - .and_then(|(delta, _)| { - if let Some(_) = delta.state { - debug!( - "[{:?}] Delta reports new desired value. Changing local value...", - S::NAME.unwrap_or_else(|| CLASSIC_SHADOW), - ); - } - self.handler.change_shadow_value( - &mut self.state, - delta.state.clone(), - Some(false), - )?; - Ok(delta.state) - })? - } - Topic::UpdateAccepted => { - // Confirm the updated data in the message body matches the - // device state. - - debug!( - "[{:?}] Finished updating reported shadow value.", - S::NAME.unwrap_or_else(|| CLASSIC_SHADOW) - ); - - None - } - _ => None, - }; - - Ok((self.get(), delta)) + Ok((&self.state, delta)) } /// Get an immutable reference to the internal local state. @@ -520,10 +547,9 @@ where &self.state } - /// Initiate an `UpdateShadow` request, reporting the local state to the cloud. - pub fn report_shadow(&mut self) -> Result<(), Error> { - self.handler - .change_shadow_value(&mut self.state, None, Some(false))?; + /// Report the state of the shadow. + pub async fn report(&mut self) -> Result<(), Error> { + self.handler.report(&self.state).await?; Ok(()) } @@ -532,47 +558,59 @@ where /// This function will update the desired state of the shadow in the cloud, /// and depending on whether the state update is rejected or accepted, it /// will automatically update the local version after response - pub fn update(&mut self, f: F) -> Result<(), Error> { + pub async fn update(&mut self, f: F) -> Result<(), Error> { let mut desired = S::PatchState::default(); f(&self.state, &mut desired); - self.handler - .change_shadow_value(&mut self.state, Some(desired), Some(false))?; + self.handler.report(&desired).await?; + + self.state.apply_patch(desired); Ok(()) } /// Initiate a `GetShadow` request, updating the local state from the cloud. - pub fn get_shadow(&self) -> Result<(), Error> { - self.handler.get_shadow() + pub async fn get_shadow(&mut self) -> Result<&S, Error> { + let delta_state = self.handler.get_shadow().await?; + + if let Some(desired) = delta_state.desired { + self.state.apply_patch(desired); + if delta_state.delta.is_some() { + self.handler.report(&self.state).await?; + } + } + + Ok(&self.state) } - pub fn delete_shadow(&mut self) -> Result<(), Error> { - self.handler.delete_shadow() + pub async fn delete_shadow(&mut self) -> Result<(), Error> { + self.handler.delete_shadow().await } } -impl<'a, S, M> core::fmt::Debug for Shadow<'a, S, M> +impl<'a, 'm, S, M, const SUBS: usize> core::fmt::Debug for Shadow<'a, 'm, S, M, SUBS> where + BitsImpl<{ SUBS }>: Bits, S: ShadowState + core::fmt::Debug, - M: Mqtt, + M: RawMutex, [(); S::MAX_PAYLOAD_SIZE + PARTIAL_REQUEST_OVERHEAD]:, { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { write!( f, "[{:?}] = {:?}", - S::NAME.unwrap_or_else(|| CLASSIC_SHADOW), + S::NAME.unwrap_or(CLASSIC_SHADOW), self.get() ) } } #[cfg(feature = "defmt")] -impl<'a, S, M> defmt::Format for Shadow<'a, S, M> +impl<'a, 'm, S, M, const SUBS: usize> defmt::Format for Shadow<'a, 'm, S, M, SUBS> where + BitsImpl<{ SUBS }>: Bits, S: ShadowState + defmt::Format, - M: Mqtt, + M: RawMutex, [(); S::MAX_PAYLOAD_SIZE + PARTIAL_REQUEST_OVERHEAD]:, { fn format(&self, fmt: defmt::Formatter) { @@ -585,17 +623,6 @@ where } } -impl<'a, S, M> Drop for Shadow<'a, S, M> -where - S: ShadowState, - M: Mqtt, - [(); S::MAX_PAYLOAD_SIZE + PARTIAL_REQUEST_OVERHEAD]:, -{ - fn drop(&mut self) { - self.unsubscribe().ok(); - } -} - // #[cfg(test)] // mod tests { // use super::*; diff --git a/src/shadows/topics.rs b/src/shadows/topics.rs index c73e35a..34642d0 100644 --- a/src/shadows/topics.rs +++ b/src/shadows/topics.rs @@ -2,8 +2,8 @@ use core::fmt::Write; +use embedded_mqtt::QoS; use heapless::String; -use mqttrust::{Mqtt, QoS, SubscribeTopic}; use crate::jobs::MAX_THING_NAME_LEN; @@ -33,6 +33,7 @@ pub enum Topic { UpdateRejected, DeleteAccepted, DeleteRejected, + Any, } impl Topic { @@ -188,6 +189,14 @@ impl Topic { name_prefix, shadow_name )), + Self::Any => topic_path.write_fmt(format_args!( + "{}/{}/{}{}{}/#", + Self::PREFIX, + thing_name, + Self::SHADOW, + name_prefix, + shadow_name + )), } .map_err(|_| Error::Overflow)?; @@ -233,29 +242,6 @@ impl Subscribe { .map(|(topic, qos)| Ok((Topic::from(*topic).format(thing_name, shadow_name)?, *qos))) .collect() } - - pub fn send(self, mqtt: &M, shadow_name: Option<&'static str>) -> Result<(), Error> { - if self.topics.is_empty() { - return Ok(()); - } - - let topic_paths = self.topics(mqtt.client_id(), shadow_name)?; - - let topics: heapless::Vec<_, N> = topic_paths - .iter() - .map(|(s, qos)| SubscribeTopic { - topic_path: s.as_str(), - qos: *qos, - }) - .collect(); - - debug!("Subscribing!"); - - for t in topics.chunks(5) { - mqtt.subscribe(t)?; - } - Ok(()) - } } #[derive(Default)] @@ -295,19 +281,4 @@ impl Unsubscribe { .map(|topic| Topic::from(*topic).format(thing_name, shadow_name)) .collect() } - - pub fn send(self, mqtt: &M, shadow_name: Option<&'static str>) -> Result<(), Error> { - if self.topics.is_empty() { - return Ok(()); - } - - let topic_paths = self.topics(mqtt.client_id(), shadow_name)?; - let topics: heapless::Vec<_, N> = topic_paths.iter().map(|s| s.as_str()).collect(); - - for t in topics.chunks(5) { - mqtt.unsubscribe(t)?; - } - - Ok(()) - } } diff --git a/tests/common/file_handler.rs b/tests/common/file_handler.rs index 942a082..d90e995 100644 --- a/tests/common/file_handler.rs +++ b/tests/common/file_handler.rs @@ -1,14 +1,13 @@ use core::ops::Deref; -use embassy_sync::blocking_mutex::raw::NoopRawMutex; -use embassy_sync::mutex::Mutex; +use embedded_storage_async::nor_flash::{ErrorType, NorFlash, ReadNorFlash}; use rustot::ota::{ - self, + encoding::json, pal::{OtaPal, OtaPalError, PalImageState}, }; use sha2::{Digest, Sha256}; use std::{ - fs::File, - io::{Cursor, Read, Write}, + convert::Infallible, + io::{Cursor, Write}, }; #[derive(Debug, PartialEq, Eq)] @@ -17,8 +16,44 @@ pub enum State { Boot, } +pub struct BlockFile { + filebuf: Cursor>, +} + +impl NorFlash for BlockFile { + const WRITE_SIZE: usize = 1; + + const ERASE_SIZE: usize = 1; + + async fn erase(&mut self, _from: u32, _to: u32) -> Result<(), Self::Error> { + Ok(()) + } + + async fn write(&mut self, offset: u32, bytes: &[u8]) -> Result<(), Self::Error> { + self.filebuf.set_position(offset as u64); + self.filebuf.write_all(bytes).unwrap(); + Ok(()) + } +} + +impl ReadNorFlash for BlockFile { + const READ_SIZE: usize = 1; + + async fn read(&mut self, _offset: u32, _bytes: &mut [u8]) -> Result<(), Self::Error> { + todo!() + } + + fn capacity(&self) -> usize { + self.filebuf.get_ref().capacity() + } +} + +impl ErrorType for BlockFile { + type Error = Infallible; +} + pub struct FileHandler { - filebuf: Option>>, + filebuf: Option, compare_file_path: String, pub plateform_state: State, } @@ -34,6 +69,8 @@ impl FileHandler { } impl OtaPal for FileHandler { + type BlockWriter = BlockFile; + async fn abort( &mut self, _file: &rustot::ota::encoding::FileContext, @@ -44,9 +81,10 @@ impl OtaPal for FileHandler { async fn create_file_for_rx( &mut self, file: &rustot::ota::encoding::FileContext, - ) -> Result<(), OtaPalError> { - self.filebuf = Some(Cursor::new(Vec::with_capacity(file.filesize))); - Ok(()) + ) -> Result<&mut Self::BlockWriter, OtaPalError> { + Ok(self.filebuf.get_or_insert(BlockFile { + filebuf: Cursor::new(Vec::with_capacity(file.filesize)), + })) } async fn get_platform_image_state(&mut self) -> Result { @@ -78,12 +116,12 @@ impl OtaPal for FileHandler { if let Some(ref mut buf) = &mut self.filebuf { log::debug!( "Closing completed file. Len: {}/{} -> {}", - buf.get_ref().len(), + buf.filebuf.get_ref().len(), file.filesize, file.filepath.as_str() ); - let mut expected_data = std::fs::read(self.compare_file_path.as_str()).unwrap(); + let expected_data = std::fs::read(self.compare_file_path.as_str()).unwrap(); let mut expected_hasher = ::new(); expected_hasher.update(&expected_data); let expected_hash = expected_hasher.finalize(); @@ -93,27 +131,19 @@ impl OtaPal for FileHandler { self.compare_file_path, file.filepath.as_str() ); - assert_eq!(buf.get_ref().len(), file.filesize); + assert_eq!(buf.filebuf.get_ref().len(), file.filesize); let mut hasher = ::new(); - hasher.update(&buf.get_ref()); + hasher.update(&buf.filebuf.get_ref()); assert_eq!(hasher.finalize().deref(), expected_hash.deref()); // Check file signature - match &file.signature { - ota::encoding::json::Signature::Sha1Rsa(_) => { - panic!("Unexpected signature format: Sha1Rsa. Expected Sha256Ecdsa") - } - ota::encoding::json::Signature::Sha256Rsa(_) => { - panic!("Unexpected signature format: Sha256Rsa. Expected Sha256Ecdsa") - } - ota::encoding::json::Signature::Sha1Ecdsa(_) => { - panic!("Unexpected signature format: Sha1Ecdsa. Expected Sha256Ecdsa") - } - ota::encoding::json::Signature::Sha256Ecdsa(sig) => { - assert_eq!(sig.as_str(), "This is my custom signature\\n") - } - } + let signature = match file.signature.as_ref() { + Some(json::Signature::Sha256Ecdsa(ref s)) => s.as_str(), + sig => panic!("Unexpected signature format! {:?}", sig), + }; + + assert_eq!(signature, "This is my custom signature\\n"); self.plateform_state = State::Swap; @@ -122,20 +152,4 @@ impl OtaPal for FileHandler { Err(OtaPalError::BadFileHandle) } } - - async fn write_block( - &mut self, - _file: &mut rustot::ota::encoding::FileContext, - block_offset: usize, - block_payload: &[u8], - ) -> Result { - if let Some(ref mut buf) = &mut self.filebuf { - buf.set_position(block_offset as u64); - buf.write(block_payload) - .map_err(|_e| OtaPalError::FileWriteFailed)?; - Ok(block_payload.len()) - } else { - Err(OtaPalError::BadFileHandle) - } - } } diff --git a/tests/common/network.rs b/tests/common/network.rs index dfbe27c..0cfe3db 100644 --- a/tests/common/network.rs +++ b/tests/common/network.rs @@ -40,7 +40,7 @@ impl Dns for Network { host: &str, addr_type: AddrType, ) -> Result { - for ip in tokio::net::lookup_host(host).await? { + for ip in tokio::net::lookup_host(format!("{}:0", host)).await? { match (&addr_type, ip) { (AddrType::IPv4 | AddrType::Either, SocketAddr::V4(ip)) => { return Ok(IpAddr::V4(Ipv4Addr::from(ip.ip().octets()))) @@ -114,7 +114,7 @@ impl Dns for TlsNetwork { addr_type: AddrType, ) -> Result { log::info!("Looking up {}", host); - for ip in tokio::net::lookup_host(host).await? { + for ip in tokio::net::lookup_host(format!("{}:0", host)).await? { log::info!("Found IP {}", ip); match (&addr_type, ip) { diff --git a/tests/ota.rs b/tests/ota.rs index ac82f18..da8c020 100644 --- a/tests/ota.rs +++ b/tests/ota.rs @@ -3,28 +3,21 @@ mod common; -use std::{net::ToSocketAddrs, process}; - use common::credentials; use common::file_handler::{FileHandler, State as FileHandlerState}; use common::network::TlsNetwork; use embassy_futures::select; -use embassy_sync::blocking_mutex::raw::{NoopRawMutex, RawMutex}; -use embassy_time::Duration; +use embassy_sync::blocking_mutex::raw::NoopRawMutex; use embedded_mqtt::transport::embedded_nal::NalTransport; -use embedded_mqtt::{ - Config, DomainBroker, IpBroker, Message, Publish, QoS, RetainHandling, State, Subscribe, - SubscribeTopic, -}; +use embedded_mqtt::{Config, DomainBroker, Message, State, Subscribe, SubscribeTopic}; use futures::StreamExt; -use serde::{Deserialize, Serialize}; -use static_cell::make_static; +use serde::Deserialize; +use static_cell::StaticCell; use rustot::{ jobs::{ self, data_types::{DescribeJobExecutionResponse, NextJobExecutionChanged}, - JobTopic, StatusDetails, }, ota::{ self, @@ -49,7 +42,7 @@ impl<'a> Jobs<'a> { } fn handle_ota<'a, const SUBS: usize>( - message: Message<'a, NoopRawMutex, SUBS>, + message: Message<'a, SUBS>, config: &ota::config::Config, ) -> Option { let job = match jobs::Topic::from_str(message.topic_name()) { @@ -99,8 +92,7 @@ async fn test_mqtt_ota() { // Create the MQTT stack let broker = DomainBroker::<_, 128>::new(format!("{}:8883", hostname).as_str(), network).unwrap(); - let config = - Config::new(thing_name, broker).keepalive_interval(embassy_time::Duration::from_secs(50)); + let config = Config::new(thing_name).keepalive_interval(embassy_time::Duration::from_secs(50)); static STATE: StaticCell> = StaticCell::new(); let state = STATE.init(State::::new()); @@ -110,26 +102,26 @@ async fn test_mqtt_ota() { let ota_fut = async { let mut jobs_subscription = client - .subscribe::<2>(Subscribe::new(&[ - SubscribeTopic { - topic_path: jobs::JobTopic::NotifyNext - .format::<64>(thing_name)? - .as_str(), - maximum_qos: QoS::AtLeastOnce, - no_local: false, - retain_as_published: false, - retain_handling: RetainHandling::SendAtSubscribeTime, - }, - SubscribeTopic { - topic_path: jobs::JobTopic::DescribeAccepted("$next") - .format::<64>(thing_name)? - .as_str(), - maximum_qos: QoS::AtLeastOnce, - no_local: false, - retain_as_published: false, - retain_handling: RetainHandling::SendAtSubscribeTime, - }, - ])) + .subscribe::<2>( + Subscribe::builder() + .topics(&[ + SubscribeTopic::builder() + .topic_path( + jobs::JobTopic::NotifyNext + .format::<64>(thing_name)? + .as_str(), + ) + .build(), + SubscribeTopic::builder() + .topic_path( + jobs::JobTopic::DescribeAccepted("$next") + .format::<64>(thing_name)? + .as_str(), + ) + .build(), + ]) + .build(), + ) .await?; Updater::check_for_job(&client).await?; @@ -167,7 +159,7 @@ async fn test_mqtt_ota() { Ok::<_, ota::error::OtaError>(()) }; - let mut transport = NalTransport::new(network); + let mut transport = NalTransport::new(network, broker); match embassy_time::with_timeout( embassy_time::Duration::from_secs(25), diff --git a/tests/provisioning.rs b/tests/provisioning.rs index 804d2c5..e9eafd5 100644 --- a/tests/provisioning.rs +++ b/tests/provisioning.rs @@ -3,23 +3,16 @@ mod common; -use std::{net::ToSocketAddrs, process}; - use common::credentials; use common::network::TlsNetwork; use ecdsa::Signature; use embassy_futures::select; use embassy_sync::blocking_mutex::raw::NoopRawMutex; -use embedded_mqtt::{ - transport::embedded_nal::NalTransport, Config, DomainBroker, IpBroker, Publish, State, - Subscribe, SubscribeTopic, -}; +use embedded_mqtt::{transport::embedded_nal::NalTransport, Config, DomainBroker, State}; use p256::{ecdsa::signature::Signer, NistP256}; -use rustot::provisioning::{ - topics::Topic, CredentialHandler, Credentials, Error, FleetProvisioner, -}; +use rustot::provisioning::{CredentialHandler, Credentials, Error, FleetProvisioner}; use serde::{Deserialize, Serialize}; -use static_cell::make_static; +use static_cell::StaticCell; pub struct OwnedCredentials { pub certificate_id: String, @@ -82,8 +75,7 @@ async fn test_provisioning() { // Create the MQTT stack let broker = DomainBroker::<_, 128>::new(format!("{}:8883", hostname).as_str(), network).unwrap(); - let config = - Config::new(thing_name, broker).keepalive_interval(embassy_time::Duration::from_secs(50)); + let config = Config::new(thing_name).keepalive_interval(embassy_time::Duration::from_secs(50)); static STATE: StaticCell> = StaticCell::new(); let state = STATE.init(State::::new()); @@ -115,7 +107,7 @@ async fn test_provisioning() { &mut credential_handler, ); - let mut transport = NalTransport::new(network); + let mut transport = NalTransport::new(network, broker); let device_config = match embassy_time::with_timeout( embassy_time::Duration::from_secs(15), diff --git a/tests/shadows.rs b/tests/shadows.rs index cbd979e..af788dd 100644 --- a/tests/shadows.rs +++ b/tests/shadows.rs @@ -26,7 +26,17 @@ // use core::fmt::Write; -// use common::{clock::SysClock, credentials, network::Network}; +// use common::{ +// clock::SysClock, +// credentials, +// network::{Network, TlsNetwork}, +// }; +// use embassy_futures::select; +// use embassy_sync::blocking_mutex::raw::NoopRawMutex; +// use embedded_mqtt::{ +// transport::embedded_nal::{self, NalTransport}, +// DomainBroker, Properties, Publish, QoS, State, +// }; // use embedded_nal::Ipv4Addr; // use mqttrust::Mqtt; // use mqttrust_core::{bbqueue::BBBuffer, EventLoop, MqttOptions, Notification}; @@ -37,9 +47,7 @@ // use serde::{de::DeserializeOwned, Deserialize, Serialize}; // use smlang::statemachine; - -// const Q_SIZE: usize = 1024 * 6; -// static mut Q: BBBuffer = BBBuffer::new(); +// use static_cell::StaticCell; // #[derive(Debug, Clone, Default, PartialEq, Eq, PartialOrd, Ord)] // pub struct ConfigId(pub u8); @@ -284,7 +292,7 @@ // pub fn spin( // &mut self, // notification: Notification, -// mqtt_client: &mqttrust_core::Client<'static, 'static, Q_SIZE>, +// mqtt_client: &embedded_mqtt::MqttClient<'a, M, 1>, // ) -> bool { // log::info!("State: {:?}", self.state()); // match (self.state(), notification) { @@ -294,15 +302,19 @@ // (&States::DeleteShadow, Notification::Suback(_)) => { // mqtt_client // .publish( -// &Topic::Update -// .format::<128>( -// mqtt_client.client_id(), -// ::NAME, +// Publish::builder() +// .topic_name( +// &Topic::Update +// .format::<128>( +// mqtt_client.client_id(), +// ::NAME, +// ) +// .unwrap(), // ) -// .unwrap(), -// b"{\"state\":{\"desired\":null,\"reported\":null}}", -// mqttrust::QoS::AtLeastOnce, +// .payload(b"{\"state\":{\"desired\":null,\"reported\":null}}") +// .build(), // ) +// .await // .unwrap(); // self.process_event(Events::Get).unwrap(); @@ -387,14 +399,17 @@ // mqtt_client // .publish( -// &Topic::Update -// .format::<128>( -// mqtt_client.client_id(), -// ::NAME, +// Publish::builder() +// .topic_name( +// &Topic::Update +// .format::<128>( +// mqtt_client.client_id(), +// ::NAME, +// ) +// .unwrap(), // ) -// .unwrap(), -// payload.as_bytes(), -// mqttrust::QoS::AtLeastOnce, +// .payload(payload.as_bytes()) +// .build(), // ) // .unwrap(); // self.process_event(Events::Ack).unwrap(); @@ -453,50 +468,66 @@ // } // } -// #[test] -// fn test_shadows() { +// #[tokio::test(flavor = "current_thread")] +// async fn test_shadows() { // env_logger::init(); -// let (p, c) = unsafe { Q.try_split_framed().unwrap() }; - // log::info!("Starting shadows test..."); -// let hostname = credentials::HOSTNAME.unwrap(); // let (thing_name, identity) = credentials::identity(); -// let connector = TlsConnector::builder() -// .identity(identity) -// .add_root_certificate(credentials::root_ca()) -// .build() -// .unwrap(); +// let hostname = credentials::HOSTNAME.unwrap(); -// let mut network = Network::new_tls(connector, std::string::String::from(hostname)); +// static NETWORK: StaticCell = StaticCell::new(); +// let network = NETWORK.init(TlsNetwork::new(hostname.to_owned(), identity)); -// let mut mqtt_eventloop = EventLoop::new( -// c, -// SysClock::new(), -// MqttOptions::new(thing_name, hostname.into(), 8883).set_clean_session(true), -// ); +// // Create the MQTT stack +// let broker = +// DomainBroker::<_, 128>::new(format!("{}:8883", hostname).as_str(), &network).unwrap(); +// let config = embedded_mqtt::Config::new(thing_name) +// .keepalive_interval(embassy_time::Duration::from_secs(50)); -// let mqtt_client = mqttrust_core::Client::new(p, thing_name); +// let mut state = State::::new(); +// let (mut stack, client) = embedded_mqtt::new(&mut state, config); -// let mut test_state = StateMachine::new(TestContext { -// shadow: Shadow::new(WifiConfig::default(), &mqtt_client, true).unwrap(), -// update_cnt: 0, -// }); +// let mqtt_client = client; -// loop { -// if nb::block!(mqtt_eventloop.connect(&mut network)).expect("to connect to mqtt") { -// log::info!("Successfully connected to broker"); -// } +// let shadow = Shadow::new(WifiConfig::default(), &mqtt_client).unwrap(); -// match mqtt_eventloop.yield_event(&mut network) { -// Ok(notification) => { -// if test_state.spin(notification, &mqtt_client) { -// break; -// } -// } -// Err(_) => {} +// // loop { +// // if nb::block!(mqtt_eventloop.connect(&mut network)).expect("to connect to mqtt") { +// // log::info!("Successfully connected to broker"); +// // } + +// // match mqtt_eventloop.yield_event(&mut network) { +// // Ok(notification) => { +// // if test_state.spin(notification, &mqtt_client) { +// // break; +// // } +// // } +// // Err(_) => {} +// // } +// // } + +// // cloud_updater(mqtt_client); + +// let shadows_fut = async { +// shadow.next_update().await; +// todo!() +// }; + +// let mut transport = NalTransport::new(network, broker); + +// match embassy_time::with_timeout( +// embassy_time::Duration::from_secs(25), +// select::select(stack.run(&mut transport), shadows_fut), +// ) +// .await +// .unwrap() +// { +// select::Either::First(_) => { +// unreachable!() // } -// } +// select::Either::Second(result) => result.unwrap(), +// }; // } From 1763f6c38441ca7bb7ae9bb2c0d60b02fbeac668 Mon Sep 17 00:00:00 2001 From: Mathias Date: Thu, 26 Sep 2024 13:40:55 +0200 Subject: [PATCH 22/36] Add report fn to persisted shadows --- src/shadows/mod.rs | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/shadows/mod.rs b/src/shadows/mod.rs index 2825fd5..c508661 100644 --- a/src/shadows/mod.rs +++ b/src/shadows/mod.rs @@ -224,7 +224,7 @@ where } } - pub async fn delete_shadow(&mut self) -> Result<(), Error> { + pub async fn delete_shadow(&self) -> Result<(), Error> { // Wait for mqtt to connect self.mqtt.wait_connected().await; @@ -444,7 +444,7 @@ where } /// Get an immutable reference to the internal local state. - pub async fn try_get(&mut self) -> Result { + pub async fn try_get(&self) -> Result { self.dao.lock().await.read().await } @@ -465,6 +465,14 @@ where Ok(state) } + /// Report the state of the shadow. + pub async fn report(&self) -> Result<(), Error> { + let state = self.dao.lock().await.read().await?; + + self.handler.report(&state).await?; + Ok(()) + } + /// Update the state of the shadow. /// /// This function will update the desired state of the shadow in the cloud, @@ -491,7 +499,7 @@ where Ok(()) } - pub async fn delete_shadow(&mut self) -> Result<(), Error> { + pub async fn delete_shadow(&self) -> Result<(), Error> { self.handler.delete_shadow().await?; self.dao.lock().await.write(&S::default()).await?; Ok(()) From 967c0b5dcc1bde1454ef2bf1a4646f3514330f8a Mon Sep 17 00:00:00 2001 From: Mathias Date: Thu, 3 Oct 2024 10:00:17 +0200 Subject: [PATCH 23/36] OTA tests pass again after rewrite --- Cargo.toml | 5 +- src/ota/config.rs | 2 +- src/ota/control_interface/mod.rs | 4 +- src/ota/control_interface/mqtt.rs | 37 +- src/ota/data_interface/http.rs | 14 +- src/ota/data_interface/mod.rs | 13 +- src/ota/data_interface/mqtt.rs | 23 +- src/ota/encoding/mod.rs | 2 +- src/ota/mod.rs | 718 ++++++++++++++++------------ tests/common/credentials.rs | 3 + tests/{ota.rs => ota_mqtt.rs} | 81 ++-- tests/provisioning.rs | 5 +- tests/shadows.rs | 771 +++++++++--------------------- 13 files changed, 746 insertions(+), 932 deletions(-) rename tests/{ota.rs => ota_mqtt.rs} (71%) diff --git a/Cargo.toml b/Cargo.toml index 42f4d3a..4c3242b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,6 +46,7 @@ embedded-nal-async = "0.7" env_logger = "0.11" sha2 = "0.10.1" static_cell = { version = "2", features = ["nightly"] } +log = { version = "0.4" } tokio = { version = "1.33", default-features = false, features = [ "macros", @@ -82,5 +83,5 @@ defmt = [ ] log = ["dep:log", "embedded-mqtt/log"] -# [patch."ssh://git@github.com/FactbirdHQ/embedded-mqtt"] -# embedded-mqtt = { path = "../embedded-mqtt" } +[patch."ssh://git@github.com/FactbirdHQ/embedded-mqtt"] +embedded-mqtt = { path = "../embedded-mqtt" } diff --git a/src/ota/config.rs b/src/ota/config.rs index 39fe685..ef5cd28 100644 --- a/src/ota/config.rs +++ b/src/ota/config.rs @@ -11,7 +11,7 @@ pub struct Config { impl Default for Config { fn default() -> Self { Self { - block_size: 256, + block_size: 1024, max_request_momentum: 3, request_wait: Duration::from_secs(8), status_update_frequency: 24, diff --git a/src/ota/control_interface/mod.rs b/src/ota/control_interface/mod.rs index 8962ea2..51e7a7c 100644 --- a/src/ota/control_interface/mod.rs +++ b/src/ota/control_interface/mod.rs @@ -4,6 +4,7 @@ use super::{ config::Config, encoding::{json::JobStatusReason, FileContext}, error::OtaError, + ProgressState, }; pub mod mqtt; @@ -13,7 +14,8 @@ pub trait ControlInterface { async fn request_job(&self) -> Result<(), OtaError>; async fn update_job_status( &self, - file_ctx: &mut FileContext, + file_ctx: &FileContext, + progress: &mut ProgressState, config: &Config, status: JobStatus, reason: JobStatusReason, diff --git a/src/ota/control_interface/mqtt.rs b/src/ota/control_interface/mqtt.rs index 4456428..8cd9cf2 100644 --- a/src/ota/control_interface/mqtt.rs +++ b/src/ota/control_interface/mqtt.rs @@ -12,6 +12,7 @@ use crate::ota::config::Config; use crate::ota::encoding::json::JobStatusReason; use crate::ota::encoding::{self, FileContext}; use crate::ota::error::OtaError; +use crate::ota::ProgressState; impl<'a, M: RawMutex, const SUBS: usize> ControlInterface for embedded_mqtt::MqttClient<'a, M, SUBS> where @@ -35,16 +36,21 @@ where Ok(()) } - /// Update the job status on the service side with progress or completion - /// info + /// Update the job status on the service side. + /// + /// Returns a Result indicating success or an error, + /// along with an Option containing the updated status details + /// if they were modified. async fn update_job_status( &self, - file_ctx: &mut FileContext, + file_ctx: &FileContext, + progress_state: &mut ProgressState, config: &Config, status: JobStatus, reason: JobStatusReason, ) -> Result<(), OtaError> { - file_ctx + // Update the status details within this function. + progress_state .status_details .insert( heapless::String::try_from("self_test").unwrap(), @@ -54,16 +60,14 @@ where let mut qos = QoS::AtLeastOnce; - if let (JobStatus::InProgress, _) | (JobStatus::Succeeded, _) = (status, reason) { - let total_blocks = - ((file_ctx.filesize + config.block_size - 1) / config.block_size) as u32; - let received_blocks = total_blocks - file_ctx.blocks_remaining as u32; + if let JobStatus::InProgress | JobStatus::Succeeded = status { + let received_blocks = progress_state.total_blocks - progress_state.blocks_remaining; // Output a status update once in a while. Always update first and // last status - if file_ctx.blocks_remaining != 0 + if progress_state.blocks_remaining != 0 && received_blocks != 0 - && received_blocks % config.status_update_frequency != 0 + && received_blocks % config.status_update_frequency as usize != 0 { return Ok(()); } @@ -74,10 +78,13 @@ where if status != JobStatus::Succeeded && reason != JobStatusReason::SelfTestActive { let mut progress = heapless::String::new(); progress - .write_fmt(format_args!("{}/{}", received_blocks, total_blocks)) + .write_fmt(format_args!( + "{}/{}", + received_blocks, progress_state.total_blocks + )) .map_err(|_| OtaError::Overflow)?; - file_ctx + progress_state .status_details .insert(heapless::String::try_from("progress").unwrap(), progress) .map_err(|_| OtaError::Overflow)?; @@ -86,7 +93,7 @@ where // Downgrade progress updates to QOS 0 to avoid overloading MQTT // buffers during active streaming. But make sure to always send and await ack for first update and last update if status == JobStatus::InProgress - && file_ctx.blocks_remaining != 0 + && progress_state.blocks_remaining != 0 && received_blocks != 0 { qos = QoS::AtMostOnce; @@ -126,13 +133,15 @@ where |buf| { Jobs::update(status) .client_token(self.client_id()) - .status_details(&file_ctx.status_details) + .status_details(&progress_state.status_details) .payload(buf) .map_err(|_| EncodingError::BufferSize) }, 512, ); + warn!("Updating job status! {:?}", status); + self.publish( Publish::builder() .qos(qos) diff --git a/src/ota/data_interface/http.rs b/src/ota/data_interface/http.rs index c1ba639..4b53ac9 100644 --- a/src/ota/data_interface/http.rs +++ b/src/ota/data_interface/http.rs @@ -20,23 +20,11 @@ impl DataInterface for HttpInterface { Ok(()) } - fn request_file_block( + fn request_file_blocks( &self, _file_ctx: &mut FileContext, _config: &Config, ) -> Result<(), OtaError> { Ok(()) } - - fn decode_file_block<'b>( - &self, - _file_ctx: &mut FileContext, - _payload: &'b mut [u8], - ) -> Result, OtaError> { - unimplemented!() - } - - fn cleanup(&self, _file_ctx: &mut FileContext, _config: &Config) -> Result<(), OtaError> { - Ok(()) - } } diff --git a/src/ota/data_interface/mod.rs b/src/ota/data_interface/mod.rs index cfb9cee..ebc943e 100644 --- a/src/ota/data_interface/mod.rs +++ b/src/ota/data_interface/mod.rs @@ -9,7 +9,7 @@ use serde::Deserialize; use crate::ota::config::Config; -use super::{encoding::FileContext, error::OtaError}; +use super::{encoding::FileContext, error::OtaError, ProgressState}; #[derive(Debug, Clone, PartialEq, Deserialize)] #[cfg_attr(feature = "defmt", derive(defmt::Format))] @@ -60,15 +60,12 @@ pub trait DataInterface { file_ctx: &FileContext, ) -> Result, OtaError>; - async fn request_file_block( + async fn request_file_blocks( &self, - file_ctx: &mut FileContext, + file_ctx: &FileContext, + progress_state: &mut ProgressState, config: &Config, ) -> Result<(), OtaError>; - fn decode_file_block<'a>( - &self, - file_ctx: &FileContext, - payload: &'a mut [u8], - ) -> Result, OtaError>; + fn decode_file_block<'a>(&self, payload: &'a mut [u8]) -> Result, OtaError>; } diff --git a/src/ota/data_interface/mqtt.rs b/src/ota/data_interface/mqtt.rs index 17bbf48..1bbaae6 100644 --- a/src/ota/data_interface/mqtt.rs +++ b/src/ota/data_interface/mqtt.rs @@ -10,6 +10,7 @@ use embedded_mqtt::{ use futures::StreamExt; use crate::ota::error::OtaError; +use crate::ota::ProgressState; use crate::{ jobs::{MAX_STREAM_ID_LEN, MAX_THING_NAME_LEN}, ota::{ @@ -161,12 +162,13 @@ where } /// Request file block by publishing to the get stream topic - async fn request_file_block( + async fn request_file_blocks( &self, - file_ctx: &mut FileContext, + file_ctx: &FileContext, + progress_state: &mut ProgressState, config: &Config, ) -> Result<(), OtaError> { - file_ctx.request_block_remaining = file_ctx.bitmap.len() as u32; + progress_state.request_block_remaining = progress_state.bitmap.len() as u32; let payload = DeferredPayload::new( |buf| { @@ -177,9 +179,9 @@ where stream_version: None, file_id: file_ctx.fileid, block_size: config.block_size, - block_offset: Some(file_ctx.block_offset), - block_bitmap: Some(&file_ctx.bitmap), - number_of_blocks: None, + block_offset: Some(progress_state.block_offset), + block_bitmap: Some(&progress_state.bitmap), + number_of_blocks: Some(progress_state.request_block_remaining), }, buf, ) @@ -190,7 +192,7 @@ where debug!( "Requesting more file blocks. Remaining: {}", - file_ctx.request_block_remaining + progress_state.request_block_remaining ); self.publish( @@ -202,6 +204,7 @@ where )? .as_str(), ) + // .qos(embedded_mqtt::QoS::AtMostOnce) .payload(payload) .build(), ) @@ -211,11 +214,7 @@ where } /// Decode a cbor encoded fileblock received from streaming service - fn decode_file_block<'c>( - &self, - _file_ctx: &FileContext, - payload: &'c mut [u8], - ) -> Result, OtaError> { + fn decode_file_block<'c>(&self, payload: &'c mut [u8]) -> Result, OtaError> { Ok( serde_cbor::de::from_mut_slice::(payload) .map_err(|_| OtaError::Encoding)? diff --git a/src/ota/encoding/mod.rs b/src/ota/encoding/mod.rs index 7c88700..b0349ed 100644 --- a/src/ota/encoding/mod.rs +++ b/src/ota/encoding/mod.rs @@ -14,7 +14,7 @@ use super::data_interface::Protocol; use super::error::OtaError; use super::JobEventData; -#[derive(Clone, PartialEq)] +#[derive(Clone, Debug, PartialEq)] pub struct Bitmap(bitmaps::Bitmap<32>); impl Bitmap { diff --git a/src/ota/mod.rs b/src/ota/mod.rs index 338320c..cba36c9 100644 --- a/src/ota/mod.rs +++ b/src/ota/mod.rs @@ -1,35 +1,3 @@ -//! ## Over-the-air (OTA) flashing of firmware -//! -//! AWS IoT OTA works by using AWS IoT Jobs to manage firmware transfer and -//! status reporting of OTA. -//! -//! The OTA Jobs API makes use of the following special MQTT Topics: -//! - $aws/things/{thing_name}/jobs/$next/get/accepted -//! - $aws/things/{thing_name}/jobs/notify-next -//! - $aws/things/{thing_name}/jobs/$next/get -//! - $aws/things/{thing_name}/jobs/{job_id}/update -//! - $aws/things/{thing_name}/streams/{stream_id}/data/cbor -//! - $aws/things/{thing_name}/streams/{stream_id}/get/cbor -//! -//! Most of the data structures for the Jobs API has been copied from Rusoto: -//! -//! -//! ### OTA Flow: -//! 1. Device subscribes to notification topics for AWS IoT jobs and listens for -//! update messages. -//! 2. When an update is available, the OTA agent publishes requests to AWS IoT -//! and receives updates using the HTTP or MQTT protocol, depending on the -//! settings you chose. -//! 3. The OTA agent checks the digital signature of the downloaded files and, -//! if the files are valid, installs the firmware update to the appropriate -//! flash bank. -//! -//! The OTA depends on working, and correctly setup: -//! - Bootloader -//! - MQTT Client -//! - Code sign verification -//! - CBOR deserializer - pub mod config; pub mod control_interface; pub mod data_interface; @@ -37,24 +5,22 @@ pub mod encoding; pub mod error; pub mod pal; -use core::{ - ops::DerefMut, - sync::atomic::{AtomicU8, Ordering}, -}; +use core::ops::DerefMut as _; #[cfg(feature = "ota_mqtt_data")] pub use data_interface::mqtt::{Encoding, Topic}; +use embassy_sync::{blocking_mutex::raw::NoopRawMutex, mutex::Mutex, signal::Signal}; use embedded_storage_async::nor_flash::{NorFlash, NorFlashError as _}; use crate::{ - jobs::data_types::JobStatus, + jobs::{data_types::JobStatus, StatusDetailsOwned}, ota::{data_interface::BlockTransfer, encoding::json::JobStatusReason}, }; use self::{ control_interface::ControlInterface, data_interface::DataInterface, - encoding::FileContext, + encoding::{Bitmap, FileContext}, pal::{ImageState, ImageStateReason}, }; @@ -78,147 +44,49 @@ impl Updater { pub async fn perform_ota<'a, 'b, C: ControlInterface, D: DataInterface>( control: &C, data: &D, - mut file_ctx: FileContext, + file_ctx: FileContext, pal: &mut impl pal::OtaPal, config: &config::Config, ) -> Result<(), error::OtaError> { - // If the job is in self test mode, don't start an OTA update but - // instead do the following: - // - // If the firmware that performed the update was older than the - // currently running firmware, set the image state to "Testing." This is - // the success path. - // - // If it's the same or newer, reject the job since either the firmware - // was not accepted during self test or an incorrect image was sent by - // the OTA operator. - let platform_self_test = pal - .get_platform_image_state() - .await - .map_or(false, |i| i == pal::PalImageState::PendingCommit); - - match (file_ctx.self_test(), platform_self_test) { - (true, true) => { - // Run self-test! - Self::set_image_state_with_reason( - control, - pal, - &config, - &mut file_ctx, - ImageState::Testing(ImageStateReason::VersionCheck), - ) - .await?; - - info!("Beginning self-test"); - - let test_fut = pal.complete_callback(pal::OtaEvent::StartTest); - - match config.self_test_timeout { - Some(timeout) => embassy_time::with_timeout(timeout, test_fut) - .await - .map_err(|_| error::OtaError::Timeout)?, - None => test_fut.await, - }?; - - control - .update_job_status( - &mut file_ctx, - &config, - JobStatus::Succeeded, - JobStatusReason::Accepted, - ) - .await?; - - return Ok(()); - } - (false, false) => {} - (false, true) => { - // Received a job that is not in self-test but platform is, so - // reboot the device to allow roll back to previous image. - error!("Rejecting new image and rebooting: The platform is in the self-test state while the job is not."); - pal.reset_device().await?; - return Err(error::OtaError::ResetFailed); - } - (true, false) => { - // The job is in self test but the platform image state is not so it - // could be an attack on the platform image state. Reject the update - // (this should also cause the image to be erased), aborting the job - // and reset the device. - error!("Rejecting new image and rebooting: the job is in the self-test state while the platform is not."); - // loop { - // embassy_time::Timer::after_secs(1).await; - // } - Self::set_image_state_with_reason( - control, - pal, - &config, - &mut file_ctx, - ImageState::Rejected(ImageStateReason::ImageStateMismatch), - ) - .await?; - pal.reset_device().await?; - return Err(error::OtaError::ResetFailed); - } - } - - if !file_ctx.protocols.contains(&D::PROTOCOL) { - error!("Unable to handle current OTA job with given data interface ({:?}). Supported protocols: {:?}. Aborting current update.", D::PROTOCOL, file_ctx.protocols); - Self::set_image_state_with_reason( - control, - pal, - &config, - &mut file_ctx, - ImageState::Aborted(ImageStateReason::InvalidDataProtocol), - ) - .await?; - return Err(error::OtaError::InvalidInterface); - } + let progress_state = Mutex::new(ProgressState { + total_blocks: (file_ctx.filesize + config.block_size - 1) / config.block_size, + blocks_remaining: (file_ctx.filesize + config.block_size - 1) / config.block_size, + block_offset: file_ctx.block_offset, + request_block_remaining: file_ctx.bitmap.len() as u32, + bitmap: file_ctx.bitmap.clone(), + file_size: file_ctx.filesize, + request_momentum: 0, + status_details: file_ctx.status_details.clone(), + }); + + // Create the JobUpdater + let mut job_updater = JobUpdater::new(&file_ctx, &progress_state, &config, control); + + match job_updater.initialize::(pal).await? { + Some(()) => {} + None => return Ok(()), + }; info!("Job document was accepted. Attempting to begin the update"); - let request_momentum = AtomicU8::new(0); - - // FIXME: - // let momentum_fut = async { - // while file_ctx.lock().await.blocks_remaining > 0 { - // if request_momentum.load(Ordering::Relaxed) <= config.max_request_momentum { - // // Each request increases the momentum until a response is - // // received. Too much momentum is interpreted as a failure to - // // communicate and will cause us to abort the OTA. - // request_momentum.fetch_add(1, Ordering::Relaxed); + // Spawn the request momentum future + let momentum_fut = Self::handle_momentum(data, &config, &file_ctx, &progress_state); - // // Reset number of blocks requested - // let mut ctx = file_ctx.lock().await; - // ctx.request_block_remaining = ctx.bitmap.len() as u32; - - // // Request data blocks - // data.request_file_block(&ctx, &config).await?; - // } else { - // // Too many requests have been sent without a response or too - // // many failures when trying to publish the request message. - // // Abort. - // return Err(error::OtaError::MomentumAbort); - // } - - // embassy_time::Timer::after(config.request_wait).await; - // } - - // Ok(()) - // }; + // Spawn the status update future + let status_update_fut = job_updater.handle_status_updates(); + // Spawn the data handling future let data_fut = async { // Create/Open the OTA file on the file system - let block_writer = match pal.create_file_for_rx(&file_ctx).await { + let mut block_writer = match pal.create_file_for_rx(&file_ctx).await { Ok(block_writer) => block_writer, Err(e) => { - Self::set_image_state_with_reason( - control, - pal, - &config, - &mut file_ctx, - ImageState::Aborted(ImageStateReason::Pal(e)), - ) - .await?; + job_updater + .set_image_state_with_reason( + pal, + ImageState::Aborted(ImageStateReason::Pal(e)), + ) + .await?; pal.close_file(&file_ctx).await?; return Err(e.into()); @@ -227,155 +95,172 @@ impl Updater { info!("Initialized file handler! Requesting file blocks"); + // Prepare the storage layer on receiving a new file + let mut subscription = data.init_file_transfer(&file_ctx).await?; + + { + let mut progress = progress_state.lock().await; + data.request_file_blocks(&file_ctx, &mut progress, &config) + .await?; + } + + info!("Awaiting file blocks!"); + loop { - // Prepare the storage layer on receiving a new file - let mut subscription = data.init_file_transfer(&mut file_ctx).await?; - - data.request_file_block(&mut file_ctx, &config).await?; - - info!("Awaiting file blocks!"); - - while let Some(mut payload) = subscription.next_block().await? { - debug!("process_data_handler"); - // Decode the file block received - match Self::ingest_data_block( - data, - block_writer, - &config, - &mut file_ctx, - payload.deref_mut(), - ) - .await - { - Ok(true) => match pal.close_file(&file_ctx).await { - Err(e) => { - control - .update_job_status( - &mut file_ctx, - &config, - JobStatus::Failed, - JobStatusReason::Pal(0), - ) - .await?; - - return Err(e.into()); + // Select over the futures + match subscription.next_block().await { + Ok(Some(mut payload)) => { + warn!("Got block!"); + // Decode the file block received + let mut progress = progress_state.lock().await; + + match Self::ingest_data_block( + data, + &mut block_writer, + &config, + &mut progress, + payload.deref_mut(), + ) + .await + { + Ok(true) => { + // ... (Handle end of file) ... + match pal.close_file(&file_ctx).await { + Err(e) => { + job_updater.signal_update( + JobStatus::Failed, + JobStatusReason::Pal(0), + ); + + return Err(e.into()); + } + Ok(_) if file_ctx.file_type == Some(0) => { + job_updater.signal_update( + JobStatus::InProgress, + JobStatusReason::SigCheckPassed, + ); + return Ok(()); + } + Ok(_) => { + job_updater.signal_update( + JobStatus::Succeeded, + JobStatusReason::Accepted, + ); + return Ok(()); + } + } } - Ok(_) => { - let (status, reason, event) = if let Some(0) = file_ctx.file_type { - ( + Ok(false) => { + // ... (Handle successful block processing) ... + progress.request_momentum = 0; + + // Update the job status to reflect the download progress + if progress.blocks_remaining + % config.status_update_frequency as usize + == 0 + { + job_updater.signal_update( JobStatus::InProgress, - JobStatusReason::SigCheckPassed, - pal::OtaEvent::Activate, - ) + JobStatusReason::Receiving, + ); + } + + if progress.request_block_remaining > 1 { + progress.request_block_remaining -= 1; } else { - ( - JobStatus::Succeeded, - JobStatusReason::Accepted, - pal::OtaEvent::UpdateComplete, - ) - }; - - control - .update_job_status(&mut file_ctx, &config, status, reason) - .await?; - - return Ok(event); + data.request_file_blocks(&file_ctx, &mut progress, &config) + .await?; + + warn!("Done requesting more blocks! {:#?}", progress); + } } - }, - Ok(false) => { - debug!("Ingested one block!"); - // Reset the momentum counter since we received a good block - request_momentum.store(0, Ordering::Relaxed); - - // We're actively receiving a file so update the job status as - // needed - control - .update_job_status( - &mut file_ctx, - &config, - JobStatus::InProgress, - JobStatusReason::Receiving, - ) - .await?; - - if file_ctx.request_block_remaining > 1 { - file_ctx.request_block_remaining -= 1; - } else { - data.request_file_block(&mut file_ctx, &config).await?; + Err(e) if e.is_retryable() => { + // ... (Handle retryable errors) ... + } + Err(e) => { + // ... (Handle fatal errors) ... + return Err(e); } - } - Err(e) if e.is_retryable() => { - warn!("Failed to ingest data block, Error is retryable! ingest_data_block returned error {:?}", e); - } - Err(e) => { - error!("Failed to ingest data block, rejecting image: ingest_data_block returned error {:?}", e); - - // Call the platform specific code to reject the image - // TODO: This should never write to current image flags?! - // pal.set_platform_image_state(ImageState::Rejected( - // ImageStateReason::FailedIngest, - // )) - // .await?; - - // TODO: Pal reason - control - .update_job_status( - &mut file_ctx, - &config, - JobStatus::Failed, - JobStatusReason::Pal(0), - ) - .await?; - - pal.complete_callback(pal::OtaEvent::Fail).await?; - info!("Application callback! OtaEvent::Fail"); - return Err(e); } } + Ok(None) => { + error!("Stream ended unexpectedly"); + // Handle the case where next_block returns None, + // this might mean the stream has ended unexpectedly. + todo!(); + } + + // Handle status update future results + Err(e) => { + error!("Status update error: {:?}", e); + // Handle the error appropriately. + todo!(); + } } } }; - // let (momentum_res, data_res) = embassy_futures::join::join(momentum_fut, data_fut).await; + let (data_res, _) = embassy_futures::join::join( + data_fut, + embassy_futures::select::select(status_update_fut, momentum_fut), + ) + .await; + + // Cleanup and update the job status accordingly + match data_res { + Ok(()) => { + let event = if let Some(0) = file_ctx.file_type { + pal::OtaEvent::Activate + } else { + pal::OtaEvent::UpdateComplete + }; - let data_res = data_fut.await; + pal.complete_callback(event).await?; - // if let Err(e) = momentum_res { - // // Failed to send data request abort and close file. - // Self::set_image_state_with_reason( - // control, - // pal, - // &config, - // &mut file_ctx, - // ImageState::Aborted(ImageStateReason::MomentumAbort), - // ) - // .await?; + Ok(()) + } + Err(error::OtaError::MomentumAbort) => { + job_updater + .set_image_state_with_reason( + pal, + ImageState::Aborted(ImageStateReason::MomentumAbort), + ) + .await?; - // return Err(e); - // }; + Err(error::OtaError::MomentumAbort) + } + Err(e) => { + // Signal the error status + job_updater + .update_job_status(JobStatus::Failed, JobStatusReason::Pal(0)) + .await?; - pal.complete_callback(data_res?).await?; + pal.complete_callback(pal::OtaEvent::Fail).await?; + info!("Application callback! OtaEvent::Fail"); - Ok(()) + Err(e) + } + } } async fn ingest_data_block<'a, D: DataInterface>( data: &D, block_writer: &mut impl NorFlash, config: &config::Config, - file_ctx: &mut FileContext, + progress: &mut ProgressState, payload: &mut [u8], ) -> Result { - let block = data.decode_file_block(&file_ctx, payload)?; - if block.validate(config.block_size, file_ctx.filesize) { - if block.block_id < file_ctx.block_offset as usize - || !file_ctx + let block = data.decode_file_block(payload)?; + + if block.validate(config.block_size, progress.file_size) { + if block.block_id < progress.block_offset as usize + || !progress .bitmap - .get(block.block_id - file_ctx.block_offset as usize) + .get(block.block_id - progress.block_offset as usize) { info!( "Block {:?} is a DUPLICATE. {:?} blocks remaining.", - block.block_id, file_ctx.blocks_remaining + block.block_id, progress.blocks_remaining ); // Just return same progress as before @@ -384,7 +269,7 @@ impl Updater { info!( "Received block {}. {:?} blocks remaining.", - block.block_id, file_ctx.blocks_remaining + block.block_id, progress.blocks_remaining ); block_writer @@ -395,25 +280,25 @@ impl Updater { .await .map_err(|e| error::OtaError::Write(e.kind()))?; - let block_offset = file_ctx.block_offset; - file_ctx + let block_offset = progress.block_offset; + progress .bitmap .set(block.block_id - block_offset as usize, false); - file_ctx.blocks_remaining -= 1; + progress.blocks_remaining -= 1; - if file_ctx.blocks_remaining == 0 { + if progress.blocks_remaining == 0 { info!("Received final expected block of file."); // Return true to indicate end of file. Ok(true) } else { - if file_ctx.bitmap.is_empty() { - file_ctx.block_offset += 31; - file_ctx.bitmap = encoding::Bitmap::new( - file_ctx.filesize, + if progress.bitmap.is_empty() { + progress.block_offset += 31; + progress.bitmap = encoding::Bitmap::new( + progress.file_size, config.block_size, - file_ctx.block_offset, + progress.block_offset, ); } @@ -429,15 +314,193 @@ impl Updater { } } - async fn set_image_state_with_reason<'a, C: ControlInterface, PAL: pal::OtaPal>( - control: &C, - pal: &mut PAL, + async fn handle_momentum( + data: &D, config: &config::Config, - file_ctx: &mut FileContext, + file_ctx: &FileContext, + progress_state: &Mutex, + ) -> Result<(), error::OtaError> { + loop { + embassy_time::Timer::after(config.request_wait).await; + + let mut progress = progress_state.lock().await; + + if progress.blocks_remaining == 0 { + // No more blocks to request + break; + } + + if progress.request_momentum <= config.max_request_momentum { + // Increment momentum + progress.request_momentum += 1; + + warn!("Momentum requesting more blocks!"); + + // Request data blocks + data.request_file_blocks(file_ctx, &mut progress, config) + .await?; + } else { + // Too much momentum, abort + return Err(error::OtaError::MomentumAbort); + } + } + + Ok(()) + } +} + +#[derive(Clone, Debug)] +pub struct ProgressState { + pub total_blocks: usize, + pub blocks_remaining: usize, + pub file_size: usize, + pub block_offset: u32, + pub request_block_remaining: u32, + pub request_momentum: u8, + pub bitmap: Bitmap, + pub status_details: StatusDetailsOwned, +} + +pub struct JobUpdater<'a, C: ControlInterface> { + pub file_ctx: &'a FileContext, + pub progress_state: &'a Mutex, + pub config: &'a config::Config, + pub control: &'a C, + pub status_update_signal: Signal, +} + +impl<'a, C: ControlInterface> JobUpdater<'a, C> { + pub fn new( + file_ctx: &'a FileContext, + progress_state: &'a Mutex, + config: &'a config::Config, + control: &'a C, + ) -> Self { + Self { + file_ctx, + progress_state, + config, + control, + status_update_signal: Signal::::new(), + } + } + + async fn initialize( + &mut self, + pal: &mut PAL, + ) -> Result, error::OtaError> { + // If the job is in self test mode, don't start an OTA update but + // instead do the following: + // + // If the firmware that performed the update was older than the + // currently running firmware, set the image state to "Testing." This is + // the success path. + // + // If it's the same or newer, reject the job since either the firmware + // was not accepted during self test or an incorrect image was sent by + // the OTA operator. + let platform_self_test = pal + .get_platform_image_state() + .await + .map_or(false, |i| i == pal::PalImageState::PendingCommit); + + match (self.file_ctx.self_test(), platform_self_test) { + (true, true) => { + // Run self-test! + self.set_image_state_with_reason( + pal, + ImageState::Testing(ImageStateReason::VersionCheck), + ) + .await?; + + info!("Beginning self-test"); + + let test_fut = pal.complete_callback(pal::OtaEvent::StartTest); + + match self.config.self_test_timeout { + Some(timeout) => embassy_time::with_timeout(timeout, test_fut) + .await + .map_err(|_| error::OtaError::Timeout)?, + None => test_fut.await, + }?; + + let mut progress = self.progress_state.lock().await; + self.control + .update_job_status( + &self.file_ctx, + &mut progress, + self.config, + JobStatus::Succeeded, + JobStatusReason::Accepted, + ) + .await?; + + return Ok(None); + } + (false, false) => {} + (false, true) => { + // Received a job that is not in self-test but platform is, so + // reboot the device to allow roll back to previous image. + error!("Rejecting new image and rebooting: The platform is in the self-test state while the job is not."); + pal.reset_device().await?; + return Err(error::OtaError::ResetFailed); + } + (true, false) => { + // The job is in self test but the platform image state is not so it + // could be an attack on the platform image state. Reject the update + // (this should also cause the image to be erased), aborting the job + // and reset the device. + error!("Rejecting new image and rebooting: the job is in the self-test state while the platform is not."); + self.set_image_state_with_reason( + pal, + ImageState::Rejected(ImageStateReason::ImageStateMismatch), + ) + .await?; + + pal.reset_device().await?; + return Err(error::OtaError::ResetFailed); + } + } + + if !self.file_ctx.protocols.contains(&D::PROTOCOL) { + error!("Unable to handle current OTA job with given data interface ({:?}). Supported protocols: {:?}. Aborting current update.", D::PROTOCOL, self.file_ctx.protocols); + self.set_image_state_with_reason( + pal, + ImageState::Aborted(ImageStateReason::InvalidDataProtocol), + ) + .await?; + return Err(error::OtaError::InvalidInterface); + } + + Ok(Some(())) + } + + async fn handle_status_updates(&self) -> Result<(), error::OtaError> { + loop { + // Wait for a signal from the main loop + let (status, reason) = self.status_update_signal.wait().await; + + // Update the job status based on the signal + warn!("Signaled status update {:?} {:?}", status, reason); + + let mut progress = self.progress_state.lock().await; + self.control + .update_job_status(self.file_ctx, &mut progress, self.config, status, reason) + .await?; + + match status { + JobStatus::Queued | JobStatus::InProgress => {} + _ => return Ok(()), + } + } + } + + async fn set_image_state_with_reason( + &self, + pal: &mut PAL, image_state: ImageState, ) -> Result<(), error::OtaError> { // Call the platform specific code to set the image state - let image_state = match pal.set_platform_image_state(image_state).await { Err(e) if !matches!(image_state, ImageState::Aborted(_)) => { // If the platform image state couldn't be set correctly, force @@ -459,14 +522,17 @@ impl Updater { }; // Now update the image state and job status on server side + let mut progress = self.progress_state.lock().await; + match image_state { ImageState::Testing(_) => { // We discovered we're ready for test mode, put job status // in self_test active - control + self.control .update_job_status( - file_ctx, - config, + &self.file_ctx, + &mut progress, + self.config, JobStatus::InProgress, JobStatusReason::SelfTestActive, ) @@ -475,10 +541,11 @@ impl Updater { ImageState::Accepted => { // Now that we have accepted the firmware update, we can // complete the job - control + self.control .update_job_status( - file_ctx, - config, + &self.file_ctx, + &mut progress, + self.config, JobStatus::Succeeded, JobStatusReason::Accepted, ) @@ -488,10 +555,12 @@ impl Updater { // The firmware update was rejected, complete the job as // FAILED (Job service will not allow us to set REJECTED // after the job has been started already). - control + + self.control .update_job_status( - file_ctx, - config, + &self.file_ctx, + &mut progress, + self.config, JobStatus::Failed, JobStatusReason::Rejected, ) @@ -501,10 +570,12 @@ impl Updater { // The firmware update was aborted, complete the job as // FAILED (Job service will not allow us to set REJECTED // after the job has been started already). - control + + self.control .update_job_status( - file_ctx, - config, + &self.file_ctx, + &mut progress, + self.config, JobStatus::Failed, JobStatusReason::Aborted, ) @@ -513,4 +584,23 @@ impl Updater { } Ok(()) } + + // Function to signal the status update future + pub fn signal_update(&self, status: JobStatus, reason: JobStatusReason) { + self.status_update_signal.signal((status, reason)); + } + + // Function to update the job status + pub async fn update_job_status( + &mut self, + status: JobStatus, + reason: JobStatusReason, + ) -> Result<(), error::OtaError> { + let mut progress = self.progress_state.lock().await; + + self.control + .update_job_status(&self.file_ctx, &mut progress, self.config, status, reason) + .await?; + Ok(()) + } } diff --git a/tests/common/credentials.rs b/tests/common/credentials.rs index 6b9c7e3..b1bb872 100644 --- a/tests/common/credentials.rs +++ b/tests/common/credentials.rs @@ -4,6 +4,7 @@ use native_tls::{Certificate, Identity}; use p256::ecdsa::SigningKey; use pkcs8::DecodePrivateKey; +#[allow(dead_code)] pub fn identity() -> (&'static str, Identity) { let thing_name = option_env!("THING_NAME").unwrap_or_else(|| "rustot-test"); let pw = env::var("IDENTITY_PASSWORD").unwrap_or_default(); @@ -13,6 +14,7 @@ pub fn identity() -> (&'static str, Identity) { ) } +#[allow(dead_code)] pub fn claim_identity() -> (&'static str, Identity) { let thing_name = option_env!("THING_NAME").unwrap_or_else(|| "rustot-provision"); let pw = env::var("IDENTITY_PASSWORD").unwrap_or_default(); @@ -27,6 +29,7 @@ pub fn root_ca() -> Certificate { Certificate::from_pem(include_bytes!("../secrets/root-ca.pem")).unwrap() } +#[allow(dead_code)] pub fn signing_key() -> SigningKey { let pw = env::var("IDENTITY_PASSWORD").unwrap_or_default(); SigningKey::from_pkcs8_encrypted_pem(include_str!("../secrets/sign_private.pem"), pw).unwrap() diff --git a/tests/ota.rs b/tests/ota_mqtt.rs similarity index 71% rename from tests/ota.rs rename to tests/ota_mqtt.rs index da8c020..ab4ea8c 100644 --- a/tests/ota.rs +++ b/tests/ota_mqtt.rs @@ -3,13 +3,16 @@ mod common; +use bitmaps::{Bits, BitsImpl}; use common::credentials; use common::file_handler::{FileHandler, State as FileHandlerState}; use common::network::TlsNetwork; use embassy_futures::select; use embassy_sync::blocking_mutex::raw::NoopRawMutex; use embedded_mqtt::transport::embedded_nal::NalTransport; -use embedded_mqtt::{Config, DomainBroker, Message, State, Subscribe, SubscribeTopic}; +use embedded_mqtt::{ + Config, DomainBroker, Message, SliceBufferProvider, State, Subscribe, SubscribeTopic, +}; use futures::StreamExt; use serde::Deserialize; use static_cell::StaticCell; @@ -42,9 +45,12 @@ impl<'a> Jobs<'a> { } fn handle_ota<'a, const SUBS: usize>( - message: Message<'a, SUBS>, + message: Message<'a, SliceBufferProvider<'a>, SUBS>, config: &ota::config::Config, -) -> Option { +) -> Option +where + BitsImpl: Bits, +{ let job = match jobs::Topic::from_str(message.topic_name()) { Some(jobs::Topic::NotifyNext) => { let (execution_changed, _) = @@ -92,10 +98,13 @@ async fn test_mqtt_ota() { // Create the MQTT stack let broker = DomainBroker::<_, 128>::new(format!("{}:8883", hostname).as_str(), network).unwrap(); - let config = Config::new(thing_name).keepalive_interval(embassy_time::Duration::from_secs(50)); + let config = Config::builder() + .client_id(thing_name.try_into().unwrap()) + .keepalive_interval(embassy_time::Duration::from_secs(50)) + .build(); - static STATE: StaticCell> = StaticCell::new(); - let state = STATE.init(State::::new()); + static STATE: StaticCell> = StaticCell::new(); + let state = STATE.init(State::new()); let (mut stack, client) = embedded_mqtt::new(state, config); let mut file_handler = FileHandler::new("tests/assets/ota_file".to_owned()); @@ -127,33 +136,41 @@ async fn test_mqtt_ota() { Updater::check_for_job(&client).await?; let config = ota::config::Config::default(); - while let Some(message) = jobs_subscription.next().await { - if let Some(mut file_ctx) = handle_ota(message, &config) { - // We have an OTA job, leeeets go! - Updater::perform_ota( - &client, - &client, - file_ctx.clone(), - &mut file_handler, - &config, + + let message = jobs_subscription.next().await.unwrap(); + + if let Some(mut file_ctx) = handle_ota(message, &config) { + // Nested subscriptions are a problem for embedded-mqtt, so drop the + // subscription here + drop(jobs_subscription); + + // We have an OTA job, leeeets go! + Updater::perform_ota( + &client, + &client, + file_ctx.clone(), + &mut file_handler, + &config, + ) + .await?; + + assert_eq!(file_handler.plateform_state, FileHandlerState::Swap); + + log::info!("Running OTA handler second time to verify state match..."); + + // Run it twice in this particular integration test, in order to + // simulate image commit after bootloader swap + file_ctx + .status_details + .insert( + heapless::String::try_from("self_test").unwrap(), + heapless::String::try_from("active").unwrap(), ) - .await?; - - assert_eq!(file_handler.plateform_state, FileHandlerState::Swap); - - // Run it twice in this particular integration test, in order to simulate image commit after bootloader swap - file_ctx - .status_details - .insert( - heapless::String::try_from("self_test").unwrap(), - heapless::String::try_from("active").unwrap(), - ) - .unwrap(); - Updater::perform_ota(&client, &client, file_ctx, &mut file_handler, &config) - .await?; - - return Ok(()); - } + .unwrap(); + + Updater::perform_ota(&client, &client, file_ctx, &mut file_handler, &config).await?; + + return Ok(()); } Ok::<_, ota::error::OtaError>(()) diff --git a/tests/provisioning.rs b/tests/provisioning.rs index e9eafd5..44e9570 100644 --- a/tests/provisioning.rs +++ b/tests/provisioning.rs @@ -75,7 +75,10 @@ async fn test_provisioning() { // Create the MQTT stack let broker = DomainBroker::<_, 128>::new(format!("{}:8883", hostname).as_str(), network).unwrap(); - let config = Config::new(thing_name).keepalive_interval(embassy_time::Duration::from_secs(50)); + let config = Config::builder() + .client_id(thing_name.try_into().unwrap()) + .keepalive_interval(embassy_time::Duration::from_secs(50)) + .build(); static STATE: StaticCell> = StaticCell::new(); let state = STATE.init(State::::new()); diff --git a/tests/shadows.rs b/tests/shadows.rs index af788dd..5ea4d6c 100644 --- a/tests/shadows.rs +++ b/tests/shadows.rs @@ -1,533 +1,238 @@ -// //! -// //! ## Integration test of `AWS IoT Shadows` -// //! -// //! -// //! This test simulates updates of the shadow state from both device side & -// //! cloud side. Cloud side updates are done by publishing directly to the shadow -// //! topics, and ignoring the resulting update accepted response. Device side -// //! updates are done through the shadow API provided by this crate. -// //! -// //! The test runs through the following update sequence: -// //! 1. Setup clean starting point (`desired = null, reported = null`) -// //! 2. Do a `GetShadow` request to sync empty state -// //! 3. Update to initial shadow state from the device -// //! 4. Assert on the initial state -// //! 5. Update state from device -// //! 6. Assert on shadow state -// //! 7. Update state from cloud -// //! 8. Assert on shadow state -// //! 9. Update state from device -// //! 10. Assert on shadow state -// //! 11. Update state from cloud -// //! 12. Assert on shadow state -// //! - -// mod common; - -// use core::fmt::Write; - -// use common::{ -// clock::SysClock, -// credentials, -// network::{Network, TlsNetwork}, -// }; -// use embassy_futures::select; -// use embassy_sync::blocking_mutex::raw::NoopRawMutex; -// use embedded_mqtt::{ -// transport::embedded_nal::{self, NalTransport}, -// DomainBroker, Properties, Publish, QoS, State, -// }; -// use embedded_nal::Ipv4Addr; -// use mqttrust::Mqtt; -// use mqttrust_core::{bbqueue::BBBuffer, EventLoop, MqttOptions, Notification}; -// use native_tls::TlsConnector; -// use rustot::shadows::{ -// derive::ShadowState, topics::Topic, Patch, Shadow, ShadowPatch, ShadowState, -// }; -// use serde::{de::DeserializeOwned, Deserialize, Serialize}; - -// use smlang::statemachine; -// use static_cell::StaticCell; - -// #[derive(Debug, Clone, Default, PartialEq, Eq, PartialOrd, Ord)] -// pub struct ConfigId(pub u8); - -// impl Serialize for ConfigId { -// fn serialize(&self, serializer: S) -> Result -// where -// S: serde::Serializer, -// { -// let mut str = heapless::String::<3>::new(); -// write!(str, "{}", self.0).map_err(serde::ser::Error::custom)?; -// serializer.serialize_str(&str) -// } -// } - -// impl<'de> Deserialize<'de> for ConfigId { -// fn deserialize(deserializer: D) -> Result -// where -// D: serde::Deserializer<'de>, -// { -// heapless::String::<3>::deserialize(deserializer)? -// .parse() -// .map(ConfigId) -// .map_err(serde::de::Error::custom) -// } -// } - -// impl From for ConfigId { -// fn from(v: u8) -> Self { -// Self(v) -// } -// } - -// #[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)] -// pub struct NetworkMap(heapless::LinearMap>, N>); - -// impl NetworkMap -// where -// K: Eq, -// { -// pub fn insert(&mut self, k: impl Into, v: V) -> Result<(), ()> { -// self.0.insert(k.into(), Some(Patch::Set(v))).map_err(drop)?; -// Ok(()) -// } - -// pub fn remove(&mut self, k: impl Into) -> Result<(), ()> { -// self.0.insert(k.into(), None).map_err(drop)?; -// Ok(()) -// } -// } - -// impl ShadowPatch for NetworkMap -// where -// K: Clone + Default + Eq + Serialize + DeserializeOwned, -// V: Clone + Default + Serialize + DeserializeOwned, -// { -// type PatchState = NetworkMap; - -// fn apply_patch(&mut self, opt: Self::PatchState) { -// for (id, network) in opt.0.into_iter() { -// match network { -// Some(Patch::Set(v)) => { -// self.insert(id.clone(), v.clone()).ok(); -// } -// None | Some(Patch::Unset) => { -// self.remove(id.clone()).ok(); -// } -// } -// } -// } -// } - -// const MAX_NETWORKS: usize = 5; -// type KnownNetworks = NetworkMap; - -// #[derive(Debug, Clone, Default, Serialize, Deserialize, ShadowState)] -// #[shadow("wifi")] -// pub struct WifiConfig { -// pub enabled: bool, - -// pub known_networks: KnownNetworks, -// } - -// #[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)] -// pub struct ConnectionOptions { -// pub ssid: heapless::String<64>, -// pub password: Option>, - -// pub ip: Option, -// pub subnet: Option, -// pub gateway: Option, -// } - -// #[derive(Debug, Clone)] -// pub enum UpdateAction { -// Insert(u8, ConnectionOptions), -// Remove(u8), -// Enabled(bool), -// } - -// statemachine! { -// transitions: { -// *Begin + Delete = DeleteShadow, -// DeleteShadow + Get = GetShadow, -// GetShadow + Load / load_initial = LoadShadow(Option), -// LoadShadow(Option) + CheckInitial / check_initial = Check(Option), -// UpdateFromDevice(UpdateAction) + CheckState / check = Check(Option), -// UpdateFromCloud(UpdateAction) + Ack = AckUpdate, -// AckUpdate + CheckState / check_cloud = Check(Option), -// Check(Option) + UpdateStateFromDevice / get_next_device = UpdateFromDevice(UpdateAction), -// Check(Option) + UpdateStateFromCloud / get_next_cloud = UpdateFromCloud(UpdateAction), -// } -// } - -// impl core::fmt::Debug for States { -// fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { -// match self { -// Self::Begin => write!(f, "Self::Begin"), -// Self::DeleteShadow => write!(f, "Self::DeleteShadow"), -// Self::GetShadow => write!(f, "Self::GetShadow"), -// Self::AckUpdate => write!(f, "Self::AckUpdate"), -// Self::LoadShadow(t) => write!(f, "Self::LoadShadow({:?})", t), -// Self::UpdateFromDevice(t) => write!(f, "Self::UpdateFromDevice({:?})", t), -// Self::UpdateFromCloud(t) => write!(f, "Self::UpdateFromCloud({:?})", t), -// Self::Check(t) => write!(f, "Self::Check({:?})", t), -// } -// } -// } - -// fn asserts(id: usize) -> ConnectionOptions { -// match id { -// 0 => ConnectionOptions { -// ssid: heapless::String::from("MySSID"), -// password: None, -// ip: None, -// subnet: None, -// gateway: None, -// }, -// 1 => ConnectionOptions { -// ssid: heapless::String::from("MyProtectedSSID"), -// password: Some(heapless::String::from("SecretPass")), -// ip: None, -// subnet: None, -// gateway: None, -// }, -// 2 => ConnectionOptions { -// ssid: heapless::String::from("CloudSSID"), -// password: Some(heapless::String::from("SecretCloudPass")), -// ip: Some(Ipv4Addr::new(1, 2, 3, 4)), -// subnet: None, -// gateway: None, -// }, -// _ => panic!("Unknown assert ID"), -// } -// } - -// pub struct TestContext<'a> { -// shadow: Shadow<'a, WifiConfig, mqttrust_core::Client<'static, 'static, Q_SIZE>>, -// update_cnt: u8, -// } - -// impl<'a> StateMachineContext for TestContext<'a> { -// fn check_initial( -// &mut self, -// _last_update_action: &Option, -// ) -> Option { -// self.check(&UpdateAction::Remove(0)) -// } - -// fn check_cloud(&mut self) -> Option { -// self.check(&UpdateAction::Remove(0)) -// } - -// fn check(&mut self, _last_update_action: &UpdateAction) -> Option { -// let mut known_networks = KnownNetworks::default(); - -// match self.update_cnt { -// 0 => { -// // After load_initial -// known_networks.insert(0, asserts(0)).unwrap(); -// known_networks.insert(1, asserts(1)).unwrap(); -// } -// 1 => { -// // After get_next_device -// known_networks.remove(0).unwrap(); -// known_networks.insert(1, asserts(1)).unwrap(); -// } -// 2 => { -// // After get_next_cloud -// known_networks.remove(0).unwrap(); -// known_networks.insert(1, asserts(1)).unwrap(); -// known_networks.insert(2, asserts(2)).unwrap(); -// } -// 3 => { -// // After get_next_device -// known_networks.insert(0, asserts(0)).unwrap(); -// known_networks.insert(1, asserts(1)).unwrap(); -// known_networks.insert(2, asserts(2)).unwrap(); -// } -// 4 => { -// // After get_next_cloud -// known_networks.insert(0, asserts(0)).unwrap(); -// known_networks.insert(1, asserts(1)).unwrap(); -// known_networks.remove(2).unwrap(); -// } -// 5 => return None, -// _ => {} -// } - -// Some(known_networks) -// } - -// fn get_next_device(&mut self, _: &Option) -> UpdateAction { -// self.update_cnt += 1; -// match self.update_cnt { -// 1 => UpdateAction::Remove(0), -// 3 => UpdateAction::Insert(0, asserts(0)), -// 5 => UpdateAction::Remove(0), -// _ => panic!("Unexpected update counter in `get_next_device`"), -// } -// } - -// fn get_next_cloud(&mut self, _: &Option) -> UpdateAction { -// self.update_cnt += 1; - -// match self.update_cnt { -// 2 => UpdateAction::Insert(2, asserts(2)), -// 4 => UpdateAction::Remove(2), -// _ => panic!("Unexpected update counter in `get_next_cloud`"), -// } -// } - -// fn load_initial(&mut self) -> Option { -// let mut known_networks = KnownNetworks::default(); -// known_networks.insert(0, asserts(0)).unwrap(); -// known_networks.insert(1, asserts(1)).unwrap(); -// Some(known_networks) -// } -// } - -// impl<'a> StateMachine> { -// pub fn spin( -// &mut self, -// notification: Notification, -// mqtt_client: &embedded_mqtt::MqttClient<'a, M, 1>, -// ) -> bool { -// log::info!("State: {:?}", self.state()); -// match (self.state(), notification) { -// (&States::Begin, Notification::Suback(_)) => { -// self.process_event(Events::Delete).unwrap(); -// } -// (&States::DeleteShadow, Notification::Suback(_)) => { -// mqtt_client -// .publish( -// Publish::builder() -// .topic_name( -// &Topic::Update -// .format::<128>( -// mqtt_client.client_id(), -// ::NAME, -// ) -// .unwrap(), -// ) -// .payload(b"{\"state\":{\"desired\":null,\"reported\":null}}") -// .build(), -// ) -// .await -// .unwrap(); - -// self.process_event(Events::Get).unwrap(); -// } -// (&States::GetShadow, Notification::Publish(publish)) -// if matches!( -// publish.topic_name.as_str(), -// "$aws/things/rustot-test/shadow/name/wifi/update/accepted" -// ) => -// { -// self.context_mut().shadow.get_shadow().unwrap(); -// self.process_event(Events::Load).unwrap(); -// } -// (&States::LoadShadow(ref initial_map), Notification::Publish(publish)) -// if matches!( -// publish.topic_name.as_str(), -// "$aws/things/rustot-test/shadow/name/wifi/get/accepted" -// ) => -// { -// let initial_map = initial_map.clone(); - -// self.context_mut() -// .shadow -// .update(|_current, desired| { -// desired.known_networks = Some(initial_map.unwrap()); -// }) -// .unwrap(); -// self.process_event(Events::CheckInitial).unwrap(); -// } -// (&States::UpdateFromDevice(ref update_action), Notification::Publish(publish)) -// if matches!( -// publish.topic_name.as_str(), -// "$aws/things/rustot-test/shadow/name/wifi/get/accepted" -// ) => -// { -// let action = update_action.clone(); -// self.context_mut() -// .shadow -// .update(|current, desired| match action { -// UpdateAction::Insert(id, options) => { -// let mut desired_map = current.known_networks.clone(); -// desired_map.insert(id, options).unwrap(); -// desired.known_networks = Some(desired_map); -// } -// UpdateAction::Remove(id) => { -// let mut desired_map = current.known_networks.clone(); -// desired_map.remove(id).unwrap(); -// desired.known_networks = Some(desired_map); -// } -// UpdateAction::Enabled(en) => { -// desired.enabled = Some(en); -// } -// }) -// .unwrap(); -// self.process_event(Events::CheckState).unwrap(); -// } -// (&States::UpdateFromCloud(ref update_action), Notification::Publish(publish)) -// if matches!( -// publish.topic_name.as_str(), -// "$aws/things/rustot-test/shadow/name/wifi/get/accepted" -// ) => -// { -// let desired_known_networks = match update_action { -// UpdateAction::Insert(id, options) => format!( -// "\"known_networks\": {{\"{}\":{{\"set\":{}}}}}", -// id, -// serde_json_core::to_string::<_, 256>(options).unwrap() -// ), -// UpdateAction::Remove(id) => { -// format!("\"known_networks\": {{\"{}\":\"unset\"}}", id) -// } -// &UpdateAction::Enabled(en) => format!("\"enabled\": {}", en), -// }; - -// let payload = format!( -// "{{\"state\":{{\"desired\":{{{}}}, \"reported\":{}}}}}", -// desired_known_networks, -// serde_json_core::to_string::<_, 512>(self.context().shadow.get()).unwrap() -// ); - -// log::debug!("Update from cloud: {:?}", payload); - -// mqtt_client -// .publish( -// Publish::builder() -// .topic_name( -// &Topic::Update -// .format::<128>( -// mqtt_client.client_id(), -// ::NAME, -// ) -// .unwrap(), -// ) -// .payload(payload.as_bytes()) -// .build(), -// ) -// .unwrap(); -// self.process_event(Events::Ack).unwrap(); -// } -// (&States::AckUpdate, Notification::Publish(publish)) -// if matches!( -// publish.topic_name.as_str(), -// "$aws/things/rustot-test/shadow/name/wifi/update/delta" -// ) => -// { -// self.context_mut() -// .shadow -// .handle_message(&publish.topic_name, &publish.payload) -// .unwrap(); - -// self.process_event(Events::CheckState).unwrap(); -// } -// (&States::Check(ref expected_map), Notification::Publish(publish)) -// if matches!( -// publish.topic_name.as_str(), -// "$aws/things/rustot-test/shadow/name/wifi/update/accepted" -// | "$aws/things/rustot-test/shadow/name/wifi/update/delta" -// ) => -// { -// let expected = expected_map.clone(); -// self.context_mut() -// .shadow -// .handle_message(&publish.topic_name, &publish.payload) -// .unwrap(); - -// match expected { -// Some(expected_map) => { -// assert_eq!(self.context().shadow.get().known_networks, expected_map); -// self.context_mut().shadow.get_shadow().unwrap(); -// let event = if self.context().update_cnt % 2 == 0 { -// Events::UpdateStateFromDevice -// } else { -// Events::UpdateStateFromCloud -// }; -// self.process_event(event).unwrap(); -// } -// None => return true, -// } -// } -// (_, Notification::Publish(publish)) => { -// log::warn!("TOPIC: {}", publish.topic_name); -// self.context_mut() -// .shadow -// .handle_message(&publish.topic_name, &publish.payload) -// .unwrap(); -// } -// _ => {} -// } - -// false -// } -// } - -// #[tokio::test(flavor = "current_thread")] -// async fn test_shadows() { -// env_logger::init(); - -// log::info!("Starting shadows test..."); - -// let (thing_name, identity) = credentials::identity(); - -// let hostname = credentials::HOSTNAME.unwrap(); - -// static NETWORK: StaticCell = StaticCell::new(); -// let network = NETWORK.init(TlsNetwork::new(hostname.to_owned(), identity)); - -// // Create the MQTT stack -// let broker = -// DomainBroker::<_, 128>::new(format!("{}:8883", hostname).as_str(), &network).unwrap(); -// let config = embedded_mqtt::Config::new(thing_name) -// .keepalive_interval(embassy_time::Duration::from_secs(50)); - -// let mut state = State::::new(); -// let (mut stack, client) = embedded_mqtt::new(&mut state, config); - -// let mqtt_client = client; - -// let shadow = Shadow::new(WifiConfig::default(), &mqtt_client).unwrap(); - -// // loop { -// // if nb::block!(mqtt_eventloop.connect(&mut network)).expect("to connect to mqtt") { -// // log::info!("Successfully connected to broker"); -// // } - -// // match mqtt_eventloop.yield_event(&mut network) { -// // Ok(notification) => { -// // if test_state.spin(notification, &mqtt_client) { -// // break; -// // } -// // } -// // Err(_) => {} -// // } -// // } - -// // cloud_updater(mqtt_client); - -// let shadows_fut = async { -// shadow.next_update().await; -// todo!() -// }; - -// let mut transport = NalTransport::new(network, broker); - -// match embassy_time::with_timeout( -// embassy_time::Duration::from_secs(25), -// select::select(stack.run(&mut transport), shadows_fut), -// ) -// .await -// .unwrap() -// { -// select::Either::First(_) => { -// unreachable!() -// } -// select::Either::Second(result) => result.unwrap(), -// }; -// } +//! +//! ## Integration test of `AWS IoT Shadows` +//! +//! +//! This test simulates updates of the shadow state from both device side & +//! cloud side. Cloud side updates are done by publishing directly to the shadow +//! topics, and ignoring the resulting update accepted response. Device side +//! updates are done through the shadow API provided by this crate. +//! +//! The test runs through the following update sequence: +//! 1. Setup clean starting point (`desired = null, reported = null`) +//! 2. Do a `GetShadow` request to sync empty state +//! 3. Update to initial shadow state from the device +//! 4. Assert on the initial state +//! 5. Update state from device +//! 6. Assert on shadow state +//! 7. Update state from cloud +//! 8. Assert on shadow state +//! 9. Update state from device +//! 10. Assert on shadow state +//! 11. Update state from cloud +//! 12. Assert on shadow state +//! +#![allow(async_fn_in_trait)] +#![feature(type_alias_impl_trait)] + +mod common; + +use common::credentials; +use common::network::TlsNetwork; +use embassy_futures::select; +use embassy_sync::blocking_mutex::raw::NoopRawMutex; +use embedded_mqtt::{ + self, transport::embedded_nal::NalTransport, Config, DomainBroker, Publish, QoS, State, + Subscribe, SubscribeTopic, +}; +use futures::StreamExt; +use rustot::shadows::{derive::ShadowState, Shadow}; +use serde::{Deserialize, Serialize}; +use static_cell::StaticCell; + +#[derive(Debug, Default, Serialize, Deserialize, ShadowState, PartialEq)] +#[shadow("state")] +pub struct TestShadow { + foo: u32, + // #[serde(skip_serializing_if = "Option::is_none")] + // bar: Option, +} + +#[tokio::test(flavor = "current_thread")] +async fn test_shadow_update_from_device() { + env_logger::init(); + + const DESIRED_1: &str = r#"{ + "state": { + "desired": { + "foo": 42 + } + }, + "metadata": { + "foo": { + "timestamp": 1672047508 + } + }, + "version": 2, + "timestamp": 1672047508 +}"#; + + let (thing_name, identity) = credentials::identity(); + let hostname = credentials::HOSTNAME.unwrap(); + + static NETWORK: StaticCell = StaticCell::new(); + let network = NETWORK.init(TlsNetwork::new(hostname.to_owned(), identity)); + + // Create the MQTT stack + let broker = + DomainBroker::<_, 128>::new(format!("{}:8883", hostname).as_str(), network).unwrap(); + + let config = Config::builder() + .client_id(thing_name.try_into().unwrap()) + .keepalive_interval(embassy_time::Duration::from_secs(50)) + .build(); + static STATE: StaticCell> = StaticCell::new(); + let state = STATE.init(State::::new()); + let (mut stack, client) = embedded_mqtt::new(state, config); + + // Create the shadow + let mut shadow = Shadow::::new(TestShadow::default(), &client); + + let mqtt_fut = async { + let mut update_subscription = client + .subscribe::<2>( + Subscribe::builder() + .topics(&[SubscribeTopic::builder() + .topic_path( + rustot::shadows::topics::Topic::Update + .format::<128>(client.client_id(), Some("state")) + .unwrap() + .as_str(), + ) + .build()]) + .build(), + ) + .await + .unwrap(); + + let mut get_subscription = client + .subscribe::<2>( + Subscribe::builder() + .topics(&[SubscribeTopic::builder() + .topic_path( + rustot::shadows::topics::Topic::Get + .format::<128>(client.client_id(), Some("state")) + .unwrap() + .as_str(), + ) + .build()]) + .build(), + ) + .await + .unwrap(); + + // Force a shadow get first, to sync up state + client + .publish( + Publish::builder() + .topic_name( + rustot::shadows::topics::Topic::Get + .format::<128>(client.client_id(), Some("state")) + .unwrap() + .as_str(), + ) + .qos(QoS::AtLeastOnce) + .payload(&[]) + .build(), + ) + .await + .unwrap(); + + // Wait for the device to try to fetch the shadow first. + let _ = get_subscription.next().await; + + // Initial shadow state update + log::info!("Doing initial shadow update"); + client + .publish( + Publish::builder() + .topic_name( + rustot::shadows::topics::Topic::Update + .format::<128>(client.client_id(), Some("state")) + .unwrap() + .as_str(), + ) + .payload(DESIRED_1.as_bytes()) + .qos(QoS::AtLeastOnce) + .build(), + ) + .await + .unwrap(); + + loop { + select::select(update_subscription.next(), async { + // Device-side update 1 + shadow + .update(|_, desired| { + desired.foo = Some(1337); + }) + .await + .unwrap(); + + let current = shadow.get(); + assert_eq!(current.foo, 1337); + let payload = serde_json_core::to_string::<_, 512>(current).unwrap(); + log::info!("ASSERT-DEVICE: {:?}", payload); + + // Cloud-side update 1 + client + .publish( + Publish::builder() + .topic_name( + rustot::shadows::topics::Topic::Update + .format::<128>(client.client_id(), Some("state")) + .unwrap() + .as_str(), + ) + .payload(r#"{"state": {"desired": {"bar": true}}}"#.as_bytes()) + .qos(QoS::AtLeastOnce) + .build(), + ) + .await + .unwrap(); + + // Device-side update 2 + // shadow + // .update(|_state, desired| { + // // desired.bar = Some(false); + // }) + // .await + // .unwrap(); + + // let current = shadow.get(); + // let payload = serde_json_core::to_string::<_, 512>(current).unwrap(); + // log::info!("ASSERT-DEVICE: {}", payload); + // assert_eq!(current.bar, Some(false)); + + // Cloud-side update 2 + client + .publish( + Publish::builder() + .topic_name( + rustot::shadows::topics::Topic::Update + .format::<128>(client.client_id(), Some("state")) + .unwrap() + .as_str(), + ) + .payload(r#"{"state": {"desired": {"foo": 100}}}"#.as_bytes()) + .qos(QoS::AtLeastOnce) + .build(), + ) + .await + .unwrap(); + + let (s, _) = shadow.wait_delta().await.unwrap(); + let payload = serde_json_core::to_string::<_, 512>(s).unwrap(); + log::info!("ASSERT-DEVICE: {:?}", payload); + assert_eq!(s.foo, 100); + }) + .await; + } + }; + + let mut transport = NalTransport::new(network, broker); + let _ = embassy_time::with_timeout( + embassy_time::Duration::from_secs(60), + select::select(stack.run(&mut transport), mqtt_fut), + ) + .await; +} From 16bdbcaaf8675b592d88bdd832d6a0d2dfd487c2 Mon Sep 17 00:00:00 2001 From: Mathias Date: Thu, 3 Oct 2024 11:48:22 +0200 Subject: [PATCH 24/36] Add more docs --- Cargo.toml | 1 + README.md | 54 ++--- scripts/register.sh | 14 +- src/ota/control_interface/mqtt.rs | 2 +- src/ota/mod.rs | 5 - tests/README.md | 89 ++++++++- tests/shadows.rs | 321 +++++++++++++++--------------- 7 files changed, 275 insertions(+), 211 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 4c3242b..a48cb0c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -47,6 +47,7 @@ env_logger = "0.11" sha2 = "0.10.1" static_cell = { version = "2", features = ["nightly"] } log = { version = "0.4" } +serde_json = "1" tokio = { version = "1.33", default-features = false, features = [ "macros", diff --git a/README.md b/README.md index c0d217a..cd29cd3 100644 --- a/README.md +++ b/README.md @@ -1,34 +1,34 @@ -# Rust of things (rustot) - +# Rust of Things (rustot) **Work in progress** -> no_std, no_alloc crate for AWS IoT Devices, implementing Jobs, OTA, Device Defender and IoT Shadows - -This crates strives to implement the sum of: -- [AWS OTA](https://github.com/aws/ota-for-aws-iot-embedded-sdk) -- [AWS Device Defender](https://github.com/aws/Device-Defender-for-AWS-IoT-embedded-sdk) -- [AWS Jobs](https://github.com/aws/Jobs-for-AWS-IoT-embedded-sdk) -- [AWS Device Shadow](https://github.com/aws/Device-Shadow-for-AWS-IoT-embedded-sdk) -- [AWS IoT Fleet Provisioning](https://github.com/aws/Fleet-Provisioning-for-AWS-IoT-embedded-sdk) - +> A `no_std`, `no_alloc` crate for interacting with AWS IoT services on embedded devices. -![Test][test] -[![Code coverage][codecov-badge]][codecov] -![No Std][no-std-badge] -[![Crates.io Version][crates-io-badge]][crates-io] -[![Crates.io Downloads][crates-io-download-badge]][crates-io-download] +This crate aims to provide a pure-Rust implementation of essential AWS IoT features for embedded systems, inspired by the Amazon FreeRTOS AWS IoT Device SDK. -Any contributions will be welcomed! Even if they are just suggestions, bugs or reviews! +## Features -This is a port of the Amazon-FreeRTOS AWS IoT Device SDK (https://github.com/nguyenvuhung/amazon-freertos/tree/master/libraries/freertos_plus/aws/ota), written in pure Rust. +* **OTA Updates:** ([`ota`] module) + * Download and apply firmware updates securely over MQTT or HTTP. + * Supports both CBOR and raw binary firmware formats. +* **Device Shadow:** ([`shadow`] module) + * Synchronize device state with the cloud using AWS IoT Device Shadow service. + * Get, update, and delete device shadows. +* **Jobs:** ([`jobs`] module) + * Receive and execute jobs remotely on your devices. + * Track job status and report progress to AWS IoT. +* **Device Defender:** ([`defender`] module) + * Implement security best practices and detect anomalies on your devices. +* **Fleet Provisioning:** ([`provisioning`] module) + * Securely provision and connect devices to AWS IoT at scale. +* **Lightweight and `no_std`:** Designed specifically for resource-constrained embedded devices. -It is written to work with [mqttrust](https://github.com/BlackbirdHQ/mqttrust), but should work with any other mqtt client, that implements the [Mqtt trait](https://github.com/BlackbirdHQ/mqttrust/blob/master/mqttrust/src/lib.rs) from mqttrust. +## Contributing -## Tests - -> The crate is covered by tests. These tests can be run by `cargo test --tests --all-features`, and are run by the CI on every push to master. +Contributions, suggestions, bug reports, and reviews are highly appreciated! +Please refer to [CONTRIBUTING.md](CONTRIBUTING.md) for more information on how +to contribute. ## License @@ -45,13 +45,3 @@ at your option. Unless you explicitly state otherwise, any contribution intentionally submitted for inclusion in the work by you, as defined in the Apache-2.0 license, shall be dual licensed as above, without any additional terms or conditions. - - -[test]: https://github.com/BlackbirdHQ/rustot/workflows/Test/badge.svg -[no-std-badge]: https://img.shields.io/badge/no__std-yes-blue -[codecov-badge]: https://codecov.io/gh/BlackbirdHQ/rustot/branch/master/graph/badge.svg -[codecov]: https://codecov.io/gh/BlackbirdHQ/rustot -[crates-io]: https://crates.io/crates/rustot -[crates-io-badge]: https://img.shields.io/crates/v/rustot.svg?maxAge=3600 -[crates-io-download]: https://crates.io/crates/rustot -[crates-io-download-badge]: https://img.shields.io/crates/d/rustot.svg?maxAge=3600 diff --git a/scripts/register.sh b/scripts/register.sh index 95cc849..c0e0662 100755 --- a/scripts/register.sh +++ b/scripts/register.sh @@ -1,14 +1,14 @@ #!/usr/bin/env bash -# Registers the device in Blackbird's DynamoDB containing whitelisted devices to +# Registers the device in Factbird's DynamoDB containing whitelisted devices to # be provisioned. -# +# # This script will populate `tests/secrets` with `claim_certificate.pem.crt` & # `claim_private.pem.key`, as well as combine them into `claim_identity.pfx`, -# which is password protected with `env:DEVICE_ADVISOR_PASSWORD` +# which is password protected with `env:IDENTITY_PASSWORD` -if [[ -z "${DEVICE_ADVISOR_PASSWORD}" ]]; then - echo "DEVICE_ADVISOR_PASSWORD environment variable is required!" +if [[ -z "${IDENTITY_PASSWORD}" ]]; then + echo "IDENTITY_PASSWORD environment variable is required!" exit 1 fi @@ -27,6 +27,6 @@ jq -r '.certificatePem' response.json > $SECRETS_DIR/claim_certificate.pem.crt jq -r '.privateKey' response.json > $SECRETS_DIR/claim_private.pem.key rm response.json -openssl pkcs12 -export -out $SECRETS_DIR/claim_identity.pfx -inkey $SECRETS_DIR/claim_private.pem.key -in $SECRETS_DIR/claim_certificate.pem.crt -certfile $SECRETS_DIR/root-ca.pem -passout pass:$DEVICE_ADVISOR_PASSWORD +openssl pkcs12 -export -out $SECRETS_DIR/claim_identity.pfx -inkey $SECRETS_DIR/claim_private.pem.key -in $SECRETS_DIR/claim_certificate.pem.crt -certfile $SECRETS_DIR/root-ca.pem -passout pass:$IDENTITY_PASSWORD rm $SECRETS_DIR/claim_certificate.pem.crt -rm $SECRETS_DIR/claim_private.pem.key \ No newline at end of file +rm $SECRETS_DIR/claim_private.pem.key diff --git a/src/ota/control_interface/mqtt.rs b/src/ota/control_interface/mqtt.rs index 8cd9cf2..826d5b8 100644 --- a/src/ota/control_interface/mqtt.rs +++ b/src/ota/control_interface/mqtt.rs @@ -140,7 +140,7 @@ where 512, ); - warn!("Updating job status! {:?}", status); + debug!("Updating job status! {:?}", status); self.publish( Publish::builder() diff --git a/src/ota/mod.rs b/src/ota/mod.rs index cba36c9..29ff521 100644 --- a/src/ota/mod.rs +++ b/src/ota/mod.rs @@ -110,7 +110,6 @@ impl Updater { // Select over the futures match subscription.next_block().await { Ok(Some(mut payload)) => { - warn!("Got block!"); // Decode the file block received let mut progress = progress_state.lock().await; @@ -170,8 +169,6 @@ impl Updater { } else { data.request_file_blocks(&file_ctx, &mut progress, &config) .await?; - - warn!("Done requesting more blocks! {:#?}", progress); } } Err(e) if e.is_retryable() => { @@ -481,8 +478,6 @@ impl<'a, C: ControlInterface> JobUpdater<'a, C> { let (status, reason) = self.status_update_signal.wait().await; // Update the job status based on the signal - warn!("Signaled status update {:?} {:?}", status, reason); - let mut progress = self.progress_state.lock().await; self.control .update_job_status(self.file_ctx, &mut progress, self.config, status, reason) diff --git a/tests/README.md b/tests/README.md index da1c761..32d9cff 100644 --- a/tests/README.md +++ b/tests/README.md @@ -1,14 +1,87 @@ -This folder contains a number of examples that shows how to use this crate. +# AWS IoT Rust Examples -
+This repository contains examples demonstrating how to use the AWS IoT SDK for Rust. These examples are also integrated into our CI pipeline as integration tests. + +## Examples ### AWS IoT Fleet Provisioning (`provisioning.rs`) -
-This example can be run by `RUST_LOG=trace AWS_HOSTNAME=xxxxxxxx-ats.iot.eu-west-1.amazonaws.com cargo r --example provisioning --features log`, assuming you have an `examples/secrets/claim_identity.pfx` file with the claiming credentials. +This example demonstrates how to use the AWS IoT Fleet Provisioning service to provision a device. + +**Requirements:** + +* An AWS account with AWS IoT Core and AWS IoT Fleet Provisioning configured. +* A device certificate and private key. You can generate these using OpenSSL or your preferred method. +* A provisioning template configured in your AWS account. + +**To run the example:** + +1. **Create a PKCS #12 (.pfx) identity file:** + If you haven't already, create a PKCS #12 (.pfx) file containing your device certificate and private key. You can use OpenSSL for this: + + ```bash + openssl pkcs12 -export -out claim_identity.pfx -inkey private.pem.key -in certificate.pem.crt -certfile root-ca.pem + ``` + Replace `private.pem.key`, `certificate.pem.crt`, and `root-ca.pem` with your actual file names. + +2. **Store the Identity File:** + Place the `claim_identity.pfx` file in the `tests/secrets/` directory. + +3. **Set Environment Variables:** + Set the following environment variables: + * `IDENTITY_PASSWORD`: The password you set for the `claim_identity.pfx` file. + * `AWS_HOSTNAME`: Your AWS IoT endpoint. You can find this in the AWS IoT console. + +4. **Run the Test:** + + ```bash + cargo test --test provisioning --features "log,std" + ``` + +### AWS IoT OTA (`ota_mqtt.rs`) + +This example demonstrates how to perform an over-the-air (OTA) firmware update using AWS IoT Jobs. + +**Requirements:** + +* An AWS account with AWS IoT Core and AWS IoT Jobs configured. +* A device certificate and private key. +* A PKCS #12 (.pfx) file containing the device certificate and private key (see previous example for creation instructions). +* An OTA update job created in your AWS account. + +**To run the example:** + +1. **Create an OTA Job:** Create an OTA update job. You can find instructions on how to do this in the AWS IoT documentation or refer to the `scripts/create_ota.sh` script for inspiration. +2. **Store the Identity File:** Ensure the `identity.pfx` file (containing your device certificate and private key) is located in the `tests/secrets/` directory. +3. **Set Environment Variables:** + * `IDENTITY_PASSWORD`: The password for your `identity.pfx` file. + * `AWS_HOSTNAME`: Your AWS IoT endpoint. + +4. **Run the Test:** + + ```bash + cargo test --test ota_mqtt --features "log,std" + ``` + +### AWS IoT Shadows (`shadows.rs`) + +This example demonstrates how to interact with AWS IoT device shadows. Device shadows allow you to store and retrieve the latest state of your devices even when they are offline. + +**Requirements:** + +* An AWS account with AWS IoT Core and AWS IoT Device Shadows configured. +* A device certificate and private key. +* A PKCS #12 (.pfx) file containing the device certificate and private key (see previous examples for creation instructions). + +**To run the example:** + +1. **Store the Identity File:** Ensure the `claim_identity.pfx` file (containing your device certificate and private key) is in the `tests/secrets/` directory. +2. **Set Environment Variables:** + * `IDENTITY_PASSWORD`: The password for your `claim_identity.pfx` file. + * `AWS_HOSTNAME`: Your AWS IoT endpoint. -pfx identity files can be created from a set of device certificate and private key using OpenSSL as: `openssl pkcs12 -export -out claim_identity.pfx -inkey private.pem.key -in certificate.pem.crt -certfile root-ca.pem` -
-
+3. **Run the Test:** -### AWS IoT OTA (`ota.rs`) \ No newline at end of file + ```bash + cargo test --test shadows --features "log,std" + ``` diff --git a/tests/shadows.rs b/tests/shadows.rs index 5ea4d6c..86e2d93 100644 --- a/tests/shadows.rs +++ b/tests/shadows.rs @@ -1,4 +1,3 @@ -//! //! ## Integration test of `AWS IoT Shadows` //! //! @@ -21,6 +20,7 @@ //! 11. Update state from cloud //! 12. Assert on shadow state //! + #![allow(async_fn_in_trait)] #![feature(type_alias_impl_trait)] @@ -31,14 +31,17 @@ use common::network::TlsNetwork; use embassy_futures::select; use embassy_sync::blocking_mutex::raw::NoopRawMutex; use embedded_mqtt::{ - self, transport::embedded_nal::NalTransport, Config, DomainBroker, Publish, QoS, State, - Subscribe, SubscribeTopic, + self, transport::embedded_nal::NalTransport, Config, DomainBroker, MqttClient, Publish, QoS, + State, Subscribe, SubscribeTopic, }; use futures::StreamExt; -use rustot::shadows::{derive::ShadowState, Shadow}; +use rustot::shadows::{derive::ShadowState, Shadow, ShadowState}; use serde::{Deserialize, Serialize}; +use serde_json::json; use static_cell::StaticCell; +const MAX_SUBSCRIBERS: usize = 8; + #[derive(Debug, Default, Serialize, Deserialize, ShadowState, PartialEq)] #[shadow("state")] pub struct TestShadow { @@ -47,24 +50,83 @@ pub struct TestShadow { // bar: Option, } +/// Helper function to mimic cloud side updates using MQTT client directly +async fn cloud_update(client: &MqttClient<'static, NoopRawMutex, MAX_SUBSCRIBERS>, payload: &[u8]) { + client + .publish( + Publish::builder() + .topic_name( + rustot::shadows::topics::Topic::Update + .format::<128>(client.client_id(), TestShadow::NAME) + .unwrap() + .as_str(), + ) + .payload(payload) + .qos(QoS::AtLeastOnce) + .build(), + ) + .await + .unwrap(); +} + +/// Helper function to assert on the current shadow state +async fn assert_shadow( + client: &MqttClient<'static, NoopRawMutex, MAX_SUBSCRIBERS>, + expected: serde_json::Value, +) { + let mut get_shadow_sub = client + .subscribe::<1>( + Subscribe::builder() + .topics(&[SubscribeTopic::builder() + .topic_path( + rustot::shadows::topics::Topic::GetAccepted + .format::<128>(client.client_id(), TestShadow::NAME) + .unwrap() + .as_str(), + ) + .build()]) + .build(), + ) + .await + .unwrap(); + + client + .publish( + Publish::builder() + .topic_name( + rustot::shadows::topics::Topic::Get + .format::<128>(client.client_id(), TestShadow::NAME) + .unwrap() + .as_str(), + ) + .payload(b"") + .build(), + ) + .await + .unwrap(); + + let current_shadow = get_shadow_sub.next().await.unwrap(); + + assert_eq!( + serde_json::from_slice::(current_shadow.payload()) + .unwrap() + .get("state") + .unwrap(), + &expected, + ); +} + #[tokio::test(flavor = "current_thread")] async fn test_shadow_update_from_device() { env_logger::init(); const DESIRED_1: &str = r#"{ - "state": { - "desired": { - "foo": 42 - } - }, - "metadata": { - "foo": { - "timestamp": 1672047508 + "state": { + "desired": { + "foo": 42 + } } - }, - "version": 2, - "timestamp": 1672047508 -}"#; + }"#; let (thing_name, identity) = credentials::identity(); let hostname = credentials::HOSTNAME.unwrap(); @@ -80,153 +142,96 @@ async fn test_shadow_update_from_device() { .client_id(thing_name.try_into().unwrap()) .keepalive_interval(embassy_time::Duration::from_secs(50)) .build(); - static STATE: StaticCell> = StaticCell::new(); - let state = STATE.init(State::::new()); + + static STATE: StaticCell> = + StaticCell::new(); + let state = STATE.init(State::new()); let (mut stack, client) = embedded_mqtt::new(state, config); // Create the shadow - let mut shadow = Shadow::::new(TestShadow::default(), &client); + let mut shadow = Shadow::::new(TestShadow::default(), &client); + + // let delta_fut = async { + // loop { + // let delta = shadow.wait_delta().await.unwrap(); + // } + // }; let mqtt_fut = async { - let mut update_subscription = client - .subscribe::<2>( - Subscribe::builder() - .topics(&[SubscribeTopic::builder() - .topic_path( - rustot::shadows::topics::Topic::Update - .format::<128>(client.client_id(), Some("state")) - .unwrap() - .as_str(), - ) - .build()]) - .build(), - ) - .await - .unwrap(); - - let mut get_subscription = client - .subscribe::<2>( - Subscribe::builder() - .topics(&[SubscribeTopic::builder() - .topic_path( - rustot::shadows::topics::Topic::Get - .format::<128>(client.client_id(), Some("state")) - .unwrap() - .as_str(), - ) - .build()]) - .build(), - ) - .await - .unwrap(); - - // Force a shadow get first, to sync up state - client - .publish( - Publish::builder() - .topic_name( - rustot::shadows::topics::Topic::Get - .format::<128>(client.client_id(), Some("state")) - .unwrap() - .as_str(), - ) - .qos(QoS::AtLeastOnce) - .payload(&[]) - .build(), - ) - .await - .unwrap(); - - // Wait for the device to try to fetch the shadow first. - let _ = get_subscription.next().await; - - // Initial shadow state update - log::info!("Doing initial shadow update"); - client - .publish( - Publish::builder() - .topic_name( - rustot::shadows::topics::Topic::Update - .format::<128>(client.client_id(), Some("state")) - .unwrap() - .as_str(), - ) - .payload(DESIRED_1.as_bytes()) - .qos(QoS::AtLeastOnce) - .build(), - ) - .await - .unwrap(); - - loop { - select::select(update_subscription.next(), async { - // Device-side update 1 - shadow - .update(|_, desired| { - desired.foo = Some(1337); - }) - .await - .unwrap(); - - let current = shadow.get(); - assert_eq!(current.foo, 1337); - let payload = serde_json_core::to_string::<_, 512>(current).unwrap(); - log::info!("ASSERT-DEVICE: {:?}", payload); - - // Cloud-side update 1 - client - .publish( - Publish::builder() - .topic_name( - rustot::shadows::topics::Topic::Update - .format::<128>(client.client_id(), Some("state")) - .unwrap() - .as_str(), - ) - .payload(r#"{"state": {"desired": {"bar": true}}}"#.as_bytes()) - .qos(QoS::AtLeastOnce) - .build(), - ) - .await - .unwrap(); - - // Device-side update 2 - // shadow - // .update(|_state, desired| { - // // desired.bar = Some(false); - // }) - // .await - // .unwrap(); - - // let current = shadow.get(); - // let payload = serde_json_core::to_string::<_, 512>(current).unwrap(); - // log::info!("ASSERT-DEVICE: {}", payload); - // assert_eq!(current.bar, Some(false)); - - // Cloud-side update 2 - client - .publish( - Publish::builder() - .topic_name( - rustot::shadows::topics::Topic::Update - .format::<128>(client.client_id(), Some("state")) - .unwrap() - .as_str(), - ) - .payload(r#"{"state": {"desired": {"foo": 100}}}"#.as_bytes()) - .qos(QoS::AtLeastOnce) - .build(), - ) - .await - .unwrap(); - - let (s, _) = shadow.wait_delta().await.unwrap(); - let payload = serde_json_core::to_string::<_, 512>(s).unwrap(); - log::info!("ASSERT-DEVICE: {:?}", payload); - assert_eq!(s.foo, 100); - }) - .await; - } + // 1. Setup clean starting point (`desired = null, reported = null`) + cloud_update( + &client, + r#"{"state": {"desired": null, "reported": null} }"#.as_bytes(), + ) + .await; + + // 2. Do a `GetShadow` request to sync empty state + let _ = shadow.get_shadow().await.unwrap(); + + // 3. Update to initial shadow state from the device + let _ = shadow.report().await.unwrap(); + + // 4. Assert on the initial state + assert_shadow( + &client, + json!({ + "reported": { + "foo": 0 + } + }), + ) + .await; + + // 5. Update state from device + // 6. Assert on shadow state + // 7. Update state from cloud + cloud_update(&client, DESIRED_1.as_bytes()).await; + + // 8. Assert on shadow state + // 9. Update state from device + + // 10. Assert on shadow state + assert_shadow( + &client, + json!({ + "reported": { + "foo": 0 + }, + "desired": { + "foo": 42 + }, + "delta": { + "foo": 42 + } + }), + ) + .await; + + // 11. Update desired state from cloud + cloud_update( + &client, + r#"{"state": {"desired": {"bar": true}}}"#.as_bytes(), + ) + .await; + + // 12. Assert on shadow state + assert_shadow( + &client, + json!({ + "reported": { + "foo": 0 + }, + "desired": { + "foo": 42, + "bar": true + }, + "delta": { + "foo": 42, + "bar": true + } + }), + ) + .await; }; let mut transport = NalTransport::new(network, broker); From bb34d44020122abcf277d00b21b824d56e54e2be Mon Sep 17 00:00:00 2001 From: Mathias Date: Thu, 3 Oct 2024 11:56:50 +0200 Subject: [PATCH 25/36] Update CI actions --- .github/workflows/ci.yml | 154 ++++++++++++++++++++++++++++++++------- 1 file changed, 128 insertions(+), 26 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3f9320d..d0464e5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -6,68 +6,104 @@ on: - master pull_request: +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + env: ALL_FEATURES: "ota_mqtt_data,ota_http_data" jobs: - cancel_previous_runs: - name: Cancel previous runs + build: + name: Build runs-on: ubuntu-latest steps: - - uses: styfle/cancel-workflow-action@0.4.1 - with: - access_token: ${{ secrets.GITHUB_TOKEN }} - + - name: Checkout source code + uses: actions/checkout@v4 + + - uses: dsherret/rust-toolchain-file@v1 + + - name: Build (library) + run: cargo build --all --target thumbv7em-none-eabihf + + # - name: Build (examples) + # run: | + # for EXAMPLE in $(ls examples); + # do + # (cd examples/$EXAMPLE && cargo build) + # done + test: - name: Build & Test + name: Test runs-on: ubuntu-latest steps: - name: Checkout source code - uses: actions/checkout@v3 + uses: actions/checkout@v4 + - uses: dsherret/rust-toolchain-file@v1 - - name: Build + + - name: Doc Tests uses: actions-rs/cargo@v1 with: - command: build - args: --all --target thumbv7em-none-eabihf --features ${{ env.ALL_FEATURES }} + command: test + args: --doc --features "std,log" - - name: Test + - name: Unit Tests uses: actions-rs/cargo@v1 with: command: test - args: --lib --features "ota_mqtt_data,log" - + args: --lib --features "std,log" + rustfmt: name: rustfmt runs-on: ubuntu-latest steps: - name: Checkout source code - uses: actions/checkout@v3 + + uses: actions/checkout@v4 - uses: dsherret/rust-toolchain-file@v1 - - name: Rustfmt - run: cargo fmt -- --check + + - name: Run rustfmt (library) + run: cargo fmt --all -- --check --verbose + + # - name: Run rustfmt (examples) + # run: | + # for EXAMPLE in $(ls examples); + # do + # (cd examples/$EXAMPLE && cargo fmt --all -- --check --verbose) + # done clippy: name: clippy runs-on: ubuntu-latest + env: + CLIPPY_PARAMS: -W clippy::all -W clippy::pedantic -W clippy::nursery -W clippy::cargo steps: - name: Checkout source code - uses: actions/checkout@v3 + + uses: actions/checkout@v4 - uses: dsherret/rust-toolchain-file@v1 - - name: Run clippy - uses: actions-rs/clippy-check@v1 - with: - token: ${{ secrets.GITHUB_TOKEN }} - args: -- ${{ env.CLIPPY_PARAMS }} + + - name: Run clippy (library) + run: cargo clippy --features "log" -- ${{ env.CLIPPY_PARAMS }} + + # - name: Run clippy (examples) + # run: | + # for EXAMPLE in $(ls examples); + # do + # (cd examples/$EXAMPLE && cargo clippy -- ${{ env.CLIPPY_PARAMS }}) + # done integration-test: name: Integration Tests runs-on: ubuntu-latest - needs: ['test', 'rustfmt', 'clippy'] + needs: ["build", "test", "rustfmt", "clippy"] steps: - name: Checkout source code - uses: actions/checkout@v3 + uses: actions/checkout@v4 + - uses: dsherret/rust-toolchain-file@v1 + - name: Create OTA Job run: | ./scripts/create_ota.sh @@ -75,6 +111,7 @@ jobs: AWS_DEFAULT_REGION: ${{ secrets.MGMT_AWS_DEFAULT_REGION }} AWS_ACCESS_KEY_ID: ${{ secrets.MGMT_AWS_ACCESS_KEY_ID }} AWS_SECRET_ACCESS_KEY: ${{ secrets.MGMT_AWS_SECRET_ACCESS_KEY }} + - name: Integration Tests uses: actions-rs/cargo@v1 with: @@ -91,4 +128,69 @@ jobs: env: AWS_DEFAULT_REGION: ${{ secrets.MGMT_AWS_DEFAULT_REGION }} AWS_ACCESS_KEY_ID: ${{ secrets.MGMT_AWS_ACCESS_KEY_ID }} - AWS_SECRET_ACCESS_KEY: ${{ secrets.MGMT_AWS_SECRET_ACCESS_KEY }} \ No newline at end of file + AWS_SECRET_ACCESS_KEY: ${{ secrets.MGMT_AWS_SECRET_ACCESS_KEY }} + + # device_advisor: + # name: AWS IoT Device Advisor + # runs-on: ubuntu-latest + # needs: test + # env: + # AWS_EC2_METADATA_DISABLED: true + # AWS_DEFAULT_REGION: ${{ secrets.MGMT_AWS_DEFAULT_REGION }} + # AWS_ACCESS_KEY_ID: ${{ secrets.MGMT_AWS_ACCESS_KEY_ID }} + # AWS_SECRET_ACCESS_KEY: ${{ secrets.MGMT_AWS_SECRET_ACCESS_KEY }} + # SUITE_ID: 1gaev57dq6i5 + # THING_ARN: arn:aws:iot:eu-west-1:411974994697:thing/embedded-mqtt + # steps: + # - name: Checkout source code + # uses: actions/checkout@v4 + + # - uses: dsherret/rust-toolchain-file@v1 + + # - name: Get AWS_HOSTNAME + # id: hostname + # run: | + # hostname=$(aws iotdeviceadvisor get-endpoint --thing-arn ${{ env.THING_ARN }} --output text --query endpoint) + # ret=$? + # echo "::set-output name=AWS_HOSTNAME::$hostname" + # exit $ret + + # - name: Build test binary + # env: + # AWS_HOSTNAME: ${{ steps.hostname.outputs.AWS_HOSTNAME }} + # run: cargo build --features=log --example aws_device_advisor --release + + # - name: Start test suite + # id: test_suite + # run: | + # suite_id=$(aws iotdeviceadvisor start-suite-run --suite-definition-id ${{ env.SUITE_ID }} --suite-run-configuration "primaryDevice={thingArn=${{ env.THING_ARN }}},parallelRun=true" --output text --query suiteRunId) + # ret=$? + # echo "::set-output name=SUITE_RUN_ID::$suite_id" + # exit $ret + + # - name: Execute test binary + # id: binary + # env: + # DEVICE_ADVISOR_PASSWORD: ${{ secrets.DEVICE_ADVISOR_PASSWORD }} + # RUST_LOG: trace + # run: | + # nohup ./target/release/examples/aws_device_advisor > device_advisor_integration.log & + # echo "::set-output name=PID::$!" + + # - name: Monitor test run + # run: | + # chmod +x ./scripts/da_monitor.sh + # echo ${{ env.SUITE_ID }} ${{ steps.test_suite.outputs.SUITE_RUN_ID }} ${{ steps.binary.outputs.PID }} + # ./scripts/da_monitor.sh ${{ env.SUITE_ID }} ${{ steps.test_suite.outputs.SUITE_RUN_ID }} ${{ steps.binary.outputs.PID }} + + # - name: Kill test binary process + # if: ${{ always() }} + # run: kill ${{ steps.binary.outputs.PID }} || true + + # - name: Log binary output + # if: ${{ always() }} + # run: cat device_advisor_integration.log + + # - name: Stop test suite + # if: ${{ failure() }} + # run: aws iotdeviceadvisor stop-suite-run --suite-definition-id ${{ env.SUITE_ID }} --suite-run-id ${{ steps.test_suite.outputs.SUITE_RUN_ID }} From 916296ee4a914c7a68366c7cf6fbc2925165a592 Mon Sep 17 00:00:00 2001 From: Mathias Date: Thu, 3 Oct 2024 11:59:07 +0200 Subject: [PATCH 26/36] Remove local patch --- Cargo.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index a48cb0c..b4c2f01 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -84,5 +84,5 @@ defmt = [ ] log = ["dep:log", "embedded-mqtt/log"] -[patch."ssh://git@github.com/FactbirdHQ/embedded-mqtt"] -embedded-mqtt = { path = "../embedded-mqtt" } +# [patch."ssh://git@github.com/FactbirdHQ/embedded-mqtt"] +# embedded-mqtt = { path = "../embedded-mqtt" } From d7129323410e30369b7238ec979f2ce9e041ea7e Mon Sep 17 00:00:00 2001 From: Mathias Date: Fri, 4 Oct 2024 10:55:52 +0200 Subject: [PATCH 27/36] Make sure OTA request momentum is not running prior to subscription --- Cargo.toml | 4 ++-- src/ota/mod.rs | 18 +++++++++++------- src/provisioning/mod.rs | 2 +- tests/ota_mqtt.rs | 5 ++--- 4 files changed, 16 insertions(+), 13 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index b4c2f01..a48cb0c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -84,5 +84,5 @@ defmt = [ ] log = ["dep:log", "embedded-mqtt/log"] -# [patch."ssh://git@github.com/FactbirdHQ/embedded-mqtt"] -# embedded-mqtt = { path = "../embedded-mqtt" } +[patch."ssh://git@github.com/FactbirdHQ/embedded-mqtt"] +embedded-mqtt = { path = "../embedded-mqtt" } diff --git a/src/ota/mod.rs b/src/ota/mod.rs index 29ff521..0a75a2a 100644 --- a/src/ota/mod.rs +++ b/src/ota/mod.rs @@ -55,7 +55,7 @@ impl Updater { request_block_remaining: file_ctx.bitmap.len() as u32, bitmap: file_ctx.bitmap.clone(), file_size: file_ctx.filesize, - request_momentum: 0, + request_momentum: None, status_details: file_ctx.status_details.clone(), }); @@ -151,7 +151,7 @@ impl Updater { } Ok(false) => { // ... (Handle successful block processing) ... - progress.request_momentum = 0; + progress.request_momentum = Some(0); // Update the job status to reflect the download progress if progress.blocks_remaining @@ -327,15 +327,19 @@ impl Updater { break; } - if progress.request_momentum <= config.max_request_momentum { + let Some(request_momentum) = &mut progress.request_momentum else { + continue; + }; + + if *request_momentum <= config.max_request_momentum { // Increment momentum - progress.request_momentum += 1; + *request_momentum += 1; warn!("Momentum requesting more blocks!"); // Request data blocks - data.request_file_blocks(file_ctx, &mut progress, config) - .await?; + // data.request_file_blocks(file_ctx, &mut progress, config) + // .await?; } else { // Too much momentum, abort return Err(error::OtaError::MomentumAbort); @@ -353,7 +357,7 @@ pub struct ProgressState { pub file_size: usize, pub block_offset: u32, pub request_block_remaining: u32, - pub request_momentum: u8, + pub request_momentum: Option, pub bitmap: Bitmap, pub status_details: StatusDetailsOwned, } diff --git a/src/provisioning/mod.rs b/src/provisioning/mod.rs index d86e2a7..32181b8 100644 --- a/src/provisioning/mod.rs +++ b/src/provisioning/mod.rs @@ -255,7 +255,7 @@ impl FleetProvisioner { .await?; drop(message); - drop(create_subscription); + create_subscription.unsubscribe().await?; let mut message = register_subscription .next() diff --git a/tests/ota_mqtt.rs b/tests/ota_mqtt.rs index ab4ea8c..b38d10e 100644 --- a/tests/ota_mqtt.rs +++ b/tests/ota_mqtt.rs @@ -140,9 +140,8 @@ async fn test_mqtt_ota() { let message = jobs_subscription.next().await.unwrap(); if let Some(mut file_ctx) = handle_ota(message, &config) { - // Nested subscriptions are a problem for embedded-mqtt, so drop the - // subscription here - drop(jobs_subscription); + // Nested subscriptions are a problem for embedded-mqtt, so unsubscribe here + jobs_subscription.unsubscribe().await.unwrap(); // We have an OTA job, leeeets go! Updater::perform_ota( From 01c44a0b4854e1d1e6d006acf99a0222a445cc38 Mon Sep 17 00:00:00 2001 From: Kenneth Knudsen <98805797+KennethKnudsen97@users.noreply.github.com> Date: Thu, 10 Oct 2024 13:38:39 +0200 Subject: [PATCH 28/36] Fix(async): Ensure one sub on topic (#64) * Ensure only one sub on the same topic * wait for mqtt connected in report * Reduce request locks to one --- src/shadows/mod.rs | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/src/shadows/mod.rs b/src/shadows/mod.rs index c508661..5d46ac0 100644 --- a/src/shadows/mod.rs +++ b/src/shadows/mod.rs @@ -4,7 +4,7 @@ mod error; mod shadow_diff; pub mod topics; -use core::{marker::PhantomData, ops::DerefMut}; +use core::{marker::PhantomData, ops::DerefMut, sync::atomic}; use bitmaps::{Bits, BitsImpl}; pub use data_types::Patch; @@ -42,6 +42,9 @@ where mqtt: &'m embedded_mqtt::MqttClient<'a, M, SUBS>, subscription: Mutex>>, _shadow: PhantomData, + // request_lock is used to ensure that shadow operations such as subscribing, updating, or + // deleting are serialized, preventing multiple concurrent requests to the same MQTT topics. + request_lock: Mutex, } impl<'a, 'm, M: RawMutex, S: ShadowState, const SUBS: usize> ShadowHandler<'a, 'm, M, S, SUBS> @@ -95,6 +98,7 @@ where if let Some(client) = delta.client_token { if client.eq(self.mqtt.client_id()) { + warn!("DELTA CLIENT TOKEN WAS == TO DEVICE CLIENT ID"); return Ok(None); } } @@ -105,6 +109,8 @@ where /// Internal helper function for applying a delta state to the actual shadow /// state, and update the cloud shadow. async fn report(&self, reported: &R) -> Result<(), Error> { + let _update_requested_lock = self.request_lock.lock().await; + debug!( "[{:?}] Updating reported shadow value.", S::NAME.unwrap_or(CLASSIC_SHADOW), @@ -127,6 +133,9 @@ where S::MAX_PAYLOAD_SIZE + PARTIAL_REQUEST_OVERHEAD, ); + //Wait for mqtt to connect + self.mqtt.wait_connected().await; + let mut sub = self.publish_and_subscribe(Topic::Update, payload).await?; //*** WAIT RESPONSE ***/ @@ -179,6 +188,8 @@ where /// Initiate a `GetShadow` request, updating the local state from the cloud. async fn get_shadow(&self) -> Result, Error> { + let _get_requested_lock = self.request_lock.lock().await; + //Wait for mqtt to connect self.mqtt.wait_connected().await; @@ -225,6 +236,8 @@ where } pub async fn delete_shadow(&self) -> Result<(), Error> { + let _delete_request = self.request_lock.lock().await; + // Wait for mqtt to connect self.mqtt.wait_connected().await; @@ -256,6 +269,8 @@ where } pub async fn create_shadow(&self) -> Result, Error> { + let _create_requested_lock = self.request_lock.lock().await; + debug!( "[{:?}] Creating initial shadow value.", S::NAME.unwrap_or(CLASSIC_SHADOW), @@ -403,6 +418,7 @@ where mqtt, subscription: Mutex::new(None), _shadow: PhantomData, + request_lock: Mutex::new(()), }; Self { @@ -528,6 +544,7 @@ where mqtt, subscription: Mutex::new(None), _shadow: PhantomData, + request_lock: Mutex::new(()), }; Self { handler, state } } From c3456cd85b0d2e91a7024ffaec43ab7623511ca8 Mon Sep 17 00:00:00 2001 From: Kenneth Knudsen Date: Thu, 10 Oct 2024 14:57:36 +0200 Subject: [PATCH 29/36] bump embedded mqtt --- Cargo.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index a48cb0c..f63eb00 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,7 +29,7 @@ serde_cbor = { version = "0.11", default-features = false, optional = true } serde-json-core = { version = "0.5" } shadow-derive = { path = "shadow_derive", version = "0.2.1" } embedded-storage-async = "0.4" -embedded-mqtt = { git = "ssh://git@github.com/FactbirdHQ/embedded-mqtt", rev = "d2b7c02" } +embedded-mqtt = { git = "ssh://git@github.com/FactbirdHQ/embedded-mqtt", rev = "5e28a55da737356a3a9c1597ae4ff123e2481b1b" } futures = { version = "0.3.28", default-features = false } @@ -84,5 +84,5 @@ defmt = [ ] log = ["dep:log", "embedded-mqtt/log"] -[patch."ssh://git@github.com/FactbirdHQ/embedded-mqtt"] -embedded-mqtt = { path = "../embedded-mqtt" } +# [patch."ssh://git@github.com/FactbirdHQ/embedded-mqtt"] +# embedded-mqtt = { path = "../embedded-mqtt" } From 2378458ae37d5d8c16081118d5d922ae01becf08 Mon Sep 17 00:00:00 2001 From: Mathias Koch Date: Tue, 29 Oct 2024 13:19:00 +0100 Subject: [PATCH 30/36] Fix OTA & bump dependencies (#67) --- Cargo.toml | 6 +- src/jobs/data_types.rs | 2 +- src/ota/config.rs | 2 +- src/ota/control_interface/mod.rs | 2 - src/ota/control_interface/mqtt.rs | 177 +++++++++++++++--------------- src/ota/data_interface/mqtt.rs | 13 +-- src/ota/encoding/json.rs | 1 + src/ota/mod.rs | 40 ++++--- src/provisioning/mod.rs | 59 ++++------ src/shadows/mod.rs | 39 +++---- tests/ota_mqtt.rs | 12 +- tests/provisioning.rs | 8 +- tests/shadows.rs | 14 +-- 13 files changed, 172 insertions(+), 203 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index f63eb00..9980820 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,10 +26,10 @@ bitmaps = { version = "3.1", default-features = false } heapless = { version = "0.8", features = ["serde"] } serde = { version = "1.0", default-features = false, features = ["derive"] } serde_cbor = { version = "0.11", default-features = false, optional = true } -serde-json-core = { version = "0.5" } +serde-json-core = { version = "0.6" } shadow-derive = { path = "shadow_derive", version = "0.2.1" } embedded-storage-async = "0.4" -embedded-mqtt = { git = "ssh://git@github.com/FactbirdHQ/embedded-mqtt", rev = "5e28a55da737356a3a9c1597ae4ff123e2481b1b" } +embedded-mqtt = { git = "ssh://git@github.com/FactbirdHQ/embedded-mqtt", rev = "74eb53d" } futures = { version = "0.3.28", default-features = false } @@ -42,7 +42,7 @@ defmt = { version = "0.3", optional = true } [dev-dependencies] native-tls = { version = "0.2" } -embedded-nal-async = "0.7" +embedded-nal-async = "0.8" env_logger = "0.11" sha2 = "0.10.1" static_cell = { version = "2", features = ["nightly"] } diff --git a/src/jobs/data_types.rs b/src/jobs/data_types.rs index 36449dd..fb6100e 100644 --- a/src/jobs/data_types.rs +++ b/src/jobs/data_types.rs @@ -292,7 +292,7 @@ pub struct Jobs { pub struct ErrorResponse<'a> { pub code: ErrorCode, /// An error message string. - message: &'a str, + pub message: &'a str, /// A client token used to correlate requests and responses. Enter an /// arbitrary value here and it is reflected in the response. #[serde(rename = "clientToken")] diff --git a/src/ota/config.rs b/src/ota/config.rs index ef5cd28..d57cd75 100644 --- a/src/ota/config.rs +++ b/src/ota/config.rs @@ -13,7 +13,7 @@ impl Default for Config { Self { block_size: 1024, max_request_momentum: 3, - request_wait: Duration::from_secs(8), + request_wait: Duration::from_secs(5), status_update_frequency: 24, self_test_timeout: None, } diff --git a/src/ota/control_interface/mod.rs b/src/ota/control_interface/mod.rs index 51e7a7c..ffd8b03 100644 --- a/src/ota/control_interface/mod.rs +++ b/src/ota/control_interface/mod.rs @@ -1,7 +1,6 @@ use crate::jobs::data_types::JobStatus; use super::{ - config::Config, encoding::{json::JobStatusReason, FileContext}, error::OtaError, ProgressState, @@ -16,7 +15,6 @@ pub trait ControlInterface { &self, file_ctx: &FileContext, progress: &mut ProgressState, - config: &Config, status: JobStatus, reason: JobStatusReason, ) -> Result<(), OtaError>; diff --git a/src/ota/control_interface/mqtt.rs b/src/ota/control_interface/mqtt.rs index 826d5b8..601ad68 100644 --- a/src/ota/control_interface/mqtt.rs +++ b/src/ota/control_interface/mqtt.rs @@ -1,9 +1,8 @@ use core::fmt::Write; -use bitmaps::{Bits, BitsImpl}; use embassy_sync::blocking_mutex::raw::RawMutex; +use embassy_time::with_timeout; use embedded_mqtt::{DeferredPayload, EncodingError, Publish, QoS, Subscribe, SubscribeTopic}; -use futures::StreamExt as _; use super::ControlInterface; use crate::jobs::data_types::{ErrorResponse, JobStatus, UpdateJobExecutionResponse}; @@ -14,10 +13,7 @@ use crate::ota::encoding::{self, FileContext}; use crate::ota::error::OtaError; use crate::ota::ProgressState; -impl<'a, M: RawMutex, const SUBS: usize> ControlInterface for embedded_mqtt::MqttClient<'a, M, SUBS> -where - BitsImpl<{ SUBS }>: Bits, -{ +impl<'a, M: RawMutex> ControlInterface for embedded_mqtt::MqttClient<'a, M> { /// Check for next available OTA job from the job service by publishing a /// "get next job" message to the job service. async fn request_job(&self) -> Result<(), OtaError> { @@ -45,7 +41,6 @@ where &self, file_ctx: &FileContext, progress_state: &mut ProgressState, - config: &Config, status: JobStatus, reason: JobStatusReason, ) -> Result<(), OtaError> { @@ -63,15 +58,6 @@ where if let JobStatus::InProgress | JobStatus::Succeeded = status { let received_blocks = progress_state.total_blocks - progress_state.blocks_remaining; - // Output a status update once in a while. Always update first and - // last status - if progress_state.blocks_remaining != 0 - && received_blocks != 0 - && received_blocks % config.status_update_frequency as usize != 0 - { - return Ok(()); - } - // Don't override the progress on succeeded, nor on self-test // active. (Cases where progress counter is lost due to device // restarts) @@ -92,40 +78,40 @@ where // Downgrade progress updates to QOS 0 to avoid overloading MQTT // buffers during active streaming. But make sure to always send and await ack for first update and last update - if status == JobStatus::InProgress - && progress_state.blocks_remaining != 0 - && received_blocks != 0 - { - qos = QoS::AtMostOnce; - } + // if status == JobStatus::InProgress + // && progress_state.blocks_remaining != 0 + // && received_blocks != 0 + // { + // qos = QoS::AtMostOnce; + // } } - let mut sub = self - .subscribe::<2>( - Subscribe::builder() - .topics(&[ - SubscribeTopic::builder() - .topic_path( - JobTopic::UpdateAccepted(file_ctx.job_name.as_str()) - .format::<{ MAX_THING_NAME_LEN + MAX_JOB_ID_LEN + 34 }>( - self.client_id(), - )? - .as_str(), - ) - .build(), - SubscribeTopic::builder() - .topic_path( - JobTopic::UpdateRejected(file_ctx.job_name.as_str()) - .format::<{ MAX_THING_NAME_LEN + MAX_JOB_ID_LEN + 34 }>( - self.client_id(), - )? - .as_str(), - ) - .build(), - ]) - .build(), - ) - .await?; + // let mut sub = self + // .subscribe::<2>( + // Subscribe::builder() + // .topics(&[ + // SubscribeTopic::builder() + // .topic_path( + // JobTopic::UpdateAccepted(file_ctx.job_name.as_str()) + // .format::<{ MAX_THING_NAME_LEN + MAX_JOB_ID_LEN + 34 }>( + // self.client_id(), + // )? + // .as_str(), + // ) + // .build(), + // SubscribeTopic::builder() + // .topic_path( + // JobTopic::UpdateRejected(file_ctx.job_name.as_str()) + // .format::<{ MAX_THING_NAME_LEN + MAX_JOB_ID_LEN + 34 }>( + // self.client_id(), + // )? + // .as_str(), + // ) + // .build(), + // ]) + // .build(), + // ) + // .await?; let topic = JobTopic::Update(file_ctx.job_name.as_str()) .format::<{ MAX_THING_NAME_LEN + MAX_JOB_ID_LEN + 25 }>(self.client_id())?; @@ -151,44 +137,61 @@ where ) .await?; - loop { - let message = sub.next().await.ok_or(JobError::Encoding)?; - - // Check if topic is GetAccepted - match crate::jobs::Topic::from_str(message.topic_name()) { - Some(crate::jobs::Topic::UpdateAccepted(_)) => { - // Check client token - let (response, _) = serde_json_core::from_slice::< - UpdateJobExecutionResponse>, - >(message.payload()) - .map_err(|_| JobError::Encoding)?; - - if response.client_token != Some(self.client_id()) { - error!( - "Unexpected client token received: {}, expected: {}", - response.client_token.unwrap_or("None"), - self.client_id() - ); - continue; - } - - return Ok(()); - } - Some(crate::jobs::Topic::UpdateRejected(_)) => { - let (error_response, _) = - serde_json_core::from_slice::(message.payload()) - .map_err(|_| JobError::Encoding)?; - - if error_response.client_token != Some(self.client_id()) { - continue; - } - - return Err(OtaError::UpdateRejected(error_response.code)); - } - _ => { - error!("Expected Topic name GetRejected or GetAccepted but got something else"); - } - } - } + Ok(()) + + // loop { + // let message = match with_timeout( + // embassy_time::Duration::from_secs(1), + // sub.next_message(), + // ) + // .await + // { + // Ok(res) => res.ok_or(JobError::Encoding)?, + // Err(_) => return Err(OtaError::Timeout), + // }; + + // // Check if topic is GetAccepted + // match crate::jobs::Topic::from_str(message.topic_name()) { + // Some(crate::jobs::Topic::UpdateAccepted(_)) => { + // // Check client token + // let (response, _) = serde_json_core::from_slice::< + // UpdateJobExecutionResponse>, + // >(message.payload()) + // .map_err(|_| JobError::Encoding)?; + + // if response.client_token != Some(self.client_id()) { + // error!( + // "Unexpected client token received: {}, expected: {}", + // response.client_token.unwrap_or("None"), + // self.client_id() + // ); + // continue; + // } + + // return Ok(()); + // } + // Some(crate::jobs::Topic::UpdateRejected(_)) => { + // let (error_response, _) = + // serde_json_core::from_slice::(message.payload()) + // .map_err(|_| JobError::Encoding)?; + + // if error_response.client_token != Some(self.client_id()) { + // error!( + // "Unexpected client token received: {}, expected: {}", + // error_response.client_token.unwrap_or("None"), + // self.client_id() + // ); + // continue; + // } + + // error!("OTA Update rejected: {:?}", error_response.message); + + // return Err(OtaError::UpdateRejected(error_response.code)); + // } + // _ => { + // error!("Expected Topic name GetRejected or GetAccepted but got something else"); + // } + // } + // } } } diff --git a/src/ota/data_interface/mqtt.rs b/src/ota/data_interface/mqtt.rs index 1bbaae6..b6516d1 100644 --- a/src/ota/data_interface/mqtt.rs +++ b/src/ota/data_interface/mqtt.rs @@ -2,7 +2,6 @@ use core::fmt::{Display, Write}; use core::ops::DerefMut; use core::str::FromStr; -use bitmaps::{Bits, BitsImpl}; use embassy_sync::blocking_mutex::raw::RawMutex; use embedded_mqtt::{ DeferredPayload, EncodingError, MqttClient, Publish, Subscribe, SubscribeTopic, Subscription, @@ -125,22 +124,16 @@ impl<'a> OtaTopic<'a> { } } -impl<'a, 'b, M: RawMutex, const SUBS: usize> BlockTransfer for Subscription<'a, 'b, M, SUBS, 1> -where - BitsImpl<{ SUBS }>: Bits, -{ +impl<'a, 'b, M: RawMutex> BlockTransfer for Subscription<'a, 'b, M, 1> { async fn next_block(&mut self) -> Result>, OtaError> { Ok(self.next().await) } } -impl<'a, M: RawMutex, const SUBS: usize> DataInterface for MqttClient<'a, M, SUBS> -where - BitsImpl<{ SUBS }>: Bits, -{ +impl<'a, M: RawMutex> DataInterface for MqttClient<'a, M> { const PROTOCOL: Protocol = Protocol::Mqtt; - type ActiveTransfer<'t> = Subscription<'a, 't, M, SUBS, 1> where Self: 't; + type ActiveTransfer<'t> = Subscription<'a, 't, M, 1> where Self: 't; /// Init file transfer by subscribing to the OTA data stream topic async fn init_file_transfer( diff --git a/src/ota/encoding/json.rs b/src/ota/encoding/json.rs index c258942..12fae02 100644 --- a/src/ota/encoding/json.rs +++ b/src/ota/encoding/json.rs @@ -139,6 +139,7 @@ mod tests { (JobStatusReason::Accepted, "accepted"), (JobStatusReason::Rejected, "rejected"), (JobStatusReason::Aborted, "aborted"), + (JobStatusReason::Pal(123), "pal err"), ]; for (reason, exp) in reasons { diff --git a/src/ota/mod.rs b/src/ota/mod.rs index 0a75a2a..008dd2c 100644 --- a/src/ota/mod.rs +++ b/src/ota/mod.rs @@ -126,10 +126,11 @@ impl Updater { // ... (Handle end of file) ... match pal.close_file(&file_ctx).await { Err(e) => { - job_updater.signal_update( - JobStatus::Failed, - JobStatusReason::Pal(0), - ); + // FIXME: This seems like duplicate status update, as it will also report during cleanup + // job_updater.signal_update( + // JobStatus::Failed, + // JobStatusReason::Pal(0), + // ); return Err(e.into()); } @@ -212,6 +213,11 @@ impl Updater { pal::OtaEvent::UpdateComplete }; + info!( + "OTA Download finished! Running complete callback: {:?}", + event + ); + pal.complete_callback(event).await?; Ok(()) @@ -331,15 +337,19 @@ impl Updater { continue; }; - if *request_momentum <= config.max_request_momentum { - // Increment momentum - *request_momentum += 1; + // Increment momentum + *request_momentum += 1; + if *request_momentum == 1 { + continue; + } + + if *request_momentum <= config.max_request_momentum { warn!("Momentum requesting more blocks!"); // Request data blocks - // data.request_file_blocks(file_ctx, &mut progress, config) - // .await?; + data.request_file_blocks(file_ctx, &mut progress, config) + .await?; } else { // Too much momentum, abort return Err(error::OtaError::MomentumAbort); @@ -351,6 +361,7 @@ impl Updater { } #[derive(Clone, Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] pub struct ProgressState { pub total_blocks: usize, pub blocks_remaining: usize, @@ -358,7 +369,9 @@ pub struct ProgressState { pub block_offset: u32, pub request_block_remaining: u32, pub request_momentum: Option, + #[cfg_attr(feature = "defmt", defmt(Debug2Format))] pub bitmap: Bitmap, + #[cfg_attr(feature = "defmt", defmt(Debug2Format))] pub status_details: StatusDetailsOwned, } @@ -430,7 +443,6 @@ impl<'a, C: ControlInterface> JobUpdater<'a, C> { .update_job_status( &self.file_ctx, &mut progress, - self.config, JobStatus::Succeeded, JobStatusReason::Accepted, ) @@ -484,7 +496,7 @@ impl<'a, C: ControlInterface> JobUpdater<'a, C> { // Update the job status based on the signal let mut progress = self.progress_state.lock().await; self.control - .update_job_status(self.file_ctx, &mut progress, self.config, status, reason) + .update_job_status(self.file_ctx, &mut progress, status, reason) .await?; match status { @@ -531,7 +543,6 @@ impl<'a, C: ControlInterface> JobUpdater<'a, C> { .update_job_status( &self.file_ctx, &mut progress, - self.config, JobStatus::InProgress, JobStatusReason::SelfTestActive, ) @@ -544,7 +555,6 @@ impl<'a, C: ControlInterface> JobUpdater<'a, C> { .update_job_status( &self.file_ctx, &mut progress, - self.config, JobStatus::Succeeded, JobStatusReason::Accepted, ) @@ -559,7 +569,6 @@ impl<'a, C: ControlInterface> JobUpdater<'a, C> { .update_job_status( &self.file_ctx, &mut progress, - self.config, JobStatus::Failed, JobStatusReason::Rejected, ) @@ -574,7 +583,6 @@ impl<'a, C: ControlInterface> JobUpdater<'a, C> { .update_job_status( &self.file_ctx, &mut progress, - self.config, JobStatus::Failed, JobStatusReason::Aborted, ) @@ -598,7 +606,7 @@ impl<'a, C: ControlInterface> JobUpdater<'a, C> { let mut progress = self.progress_state.lock().await; self.control - .update_job_status(&self.file_ctx, &mut progress, self.config, status, reason) + .update_job_status(&self.file_ctx, &mut progress, status, reason) .await?; Ok(()) } diff --git a/src/provisioning/mod.rs b/src/provisioning/mod.rs index 32181b8..0a27f83 100644 --- a/src/provisioning/mod.rs +++ b/src/provisioning/mod.rs @@ -4,7 +4,6 @@ pub mod topics; use core::future::Future; -use bitmaps::{Bits, BitsImpl}; use embassy_sync::blocking_mutex::raw::RawMutex; use embedded_mqtt::{ BufferProvider, DeferredPayload, EncodingError, Message, Publish, Subscribe, SubscribeTopic, @@ -41,14 +40,13 @@ pub struct Credentials<'a> { pub struct FleetProvisioner; impl FleetProvisioner { - pub async fn provision<'a, C, M: RawMutex, const SUBS: usize>( - mqtt: &embedded_mqtt::MqttClient<'a, M, SUBS>, + pub async fn provision<'a, C, M: RawMutex>( + mqtt: &embedded_mqtt::MqttClient<'a, M>, template_name: &str, parameters: Option, credential_handler: &mut impl CredentialHandler, ) -> Result, Error> where - BitsImpl<{ SUBS }>: Bits, C: DeserializeOwned, { Self::provision_inner( @@ -62,15 +60,14 @@ impl FleetProvisioner { .await } - pub async fn provision_csr<'a, C, M: RawMutex, const SUBS: usize>( - mqtt: &embedded_mqtt::MqttClient<'a, M, SUBS>, + pub async fn provision_csr<'a, C, M: RawMutex>( + mqtt: &embedded_mqtt::MqttClient<'a, M>, template_name: &str, parameters: Option, csr: &str, credential_handler: &mut impl CredentialHandler, ) -> Result, Error> where - BitsImpl<{ SUBS }>: Bits, C: DeserializeOwned, { Self::provision_inner( @@ -85,14 +82,13 @@ impl FleetProvisioner { } #[cfg(feature = "provision_cbor")] - pub async fn provision_cbor<'a, C, M: RawMutex, const SUBS: usize>( - mqtt: &embedded_mqtt::MqttClient<'a, M, SUBS>, + pub async fn provision_cbor<'a, C, M: RawMutex>( + mqtt: &embedded_mqtt::MqttClient<'a, M>, template_name: &str, parameters: Option, credential_handler: &mut impl CredentialHandler, ) -> Result, Error> where - BitsImpl<{ SUBS }>: Bits, C: DeserializeOwned, { Self::provision_inner( @@ -107,15 +103,14 @@ impl FleetProvisioner { } #[cfg(feature = "provision_cbor")] - pub async fn provision_csr_cbor<'a, C, M: RawMutex, const SUBS: usize>( - mqtt: &embedded_mqtt::MqttClient<'a, M, SUBS>, + pub async fn provision_csr_cbor<'a, C, M: RawMutex>( + mqtt: &embedded_mqtt::MqttClient<'a, M>, template_name: &str, parameters: Option, csr: &str, credential_handler: &mut impl CredentialHandler, ) -> Result, Error> where - BitsImpl<{ SUBS }>: Bits, C: DeserializeOwned, { Self::provision_inner( @@ -130,8 +125,8 @@ impl FleetProvisioner { } #[cfg(feature = "provision_cbor")] - async fn provision_inner<'a, C, M: RawMutex, const SUBS: usize>( - mqtt: &embedded_mqtt::MqttClient<'a, M, SUBS>, + async fn provision_inner<'a, C, M: RawMutex>( + mqtt: &embedded_mqtt::MqttClient<'a, M>, template_name: &str, parameters: Option, csr: Option<&str>, @@ -139,7 +134,6 @@ impl FleetProvisioner { payload_format: PayloadFormat, ) -> Result, Error> where - BitsImpl<{ SUBS }>: Bits, C: DeserializeOwned, { use embedded_mqtt::SliceBufferProvider; @@ -154,8 +148,8 @@ impl FleetProvisioner { Some(Topic::CreateKeysAndCertificateAccepted(format)) => { let response = Self::deserialize::< CreateKeysAndCertificateResponse, + M, SliceBufferProvider<'a>, - SUBS, >(format, &mut message)?; credential_handler @@ -172,8 +166,8 @@ impl FleetProvisioner { Some(Topic::CreateCertificateFromCsrAccepted(format)) => { let response = Self::deserialize::< CreateCertificateFromCsrResponse, + M, SliceBufferProvider<'a>, - SUBS, >(format, &mut message)?; credential_handler @@ -266,8 +260,8 @@ impl FleetProvisioner { Some(Topic::RegisterThingAccepted(_, format)) => { let response = Self::deserialize::< RegisterThingResponse<'_, C>, + M, SliceBufferProvider<'a>, - SUBS, >(format, &mut message)?; Ok(response.device_configuration) @@ -286,14 +280,11 @@ impl FleetProvisioner { } } - async fn begin<'a, 'b, M: RawMutex, const SUBS: usize>( - mqtt: &'b embedded_mqtt::MqttClient<'a, M, SUBS>, + async fn begin<'a, 'b, M: RawMutex>( + mqtt: &'b embedded_mqtt::MqttClient<'a, M>, csr: Option<&str>, payload_format: PayloadFormat, - ) -> Result, Error> - where - BitsImpl<{ SUBS }>: Bits, - { + ) -> Result, Error> { if let Some(csr) = csr { let subscription = mqtt .subscribe( @@ -395,13 +386,10 @@ impl FleetProvisioner { } } - fn deserialize<'a, R: Deserialize<'a>, B: BufferProvider, const SUBS: usize>( + fn deserialize<'a, R: Deserialize<'a>, M: RawMutex, B: BufferProvider>( payload_format: PayloadFormat, - message: &'a mut Message<'_, B, SUBS>, - ) -> Result - where - BitsImpl<{ SUBS }>: Bits, - { + message: &'a mut Message<'_, M, B>, + ) -> Result { trace!( "Accepted Topic {:?}. Payload len: {:?}", payload_format, @@ -415,13 +403,10 @@ impl FleetProvisioner { }) } - fn handle_error( + fn handle_error( format: PayloadFormat, - mut message: Message<'_, B, SUBS>, - ) -> Result<(), Error> - where - BitsImpl<{ SUBS }>: Bits, - { + mut message: Message<'_, M, B>, + ) -> Result<(), Error> { error!(">> {:?}", message.topic_name()); let response = match format { diff --git a/src/shadows/mod.rs b/src/shadows/mod.rs index 5d46ac0..f0f3b96 100644 --- a/src/shadows/mod.rs +++ b/src/shadows/mod.rs @@ -6,7 +6,6 @@ pub mod topics; use core::{marker::PhantomData, ops::DerefMut, sync::atomic}; -use bitmaps::{Bits, BitsImpl}; pub use data_types::Patch; use embassy_sync::{ blocking_mutex::raw::{NoopRawMutex, RawMutex}, @@ -34,22 +33,20 @@ pub trait ShadowState: ShadowPatch + Default { const MAX_PAYLOAD_SIZE: usize = 512; } -struct ShadowHandler<'a, 'm, M: RawMutex, S: ShadowState, const SUBS: usize> +struct ShadowHandler<'a, 'm, M: RawMutex, S: ShadowState> where - BitsImpl<{ SUBS }>: Bits, [(); S::MAX_PAYLOAD_SIZE + PARTIAL_REQUEST_OVERHEAD]:, { - mqtt: &'m embedded_mqtt::MqttClient<'a, M, SUBS>, - subscription: Mutex>>, + mqtt: &'m embedded_mqtt::MqttClient<'a, M>, + subscription: Mutex>>, _shadow: PhantomData, // request_lock is used to ensure that shadow operations such as subscribing, updating, or // deleting are serialized, preventing multiple concurrent requests to the same MQTT topics. request_lock: Mutex, } -impl<'a, 'm, M: RawMutex, S: ShadowState, const SUBS: usize> ShadowHandler<'a, 'm, M, S, SUBS> +impl<'a, 'm, M: RawMutex, S: ShadowState> ShadowHandler<'a, 'm, M, S> where - BitsImpl<{ SUBS }>: Bits, [(); S::MAX_PAYLOAD_SIZE + PARTIAL_REQUEST_OVERHEAD]:, { async fn handle_delta(&self) -> Result, Error> { @@ -344,7 +341,7 @@ where &self, topic: topics::Topic, payload: impl ToPayload, - ) -> Result, Error> { + ) -> Result, Error> { let (accepted, rejected) = match topic { Topic::Get => (Topic::GetAccepted, Topic::GetRejected), Topic::Update => (Topic::UpdateAccepted, Topic::UpdateRejected), @@ -394,18 +391,16 @@ where } } -pub struct PersistedShadow<'a, 'm, S: ShadowState, M: RawMutex, D: ShadowDAO, const SUBS: usize> +pub struct PersistedShadow<'a, 'm, S: ShadowState, M: RawMutex, D: ShadowDAO> where - BitsImpl<{ SUBS }>: Bits, [(); S::MAX_PAYLOAD_SIZE + PARTIAL_REQUEST_OVERHEAD]:, { - handler: ShadowHandler<'a, 'm, M, S, SUBS>, + handler: ShadowHandler<'a, 'm, M, S>, pub(crate) dao: Mutex, } -impl<'a, 'm, S, M, D, const SUBS: usize> PersistedShadow<'a, 'm, S, M, D, SUBS> +impl<'a, 'm, S, M, D> PersistedShadow<'a, 'm, S, M, D> where - BitsImpl<{ SUBS }>: Bits, S: ShadowState + Default, M: RawMutex, D: ShadowDAO, @@ -413,7 +408,7 @@ where { /// Instantiate a new shadow that will be automatically persisted to NVM /// based on the passed `DAO`. - pub fn new(mqtt: &'m embedded_mqtt::MqttClient<'a, M, SUBS>, dao: D) -> Self { + pub fn new(mqtt: &'m embedded_mqtt::MqttClient<'a, M>, dao: D) -> Self { let handler = ShadowHandler { mqtt, subscription: Mutex::new(None), @@ -522,24 +517,22 @@ where } } -pub struct Shadow<'a, 'm, S: ShadowState, M: RawMutex, const SUBS: usize> +pub struct Shadow<'a, 'm, S: ShadowState, M: RawMutex> where - BitsImpl<{ SUBS }>: Bits, [(); S::MAX_PAYLOAD_SIZE + PARTIAL_REQUEST_OVERHEAD]:, { state: S, - handler: ShadowHandler<'a, 'm, M, S, SUBS>, + handler: ShadowHandler<'a, 'm, M, S>, } -impl<'a, 'm, S, M, const SUBS: usize> Shadow<'a, 'm, S, M, SUBS> +impl<'a, 'm, S, M> Shadow<'a, 'm, S, M> where - BitsImpl<{ SUBS }>: Bits, S: ShadowState, M: RawMutex, [(); S::MAX_PAYLOAD_SIZE + PARTIAL_REQUEST_OVERHEAD]:, { /// Instantiate a new non-persisted shadow - pub fn new(state: S, mqtt: &'m embedded_mqtt::MqttClient<'a, M, SUBS>) -> Self { + pub fn new(state: S, mqtt: &'m embedded_mqtt::MqttClient<'a, M>) -> Self { let handler = ShadowHandler { mqtt, subscription: Mutex::new(None), @@ -613,9 +606,8 @@ where } } -impl<'a, 'm, S, M, const SUBS: usize> core::fmt::Debug for Shadow<'a, 'm, S, M, SUBS> +impl<'a, 'm, S, M> core::fmt::Debug for Shadow<'a, 'm, S, M> where - BitsImpl<{ SUBS }>: Bits, S: ShadowState + core::fmt::Debug, M: RawMutex, [(); S::MAX_PAYLOAD_SIZE + PARTIAL_REQUEST_OVERHEAD]:, @@ -631,9 +623,8 @@ where } #[cfg(feature = "defmt")] -impl<'a, 'm, S, M, const SUBS: usize> defmt::Format for Shadow<'a, 'm, S, M, SUBS> +impl<'a, 'm, S, M> defmt::Format for Shadow<'a, 'm, S, M> where - BitsImpl<{ SUBS }>: Bits, S: ShadowState + defmt::Format, M: RawMutex, [(); S::MAX_PAYLOAD_SIZE + PARTIAL_REQUEST_OVERHEAD]:, diff --git a/tests/ota_mqtt.rs b/tests/ota_mqtt.rs index b38d10e..ddd7b78 100644 --- a/tests/ota_mqtt.rs +++ b/tests/ota_mqtt.rs @@ -3,7 +3,6 @@ mod common; -use bitmaps::{Bits, BitsImpl}; use common::credentials; use common::file_handler::{FileHandler, State as FileHandlerState}; use common::network::TlsNetwork; @@ -44,13 +43,10 @@ impl<'a> Jobs<'a> { } } -fn handle_ota<'a, const SUBS: usize>( - message: Message<'a, SliceBufferProvider<'a>, SUBS>, +fn handle_ota<'a>( + message: Message<'a, NoopRawMutex, SliceBufferProvider<'a>>, config: &ota::config::Config, -) -> Option -where - BitsImpl: Bits, -{ +) -> Option { let job = match jobs::Topic::from_str(message.topic_name()) { Some(jobs::Topic::NotifyNext) => { let (execution_changed, _) = @@ -103,7 +99,7 @@ async fn test_mqtt_ota() { .keepalive_interval(embassy_time::Duration::from_secs(50)) .build(); - static STATE: StaticCell> = StaticCell::new(); + static STATE: StaticCell> = StaticCell::new(); let state = STATE.init(State::new()); let (mut stack, client) = embedded_mqtt::new(state, config); diff --git a/tests/provisioning.rs b/tests/provisioning.rs index 44e9570..8a51f52 100644 --- a/tests/provisioning.rs +++ b/tests/provisioning.rs @@ -80,8 +80,8 @@ async fn test_provisioning() { .keepalive_interval(embassy_time::Duration::from_secs(50)) .build(); - static STATE: StaticCell> = StaticCell::new(); - let state = STATE.init(State::::new()); + static STATE: StaticCell> = StaticCell::new(); + let state = STATE.init(State::new()); let (mut stack, client) = embedded_mqtt::new(state, config); let signing_key = credentials::signing_key(); @@ -96,14 +96,14 @@ async fn test_provisioning() { let mut credential_handler = CredentialDAO { creds: None }; #[cfg(not(feature = "provision_cbor"))] - let provision_fut = FleetProvisioner::provision::( + let provision_fut = FleetProvisioner::provision::( &client, &template_name, Some(parameters), &mut credential_handler, ); #[cfg(feature = "provision_cbor")] - let provision_fut = FleetProvisioner::provision_cbor::( + let provision_fut = FleetProvisioner::provision_cbor::( &client, &template_name, Some(parameters), diff --git a/tests/shadows.rs b/tests/shadows.rs index 86e2d93..0d9cdb0 100644 --- a/tests/shadows.rs +++ b/tests/shadows.rs @@ -40,8 +40,6 @@ use serde::{Deserialize, Serialize}; use serde_json::json; use static_cell::StaticCell; -const MAX_SUBSCRIBERS: usize = 8; - #[derive(Debug, Default, Serialize, Deserialize, ShadowState, PartialEq)] #[shadow("state")] pub struct TestShadow { @@ -51,7 +49,7 @@ pub struct TestShadow { } /// Helper function to mimic cloud side updates using MQTT client directly -async fn cloud_update(client: &MqttClient<'static, NoopRawMutex, MAX_SUBSCRIBERS>, payload: &[u8]) { +async fn cloud_update(client: &MqttClient<'static, NoopRawMutex>, payload: &[u8]) { client .publish( Publish::builder() @@ -70,10 +68,7 @@ async fn cloud_update(client: &MqttClient<'static, NoopRawMutex, MAX_SUBSCRIBERS } /// Helper function to assert on the current shadow state -async fn assert_shadow( - client: &MqttClient<'static, NoopRawMutex, MAX_SUBSCRIBERS>, - expected: serde_json::Value, -) { +async fn assert_shadow(client: &MqttClient<'static, NoopRawMutex>, expected: serde_json::Value) { let mut get_shadow_sub = client .subscribe::<1>( Subscribe::builder() @@ -143,13 +138,12 @@ async fn test_shadow_update_from_device() { .keepalive_interval(embassy_time::Duration::from_secs(50)) .build(); - static STATE: StaticCell> = - StaticCell::new(); + static STATE: StaticCell> = StaticCell::new(); let state = STATE.init(State::new()); let (mut stack, client) = embedded_mqtt::new(state, config); // Create the shadow - let mut shadow = Shadow::::new(TestShadow::default(), &client); + let mut shadow = Shadow::::new(TestShadow::default(), &client); // let delta_fut = async { // loop { From 798f496fe97a12b2b213d5148e5fcaf624b01d51 Mon Sep 17 00:00:00 2001 From: Mathias Date: Wed, 6 Nov 2024 09:29:23 +0100 Subject: [PATCH 31/36] Change default block size, and update frequency --- src/ota/config.rs | 4 ++-- src/ota/mod.rs | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/ota/config.rs b/src/ota/config.rs index d57cd75..0d6395c 100644 --- a/src/ota/config.rs +++ b/src/ota/config.rs @@ -11,10 +11,10 @@ pub struct Config { impl Default for Config { fn default() -> Self { Self { - block_size: 1024, + block_size: 256, max_request_momentum: 3, request_wait: Duration::from_secs(5), - status_update_frequency: 24, + status_update_frequency: 96, self_test_timeout: None, } } diff --git a/src/ota/mod.rs b/src/ota/mod.rs index 008dd2c..10b5b15 100644 --- a/src/ota/mod.rs +++ b/src/ota/mod.rs @@ -174,6 +174,7 @@ impl Updater { } Err(e) if e.is_retryable() => { // ... (Handle retryable errors) ... + error!("Failed block validation: {:?}! Retrying", e); } Err(e) => { // ... (Handle fatal errors) ... From f942b03be083378252434cb89ef97b9cc7fe58b3 Mon Sep 17 00:00:00 2001 From: Kenneth Knudsen <98805797+KennethKnudsen97@users.noreply.github.com> Date: Wed, 6 Nov 2024 10:55:38 +0100 Subject: [PATCH 32/36] Enhancement: Replace serde_cbor with minicbor (#68) * Replace serde_cbor with minicbor * Add cursor around cbor writer to get write position back --------- Co-authored-by: Mathias --- Cargo.toml | 13 +++++++++---- src/lib.rs | 2 -- src/ota/data_interface/mqtt.rs | 2 +- src/ota/encoding/cbor.rs | 7 ++++--- src/provisioning/error.rs | 4 ++-- src/provisioning/mod.rs | 15 ++++++--------- src/shadows/dao.rs | 15 ++++++--------- 7 files changed, 28 insertions(+), 30 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 9980820..862531f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,7 +25,11 @@ maintenance = { status = "actively-developed" } bitmaps = { version = "3.1", default-features = false } heapless = { version = "0.8", features = ["serde"] } serde = { version = "1.0", default-features = false, features = ["derive"] } -serde_cbor = { version = "0.11", default-features = false, optional = true } + +minicbor = { version = "0.25", optional = true } +minicbor-serde = { version = "0.3.2", optional = true } + + serde-json-core = { version = "0.6" } shadow-derive = { path = "shadow_derive", version = "0.2.1" } embedded-storage-async = "0.4" @@ -69,12 +73,13 @@ hex = { version = "0.4.3", features = ["alloc"] } [features] default = ["ota_mqtt_data", "provision_cbor"] -provision_cbor = ["serde_cbor"] +provision_cbor = ["dep:minicbor", "dep:minicbor-serde"] + +ota_mqtt_data = ["dep:minicbor", "dep:minicbor-serde"] -ota_mqtt_data = ["serde_cbor"] ota_http_data = [] -std = ["serde/std", "serde_cbor?/std"] +std = ["serde/std", "minicbor-serde?/std"] defmt = [ "dep:defmt", diff --git a/src/lib.rs b/src/lib.rs index ca160a6..f5d758f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,5 +11,3 @@ pub mod jobs; pub mod ota; pub mod provisioning; pub mod shadows; - -pub use serde_cbor; diff --git a/src/ota/data_interface/mqtt.rs b/src/ota/data_interface/mqtt.rs index b6516d1..4afc2fb 100644 --- a/src/ota/data_interface/mqtt.rs +++ b/src/ota/data_interface/mqtt.rs @@ -209,7 +209,7 @@ impl<'a, M: RawMutex> DataInterface for MqttClient<'a, M> { /// Decode a cbor encoded fileblock received from streaming service fn decode_file_block<'c>(&self, payload: &'c mut [u8]) -> Result, OtaError> { Ok( - serde_cbor::de::from_mut_slice::(payload) + minicbor_serde::from_slice::(payload) .map_err(|_| OtaError::Encoding)? .into(), ) diff --git a/src/ota/encoding/cbor.rs b/src/ota/encoding/cbor.rs index a30f45e..ad0f124 100644 --- a/src/ota/encoding/cbor.rs +++ b/src/ota/encoding/cbor.rs @@ -76,9 +76,10 @@ pub fn to_slice(value: &T, slice: &mut [u8]) -> Result where T: serde::ser::Serialize, { - let mut serializer = serde_cbor::ser::Serializer::new(serde_cbor::ser::SliceWrite::new(slice)); + let mut serializer = + minicbor_serde::Serializer::new(minicbor::encode::write::Cursor::new(slice)); value.serialize(&mut serializer).map_err(|_| ())?; - Ok(serializer.into_inner().bytes_written()) + Ok(serializer.into_encoder().writer().position()) } impl<'a> From> for FileBlock<'a> { @@ -170,7 +171,7 @@ mod test { 0, 0, 0, 0, 0, 0, 255, ]; - let response: GetStreamResponse = serde_cbor::de::from_mut_slice(payload).unwrap(); + let response: GetStreamResponse = minicbor_serde::from_slice(payload).unwrap(); assert_eq!( response, diff --git a/src/provisioning/error.rs b/src/provisioning/error.rs index 67f4696..20961fb 100644 --- a/src/provisioning/error.rs +++ b/src/provisioning/error.rs @@ -24,8 +24,8 @@ impl From for Error { } } -impl From for Error { - fn from(_e: serde_cbor::Error) -> Self { +impl From for Error { + fn from(_e: minicbor_serde::error::DecodeError) -> Self { Self::DeserializeCbor } } diff --git a/src/provisioning/mod.rs b/src/provisioning/mod.rs index 0a27f83..3797208 100644 --- a/src/provisioning/mod.rs +++ b/src/provisioning/mod.rs @@ -206,12 +206,11 @@ impl FleetProvisioner { Ok(match payload_format { #[cfg(feature = "provision_cbor")] PayloadFormat::Cbor => { - let mut serializer = - serde_cbor::ser::Serializer::new(serde_cbor::ser::SliceWrite::new(buf)); + let mut serializer = minicbor_serde::Serializer::new(buf); register_request .serialize(&mut serializer) .map_err(|_| EncodingError::BufferSize)?; - serializer.into_inner().bytes_written() + serializer.into_encoder().writer().len() } PayloadFormat::Json => serde_json_core::to_slice(®ister_request, buf) .map_err(|_| EncodingError::BufferSize)?, @@ -318,13 +317,11 @@ impl FleetProvisioner { Ok(match payload_format { #[cfg(feature = "provision_cbor")] PayloadFormat::Cbor => { - let mut serializer = serde_cbor::ser::Serializer::new( - serde_cbor::ser::SliceWrite::new(buf), - ); + let mut serializer = minicbor_serde::Serializer::new(buf); request .serialize(&mut serializer) .map_err(|_| EncodingError::BufferSize)?; - serializer.into_inner().bytes_written() + serializer.into_encoder().writer().len() } PayloadFormat::Json => serde_json_core::to_slice(&request, buf) .map_err(|_| EncodingError::BufferSize)?, @@ -398,7 +395,7 @@ impl FleetProvisioner { Ok(match payload_format { #[cfg(feature = "provision_cbor")] - PayloadFormat::Cbor => serde_cbor::de::from_mut_slice::(message.payload_mut())?, + PayloadFormat::Cbor => minicbor_serde::from_slice::(message.payload_mut())?, PayloadFormat::Json => serde_json_core::from_slice::(message.payload())?.0, }) } @@ -412,7 +409,7 @@ impl FleetProvisioner { let response = match format { #[cfg(feature = "provision_cbor")] PayloadFormat::Cbor => { - serde_cbor::de::from_mut_slice::(message.payload_mut())? + minicbor_serde::from_slice::(message.payload_mut())? } PayloadFormat::Json => { serde_json_core::from_slice::(message.payload())?.0 diff --git a/src/shadows/dao.rs b/src/shadows/dao.rs index 1435cd6..abbb85f 100644 --- a/src/shadows/dao.rs +++ b/src/shadows/dao.rs @@ -28,10 +28,8 @@ where } Ok( - serde_cbor::de::from_mut_slice::( - &mut buf[U32_SIZE..len as usize + U32_SIZE], - ) - .map_err(|_| Error::InvalidPayload)?, + minicbor_serde::from_slice::(&mut buf[U32_SIZE..len as usize + U32_SIZE]) + .map_err(|_| Error::InvalidPayload)?, ) } _ => Err(Error::InvalidPayload), @@ -43,14 +41,13 @@ where let buf = &mut [0u8; S::MAX_PAYLOAD_SIZE + U32_SIZE]; - let mut serializer = serde_cbor::ser::Serializer::new(serde_cbor::ser::SliceWrite::new( - &mut buf[U32_SIZE..], - )) - .packed_format(); + let mut serializer = minicbor_serde::Serializer::new(&mut buf[U32_SIZE..]); + state .serialize(&mut serializer) .map_err(|_| Error::InvalidPayload)?; - let len = serializer.into_inner().bytes_written(); + + let len = serializer.into_encoder().writer().len(); if len > S::MAX_PAYLOAD_SIZE { return Err(Error::Overflow); From 7e7eab84fe3a4c9b582d3c929469b033dd54b1fb Mon Sep 17 00:00:00 2001 From: Mathias Date: Thu, 21 Nov 2024 13:26:44 +0100 Subject: [PATCH 33/36] Fix minicbor serialize returning number of written bytes correctly --- src/provisioning/mod.rs | 12 ++++++++---- src/shadows/dao.rs | 6 ++++-- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/provisioning/mod.rs b/src/provisioning/mod.rs index 3797208..e0bcd95 100644 --- a/src/provisioning/mod.rs +++ b/src/provisioning/mod.rs @@ -206,11 +206,13 @@ impl FleetProvisioner { Ok(match payload_format { #[cfg(feature = "provision_cbor")] PayloadFormat::Cbor => { - let mut serializer = minicbor_serde::Serializer::new(buf); + let mut serializer = minicbor_serde::Serializer::new( + minicbor::encode::write::Cursor::new(buf), + ); register_request .serialize(&mut serializer) .map_err(|_| EncodingError::BufferSize)?; - serializer.into_encoder().writer().len() + serializer.into_encoder().writer().position() } PayloadFormat::Json => serde_json_core::to_slice(®ister_request, buf) .map_err(|_| EncodingError::BufferSize)?, @@ -317,11 +319,13 @@ impl FleetProvisioner { Ok(match payload_format { #[cfg(feature = "provision_cbor")] PayloadFormat::Cbor => { - let mut serializer = minicbor_serde::Serializer::new(buf); + let mut serializer = minicbor_serde::Serializer::new( + minicbor::encode::write::Cursor::new(buf), + ); request .serialize(&mut serializer) .map_err(|_| EncodingError::BufferSize)?; - serializer.into_encoder().writer().len() + serializer.into_encoder().writer().position() } PayloadFormat::Json => serde_json_core::to_slice(&request, buf) .map_err(|_| EncodingError::BufferSize)?, diff --git a/src/shadows/dao.rs b/src/shadows/dao.rs index abbb85f..b170688 100644 --- a/src/shadows/dao.rs +++ b/src/shadows/dao.rs @@ -41,13 +41,15 @@ where let buf = &mut [0u8; S::MAX_PAYLOAD_SIZE + U32_SIZE]; - let mut serializer = minicbor_serde::Serializer::new(&mut buf[U32_SIZE..]); + let mut serializer = minicbor_serde::Serializer::new(minicbor::encode::write::Cursor::new( + &mut buf[U32_SIZE..], + )); state .serialize(&mut serializer) .map_err(|_| Error::InvalidPayload)?; - let len = serializer.into_encoder().writer().len(); + let len = serializer.into_encoder().writer().position(); if len > S::MAX_PAYLOAD_SIZE { return Err(Error::Overflow); From 2c5757d0a26d81e94a8c968b15b6715a887081f3 Mon Sep 17 00:00:00 2001 From: Mathias Date: Mon, 13 Jan 2025 08:19:00 +0100 Subject: [PATCH 34/36] Bump embassy dependencies to newly released versions --- Cargo.toml | 6 ++--- src/provisioning/mod.rs | 60 +++++++++++++---------------------------- 2 files changed, 22 insertions(+), 44 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 862531f..c3e36cd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,11 +33,11 @@ minicbor-serde = { version = "0.3.2", optional = true } serde-json-core = { version = "0.6" } shadow-derive = { path = "shadow_derive", version = "0.2.1" } embedded-storage-async = "0.4" -embedded-mqtt = { git = "ssh://git@github.com/FactbirdHQ/embedded-mqtt", rev = "74eb53d" } +embedded-mqtt = { git = "ssh://git@github.com/FactbirdHQ/embedded-mqtt", rev = "dc7c390" } futures = { version = "0.3.28", default-features = false } -embassy-time = { version = "0.3" } +embassy-time = { version = "0.4" } embassy-sync = "0.6" embassy-futures = "0.1" @@ -62,7 +62,7 @@ tokio = { version = "1.33", default-features = false, features = [ ] } tokio-native-tls = { version = "0.3.1" } embassy-futures = { version = "0.1.0" } -embassy-time = { version = "0.3", features = ["log", "std", "generic-queue"] } +embassy-time = { version = "0.4", features = ["log", "std", "generic-queue"] } embedded-io-adapters = { version = "0.6.0", features = ["tokio-1"] } ecdsa = { version = "0.16", features = ["pkcs8", "pem"] } diff --git a/src/provisioning/mod.rs b/src/provisioning/mod.rs index e0bcd95..3327762 100644 --- a/src/provisioning/mod.rs +++ b/src/provisioning/mod.rs @@ -136,8 +136,6 @@ impl FleetProvisioner { where C: DeserializeOwned, { - use embedded_mqtt::SliceBufferProvider; - let mut create_subscription = Self::begin(mqtt, csr, payload_format).await?; let mut message = create_subscription .next() @@ -146,11 +144,8 @@ impl FleetProvisioner { let ownership_token = match Topic::from_str(message.topic_name()) { Some(Topic::CreateKeysAndCertificateAccepted(format)) => { - let response = Self::deserialize::< - CreateKeysAndCertificateResponse, - M, - SliceBufferProvider<'a>, - >(format, &mut message)?; + let response = + Self::deserialize::(format, &mut message)?; credential_handler .store_credentials(Credentials { @@ -164,11 +159,10 @@ impl FleetProvisioner { } Some(Topic::CreateCertificateFromCsrAccepted(format)) => { - let response = Self::deserialize::< - CreateCertificateFromCsrResponse, - M, - SliceBufferProvider<'a>, - >(format, &mut message)?; + let response = Self::deserialize::( + format, + message.payload_mut(), + )?; credential_handler .store_credentials(Credentials { @@ -186,7 +180,7 @@ impl FleetProvisioner { Topic::CreateKeysAndCertificateRejected(format) | Topic::CreateCertificateFromCsrRejected(format), ) => { - return Err(Self::handle_error(format, message).unwrap_err()); + return Err(Self::handle_error(format, message.payload_mut()).unwrap_err()); } t => { @@ -259,18 +253,17 @@ impl FleetProvisioner { match Topic::from_str(message.topic_name()) { Some(Topic::RegisterThingAccepted(_, format)) => { - let response = Self::deserialize::< - RegisterThingResponse<'_, C>, - M, - SliceBufferProvider<'a>, - >(format, &mut message)?; + let response = Self::deserialize::>( + format, + message.payload_mut(), + )?; Ok(response.device_configuration) } // Error happened! Some(Topic::RegisterThingRejected(_, format)) => { - Err(Self::handle_error(format, message).unwrap_err()) + Err(Self::handle_error(format, message.payload_mut()).unwrap_err()) } t => { @@ -387,37 +380,22 @@ impl FleetProvisioner { } } - fn deserialize<'a, R: Deserialize<'a>, M: RawMutex, B: BufferProvider>( + fn deserialize<'a, R: Deserialize<'a>>( payload_format: PayloadFormat, - message: &'a mut Message<'_, M, B>, + payload: &'a mut [u8], ) -> Result { - trace!( - "Accepted Topic {:?}. Payload len: {:?}", - payload_format, - message.payload().len() - ); - Ok(match payload_format { #[cfg(feature = "provision_cbor")] - PayloadFormat::Cbor => minicbor_serde::from_slice::(message.payload_mut())?, - PayloadFormat::Json => serde_json_core::from_slice::(message.payload())?.0, + PayloadFormat::Cbor => minicbor_serde::from_slice::(payload)?, + PayloadFormat::Json => serde_json_core::from_slice::(payload)?.0, }) } - fn handle_error( - format: PayloadFormat, - mut message: Message<'_, M, B>, - ) -> Result<(), Error> { - error!(">> {:?}", message.topic_name()); - + fn handle_error(format: PayloadFormat, payload: &mut [u8]) -> Result<(), Error> { let response = match format { #[cfg(feature = "provision_cbor")] - PayloadFormat::Cbor => { - minicbor_serde::from_slice::(message.payload_mut())? - } - PayloadFormat::Json => { - serde_json_core::from_slice::(message.payload())?.0 - } + PayloadFormat::Cbor => minicbor_serde::from_slice::(payload)?, + PayloadFormat::Json => serde_json_core::from_slice::(payload)?.0, }; error!("{:?}", response); From 3e2cad085c8dcc2747cfbc776700391cbecbf8e9 Mon Sep 17 00:00:00 2001 From: Kenneth Knudsen <98805797+KennethKnudsen97@users.noreply.github.com> Date: Thu, 6 Feb 2025 19:39:08 +0100 Subject: [PATCH 35/36] remove locks (#71) --- src/shadows/mod.rs | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/src/shadows/mod.rs b/src/shadows/mod.rs index f0f3b96..e717a6e 100644 --- a/src/shadows/mod.rs +++ b/src/shadows/mod.rs @@ -40,9 +40,6 @@ where mqtt: &'m embedded_mqtt::MqttClient<'a, M>, subscription: Mutex>>, _shadow: PhantomData, - // request_lock is used to ensure that shadow operations such as subscribing, updating, or - // deleting are serialized, preventing multiple concurrent requests to the same MQTT topics. - request_lock: Mutex, } impl<'a, 'm, M: RawMutex, S: ShadowState> ShadowHandler<'a, 'm, M, S> @@ -106,8 +103,6 @@ where /// Internal helper function for applying a delta state to the actual shadow /// state, and update the cloud shadow. async fn report(&self, reported: &R) -> Result<(), Error> { - let _update_requested_lock = self.request_lock.lock().await; - debug!( "[{:?}] Updating reported shadow value.", S::NAME.unwrap_or(CLASSIC_SHADOW), @@ -185,8 +180,6 @@ where /// Initiate a `GetShadow` request, updating the local state from the cloud. async fn get_shadow(&self) -> Result, Error> { - let _get_requested_lock = self.request_lock.lock().await; - //Wait for mqtt to connect self.mqtt.wait_connected().await; @@ -233,8 +226,6 @@ where } pub async fn delete_shadow(&self) -> Result<(), Error> { - let _delete_request = self.request_lock.lock().await; - // Wait for mqtt to connect self.mqtt.wait_connected().await; @@ -266,8 +257,6 @@ where } pub async fn create_shadow(&self) -> Result, Error> { - let _create_requested_lock = self.request_lock.lock().await; - debug!( "[{:?}] Creating initial shadow value.", S::NAME.unwrap_or(CLASSIC_SHADOW), @@ -413,7 +402,6 @@ where mqtt, subscription: Mutex::new(None), _shadow: PhantomData, - request_lock: Mutex::new(()), }; Self { @@ -537,7 +525,6 @@ where mqtt, subscription: Mutex::new(None), _shadow: PhantomData, - request_lock: Mutex::new(()), }; Self { handler, state } } From 91034929720e5d8e85688d5151b1442162895220 Mon Sep 17 00:00:00 2001 From: Kenneth Knudsen <98805797+KennethKnudsen97@users.noreply.github.com> Date: Fri, 7 Feb 2025 15:57:20 +0100 Subject: [PATCH 36/36] Feature: Implement defender metrics (#70) * Initial structure of defender metrics * metric structure and tests * generic custom metric * Change Custom metric to use references * aws types with references * impl tuple for Version * remove timestamp as argument for function in custom metric * include aws metrics and use bon crate for building metric struct * error handling * error handling * error handling * feature flag for cbor and temp fix for Header serialize * smal changes * String list example * Metric integration test * Update src/defender_metrics/data_types.rs Co-authored-by: Mathias Koch * Cargo clippy and unit test * cargo clippy fix * fixed unit test and version serialization * fix test --------- Co-authored-by: Mathias Koch --- .vscode/settings.json | 5 +- Cargo.toml | 13 +- rust-toolchain.toml | 2 +- shadow_derive/src/lib.rs | 24 +- src/defender_metrics/aws_types.rs | 81 ++++++ src/defender_metrics/data_types.rs | 84 ++++++ src/defender_metrics/errors.rs | 29 ++ src/defender_metrics/mod.rs | 444 +++++++++++++++++++++++++++++ src/defender_metrics/topics.rs | 78 +++++ src/jobs/describe.rs | 2 +- src/jobs/subscribe.rs | 2 +- src/jobs/update.rs | 4 +- src/lib.rs | 1 + src/ota/control_interface/mqtt.rs | 12 +- src/ota/data_interface/mqtt.rs | 2 +- src/ota/encoding/mod.rs | 2 +- src/ota/mod.rs | 22 +- src/provisioning/mod.rs | 4 +- src/provisioning/topics.rs | 6 +- src/shadows/mod.rs | 6 +- src/shadows/topics.rs | 8 +- tests/common/network.rs | 8 +- tests/metric.rs | 119 ++++++++ 23 files changed, 898 insertions(+), 60 deletions(-) create mode 100644 src/defender_metrics/aws_types.rs create mode 100644 src/defender_metrics/data_types.rs create mode 100644 src/defender_metrics/errors.rs create mode 100644 src/defender_metrics/mod.rs create mode 100644 src/defender_metrics/topics.rs create mode 100644 tests/metric.rs diff --git a/.vscode/settings.json b/.vscode/settings.json index 48fb5ea..c60380a 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,5 +1,6 @@ { - "rust-analyzer.checkOnSave.allTargets": false, + "rust-analyzer.checkOnSave.allTargets": true, "rust-analyzer.cargo.features": ["log"], - "rust-analyzer.cargo.target": "x86_64-unknown-linux-gnu" + "rust-analyzer.cargo.target": "x86_64-unknown-linux-gnu", + } \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index c3e36cd..5b1b7a9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,6 +43,7 @@ embassy-futures = "0.1" log = { version = "0.4", default-features = false, optional = true } defmt = { version = "0.3", optional = true } +bon = { version = "3.3.2", default-features = false } [dev-dependencies] native-tls = { version = "0.2" } @@ -62,7 +63,7 @@ tokio = { version = "1.33", default-features = false, features = [ ] } tokio-native-tls = { version = "0.3.1" } embassy-futures = { version = "0.1.0" } -embassy-time = { version = "0.4", features = ["log", "std", "generic-queue"] } +embassy-time = { version = "0.4", features = ["log", "std", "generic-queue-8"] } embedded-io-adapters = { version = "0.6.0", features = ["tokio-1"] } ecdsa = { version = "0.16", features = ["pkcs8", "pem"] } @@ -70,8 +71,11 @@ p256 = "0.13" pkcs8 = { version = "0.10", features = ["encryption", "pem"] } hex = { version = "0.4.3", features = ["alloc"] } + [features] -default = ["ota_mqtt_data", "provision_cbor"] +default = ["ota_mqtt_data", "metric_cbor", "provision_cbor"] + +metric_cbor = ["dep:minicbor", "dep:minicbor-serde"] provision_cbor = ["dep:minicbor", "dep:minicbor-serde"] @@ -89,5 +93,6 @@ defmt = [ ] log = ["dep:log", "embedded-mqtt/log"] -# [patch."ssh://git@github.com/FactbirdHQ/embedded-mqtt"] -# embedded-mqtt = { path = "../embedded-mqtt" } + +[patch."ssh://git@github.com/FactbirdHQ/embedded-mqtt"] +embedded-mqtt = { path = "../embedded-mqtt" } diff --git a/rust-toolchain.toml b/rust-toolchain.toml index b6369b9..ed37409 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,5 +1,5 @@ [toolchain] -channel = "nightly-2024-07-17" +channel = "nightly-2024-09-06" components = ["rust-src", "rustfmt", "llvm-tools"] targets = [ "x86_64-unknown-linux-gnu", diff --git a/shadow_derive/src/lib.rs b/shadow_derive/src/lib.rs index 09cdef0..06da9f5 100644 --- a/shadow_derive/src/lib.rs +++ b/shadow_derive/src/lib.rs @@ -118,8 +118,7 @@ fn create_assigners(fields: &Vec) -> Vec { if field .attrs .iter() - .find(|a| a.path.is_ident("static_shadow_field")) - .is_some() + .any(|a| a.path.is_ident("static_shadow_field")) { None } else { @@ -133,7 +132,7 @@ fn create_assigners(fields: &Vec) -> Vec { .collect::>() } -fn create_optional_fields(fields: &Vec) -> Vec { +fn create_optional_fields(fields: &[Field]) -> Vec { fields .iter() .filter_map(|field| { @@ -153,8 +152,7 @@ fn create_optional_fields(fields: &Vec) -> Vec if field .attrs .iter() - .find(|a| a.path.is_ident("static_shadow_field")) - .is_some() + .any(|a| a.path.is_ident("static_shadow_field")) { None } else { @@ -183,13 +181,13 @@ fn generate_shadow_state(input: &StructParseInput) -> proc_macro2::TokenStream { None => quote! { None }, }; - return quote! { + quote! { #[automatically_derived] impl #impl_generics rustot::shadows::ShadowState for #ident #ty_generics #where_clause { const NAME: Option<&'static str> = #name; // const MAX_PAYLOAD_SIZE: usize = 512; } - }; + } } fn generate_shadow_patch_struct(input: &StructParseInput) -> proc_macro2::TokenStream { @@ -205,10 +203,10 @@ fn generate_shadow_patch_struct(input: &StructParseInput) -> proc_macro2::TokenS let optional_ident = format_ident!("Patch{}", ident); - let assigners = create_assigners(&shadow_fields); - let optional_fields = create_optional_fields(&shadow_fields); + let assigners = create_assigners(shadow_fields); + let optional_fields = create_optional_fields(shadow_fields); - return quote! { + quote! { #[automatically_derived] #[derive(Default, Clone, ::serde::Deserialize, ::serde::Serialize)] #(#copy_attrs)* @@ -228,7 +226,7 @@ fn generate_shadow_patch_struct(input: &StructParseInput) -> proc_macro2::TokenS )* } } - }; + } } fn generate_shadow_patch_enum(input: &EnumParseInput) -> proc_macro2::TokenStream { @@ -238,7 +236,7 @@ fn generate_shadow_patch_enum(input: &EnumParseInput) -> proc_macro2::TokenStrea let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); - return quote! { + quote! { #[automatically_derived] impl #impl_generics rustot::shadows::ShadowPatch for #ident #ty_generics #where_clause { type PatchState = #ident #ty_generics; @@ -247,5 +245,5 @@ fn generate_shadow_patch_enum(input: &EnumParseInput) -> proc_macro2::TokenStrea *self = opt; } } - }; + } } diff --git a/src/defender_metrics/aws_types.rs b/src/defender_metrics/aws_types.rs new file mode 100644 index 0000000..484a5a0 --- /dev/null +++ b/src/defender_metrics/aws_types.rs @@ -0,0 +1,81 @@ +use serde::Serialize; + +#[derive(Debug, Serialize)] +pub struct TcpConnections<'a> { + #[serde(rename = "ec")] + pub established_connections: Option<&'a EstablishedConnections<'a>>, +} + +#[derive(Debug, Serialize)] +pub struct EstablishedConnections<'a> { + #[serde(rename = "cs")] + pub connections: Option<&'a [&'a Connection<'a>]>, + + #[serde(rename = "t")] + pub total: Option, +} + +#[derive(Debug, Serialize)] +pub struct Connection<'a> { + #[serde(rename = "rad")] + pub remote_addr: &'a str, + + /// Port number, must be >= 0 + #[serde(rename = "lp")] + pub local_port: Option, + + /// Interface name + #[serde(rename = "li")] + pub local_interface: Option<&'a str>, +} + +#[derive(Debug, Serialize)] +pub struct ListeningTcpPorts<'a> { + #[serde(rename = "pts")] + pub ports: Option<&'a [&'a TcpPort<'a>]>, + + #[serde(rename = "t")] + pub total: Option, +} + +#[derive(Debug, Serialize)] +pub struct TcpPort<'a> { + #[serde(rename = "pt")] + pub port: u16, + + #[serde(rename = "if")] + pub interface: Option<&'a str>, +} + +#[derive(Debug, Serialize)] +pub struct ListeningUdpPorts<'a> { + #[serde(rename = "pts")] + pub ports: Option<&'a [&'a UdpPort<'a>]>, + + #[serde(rename = "t")] + pub total: Option, +} + +#[derive(Debug, Serialize)] +pub struct UdpPort<'a> { + #[serde(rename = "pt")] + pub port: u16, + + #[serde(rename = "if")] + pub interface: Option<&'a str>, +} + +#[derive(Debug, Serialize)] +pub struct NetworkStats { + #[serde(rename = "bi")] + pub bytes_in: Option, + + #[serde(rename = "bo")] + pub bytes_out: Option, + + #[serde(rename = "pi")] + pub packets_in: Option, + + #[serde(rename = "po")] + pub packets_out: Option, +} diff --git a/src/defender_metrics/data_types.rs b/src/defender_metrics/data_types.rs new file mode 100644 index 0000000..2f19cd9 --- /dev/null +++ b/src/defender_metrics/data_types.rs @@ -0,0 +1,84 @@ +use core::fmt::{Display, Write}; + +use bon::Builder; +use embassy_time::Instant; +use serde::{ser::SerializeStruct, Deserialize, Serialize}; + +use super::aws_types::{ListeningTcpPorts, ListeningUdpPorts, NetworkStats, TcpConnections}; + +#[derive(Debug, Serialize, Builder)] +pub struct Metric<'a, C: Serialize> { + #[serde(rename = "hed")] + pub header: Header, + + #[serde(rename = "met")] + pub metrics: Option>, + + #[serde(rename = "cmet")] + pub custom_metrics: Option, +} + +#[derive(Debug, Serialize)] +pub struct Metrics<'a> { + listening_tcp_ports: Option>, + listening_udp_ports: Option>, + network_stats: Option, + tcp_connections: Option>, +} + +#[derive(Debug, Serialize)] +pub struct Header { + /// Monotonically increasing value. Epoch timestamp recommended. + #[serde(rename = "rid")] + pub report_id: i64, + + /// Version in Major.Minor format. + #[serde(rename = "v")] + pub version: Version, +} + +impl Default for Header { + fn default() -> Self { + Self { + report_id: Instant::now().as_millis() as i64, + version: Default::default(), + } + } +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "snake_case")] +pub enum CustomMetric<'a> { + Number(i64), + NumberList(&'a [u64]), + StringList(&'a [&'a str]), + IpList(&'a [&'a str]), +} + +/// Format is `Version(Major, Minor)` +#[derive(Debug, PartialEq, Deserialize)] +pub struct Version(pub u8, pub u8); + +impl Serialize for Version { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let mut st: heapless::String<7> = heapless::String::new(); + st.write_fmt(format_args!("{}.{}", self.0, self.1)).ok(); + + serializer.serialize_str(&st) + } +} + +impl Display for Version { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(f, "{}.{}", self.0, self.1,) + } +} + +impl Default for Version { + fn default() -> Self { + Self(1, 0) + } +} diff --git a/src/defender_metrics/errors.rs b/src/defender_metrics/errors.rs new file mode 100644 index 0000000..5289e14 --- /dev/null +++ b/src/defender_metrics/errors.rs @@ -0,0 +1,29 @@ +use serde::Deserialize; + +#[derive(Debug, Deserialize)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct ErrorResponse<'a> { + #[serde(rename = "thingName")] + pub thing_name: &'a str, + pub status: &'a str, + #[serde(rename = "statusDetails")] + pub status_details: StatusDetails<'a>, + pub timestamp: i64, +} +#[derive(Debug, Deserialize)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct StatusDetails<'a> { + #[serde(rename = "ErrorCode")] + pub error_code: MetricError, + #[serde(rename = "ErrorMessage")] + pub error_message: Option<&'a str>, +} +#[derive(Debug, Deserialize)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum MetricError { + Malformed, + InvalidPayload, + Throttled, + MissingHeader, + Other, +} diff --git a/src/defender_metrics/mod.rs b/src/defender_metrics/mod.rs new file mode 100644 index 0000000..47aa3c0 --- /dev/null +++ b/src/defender_metrics/mod.rs @@ -0,0 +1,444 @@ +use crate::shadows::Error; +use data_types::Metric; +use embassy_sync::blocking_mutex::raw::RawMutex; +use embedded_mqtt::{DeferredPayload, Publish, Subscribe, SubscribeTopic, ToPayload}; +use errors::{ErrorResponse, MetricError}; +use futures::StreamExt; +use serde::Serialize; +use topics::Topic; + +// pub mod aws_types; +pub mod aws_types; +pub mod data_types; +pub mod errors; +pub mod topics; + +pub struct MetricHandler<'a, 'm, M: RawMutex> { + mqtt: &'m embedded_mqtt::MqttClient<'a, M>, +} + +impl<'a, 'm, M: RawMutex> MetricHandler<'a, 'm, M> { + pub fn new(mqtt: &'m embedded_mqtt::MqttClient<'a, M>) -> Self { + Self { mqtt } + } + + pub async fn publish_metric<'c, C: Serialize>( + &self, + metric: Metric<'c, C>, + max_payload_size: usize, + ) -> Result<(), MetricError> { + //Wait for mqtt to connect + self.mqtt.wait_connected().await; + + let payload = DeferredPayload::new( + |buf: &mut [u8]| { + #[cfg(feature = "metric_cbor")] + { + let mut serializer = minicbor_serde::Serializer::new( + minicbor::encode::write::Cursor::new(&mut *buf), + ); + + match metric.serialize(&mut serializer) { + Ok(_) => {} + Err(_) => { + error!("An error happened when serializing metric with cbor"); + return Err(embedded_mqtt::EncodingError::BufferSize); + } + }; + + Ok(serializer.into_encoder().writer().position()) + } + + #[cfg(not(feature = "metric_cbor"))] + { + serde_json_core::to_slice(&metric, buf) + .map_err(|_| embedded_mqtt::EncodingError::BufferSize) + } + }, + max_payload_size, + ); + + let mut subscription = self + .publish_and_subscribe(payload) + .await + .map_err(|_| MetricError::Other)?; + + loop { + let message = subscription.next().await.ok_or(MetricError::Malformed)?; + + match Topic::from_str(message.topic_name()) { + Some(Topic::Accepted) => return Ok(()), + Some(Topic::Rejected) => { + let error_response = + serde_json_core::from_slice::(message.payload()) + .map_err(|_| MetricError::InvalidPayload)?; + + return Err(error_response.0.status_details.error_code); + } + + _ => (), + }; + } + } + async fn publish_and_subscribe( + &self, + payload: impl ToPayload, + ) -> Result, Error> { + let sub = self + .mqtt + .subscribe::<2>( + Subscribe::builder() + .topics(&[ + SubscribeTopic::builder() + .topic_path( + Topic::Accepted + .format::<64>(self.mqtt.client_id())? + .as_str(), + ) + .build(), + SubscribeTopic::builder() + .topic_path( + Topic::Rejected + .format::<64>(self.mqtt.client_id())? + .as_str(), + ) + .build(), + ]) + .build(), + ) + .await + .map_err(Error::MqttError)?; + + //*** PUBLISH REQUEST ***/ + let topic_name = Topic::Publish.format::<64>(self.mqtt.client_id())?; + + match self + .mqtt + .publish( + Publish::builder() + .topic_name(topic_name.as_str()) + .payload(payload) + .build(), + ) + .await + .map_err(Error::MqttError) + { + Ok(_) => {} + Err(_) => { + error!("ERROR PUBLISHING PAYLOAD"); + return Err(Error::MqttError(embedded_mqtt::Error::BadTopicFilter)); + } + }; + + Ok(sub) + } +} + +#[cfg(test)] +mod tests { + use core::str::FromStr; + + use super::data_types::*; + + use heapless::{LinearMap, String}; + use serde::{ser::SerializeStruct, Serialize}; + + #[test] + fn serialize_version_json() { + let test_cases = [ + (Version(2, 0), "\"2.0\""), + (Version(0, 0), "\"0.0\""), + (Version(0, 1), "\"0.1\""), + (Version(255, 200), "\"255.200\""), + ]; + + for (version, expected) in test_cases.iter() { + let string: String<100> = serde_json_core::to_string(version).unwrap(); + assert_eq!( + string, *expected, + "Serialization failed for Version({}, {}): expected {}, got {}", + version.0, version.1, expected, string + ); + } + } + #[test] + fn serialize_version_cbor() { + let test_cases: [(Version, [u8; 8]); 4] = [ + (Version(2, 0), [99, 50, 46, 48, 0, 0, 0, 0]), + (Version(0, 0), [99, 48, 46, 48, 0, 0, 0, 0]), + (Version(0, 1), [99, 48, 46, 49, 0, 0, 0, 0]), + (Version(255, 200), [103, 50, 53, 53, 46, 50, 48, 48]), + ]; + + for (version, expected) in test_cases.iter() { + let mut buf = [0u8; 200]; + + let mut serializer = + minicbor_serde::Serializer::new(minicbor::encode::write::Cursor::new(&mut buf[..])); + + version.serialize(&mut serializer).unwrap(); + + let len = serializer.into_encoder().writer().position(); + + assert_eq!( + &buf[..len], + &expected[..len], + "Serialization failed for Version({}, {}): expected {:?}, got {:?}", + version.0, + version.1, + expected, + &buf[..len], + ); + } + } + + #[test] + fn custom_serialization_cbor() { + #[derive(Debug)] + struct WifiMetric { + signal_strength: u8, + } + + impl Serialize for WifiMetric { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let mut outer = serializer.serialize_struct("WifiMetricWrapper", 1)?; + + // Define the type we want to wrap our signal_strength field in + #[derive(Serialize)] + struct Number { + number: u8, + } + + let number = Number { + number: self.signal_strength, + }; + + // Serialize number and wrap in array + outer.serialize_field("MyMetricOfType_Number", &[number])?; + outer.end() + } + } + + let custom_metrics: WifiMetric = WifiMetric { + signal_strength: 23, + }; + + let metric = Metric::builder() + .header(Default::default()) + .custom_metrics(custom_metrics) + .build(); + + let mut buf = [255u8; 1000]; + + let mut serializer = + minicbor_serde::Serializer::new(minicbor::encode::write::Cursor::new(&mut buf[..])); + + metric.serialize(&mut serializer).unwrap(); + + let len = serializer.into_encoder().writer().position(); + + assert_eq!( + &buf[..len], + [ + 163, 99, 104, 101, 100, 162, 99, 114, 105, 100, 0, 97, 118, 99, 49, 46, 48, 99, + 109, 101, 116, 246, 100, 99, 109, 101, 116, 161, 117, 77, 121, 77, 101, 116, 114, + 105, 99, 79, 102, 84, 121, 112, 101, 95, 78, 117, 109, 98, 101, 114, 129, 161, 102, + 110, 117, 109, 98, 101, 114, 23 + ] + ) + } + + #[test] + fn custom_serialization() { + #[derive(Debug)] + struct WifiMetric { + signal_strength: u8, + } + + impl Serialize for WifiMetric { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let mut outer = serializer.serialize_struct("WifiMetricWrapper", 1)?; + + // Define the type we want to wrap our signal_strength field in + #[derive(Serialize)] + struct Number { + number: u8, + } + + let number = Number { + number: self.signal_strength, + }; + + // Serialize number and wrap in array + outer.serialize_field("MyMetricOfType_Number", &[number])?; + outer.end() + } + } + + let custom_metrics: WifiMetric = WifiMetric { + signal_strength: 23, + }; + + let metric = Metric::builder() + .header(Default::default()) + .custom_metrics(custom_metrics) + .build(); + + let payload: String<4000> = serde_json_core::to_string(&metric).unwrap(); + + assert_eq!("{\"hed\":{\"rid\":0,\"v\":\"1.0\"},\"met\":null,\"cmet\":{\"MyMetricOfType_Number\":[{\"number\":23}]}}", payload.as_str()) + } + #[test] + fn custom_serialization_string_list() { + #[derive(Debug)] + struct CellType { + cell_type: String, + } + + impl Serialize for CellType { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let mut outer = serializer.serialize_struct("CellType", 1)?; + + // Define the type we want to wrap our signal_strength field in + #[derive(Serialize)] + struct StringList<'a> { + string_list: &'a [&'a str], + } + + let list = StringList { + string_list: &[&self.cell_type.as_str()], + }; + + // Serialize number and wrap in array + outer.serialize_field("cell_type", &[list])?; + outer.end() + } + } + + let custom_metrics: CellType<4> = CellType { + cell_type: String::from_str("gsm").unwrap(), + }; + + let metric = Metric::builder() + .header(Default::default()) + .custom_metrics(custom_metrics) + .build(); + + let payload: String<4000> = serde_json_core::to_string(&metric).unwrap(); + + assert_eq!("{\"hed\":{\"rid\":0,\"v\":\"1.0\"},\"met\":null,\"cmet\":{\"cell_type\":[{\"string_list\":[\"gsm\"]}]}}", payload.as_str()) + } + #[test] + fn number() { + let mut custom_metrics: LinearMap, [CustomMetric; 1], 16> = LinearMap::new(); + + let name_of_metric = String::from_str("myMetric").unwrap(); + + custom_metrics + .insert(name_of_metric, [CustomMetric::Number(23)]) + .unwrap(); + + let metric = Metric::builder() + .header(Default::default()) + .custom_metrics(custom_metrics) + .build(); + + let payload: String<4000> = serde_json_core::to_string(&metric).unwrap(); + + assert_eq!("{\"hed\":{\"rid\":0,\"v\":\"1.0\"},\"met\":null,\"cmet\":{\"myMetric\":[{\"number\":23}]}}", payload.as_str()) + } + + #[test] + fn number_list() { + let mut custom_metrics: LinearMap, [CustomMetric; 1], 16> = LinearMap::new(); + + // NUMBER LIST + let my_number_list = String::from_str("my_number_list").unwrap(); + + custom_metrics + .insert(my_number_list, [CustomMetric::NumberList(&[123, 456, 789])]) + .unwrap(); + + let metric = Metric::builder() + .header(Default::default()) + .custom_metrics(custom_metrics) + .build(); + + let payload: String<4000> = serde_json_core::to_string(&metric).unwrap(); + + assert_eq!("{\"hed\":{\"rid\":0,\"v\":\"1.0\"},\"met\":null,\"cmet\":{\"my_number_list\":[{\"number_list\":[123,456,789]}]}}", payload.as_str()) + } + + #[test] + fn string_list() { + let mut custom_metrics: LinearMap, [CustomMetric; 1], 16> = LinearMap::new(); + + // STRING LIST + let my_string_list = String::from_str("my_string_list").unwrap(); + + custom_metrics + .insert( + my_string_list, + [CustomMetric::StringList(&["value_1", "value_2"])], + ) + .unwrap(); + + let metric = Metric::builder() + .header(Default::default()) + .custom_metrics(custom_metrics) + .build(); + + let payload: String<4000> = serde_json_core::to_string(&metric).unwrap(); + + assert_eq!("{\"hed\":{\"rid\":0,\"v\":\"1.0\"},\"met\":null,\"cmet\":{\"my_string_list\":[{\"string_list\":[\"value_1\",\"value_2\"]}]}}", payload.as_str()) + } + + #[test] + fn all_types() { + let mut custom_metrics: LinearMap, [CustomMetric; 1], 4> = LinearMap::new(); + + let my_number = String::from_str("MyMetricOfType_Number").unwrap(); + custom_metrics + .insert(my_number, [CustomMetric::Number(1)]) + .unwrap(); + + let my_number_list = String::from_str("MyMetricOfType_NumberList").unwrap(); + custom_metrics + .insert(my_number_list, [CustomMetric::NumberList(&[1, 2, 3])]) + .unwrap(); + + let my_string_list = String::from_str("MyMetricOfType_StringList").unwrap(); + custom_metrics + .insert( + my_string_list, + [CustomMetric::StringList(&["value_1", "value_2"])], + ) + .unwrap(); + + let my_ip_list = String::from_str("MyMetricOfType_IpList").unwrap(); + custom_metrics + .insert( + my_ip_list, + [CustomMetric::IpList(&["172.0.0.0", "172.0.0.10"])], + ) + .unwrap(); + + let metric = Metric::builder() + .header(Default::default()) + .custom_metrics(custom_metrics) + .build(); + + let payload: String<4000> = serde_json_core::to_string(&metric).unwrap(); + + assert_eq!("{\"hed\":{\"rid\":0,\"v\":\"1.0\"},\"met\":null,\"cmet\":{\"MyMetricOfType_Number\":[{\"number\":1}],\"MyMetricOfType_NumberList\":[{\"number_list\":[1,2,3]}],\"MyMetricOfType_StringList\":[{\"string_list\":[\"value_1\",\"value_2\"]}],\"MyMetricOfType_IpList\":[{\"ip_list\":[\"172.0.0.0\",\"172.0.0.10\"]}]}}", payload.as_str()) + } +} diff --git a/src/defender_metrics/topics.rs b/src/defender_metrics/topics.rs new file mode 100644 index 0000000..eba97f7 --- /dev/null +++ b/src/defender_metrics/topics.rs @@ -0,0 +1,78 @@ +#![allow(dead_code)] +use core::fmt::Write; + +use heapless::String; + +use crate::shadows::Error; + +pub enum PayloadFormat { + #[cfg(feature = "metric_cbor")] + Cbor, + #[cfg(not(feature = "metric_cbor"))] + Json, +} + +#[derive(PartialEq, Eq, Clone, Copy)] +pub enum Topic { + Accepted, + Rejected, + Publish, +} + +impl Topic { + const PREFIX: &'static str = "$aws/things"; + const NAME: &'static str = "defender/metrics"; + + #[cfg(feature = "metric_cbor")] + const PAYLOAD_FORMAT: &'static str = "cbor"; + + #[cfg(not(feature = "metric_cbor"))] + const PAYLOAD_FORMAT: &'static str = "json"; + + pub fn format(&self, thing_name: &str) -> Result, Error> { + let mut topic_path = String::new(); + + match self { + Self::Accepted => topic_path.write_fmt(format_args!( + "{}/{}/{}/{}/accepted", + Self::PREFIX, + thing_name, + Self::NAME, + Self::PAYLOAD_FORMAT, + )), + Self::Rejected => topic_path.write_fmt(format_args!( + "{}/{}/{}/{}/rejected", + Self::PREFIX, + thing_name, + Self::NAME, + Self::PAYLOAD_FORMAT, + )), + Self::Publish => topic_path.write_fmt(format_args!( + "{}/{}/{}/{}", + Self::PREFIX, + thing_name, + Self::NAME, + Self::PAYLOAD_FORMAT, + )), + } + .map_err(|_| Error::Overflow)?; + + Ok(topic_path) + } + + pub fn from_str(s: &str) -> Option { + let tt = s.splitn(7, '/').collect::>(); + match (tt.first(), tt.get(1), tt.get(3), tt.get(4)) { + (Some(&"$aws"), Some(&"things"), Some(&"defender"), Some(&"metrics")) => { + // This is a defender metric topic, now figure out which one. + + match tt.get(6) { + Some(&"accepted") => Some(Topic::Accepted), + Some(&"rejected") => Some(Topic::Rejected), + _ => None, + } + } + _ => None, + } + } +} diff --git a/src/jobs/describe.rs b/src/jobs/describe.rs index 5c59d5c..81579c7 100644 --- a/src/jobs/describe.rs +++ b/src/jobs/describe.rs @@ -89,7 +89,7 @@ impl<'a> Describe<'a> { let payload_len = serde_json_core::to_slice( &DescribeJobExecutionRequest { execution_number: self.execution_number, - include_job_document: self.include_job_document.then(|| true), + include_job_document: self.include_job_document.then_some(true), client_token: self.client_token, }, buf, diff --git a/src/jobs/subscribe.rs b/src/jobs/subscribe.rs index cf796eb..4fc604b 100644 --- a/src/jobs/subscribe.rs +++ b/src/jobs/subscribe.rs @@ -18,7 +18,7 @@ pub enum Topic<'a> { impl<'a> Topic<'a> { pub fn from_str(s: &'a str) -> Option { let tt = s.splitn(8, '/').collect::>(); - Some(match (tt.get(0), tt.get(1), tt.get(2), tt.get(3)) { + Some(match (tt.first(), tt.get(1), tt.get(2), tt.get(3)) { (Some(&"$aws"), Some(&"things"), _, Some(&"jobs")) => { // This is a job topic! Figure out which match (tt.get(4), tt.get(5), tt.get(6), tt.get(7)) { diff --git a/src/jobs/update.rs b/src/jobs/update.rs index 3875bd0..867d0ac 100644 --- a/src/jobs/update.rs +++ b/src/jobs/update.rs @@ -146,9 +146,9 @@ impl<'a> Update<'a> { let payload_len = serde_json_core::to_slice( &UpdateJobExecutionRequest { execution_number: self.execution_number, - include_job_document: self.include_job_document.then(|| true), + include_job_document: self.include_job_document.then_some(true), expected_version: self.expected_version, - include_job_execution_state: self.include_job_execution_state.then(|| true), + include_job_execution_state: self.include_job_execution_state.then_some(true), status: self.status, status_details: self.status_details, step_timeout_in_minutes: self.step_timeout_in_minutes, diff --git a/src/lib.rs b/src/lib.rs index f5d758f..abda902 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,6 +6,7 @@ // This mod MUST go first, so that the others see its macros. pub(crate) mod fmt; +pub mod defender_metrics; pub mod jobs; #[cfg(any(feature = "ota_mqtt_data", feature = "ota_http_data"))] pub mod ota; diff --git a/src/ota/control_interface/mqtt.rs b/src/ota/control_interface/mqtt.rs index 601ad68..fdc8a74 100644 --- a/src/ota/control_interface/mqtt.rs +++ b/src/ota/control_interface/mqtt.rs @@ -1,15 +1,13 @@ use core::fmt::Write; use embassy_sync::blocking_mutex::raw::RawMutex; -use embassy_time::with_timeout; -use embedded_mqtt::{DeferredPayload, EncodingError, Publish, QoS, Subscribe, SubscribeTopic}; +use embedded_mqtt::{DeferredPayload, EncodingError, Publish, QoS}; use super::ControlInterface; -use crate::jobs::data_types::{ErrorResponse, JobStatus, UpdateJobExecutionResponse}; -use crate::jobs::{JobError, JobTopic, Jobs, MAX_JOB_ID_LEN, MAX_THING_NAME_LEN}; -use crate::ota::config::Config; +use crate::jobs::data_types::JobStatus; +use crate::jobs::{JobTopic, Jobs, MAX_JOB_ID_LEN, MAX_THING_NAME_LEN}; use crate::ota::encoding::json::JobStatusReason; -use crate::ota::encoding::{self, FileContext}; +use crate::ota::encoding::FileContext; use crate::ota::error::OtaError; use crate::ota::ProgressState; @@ -53,7 +51,7 @@ impl<'a, M: RawMutex> ControlInterface for embedded_mqtt::MqttClient<'a, M> { ) .map_err(|_| OtaError::Overflow)?; - let mut qos = QoS::AtLeastOnce; + let qos = QoS::AtLeastOnce; if let JobStatus::InProgress | JobStatus::Succeeded = status { let received_blocks = progress_state.total_blocks - progress_state.blocks_remaining; diff --git a/src/ota/data_interface/mqtt.rs b/src/ota/data_interface/mqtt.rs index 4afc2fb..e1e1c5e 100644 --- a/src/ota/data_interface/mqtt.rs +++ b/src/ota/data_interface/mqtt.rs @@ -58,7 +58,7 @@ pub enum Topic<'a> { impl<'a> Topic<'a> { pub fn from_str(s: &'a str) -> Option { let tt = s.splitn(8, '/').collect::>(); - Some(match (tt.get(0), tt.get(1), tt.get(2), tt.get(3)) { + Some(match (tt.first(), tt.get(1), tt.get(2), tt.get(3)) { (Some(&"$aws"), Some(&"things"), _, Some(&"streams")) => { // This is a stream topic! Figure out which match (tt.get(4), tt.get(5), tt.get(6), tt.get(7)) { diff --git a/src/ota/encoding/mod.rs b/src/ota/encoding/mod.rs index b0349ed..a55f972 100644 --- a/src/ota/encoding/mod.rs +++ b/src/ota/encoding/mod.rs @@ -134,7 +134,7 @@ impl FileContext { }) .collect() }) - .unwrap_or_else(|| StatusDetailsOwned::new()), + .unwrap_or_default(), job_name: heapless::String::try_from(job_data.job_name).unwrap(), block_offset, diff --git a/src/ota/mod.rs b/src/ota/mod.rs index 10b5b15..808d3d9 100644 --- a/src/ota/mod.rs +++ b/src/ota/mod.rs @@ -60,7 +60,7 @@ impl Updater { }); // Create the JobUpdater - let mut job_updater = JobUpdater::new(&file_ctx, &progress_state, &config, control); + let mut job_updater = JobUpdater::new(&file_ctx, &progress_state, config, control); match job_updater.initialize::(pal).await? { Some(()) => {} @@ -70,7 +70,7 @@ impl Updater { info!("Job document was accepted. Attempting to begin the update"); // Spawn the request momentum future - let momentum_fut = Self::handle_momentum(data, &config, &file_ctx, &progress_state); + let momentum_fut = Self::handle_momentum(data, config, &file_ctx, &progress_state); // Spawn the status update future let status_update_fut = job_updater.handle_status_updates(); @@ -100,7 +100,7 @@ impl Updater { { let mut progress = progress_state.lock().await; - data.request_file_blocks(&file_ctx, &mut progress, &config) + data.request_file_blocks(&file_ctx, &mut progress, config) .await?; } @@ -116,7 +116,7 @@ impl Updater { match Self::ingest_data_block( data, &mut block_writer, - &config, + config, &mut progress, payload.deref_mut(), ) @@ -168,7 +168,7 @@ impl Updater { if progress.request_block_remaining > 1 { progress.request_block_remaining -= 1; } else { - data.request_file_blocks(&file_ctx, &mut progress, &config) + data.request_file_blocks(&file_ctx, &mut progress, config) .await?; } } @@ -442,7 +442,7 @@ impl<'a, C: ControlInterface> JobUpdater<'a, C> { let mut progress = self.progress_state.lock().await; self.control .update_job_status( - &self.file_ctx, + self.file_ctx, &mut progress, JobStatus::Succeeded, JobStatusReason::Accepted, @@ -542,7 +542,7 @@ impl<'a, C: ControlInterface> JobUpdater<'a, C> { // in self_test active self.control .update_job_status( - &self.file_ctx, + self.file_ctx, &mut progress, JobStatus::InProgress, JobStatusReason::SelfTestActive, @@ -554,7 +554,7 @@ impl<'a, C: ControlInterface> JobUpdater<'a, C> { // complete the job self.control .update_job_status( - &self.file_ctx, + self.file_ctx, &mut progress, JobStatus::Succeeded, JobStatusReason::Accepted, @@ -568,7 +568,7 @@ impl<'a, C: ControlInterface> JobUpdater<'a, C> { self.control .update_job_status( - &self.file_ctx, + self.file_ctx, &mut progress, JobStatus::Failed, JobStatusReason::Rejected, @@ -582,7 +582,7 @@ impl<'a, C: ControlInterface> JobUpdater<'a, C> { self.control .update_job_status( - &self.file_ctx, + self.file_ctx, &mut progress, JobStatus::Failed, JobStatusReason::Aborted, @@ -607,7 +607,7 @@ impl<'a, C: ControlInterface> JobUpdater<'a, C> { let mut progress = self.progress_state.lock().await; self.control - .update_job_status(&self.file_ctx, &mut progress, status, reason) + .update_job_status(self.file_ctx, &mut progress, status, reason) .await?; Ok(()) } diff --git a/src/provisioning/mod.rs b/src/provisioning/mod.rs index 3327762..9e456f2 100644 --- a/src/provisioning/mod.rs +++ b/src/provisioning/mod.rs @@ -6,7 +6,7 @@ use core::future::Future; use embassy_sync::blocking_mutex::raw::RawMutex; use embedded_mqtt::{ - BufferProvider, DeferredPayload, EncodingError, Message, Publish, Subscribe, SubscribeTopic, + DeferredPayload, EncodingError, Publish, Subscribe, SubscribeTopic, Subscription, }; use futures::StreamExt; @@ -191,7 +191,7 @@ impl FleetProvisioner { }; let register_request = RegisterThingRequest { - certificate_ownership_token: &ownership_token, + certificate_ownership_token: ownership_token, parameters, }; diff --git a/src/provisioning/topics.rs b/src/provisioning/topics.rs index 9256c55..bef6710 100644 --- a/src/provisioning/topics.rs +++ b/src/provisioning/topics.rs @@ -96,7 +96,7 @@ impl<'a> Topic<'a> { pub fn from_str(s: &'a str) -> Option { let tt = s.splitn(6, '/').collect::>(); - match (tt.get(0), tt.get(1)) { + match (tt.first(), tt.get(1)) { (Some(&"$aws"), Some(&"provisioning-templates")) => { // This is a register thing topic, now figure out which one. @@ -107,7 +107,7 @@ impl<'a> Topic<'a> { Some(payload_format), Some(&"accepted"), ) => Some(Topic::RegisterThingAccepted( - *template_name, + template_name, PayloadFormat::from_str(payload_format).ok()?, )), ( @@ -116,7 +116,7 @@ impl<'a> Topic<'a> { Some(payload_format), Some(&"rejected"), ) => Some(Topic::RegisterThingRejected( - *template_name, + template_name, PayloadFormat::from_str(payload_format).ok()?, )), _ => None, diff --git a/src/shadows/mod.rs b/src/shadows/mod.rs index e717a6e..99bddfb 100644 --- a/src/shadows/mod.rs +++ b/src/shadows/mod.rs @@ -1,10 +1,10 @@ pub mod dao; pub mod data_types; -mod error; +pub mod error; mod shadow_diff; pub mod topics; -use core::{marker::PhantomData, ops::DerefMut, sync::atomic}; +use core::{marker::PhantomData, ops::DerefMut}; pub use data_types::Patch; use embassy_sync::{ @@ -118,7 +118,7 @@ where }; let payload = DeferredPayload::new( - |buf| { + |buf: &mut [u8]| { serde_json_core::to_slice(&request, buf) .map_err(|_| embedded_mqtt::EncodingError::BufferSize) }, diff --git a/src/shadows/topics.rs b/src/shadows/topics.rs index 34642d0..b572675 100644 --- a/src/shadows/topics.rs +++ b/src/shadows/topics.rs @@ -42,11 +42,11 @@ impl Topic { pub fn from_str(s: &str) -> Option<(Self, &str, Option<&str>)> { let tt = s.splitn(9, '/').collect::>(); - match (tt.get(0), tt.get(1), tt.get(2), tt.get(3)) { + match (tt.first(), tt.get(1), tt.get(2), tt.get(3)) { (Some(&"$aws"), Some(&"things"), Some(thing_name), Some(&Self::SHADOW)) => { // This is a shadow topic, now figure out which one. let (shadow_name, next_index) = if let Some(&"name") = tt.get(4) { - (tt.get(5).map(|s| *s), 6) + (tt.get(5).copied(), 6) } else { (None, 4) }; @@ -239,7 +239,7 @@ impl Subscribe { self.topics .iter() - .map(|(topic, qos)| Ok((Topic::from(*topic).format(thing_name, shadow_name)?, *qos))) + .map(|(topic, qos)| Ok(((*topic).format(thing_name, shadow_name)?, *qos))) .collect() } } @@ -278,7 +278,7 @@ impl Unsubscribe { self.topics .iter() - .map(|topic| Topic::from(*topic).format(thing_name, shadow_name)) + .map(|topic| (*topic).format(thing_name, shadow_name)) .collect() } } diff --git a/tests/common/network.rs b/tests/common/network.rs index 0cfe3db..d6a2d09 100644 --- a/tests/common/network.rs +++ b/tests/common/network.rs @@ -1,8 +1,8 @@ -use std::net::SocketAddr; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; use ::native_tls::Identity; use embedded_io_adapters::tokio_1::FromTokio; -use embedded_nal_async::{AddrType, Dns, IpAddr, Ipv4Addr, Ipv6Addr, TcpConnect}; +use embedded_nal_async::{AddrType, Dns, TcpConnect}; use tokio_native_tls::native_tls; use super::credentials; @@ -25,7 +25,7 @@ impl TcpConnect for Network { async fn connect<'a>( &'a self, - remote: embedded_nal_async::SocketAddr, + remote: SocketAddr, ) -> Result, Self::Error> { let stream = tokio::net::TcpStream::connect(format!("{}", remote)).await?; Ok(FromTokio::new(stream)) @@ -86,7 +86,7 @@ impl TcpConnect for TlsNetwork { async fn connect<'a>( &'a self, - remote: embedded_nal_async::SocketAddr, + remote: SocketAddr, ) -> Result, Self::Error> { log::info!("Connecting to {:?}", remote); let connector = tokio_native_tls::TlsConnector::from( diff --git a/tests/metric.rs b/tests/metric.rs new file mode 100644 index 0000000..4115c17 --- /dev/null +++ b/tests/metric.rs @@ -0,0 +1,119 @@ +//! ## Integration test of `AWS IoT Device defender metrics` +//! +//! +//! This test simulates publishing of metrics and expects a accepted response from aws +//! +//! The test runs through the following update sequence: +//! 1. Setup metric state +//! 2. Assert json format +//! 2. Publish metric +//! 3. Assert result from AWS + +mod common; + +use std::str::FromStr; + +use common::credentials; +use common::network::TlsNetwork; +use embassy_futures::select; +use embassy_sync::blocking_mutex::raw::NoopRawMutex; +use embedded_mqtt::{ + self, transport::embedded_nal::NalTransport, Config, DomainBroker, Publish, State, Subscribe, +}; +use futures::StreamExt; +use heapless::LinearMap; +use rustot::{ + defender_metrics::{ + data_types::{CustomMetric, Metric}, + MetricHandler, + }, + shadows::{derive::ShadowState, Shadow, ShadowState}, +}; +use serde::{Deserialize, Serialize}; +use serde_json::json; +use static_cell::StaticCell; + +fn assert_json_format<'a>(json: &'a str) { + log::debug!("{json}"); + let format = "{\"hed\":{\"rid\":0,\"v\":\"1.0\"},\"met\":null,\"cmet\":{\"MyMetricOfType_Number\":[{\"number\":1}],\"MyMetricOfType_NumberList\":[{\"number_list\":[1,2,3]}],\"MyMetricOfType_StringList\":[{\"string_list\":[\"value_1\",\"value_2\"]}],\"MyMetricOfType_IpList\":[{\"ip_list\":[\"172.0.0.0\",\"172.0.0.10\"]}]}}"; + + assert_eq!(json, format); +} + +#[tokio::test(flavor = "current_thread")] +async fn test_publish_metric() { + env_logger::init(); + + let (thing_name, identity) = credentials::identity(); + let hostname = credentials::HOSTNAME.unwrap(); + + static NETWORK: StaticCell = StaticCell::new(); + let network = NETWORK.init(TlsNetwork::new(hostname.to_owned(), identity)); + + // Create the MQTT stack + let broker = + DomainBroker::<_, 128>::new(format!("{}:8883", hostname).as_str(), network).unwrap(); + + let config = Config::builder() + .client_id(thing_name.try_into().unwrap()) + .keepalive_interval(embassy_time::Duration::from_secs(50)) + .build(); + + static STATE: StaticCell> = StaticCell::new(); + let state = STATE.init(State::new()); + let (mut stack, client) = embedded_mqtt::new(state, config); + + // Define metrics + let mut custom_metrics: LinearMap = LinearMap::new(); + + custom_metrics + .insert( + String::from_str("MyMetricOfType_Number").unwrap(), + [CustomMetric::Number(1)], + ) + .unwrap(); + + custom_metrics + .insert( + String::from_str("MyMetricOfType_NumberList").unwrap(), + [CustomMetric::NumberList(&[1, 2, 3])], + ) + .unwrap(); + + custom_metrics + .insert( + String::from_str("MyMetricOfType_StringList").unwrap(), + [CustomMetric::StringList(&["value_1", "value_2"])], + ) + .unwrap(); + + custom_metrics + .insert( + String::from_str("MyMetricOfType_IpList").unwrap(), + [CustomMetric::IpList(&["172.0.0.0", "172.0.0.10"])], + ) + .unwrap(); + + // Build metric + let mut metric = Metric::builder() + .custom_metrics(custom_metrics) + .header(Default::default()) + .build(); + + // Test the json format + let json = serde_json::to_string(&metric).unwrap(); + + assert_json_format(&json); + + let mut metric_handler = MetricHandler::new(&client); + + // Publish metric with mqtt + let mqtt_fut = async { assert!(metric_handler.publish_metric(metric, 2000).await.is_ok()) }; + + let mut transport = NalTransport::new(network, broker); + let _ = embassy_time::with_timeout( + embassy_time::Duration::from_secs(60), + select::select(stack.run(&mut transport), mqtt_fut), + ) + .await; +}