diff --git a/.vscode/settings.json b/.vscode/settings.json index 48fb5ea..c60380a 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,5 +1,6 @@ { - "rust-analyzer.checkOnSave.allTargets": false, + "rust-analyzer.checkOnSave.allTargets": true, "rust-analyzer.cargo.features": ["log"], - "rust-analyzer.cargo.target": "x86_64-unknown-linux-gnu" + "rust-analyzer.cargo.target": "x86_64-unknown-linux-gnu", + } \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index c3e36cd..5b1b7a9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,6 +43,7 @@ embassy-futures = "0.1" log = { version = "0.4", default-features = false, optional = true } defmt = { version = "0.3", optional = true } +bon = { version = "3.3.2", default-features = false } [dev-dependencies] native-tls = { version = "0.2" } @@ -62,7 +63,7 @@ tokio = { version = "1.33", default-features = false, features = [ ] } tokio-native-tls = { version = "0.3.1" } embassy-futures = { version = "0.1.0" } -embassy-time = { version = "0.4", features = ["log", "std", "generic-queue"] } +embassy-time = { version = "0.4", features = ["log", "std", "generic-queue-8"] } embedded-io-adapters = { version = "0.6.0", features = ["tokio-1"] } ecdsa = { version = "0.16", features = ["pkcs8", "pem"] } @@ -70,8 +71,11 @@ p256 = "0.13" pkcs8 = { version = "0.10", features = ["encryption", "pem"] } hex = { version = "0.4.3", features = ["alloc"] } + [features] -default = ["ota_mqtt_data", "provision_cbor"] +default = ["ota_mqtt_data", "metric_cbor", "provision_cbor"] + +metric_cbor = ["dep:minicbor", "dep:minicbor-serde"] provision_cbor = ["dep:minicbor", "dep:minicbor-serde"] @@ -89,5 +93,6 @@ defmt = [ ] log = ["dep:log", "embedded-mqtt/log"] -# [patch."ssh://git@github.com/FactbirdHQ/embedded-mqtt"] -# embedded-mqtt = { path = "../embedded-mqtt" } + +[patch."ssh://git@github.com/FactbirdHQ/embedded-mqtt"] +embedded-mqtt = { path = "../embedded-mqtt" } diff --git a/rust-toolchain.toml b/rust-toolchain.toml index b6369b9..ed37409 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,5 +1,5 @@ [toolchain] -channel = "nightly-2024-07-17" +channel = "nightly-2024-09-06" components = ["rust-src", "rustfmt", "llvm-tools"] targets = [ "x86_64-unknown-linux-gnu", diff --git a/shadow_derive/src/lib.rs b/shadow_derive/src/lib.rs index 09cdef0..06da9f5 100644 --- a/shadow_derive/src/lib.rs +++ b/shadow_derive/src/lib.rs @@ -118,8 +118,7 @@ fn create_assigners(fields: &Vec) -> Vec { if field .attrs .iter() - .find(|a| a.path.is_ident("static_shadow_field")) - .is_some() + .any(|a| a.path.is_ident("static_shadow_field")) { None } else { @@ -133,7 +132,7 @@ fn create_assigners(fields: &Vec) -> Vec { .collect::>() } -fn create_optional_fields(fields: &Vec) -> Vec { +fn create_optional_fields(fields: &[Field]) -> Vec { fields .iter() .filter_map(|field| { @@ -153,8 +152,7 @@ fn create_optional_fields(fields: &Vec) -> Vec if field .attrs .iter() - .find(|a| a.path.is_ident("static_shadow_field")) - .is_some() + .any(|a| a.path.is_ident("static_shadow_field")) { None } else { @@ -183,13 +181,13 @@ fn generate_shadow_state(input: &StructParseInput) -> proc_macro2::TokenStream { None => quote! { None }, }; - return quote! { + quote! { #[automatically_derived] impl #impl_generics rustot::shadows::ShadowState for #ident #ty_generics #where_clause { const NAME: Option<&'static str> = #name; // const MAX_PAYLOAD_SIZE: usize = 512; } - }; + } } fn generate_shadow_patch_struct(input: &StructParseInput) -> proc_macro2::TokenStream { @@ -205,10 +203,10 @@ fn generate_shadow_patch_struct(input: &StructParseInput) -> proc_macro2::TokenS let optional_ident = format_ident!("Patch{}", ident); - let assigners = create_assigners(&shadow_fields); - let optional_fields = create_optional_fields(&shadow_fields); + let assigners = create_assigners(shadow_fields); + let optional_fields = create_optional_fields(shadow_fields); - return quote! { + quote! { #[automatically_derived] #[derive(Default, Clone, ::serde::Deserialize, ::serde::Serialize)] #(#copy_attrs)* @@ -228,7 +226,7 @@ fn generate_shadow_patch_struct(input: &StructParseInput) -> proc_macro2::TokenS )* } } - }; + } } fn generate_shadow_patch_enum(input: &EnumParseInput) -> proc_macro2::TokenStream { @@ -238,7 +236,7 @@ fn generate_shadow_patch_enum(input: &EnumParseInput) -> proc_macro2::TokenStrea let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); - return quote! { + quote! { #[automatically_derived] impl #impl_generics rustot::shadows::ShadowPatch for #ident #ty_generics #where_clause { type PatchState = #ident #ty_generics; @@ -247,5 +245,5 @@ fn generate_shadow_patch_enum(input: &EnumParseInput) -> proc_macro2::TokenStrea *self = opt; } } - }; + } } diff --git a/src/defender_metrics/aws_types.rs b/src/defender_metrics/aws_types.rs new file mode 100644 index 0000000..484a5a0 --- /dev/null +++ b/src/defender_metrics/aws_types.rs @@ -0,0 +1,81 @@ +use serde::Serialize; + +#[derive(Debug, Serialize)] +pub struct TcpConnections<'a> { + #[serde(rename = "ec")] + pub established_connections: Option<&'a EstablishedConnections<'a>>, +} + +#[derive(Debug, Serialize)] +pub struct EstablishedConnections<'a> { + #[serde(rename = "cs")] + pub connections: Option<&'a [&'a Connection<'a>]>, + + #[serde(rename = "t")] + pub total: Option, +} + +#[derive(Debug, Serialize)] +pub struct Connection<'a> { + #[serde(rename = "rad")] + pub remote_addr: &'a str, + + /// Port number, must be >= 0 + #[serde(rename = "lp")] + pub local_port: Option, + + /// Interface name + #[serde(rename = "li")] + pub local_interface: Option<&'a str>, +} + +#[derive(Debug, Serialize)] +pub struct ListeningTcpPorts<'a> { + #[serde(rename = "pts")] + pub ports: Option<&'a [&'a TcpPort<'a>]>, + + #[serde(rename = "t")] + pub total: Option, +} + +#[derive(Debug, Serialize)] +pub struct TcpPort<'a> { + #[serde(rename = "pt")] + pub port: u16, + + #[serde(rename = "if")] + pub interface: Option<&'a str>, +} + +#[derive(Debug, Serialize)] +pub struct ListeningUdpPorts<'a> { + #[serde(rename = "pts")] + pub ports: Option<&'a [&'a UdpPort<'a>]>, + + #[serde(rename = "t")] + pub total: Option, +} + +#[derive(Debug, Serialize)] +pub struct UdpPort<'a> { + #[serde(rename = "pt")] + pub port: u16, + + #[serde(rename = "if")] + pub interface: Option<&'a str>, +} + +#[derive(Debug, Serialize)] +pub struct NetworkStats { + #[serde(rename = "bi")] + pub bytes_in: Option, + + #[serde(rename = "bo")] + pub bytes_out: Option, + + #[serde(rename = "pi")] + pub packets_in: Option, + + #[serde(rename = "po")] + pub packets_out: Option, +} diff --git a/src/defender_metrics/data_types.rs b/src/defender_metrics/data_types.rs new file mode 100644 index 0000000..2f19cd9 --- /dev/null +++ b/src/defender_metrics/data_types.rs @@ -0,0 +1,84 @@ +use core::fmt::{Display, Write}; + +use bon::Builder; +use embassy_time::Instant; +use serde::{ser::SerializeStruct, Deserialize, Serialize}; + +use super::aws_types::{ListeningTcpPorts, ListeningUdpPorts, NetworkStats, TcpConnections}; + +#[derive(Debug, Serialize, Builder)] +pub struct Metric<'a, C: Serialize> { + #[serde(rename = "hed")] + pub header: Header, + + #[serde(rename = "met")] + pub metrics: Option>, + + #[serde(rename = "cmet")] + pub custom_metrics: Option, +} + +#[derive(Debug, Serialize)] +pub struct Metrics<'a> { + listening_tcp_ports: Option>, + listening_udp_ports: Option>, + network_stats: Option, + tcp_connections: Option>, +} + +#[derive(Debug, Serialize)] +pub struct Header { + /// Monotonically increasing value. Epoch timestamp recommended. + #[serde(rename = "rid")] + pub report_id: i64, + + /// Version in Major.Minor format. + #[serde(rename = "v")] + pub version: Version, +} + +impl Default for Header { + fn default() -> Self { + Self { + report_id: Instant::now().as_millis() as i64, + version: Default::default(), + } + } +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "snake_case")] +pub enum CustomMetric<'a> { + Number(i64), + NumberList(&'a [u64]), + StringList(&'a [&'a str]), + IpList(&'a [&'a str]), +} + +/// Format is `Version(Major, Minor)` +#[derive(Debug, PartialEq, Deserialize)] +pub struct Version(pub u8, pub u8); + +impl Serialize for Version { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let mut st: heapless::String<7> = heapless::String::new(); + st.write_fmt(format_args!("{}.{}", self.0, self.1)).ok(); + + serializer.serialize_str(&st) + } +} + +impl Display for Version { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(f, "{}.{}", self.0, self.1,) + } +} + +impl Default for Version { + fn default() -> Self { + Self(1, 0) + } +} diff --git a/src/defender_metrics/errors.rs b/src/defender_metrics/errors.rs new file mode 100644 index 0000000..5289e14 --- /dev/null +++ b/src/defender_metrics/errors.rs @@ -0,0 +1,29 @@ +use serde::Deserialize; + +#[derive(Debug, Deserialize)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct ErrorResponse<'a> { + #[serde(rename = "thingName")] + pub thing_name: &'a str, + pub status: &'a str, + #[serde(rename = "statusDetails")] + pub status_details: StatusDetails<'a>, + pub timestamp: i64, +} +#[derive(Debug, Deserialize)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct StatusDetails<'a> { + #[serde(rename = "ErrorCode")] + pub error_code: MetricError, + #[serde(rename = "ErrorMessage")] + pub error_message: Option<&'a str>, +} +#[derive(Debug, Deserialize)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum MetricError { + Malformed, + InvalidPayload, + Throttled, + MissingHeader, + Other, +} diff --git a/src/defender_metrics/mod.rs b/src/defender_metrics/mod.rs new file mode 100644 index 0000000..47aa3c0 --- /dev/null +++ b/src/defender_metrics/mod.rs @@ -0,0 +1,444 @@ +use crate::shadows::Error; +use data_types::Metric; +use embassy_sync::blocking_mutex::raw::RawMutex; +use embedded_mqtt::{DeferredPayload, Publish, Subscribe, SubscribeTopic, ToPayload}; +use errors::{ErrorResponse, MetricError}; +use futures::StreamExt; +use serde::Serialize; +use topics::Topic; + +// pub mod aws_types; +pub mod aws_types; +pub mod data_types; +pub mod errors; +pub mod topics; + +pub struct MetricHandler<'a, 'm, M: RawMutex> { + mqtt: &'m embedded_mqtt::MqttClient<'a, M>, +} + +impl<'a, 'm, M: RawMutex> MetricHandler<'a, 'm, M> { + pub fn new(mqtt: &'m embedded_mqtt::MqttClient<'a, M>) -> Self { + Self { mqtt } + } + + pub async fn publish_metric<'c, C: Serialize>( + &self, + metric: Metric<'c, C>, + max_payload_size: usize, + ) -> Result<(), MetricError> { + //Wait for mqtt to connect + self.mqtt.wait_connected().await; + + let payload = DeferredPayload::new( + |buf: &mut [u8]| { + #[cfg(feature = "metric_cbor")] + { + let mut serializer = minicbor_serde::Serializer::new( + minicbor::encode::write::Cursor::new(&mut *buf), + ); + + match metric.serialize(&mut serializer) { + Ok(_) => {} + Err(_) => { + error!("An error happened when serializing metric with cbor"); + return Err(embedded_mqtt::EncodingError::BufferSize); + } + }; + + Ok(serializer.into_encoder().writer().position()) + } + + #[cfg(not(feature = "metric_cbor"))] + { + serde_json_core::to_slice(&metric, buf) + .map_err(|_| embedded_mqtt::EncodingError::BufferSize) + } + }, + max_payload_size, + ); + + let mut subscription = self + .publish_and_subscribe(payload) + .await + .map_err(|_| MetricError::Other)?; + + loop { + let message = subscription.next().await.ok_or(MetricError::Malformed)?; + + match Topic::from_str(message.topic_name()) { + Some(Topic::Accepted) => return Ok(()), + Some(Topic::Rejected) => { + let error_response = + serde_json_core::from_slice::(message.payload()) + .map_err(|_| MetricError::InvalidPayload)?; + + return Err(error_response.0.status_details.error_code); + } + + _ => (), + }; + } + } + async fn publish_and_subscribe( + &self, + payload: impl ToPayload, + ) -> Result, Error> { + let sub = self + .mqtt + .subscribe::<2>( + Subscribe::builder() + .topics(&[ + SubscribeTopic::builder() + .topic_path( + Topic::Accepted + .format::<64>(self.mqtt.client_id())? + .as_str(), + ) + .build(), + SubscribeTopic::builder() + .topic_path( + Topic::Rejected + .format::<64>(self.mqtt.client_id())? + .as_str(), + ) + .build(), + ]) + .build(), + ) + .await + .map_err(Error::MqttError)?; + + //*** PUBLISH REQUEST ***/ + let topic_name = Topic::Publish.format::<64>(self.mqtt.client_id())?; + + match self + .mqtt + .publish( + Publish::builder() + .topic_name(topic_name.as_str()) + .payload(payload) + .build(), + ) + .await + .map_err(Error::MqttError) + { + Ok(_) => {} + Err(_) => { + error!("ERROR PUBLISHING PAYLOAD"); + return Err(Error::MqttError(embedded_mqtt::Error::BadTopicFilter)); + } + }; + + Ok(sub) + } +} + +#[cfg(test)] +mod tests { + use core::str::FromStr; + + use super::data_types::*; + + use heapless::{LinearMap, String}; + use serde::{ser::SerializeStruct, Serialize}; + + #[test] + fn serialize_version_json() { + let test_cases = [ + (Version(2, 0), "\"2.0\""), + (Version(0, 0), "\"0.0\""), + (Version(0, 1), "\"0.1\""), + (Version(255, 200), "\"255.200\""), + ]; + + for (version, expected) in test_cases.iter() { + let string: String<100> = serde_json_core::to_string(version).unwrap(); + assert_eq!( + string, *expected, + "Serialization failed for Version({}, {}): expected {}, got {}", + version.0, version.1, expected, string + ); + } + } + #[test] + fn serialize_version_cbor() { + let test_cases: [(Version, [u8; 8]); 4] = [ + (Version(2, 0), [99, 50, 46, 48, 0, 0, 0, 0]), + (Version(0, 0), [99, 48, 46, 48, 0, 0, 0, 0]), + (Version(0, 1), [99, 48, 46, 49, 0, 0, 0, 0]), + (Version(255, 200), [103, 50, 53, 53, 46, 50, 48, 48]), + ]; + + for (version, expected) in test_cases.iter() { + let mut buf = [0u8; 200]; + + let mut serializer = + minicbor_serde::Serializer::new(minicbor::encode::write::Cursor::new(&mut buf[..])); + + version.serialize(&mut serializer).unwrap(); + + let len = serializer.into_encoder().writer().position(); + + assert_eq!( + &buf[..len], + &expected[..len], + "Serialization failed for Version({}, {}): expected {:?}, got {:?}", + version.0, + version.1, + expected, + &buf[..len], + ); + } + } + + #[test] + fn custom_serialization_cbor() { + #[derive(Debug)] + struct WifiMetric { + signal_strength: u8, + } + + impl Serialize for WifiMetric { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let mut outer = serializer.serialize_struct("WifiMetricWrapper", 1)?; + + // Define the type we want to wrap our signal_strength field in + #[derive(Serialize)] + struct Number { + number: u8, + } + + let number = Number { + number: self.signal_strength, + }; + + // Serialize number and wrap in array + outer.serialize_field("MyMetricOfType_Number", &[number])?; + outer.end() + } + } + + let custom_metrics: WifiMetric = WifiMetric { + signal_strength: 23, + }; + + let metric = Metric::builder() + .header(Default::default()) + .custom_metrics(custom_metrics) + .build(); + + let mut buf = [255u8; 1000]; + + let mut serializer = + minicbor_serde::Serializer::new(minicbor::encode::write::Cursor::new(&mut buf[..])); + + metric.serialize(&mut serializer).unwrap(); + + let len = serializer.into_encoder().writer().position(); + + assert_eq!( + &buf[..len], + [ + 163, 99, 104, 101, 100, 162, 99, 114, 105, 100, 0, 97, 118, 99, 49, 46, 48, 99, + 109, 101, 116, 246, 100, 99, 109, 101, 116, 161, 117, 77, 121, 77, 101, 116, 114, + 105, 99, 79, 102, 84, 121, 112, 101, 95, 78, 117, 109, 98, 101, 114, 129, 161, 102, + 110, 117, 109, 98, 101, 114, 23 + ] + ) + } + + #[test] + fn custom_serialization() { + #[derive(Debug)] + struct WifiMetric { + signal_strength: u8, + } + + impl Serialize for WifiMetric { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let mut outer = serializer.serialize_struct("WifiMetricWrapper", 1)?; + + // Define the type we want to wrap our signal_strength field in + #[derive(Serialize)] + struct Number { + number: u8, + } + + let number = Number { + number: self.signal_strength, + }; + + // Serialize number and wrap in array + outer.serialize_field("MyMetricOfType_Number", &[number])?; + outer.end() + } + } + + let custom_metrics: WifiMetric = WifiMetric { + signal_strength: 23, + }; + + let metric = Metric::builder() + .header(Default::default()) + .custom_metrics(custom_metrics) + .build(); + + let payload: String<4000> = serde_json_core::to_string(&metric).unwrap(); + + assert_eq!("{\"hed\":{\"rid\":0,\"v\":\"1.0\"},\"met\":null,\"cmet\":{\"MyMetricOfType_Number\":[{\"number\":23}]}}", payload.as_str()) + } + #[test] + fn custom_serialization_string_list() { + #[derive(Debug)] + struct CellType { + cell_type: String, + } + + impl Serialize for CellType { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let mut outer = serializer.serialize_struct("CellType", 1)?; + + // Define the type we want to wrap our signal_strength field in + #[derive(Serialize)] + struct StringList<'a> { + string_list: &'a [&'a str], + } + + let list = StringList { + string_list: &[&self.cell_type.as_str()], + }; + + // Serialize number and wrap in array + outer.serialize_field("cell_type", &[list])?; + outer.end() + } + } + + let custom_metrics: CellType<4> = CellType { + cell_type: String::from_str("gsm").unwrap(), + }; + + let metric = Metric::builder() + .header(Default::default()) + .custom_metrics(custom_metrics) + .build(); + + let payload: String<4000> = serde_json_core::to_string(&metric).unwrap(); + + assert_eq!("{\"hed\":{\"rid\":0,\"v\":\"1.0\"},\"met\":null,\"cmet\":{\"cell_type\":[{\"string_list\":[\"gsm\"]}]}}", payload.as_str()) + } + #[test] + fn number() { + let mut custom_metrics: LinearMap, [CustomMetric; 1], 16> = LinearMap::new(); + + let name_of_metric = String::from_str("myMetric").unwrap(); + + custom_metrics + .insert(name_of_metric, [CustomMetric::Number(23)]) + .unwrap(); + + let metric = Metric::builder() + .header(Default::default()) + .custom_metrics(custom_metrics) + .build(); + + let payload: String<4000> = serde_json_core::to_string(&metric).unwrap(); + + assert_eq!("{\"hed\":{\"rid\":0,\"v\":\"1.0\"},\"met\":null,\"cmet\":{\"myMetric\":[{\"number\":23}]}}", payload.as_str()) + } + + #[test] + fn number_list() { + let mut custom_metrics: LinearMap, [CustomMetric; 1], 16> = LinearMap::new(); + + // NUMBER LIST + let my_number_list = String::from_str("my_number_list").unwrap(); + + custom_metrics + .insert(my_number_list, [CustomMetric::NumberList(&[123, 456, 789])]) + .unwrap(); + + let metric = Metric::builder() + .header(Default::default()) + .custom_metrics(custom_metrics) + .build(); + + let payload: String<4000> = serde_json_core::to_string(&metric).unwrap(); + + assert_eq!("{\"hed\":{\"rid\":0,\"v\":\"1.0\"},\"met\":null,\"cmet\":{\"my_number_list\":[{\"number_list\":[123,456,789]}]}}", payload.as_str()) + } + + #[test] + fn string_list() { + let mut custom_metrics: LinearMap, [CustomMetric; 1], 16> = LinearMap::new(); + + // STRING LIST + let my_string_list = String::from_str("my_string_list").unwrap(); + + custom_metrics + .insert( + my_string_list, + [CustomMetric::StringList(&["value_1", "value_2"])], + ) + .unwrap(); + + let metric = Metric::builder() + .header(Default::default()) + .custom_metrics(custom_metrics) + .build(); + + let payload: String<4000> = serde_json_core::to_string(&metric).unwrap(); + + assert_eq!("{\"hed\":{\"rid\":0,\"v\":\"1.0\"},\"met\":null,\"cmet\":{\"my_string_list\":[{\"string_list\":[\"value_1\",\"value_2\"]}]}}", payload.as_str()) + } + + #[test] + fn all_types() { + let mut custom_metrics: LinearMap, [CustomMetric; 1], 4> = LinearMap::new(); + + let my_number = String::from_str("MyMetricOfType_Number").unwrap(); + custom_metrics + .insert(my_number, [CustomMetric::Number(1)]) + .unwrap(); + + let my_number_list = String::from_str("MyMetricOfType_NumberList").unwrap(); + custom_metrics + .insert(my_number_list, [CustomMetric::NumberList(&[1, 2, 3])]) + .unwrap(); + + let my_string_list = String::from_str("MyMetricOfType_StringList").unwrap(); + custom_metrics + .insert( + my_string_list, + [CustomMetric::StringList(&["value_1", "value_2"])], + ) + .unwrap(); + + let my_ip_list = String::from_str("MyMetricOfType_IpList").unwrap(); + custom_metrics + .insert( + my_ip_list, + [CustomMetric::IpList(&["172.0.0.0", "172.0.0.10"])], + ) + .unwrap(); + + let metric = Metric::builder() + .header(Default::default()) + .custom_metrics(custom_metrics) + .build(); + + let payload: String<4000> = serde_json_core::to_string(&metric).unwrap(); + + assert_eq!("{\"hed\":{\"rid\":0,\"v\":\"1.0\"},\"met\":null,\"cmet\":{\"MyMetricOfType_Number\":[{\"number\":1}],\"MyMetricOfType_NumberList\":[{\"number_list\":[1,2,3]}],\"MyMetricOfType_StringList\":[{\"string_list\":[\"value_1\",\"value_2\"]}],\"MyMetricOfType_IpList\":[{\"ip_list\":[\"172.0.0.0\",\"172.0.0.10\"]}]}}", payload.as_str()) + } +} diff --git a/src/defender_metrics/topics.rs b/src/defender_metrics/topics.rs new file mode 100644 index 0000000..eba97f7 --- /dev/null +++ b/src/defender_metrics/topics.rs @@ -0,0 +1,78 @@ +#![allow(dead_code)] +use core::fmt::Write; + +use heapless::String; + +use crate::shadows::Error; + +pub enum PayloadFormat { + #[cfg(feature = "metric_cbor")] + Cbor, + #[cfg(not(feature = "metric_cbor"))] + Json, +} + +#[derive(PartialEq, Eq, Clone, Copy)] +pub enum Topic { + Accepted, + Rejected, + Publish, +} + +impl Topic { + const PREFIX: &'static str = "$aws/things"; + const NAME: &'static str = "defender/metrics"; + + #[cfg(feature = "metric_cbor")] + const PAYLOAD_FORMAT: &'static str = "cbor"; + + #[cfg(not(feature = "metric_cbor"))] + const PAYLOAD_FORMAT: &'static str = "json"; + + pub fn format(&self, thing_name: &str) -> Result, Error> { + let mut topic_path = String::new(); + + match self { + Self::Accepted => topic_path.write_fmt(format_args!( + "{}/{}/{}/{}/accepted", + Self::PREFIX, + thing_name, + Self::NAME, + Self::PAYLOAD_FORMAT, + )), + Self::Rejected => topic_path.write_fmt(format_args!( + "{}/{}/{}/{}/rejected", + Self::PREFIX, + thing_name, + Self::NAME, + Self::PAYLOAD_FORMAT, + )), + Self::Publish => topic_path.write_fmt(format_args!( + "{}/{}/{}/{}", + Self::PREFIX, + thing_name, + Self::NAME, + Self::PAYLOAD_FORMAT, + )), + } + .map_err(|_| Error::Overflow)?; + + Ok(topic_path) + } + + pub fn from_str(s: &str) -> Option { + let tt = s.splitn(7, '/').collect::>(); + match (tt.first(), tt.get(1), tt.get(3), tt.get(4)) { + (Some(&"$aws"), Some(&"things"), Some(&"defender"), Some(&"metrics")) => { + // This is a defender metric topic, now figure out which one. + + match tt.get(6) { + Some(&"accepted") => Some(Topic::Accepted), + Some(&"rejected") => Some(Topic::Rejected), + _ => None, + } + } + _ => None, + } + } +} diff --git a/src/jobs/describe.rs b/src/jobs/describe.rs index 5c59d5c..81579c7 100644 --- a/src/jobs/describe.rs +++ b/src/jobs/describe.rs @@ -89,7 +89,7 @@ impl<'a> Describe<'a> { let payload_len = serde_json_core::to_slice( &DescribeJobExecutionRequest { execution_number: self.execution_number, - include_job_document: self.include_job_document.then(|| true), + include_job_document: self.include_job_document.then_some(true), client_token: self.client_token, }, buf, diff --git a/src/jobs/subscribe.rs b/src/jobs/subscribe.rs index cf796eb..4fc604b 100644 --- a/src/jobs/subscribe.rs +++ b/src/jobs/subscribe.rs @@ -18,7 +18,7 @@ pub enum Topic<'a> { impl<'a> Topic<'a> { pub fn from_str(s: &'a str) -> Option { let tt = s.splitn(8, '/').collect::>(); - Some(match (tt.get(0), tt.get(1), tt.get(2), tt.get(3)) { + Some(match (tt.first(), tt.get(1), tt.get(2), tt.get(3)) { (Some(&"$aws"), Some(&"things"), _, Some(&"jobs")) => { // This is a job topic! Figure out which match (tt.get(4), tt.get(5), tt.get(6), tt.get(7)) { diff --git a/src/jobs/update.rs b/src/jobs/update.rs index 3875bd0..867d0ac 100644 --- a/src/jobs/update.rs +++ b/src/jobs/update.rs @@ -146,9 +146,9 @@ impl<'a> Update<'a> { let payload_len = serde_json_core::to_slice( &UpdateJobExecutionRequest { execution_number: self.execution_number, - include_job_document: self.include_job_document.then(|| true), + include_job_document: self.include_job_document.then_some(true), expected_version: self.expected_version, - include_job_execution_state: self.include_job_execution_state.then(|| true), + include_job_execution_state: self.include_job_execution_state.then_some(true), status: self.status, status_details: self.status_details, step_timeout_in_minutes: self.step_timeout_in_minutes, diff --git a/src/lib.rs b/src/lib.rs index f5d758f..abda902 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,6 +6,7 @@ // This mod MUST go first, so that the others see its macros. pub(crate) mod fmt; +pub mod defender_metrics; pub mod jobs; #[cfg(any(feature = "ota_mqtt_data", feature = "ota_http_data"))] pub mod ota; diff --git a/src/ota/control_interface/mqtt.rs b/src/ota/control_interface/mqtt.rs index 601ad68..fdc8a74 100644 --- a/src/ota/control_interface/mqtt.rs +++ b/src/ota/control_interface/mqtt.rs @@ -1,15 +1,13 @@ use core::fmt::Write; use embassy_sync::blocking_mutex::raw::RawMutex; -use embassy_time::with_timeout; -use embedded_mqtt::{DeferredPayload, EncodingError, Publish, QoS, Subscribe, SubscribeTopic}; +use embedded_mqtt::{DeferredPayload, EncodingError, Publish, QoS}; use super::ControlInterface; -use crate::jobs::data_types::{ErrorResponse, JobStatus, UpdateJobExecutionResponse}; -use crate::jobs::{JobError, JobTopic, Jobs, MAX_JOB_ID_LEN, MAX_THING_NAME_LEN}; -use crate::ota::config::Config; +use crate::jobs::data_types::JobStatus; +use crate::jobs::{JobTopic, Jobs, MAX_JOB_ID_LEN, MAX_THING_NAME_LEN}; use crate::ota::encoding::json::JobStatusReason; -use crate::ota::encoding::{self, FileContext}; +use crate::ota::encoding::FileContext; use crate::ota::error::OtaError; use crate::ota::ProgressState; @@ -53,7 +51,7 @@ impl<'a, M: RawMutex> ControlInterface for embedded_mqtt::MqttClient<'a, M> { ) .map_err(|_| OtaError::Overflow)?; - let mut qos = QoS::AtLeastOnce; + let qos = QoS::AtLeastOnce; if let JobStatus::InProgress | JobStatus::Succeeded = status { let received_blocks = progress_state.total_blocks - progress_state.blocks_remaining; diff --git a/src/ota/data_interface/mqtt.rs b/src/ota/data_interface/mqtt.rs index 4afc2fb..e1e1c5e 100644 --- a/src/ota/data_interface/mqtt.rs +++ b/src/ota/data_interface/mqtt.rs @@ -58,7 +58,7 @@ pub enum Topic<'a> { impl<'a> Topic<'a> { pub fn from_str(s: &'a str) -> Option { let tt = s.splitn(8, '/').collect::>(); - Some(match (tt.get(0), tt.get(1), tt.get(2), tt.get(3)) { + Some(match (tt.first(), tt.get(1), tt.get(2), tt.get(3)) { (Some(&"$aws"), Some(&"things"), _, Some(&"streams")) => { // This is a stream topic! Figure out which match (tt.get(4), tt.get(5), tt.get(6), tt.get(7)) { diff --git a/src/ota/encoding/mod.rs b/src/ota/encoding/mod.rs index b0349ed..a55f972 100644 --- a/src/ota/encoding/mod.rs +++ b/src/ota/encoding/mod.rs @@ -134,7 +134,7 @@ impl FileContext { }) .collect() }) - .unwrap_or_else(|| StatusDetailsOwned::new()), + .unwrap_or_default(), job_name: heapless::String::try_from(job_data.job_name).unwrap(), block_offset, diff --git a/src/ota/mod.rs b/src/ota/mod.rs index 10b5b15..808d3d9 100644 --- a/src/ota/mod.rs +++ b/src/ota/mod.rs @@ -60,7 +60,7 @@ impl Updater { }); // Create the JobUpdater - let mut job_updater = JobUpdater::new(&file_ctx, &progress_state, &config, control); + let mut job_updater = JobUpdater::new(&file_ctx, &progress_state, config, control); match job_updater.initialize::(pal).await? { Some(()) => {} @@ -70,7 +70,7 @@ impl Updater { info!("Job document was accepted. Attempting to begin the update"); // Spawn the request momentum future - let momentum_fut = Self::handle_momentum(data, &config, &file_ctx, &progress_state); + let momentum_fut = Self::handle_momentum(data, config, &file_ctx, &progress_state); // Spawn the status update future let status_update_fut = job_updater.handle_status_updates(); @@ -100,7 +100,7 @@ impl Updater { { let mut progress = progress_state.lock().await; - data.request_file_blocks(&file_ctx, &mut progress, &config) + data.request_file_blocks(&file_ctx, &mut progress, config) .await?; } @@ -116,7 +116,7 @@ impl Updater { match Self::ingest_data_block( data, &mut block_writer, - &config, + config, &mut progress, payload.deref_mut(), ) @@ -168,7 +168,7 @@ impl Updater { if progress.request_block_remaining > 1 { progress.request_block_remaining -= 1; } else { - data.request_file_blocks(&file_ctx, &mut progress, &config) + data.request_file_blocks(&file_ctx, &mut progress, config) .await?; } } @@ -442,7 +442,7 @@ impl<'a, C: ControlInterface> JobUpdater<'a, C> { let mut progress = self.progress_state.lock().await; self.control .update_job_status( - &self.file_ctx, + self.file_ctx, &mut progress, JobStatus::Succeeded, JobStatusReason::Accepted, @@ -542,7 +542,7 @@ impl<'a, C: ControlInterface> JobUpdater<'a, C> { // in self_test active self.control .update_job_status( - &self.file_ctx, + self.file_ctx, &mut progress, JobStatus::InProgress, JobStatusReason::SelfTestActive, @@ -554,7 +554,7 @@ impl<'a, C: ControlInterface> JobUpdater<'a, C> { // complete the job self.control .update_job_status( - &self.file_ctx, + self.file_ctx, &mut progress, JobStatus::Succeeded, JobStatusReason::Accepted, @@ -568,7 +568,7 @@ impl<'a, C: ControlInterface> JobUpdater<'a, C> { self.control .update_job_status( - &self.file_ctx, + self.file_ctx, &mut progress, JobStatus::Failed, JobStatusReason::Rejected, @@ -582,7 +582,7 @@ impl<'a, C: ControlInterface> JobUpdater<'a, C> { self.control .update_job_status( - &self.file_ctx, + self.file_ctx, &mut progress, JobStatus::Failed, JobStatusReason::Aborted, @@ -607,7 +607,7 @@ impl<'a, C: ControlInterface> JobUpdater<'a, C> { let mut progress = self.progress_state.lock().await; self.control - .update_job_status(&self.file_ctx, &mut progress, status, reason) + .update_job_status(self.file_ctx, &mut progress, status, reason) .await?; Ok(()) } diff --git a/src/provisioning/mod.rs b/src/provisioning/mod.rs index 3327762..9e456f2 100644 --- a/src/provisioning/mod.rs +++ b/src/provisioning/mod.rs @@ -6,7 +6,7 @@ use core::future::Future; use embassy_sync::blocking_mutex::raw::RawMutex; use embedded_mqtt::{ - BufferProvider, DeferredPayload, EncodingError, Message, Publish, Subscribe, SubscribeTopic, + DeferredPayload, EncodingError, Publish, Subscribe, SubscribeTopic, Subscription, }; use futures::StreamExt; @@ -191,7 +191,7 @@ impl FleetProvisioner { }; let register_request = RegisterThingRequest { - certificate_ownership_token: &ownership_token, + certificate_ownership_token: ownership_token, parameters, }; diff --git a/src/provisioning/topics.rs b/src/provisioning/topics.rs index 9256c55..bef6710 100644 --- a/src/provisioning/topics.rs +++ b/src/provisioning/topics.rs @@ -96,7 +96,7 @@ impl<'a> Topic<'a> { pub fn from_str(s: &'a str) -> Option { let tt = s.splitn(6, '/').collect::>(); - match (tt.get(0), tt.get(1)) { + match (tt.first(), tt.get(1)) { (Some(&"$aws"), Some(&"provisioning-templates")) => { // This is a register thing topic, now figure out which one. @@ -107,7 +107,7 @@ impl<'a> Topic<'a> { Some(payload_format), Some(&"accepted"), ) => Some(Topic::RegisterThingAccepted( - *template_name, + template_name, PayloadFormat::from_str(payload_format).ok()?, )), ( @@ -116,7 +116,7 @@ impl<'a> Topic<'a> { Some(payload_format), Some(&"rejected"), ) => Some(Topic::RegisterThingRejected( - *template_name, + template_name, PayloadFormat::from_str(payload_format).ok()?, )), _ => None, diff --git a/src/shadows/mod.rs b/src/shadows/mod.rs index e717a6e..99bddfb 100644 --- a/src/shadows/mod.rs +++ b/src/shadows/mod.rs @@ -1,10 +1,10 @@ pub mod dao; pub mod data_types; -mod error; +pub mod error; mod shadow_diff; pub mod topics; -use core::{marker::PhantomData, ops::DerefMut, sync::atomic}; +use core::{marker::PhantomData, ops::DerefMut}; pub use data_types::Patch; use embassy_sync::{ @@ -118,7 +118,7 @@ where }; let payload = DeferredPayload::new( - |buf| { + |buf: &mut [u8]| { serde_json_core::to_slice(&request, buf) .map_err(|_| embedded_mqtt::EncodingError::BufferSize) }, diff --git a/src/shadows/topics.rs b/src/shadows/topics.rs index 34642d0..b572675 100644 --- a/src/shadows/topics.rs +++ b/src/shadows/topics.rs @@ -42,11 +42,11 @@ impl Topic { pub fn from_str(s: &str) -> Option<(Self, &str, Option<&str>)> { let tt = s.splitn(9, '/').collect::>(); - match (tt.get(0), tt.get(1), tt.get(2), tt.get(3)) { + match (tt.first(), tt.get(1), tt.get(2), tt.get(3)) { (Some(&"$aws"), Some(&"things"), Some(thing_name), Some(&Self::SHADOW)) => { // This is a shadow topic, now figure out which one. let (shadow_name, next_index) = if let Some(&"name") = tt.get(4) { - (tt.get(5).map(|s| *s), 6) + (tt.get(5).copied(), 6) } else { (None, 4) }; @@ -239,7 +239,7 @@ impl Subscribe { self.topics .iter() - .map(|(topic, qos)| Ok((Topic::from(*topic).format(thing_name, shadow_name)?, *qos))) + .map(|(topic, qos)| Ok(((*topic).format(thing_name, shadow_name)?, *qos))) .collect() } } @@ -278,7 +278,7 @@ impl Unsubscribe { self.topics .iter() - .map(|topic| Topic::from(*topic).format(thing_name, shadow_name)) + .map(|topic| (*topic).format(thing_name, shadow_name)) .collect() } } diff --git a/tests/common/network.rs b/tests/common/network.rs index 0cfe3db..d6a2d09 100644 --- a/tests/common/network.rs +++ b/tests/common/network.rs @@ -1,8 +1,8 @@ -use std::net::SocketAddr; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; use ::native_tls::Identity; use embedded_io_adapters::tokio_1::FromTokio; -use embedded_nal_async::{AddrType, Dns, IpAddr, Ipv4Addr, Ipv6Addr, TcpConnect}; +use embedded_nal_async::{AddrType, Dns, TcpConnect}; use tokio_native_tls::native_tls; use super::credentials; @@ -25,7 +25,7 @@ impl TcpConnect for Network { async fn connect<'a>( &'a self, - remote: embedded_nal_async::SocketAddr, + remote: SocketAddr, ) -> Result, Self::Error> { let stream = tokio::net::TcpStream::connect(format!("{}", remote)).await?; Ok(FromTokio::new(stream)) @@ -86,7 +86,7 @@ impl TcpConnect for TlsNetwork { async fn connect<'a>( &'a self, - remote: embedded_nal_async::SocketAddr, + remote: SocketAddr, ) -> Result, Self::Error> { log::info!("Connecting to {:?}", remote); let connector = tokio_native_tls::TlsConnector::from( diff --git a/tests/metric.rs b/tests/metric.rs new file mode 100644 index 0000000..4115c17 --- /dev/null +++ b/tests/metric.rs @@ -0,0 +1,119 @@ +//! ## Integration test of `AWS IoT Device defender metrics` +//! +//! +//! This test simulates publishing of metrics and expects a accepted response from aws +//! +//! The test runs through the following update sequence: +//! 1. Setup metric state +//! 2. Assert json format +//! 2. Publish metric +//! 3. Assert result from AWS + +mod common; + +use std::str::FromStr; + +use common::credentials; +use common::network::TlsNetwork; +use embassy_futures::select; +use embassy_sync::blocking_mutex::raw::NoopRawMutex; +use embedded_mqtt::{ + self, transport::embedded_nal::NalTransport, Config, DomainBroker, Publish, State, Subscribe, +}; +use futures::StreamExt; +use heapless::LinearMap; +use rustot::{ + defender_metrics::{ + data_types::{CustomMetric, Metric}, + MetricHandler, + }, + shadows::{derive::ShadowState, Shadow, ShadowState}, +}; +use serde::{Deserialize, Serialize}; +use serde_json::json; +use static_cell::StaticCell; + +fn assert_json_format<'a>(json: &'a str) { + log::debug!("{json}"); + let format = "{\"hed\":{\"rid\":0,\"v\":\"1.0\"},\"met\":null,\"cmet\":{\"MyMetricOfType_Number\":[{\"number\":1}],\"MyMetricOfType_NumberList\":[{\"number_list\":[1,2,3]}],\"MyMetricOfType_StringList\":[{\"string_list\":[\"value_1\",\"value_2\"]}],\"MyMetricOfType_IpList\":[{\"ip_list\":[\"172.0.0.0\",\"172.0.0.10\"]}]}}"; + + assert_eq!(json, format); +} + +#[tokio::test(flavor = "current_thread")] +async fn test_publish_metric() { + env_logger::init(); + + let (thing_name, identity) = credentials::identity(); + let hostname = credentials::HOSTNAME.unwrap(); + + static NETWORK: StaticCell = StaticCell::new(); + let network = NETWORK.init(TlsNetwork::new(hostname.to_owned(), identity)); + + // Create the MQTT stack + let broker = + DomainBroker::<_, 128>::new(format!("{}:8883", hostname).as_str(), network).unwrap(); + + let config = Config::builder() + .client_id(thing_name.try_into().unwrap()) + .keepalive_interval(embassy_time::Duration::from_secs(50)) + .build(); + + static STATE: StaticCell> = StaticCell::new(); + let state = STATE.init(State::new()); + let (mut stack, client) = embedded_mqtt::new(state, config); + + // Define metrics + let mut custom_metrics: LinearMap = LinearMap::new(); + + custom_metrics + .insert( + String::from_str("MyMetricOfType_Number").unwrap(), + [CustomMetric::Number(1)], + ) + .unwrap(); + + custom_metrics + .insert( + String::from_str("MyMetricOfType_NumberList").unwrap(), + [CustomMetric::NumberList(&[1, 2, 3])], + ) + .unwrap(); + + custom_metrics + .insert( + String::from_str("MyMetricOfType_StringList").unwrap(), + [CustomMetric::StringList(&["value_1", "value_2"])], + ) + .unwrap(); + + custom_metrics + .insert( + String::from_str("MyMetricOfType_IpList").unwrap(), + [CustomMetric::IpList(&["172.0.0.0", "172.0.0.10"])], + ) + .unwrap(); + + // Build metric + let mut metric = Metric::builder() + .custom_metrics(custom_metrics) + .header(Default::default()) + .build(); + + // Test the json format + let json = serde_json::to_string(&metric).unwrap(); + + assert_json_format(&json); + + let mut metric_handler = MetricHandler::new(&client); + + // Publish metric with mqtt + let mqtt_fut = async { assert!(metric_handler.publish_metric(metric, 2000).await.is_ok()) }; + + let mut transport = NalTransport::new(network, broker); + let _ = embassy_time::with_timeout( + embassy_time::Duration::from_secs(60), + select::select(stack.run(&mut transport), mqtt_fut), + ) + .await; +}