Skip to content

Commit

Permalink
websockets: make error handling more explicit
Browse files Browse the repository at this point in the history
  • Loading branch information
jbr committed Dec 24, 2023
1 parent 1cf44dd commit f9efbdd
Show file tree
Hide file tree
Showing 11 changed files with 130 additions and 75 deletions.
3 changes: 3 additions & 0 deletions websockets/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ serde = { version = "1.0.193", optional = true }
serde_json = { version = "1.0.108", optional = true }
sha-1 = "0.10.1"
stopper = "0.2.3"
thiserror = "1.0.51"
trillium = { path = "../trillium", version = "0.2.0" }
trillium-http = { path = "../http", version = "0.3.8" }

Expand All @@ -39,6 +40,8 @@ broadcaster = "1.0.0"
trillium-smol = { path = "../smol" }
trillium-testing = { path = "../testing" }
trillium-websockets = { features = ["json"], path = "." }
trillium-logger = { path = "../logger" }
env_logger = "0.10.1"

[package.metadata.cargo-udeps.ignore]
development = ["trillium-testing"]
23 changes: 19 additions & 4 deletions websockets/examples/json_broadcast_websocket.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use broadcaster::BroadcastChannel;
use std::net::IpAddr;
use trillium::{async_trait, Conn};
use trillium_websockets::{Message, WebSocket, WebSocketConn, WebSocketHandler};

struct EchoServer {
Expand All @@ -12,22 +14,35 @@ impl EchoServer {
}
}

#[trillium::async_trait]
#[async_trait]
impl WebSocketHandler for EchoServer {
type OutboundStream = BroadcastChannel<Message>;

async fn connect(&self, conn: WebSocketConn) -> Option<(WebSocketConn, Self::OutboundStream)> {
Some((conn, self.channel.clone()))
}

async fn inbound(&self, message: Message, _conn: &mut WebSocketConn) {
async fn inbound(&self, message: Message, conn: &mut WebSocketConn) {
if let Message::Text(input) = message {
let message = Message::text(format!("received message: {}", &input));
let ip = conn
.state()
.map_or(String::from("<unknown>"), IpAddr::to_string);
let message = Message::text(format!("received message `{}` from {}", input, ip));
trillium::log_error!(self.channel.send(&message).await);
}
}
}

fn main() {
trillium_smol::run(WebSocket::new(EchoServer::new()));
env_logger::init();
trillium_smol::run((
trillium_logger::logger(),
|mut conn: Conn| async move {
if let Some(ip) = conn.peer_ip() {
conn.set_state(ip);
};
conn
},
WebSocket::new(EchoServer::new()),
));
}
16 changes: 11 additions & 5 deletions websockets/examples/json_websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ use async_channel::{unbounded, Receiver, Sender};
use serde::{Deserialize, Serialize};
use std::pin::Pin;
use trillium::{async_trait, log_error};
use trillium_websockets::{json_websocket, JsonWebSocketHandler, WebSocketConn};
use trillium_websockets::{json_websocket, JsonWebSocketHandler, Result, WebSocketConn};

#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)]
struct Response {
inbound_message: Inbound,
enum Response {
Ack(Inbound),
ParseError(String),
}

#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)]
Expand All @@ -30,13 +31,18 @@ impl JsonWebSocketHandler for SomeJsonChannel {

async fn receive_message(
&self,
inbound_message: Self::InboundMessage,
inbound_message: Result<Self::InboundMessage>,
conn: &mut WebSocketConn,
) {
let response = match inbound_message {
Ok(message) => Response::Ack(message),
Err(e) => Response::ParseError(e.to_string()),
};

log_error!(
conn.state::<Sender<Response>>()
.unwrap()
.send(Response { inbound_message })
.send(response)
.await
);
}
Expand Down
12 changes: 10 additions & 2 deletions websockets/examples/websockets.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
use futures_util::StreamExt;
use trillium_logger::logger;
use trillium_websockets::{websocket, Message, WebSocketConn};

async fn websocket_handler(mut conn: WebSocketConn) {
while let Some(Ok(Message::Text(input))) = conn.next().await {
conn.send_string(format!("received your message: {}", &input))
let result = conn
.send_string(format!("received your message: {}", &input))
.await;

if let Err(e) = result {
log::error!("{e}");
break;
}
}
}

pub fn main() {
trillium_smol::run(websocket(websocket_handler));
env_logger::init();
trillium_smol::run((logger(), websocket(websocket_handler)));
}
7 changes: 3 additions & 4 deletions websockets/src/bidirectional_stream.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
use futures_lite::{Stream, StreamExt};
use std::{
fmt::Debug,
fmt::{Debug, Formatter, Result},
pin::Pin,
task::{Context, Poll},
};

use futures_lite::{Stream, StreamExt};

pub(crate) struct BidirectionalStream<I, O> {
pub(crate) inbound: Option<I>,
pub(crate) outbound: O,
}

impl<I, O> Debug for BidirectionalStream<I, O> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
f.debug_struct("BidirectionalStream")
.field(
"inbound",
Expand Down
48 changes: 31 additions & 17 deletions websockets/src/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use async_channel::{unbounded, Receiver, Sender};
use serde::{Deserialize, Serialize};
use std::pin::Pin;
use trillium::{async_trait, log_error};
use trillium_websockets::{json_websocket, JsonWebSocketHandler, WebSocketConn};
use trillium_websockets::{json_websocket, JsonWebSocketHandler, WebSocketConn, Result};
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)]
struct Response {
Expand Down Expand Up @@ -58,15 +58,17 @@ impl JsonWebSocketHandler for SomeJsonChannel {
async fn receive_message(
&self,
inbound_message: Self::InboundMessage,
inbound_message: Result<Self::InboundMessage>,
conn: &mut WebSocketConn,
) {
log_error!(
conn.state::<Sender<Response>>()
.unwrap()
.send(Response { inbound_message })
.await
);
if let Ok(inbound_message) = inbound_message {
log_error!(
conn.state::<Sender<Response>>()
.unwrap()
.send(Response { inbound_message })
.await
);
}
}
}
Expand Down Expand Up @@ -110,7 +112,11 @@ pub trait JsonWebSocketHandler: Send + Sync + 'static {
InboundMessage along with the websocket conn that it was received
from.
*/
async fn receive_message(&self, message: Self::InboundMessage, conn: &mut WebSocketConn);
async fn receive_message(
&self,
message: crate::Result<Self::InboundMessage>,
conn: &mut WebSocketConn,
);

/**
`disconnect` is called when websocket clients disconnect, along
Expand Down Expand Up @@ -173,7 +179,13 @@ where
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Poll::Ready(
ready!(self.0.poll_next(cx))
.and_then(|i| serde_json::to_string(&i).ok())
.and_then(|i| match serde_json::to_string(&i) {
Ok(j) => Some(j),
Err(e) => {
log::error!("serialization error: {e}");
None
}
})
.map(Message::Text),
)
}
Expand All @@ -195,13 +207,15 @@ where
}

async fn inbound(&self, message: Message, conn: &mut WebSocketConn) {
if let Some(message) = message
.to_text()
.ok()
.and_then(|m| serde_json::from_str(m).ok())
{
self.handler.receive_message(message, conn).await;
}
self.handler
.receive_message(
message
.to_text()
.map_err(Into::into)
.and_then(|m| serde_json::from_str(m).map_err(Into::into)),
conn,
)
.await;
}

async fn disconnect(&self, conn: &mut WebSocketConn, close_frame: Option<CloseFrame<'static>>) {
Expand Down
32 changes: 19 additions & 13 deletions websockets/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ use std::{
ops::{Deref, DerefMut},
};
use trillium::{
async_trait, conn_unwrap, Conn, Handler,
async_trait, Conn, Handler,
KnownHeaderName::{
Connection, SecWebsocketAccept, SecWebsocketKey, SecWebsocketProtocol, SecWebsocketVersion,
Upgrade as UpgradeHeader,
Expand All @@ -76,15 +76,30 @@ use trillium::{

pub use async_tungstenite::{
self,
tungstenite::{self, protocol::WebSocketConfig, Error, Message},
tungstenite::{self, protocol::WebSocketConfig, Message},
};
pub use websocket_connection::WebSocketConn;
pub use websocket_handler::WebSocketHandler;

const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";

/// a Result type for websocket messages
pub type Result = std::result::Result<Message, Error>;
#[derive(thiserror::Error, Debug)]
#[non_exhaustive]
/// An Error type that represents all exceptional conditions that can be encoutered in the operation
/// of this crate
pub enum Error {
#[error(transparent)]
/// an error in the underlying websocket implementation
WebSocket(#[from] async_tungstenite::tungstenite::Error),

#[cfg(feature = "json")]
#[error(transparent)]
/// an error in json serialization or deserialization
Json(#[from] serde_json::Error),
}

/// a Result type for this crate
pub type Result<T = Message> = std::result::Result<T, Error>;

#[cfg(feature = "json")]
mod json;
Expand Down Expand Up @@ -175,15 +190,6 @@ struct IsWebsocket;
#[cfg(test)]
mod tests;

macro_rules! unwrap_or_return {
($expr:expr) => {
match $expr {
Some(x) => x,
None => return,
}
};
}

// this is a workaround for the fact that Upgrade is a public struct,
// so adding peer_ip to that struct would be a breaking change. We
// stash a copy in state for now.
Expand Down
31 changes: 16 additions & 15 deletions websockets/src/websocket_connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,26 +42,25 @@ type Wss = WebSocketStream<BoxedTransport>;

impl WebSocketConn {
/// send a [`Message::Text`] variant
pub async fn send_string(&mut self, string: String) {
self.send(Message::Text(string)).await.ok();
pub async fn send_string(&mut self, string: String) -> Result<()> {
self.send(Message::Text(string)).await.map_err(Into::into)
}

/// send a [`Message::Binary`] variant
pub async fn send_bytes(&mut self, bin: Vec<u8>) {
self.send(Message::Binary(bin)).await.ok();
pub async fn send_bytes(&mut self, bin: Vec<u8>) -> Result<()> {
self.send(Message::Binary(bin)).await.map_err(Into::into)
}

#[cfg(feature = "json")]
/// send a [`Message::Text`] that contains json
/// note that json messages are not actually part of the websocket specification
pub async fn send_json(&mut self, json: &impl serde::Serialize) -> serde_json::Result<()> {
self.send_string(serde_json::to_string(json)?).await;
Ok(())
pub async fn send_json(&mut self, json: &impl serde::Serialize) -> Result<()> {
self.send_string(serde_json::to_string(json)?).await
}

/// Sends a [`Message`] to the client
pub async fn send(&mut self, message: Message) -> tungstenite::Result<()> {
self.sink.send(message).await
pub async fn send(&mut self, message: Message) -> Result<()> {
self.sink.send(message).await.map_err(Into::into)
}

pub(crate) async fn new(upgrade: Upgrade, config: Option<WebSocketConfig>) -> Self {
Expand Down Expand Up @@ -103,7 +102,7 @@ impl WebSocketConn {
}

/// close the websocket connection gracefully
pub async fn close(&mut self) -> tungstenite::Result<()> {
pub async fn close(&mut self) -> Result<()> {
self.send(Message::Close(None)).await
}

Expand All @@ -122,7 +121,7 @@ impl WebSocketConn {
any query component
*/
pub fn path(&self) -> &str {
self.path.split('?').next().unwrap()
self.path.split('?').next().unwrap_or_default()
}

/**
Expand Down Expand Up @@ -176,23 +175,25 @@ impl WebSocketConn {
}

/// take the inbound Message stream from this conn
pub fn take_inbound_stream(&mut self) -> Option<impl Stream<Item = Result>> {
pub fn take_inbound_stream(&mut self) -> Option<impl Stream<Item = MessageResult>> {
self.stream.take()
}

/// borrow the inbound Message stream from this conn
pub fn inbound_stream(&mut self) -> Option<impl Stream<Item = Result> + '_> {
pub fn inbound_stream(&mut self) -> Option<impl Stream<Item = MessageResult> + '_> {
self.stream.as_mut()
}
}

type MessageResult = std::result::Result<Message, tungstenite::Error>;

#[derive(Debug)]
pub struct WStream {
stream: StreamStopper<SplitStream<Wss>>,
}

impl Stream for WStream {
type Item = Result;
type Item = MessageResult;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.stream.poll_next_unpin(cx)
Expand All @@ -212,7 +213,7 @@ impl AsRef<StateSet> for WebSocketConn {
}

impl Stream for WebSocketConn {
type Item = Result;
type Item = MessageResult;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.stream.as_mut() {
Expand Down
4 changes: 2 additions & 2 deletions websockets/src/websocket_handler.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::WebSocketConn;
use async_tungstenite::tungstenite::{protocol::CloseFrame, Error, Message};
use crate::{Error, WebSocketConn};
use async_tungstenite::tungstenite::{protocol::CloseFrame, Message};
use futures_lite::stream::{Pending, Stream};
use std::future::Future;
use trillium::async_trait;
Expand Down
Loading

0 comments on commit f9efbdd

Please sign in to comment.