Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 18 additions & 27 deletions crates/socketioxide-core/src/adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,18 +169,18 @@
type AckStream: Stream<Item = AckStreamItem<Self::AckError>> + FusedStream + Send + 'static;

/// Get all the socket ids in the namespace.
fn get_all_sids(&self) -> Vec<Sid>;
fn get_all_sids(&self, filter: impl Fn(&Sid) -> bool) -> Vec<Sid>;
/// Send data to the list of socket ids.
fn send_many(&self, sids: Vec<Sid>, data: Value) -> Result<(), Vec<SocketError>>;
fn send_many(&self, sids: &[Sid], data: Value) -> Result<(), Vec<SocketError>>;
/// Send data to the list of socket ids and get a stream of acks.
fn send_many_with_ack(
&self,
sids: Vec<Sid>,
sids: &[Sid],
packet: Packet,
timeout: Option<Duration>,
) -> Self::AckStream;
/// Disconnect all the sockets in the list.
fn disconnect_many(&self, sid: Vec<Sid>) -> Result<(), Vec<SocketError>>;
fn disconnect_many(&self, sid: &[Sid]) -> Result<(), Vec<SocketError>>;
/// Get the path of the namespace.
fn path(&self) -> &Str;
/// Get the parser of the namespace.
Expand Down Expand Up @@ -373,7 +373,7 @@
}

let data = self.sockets.parser().encode(packet);
self.sockets.send_many(sids, data)
self.sockets.send_many(&sids, data)
}

/// Broadcasts the packet to the sockets that match the [`BroadcastOptions`] and return a stream of ack responses.
Expand All @@ -390,13 +390,13 @@

let count = sids.len() as u32;
// We cannot pre-serialize the packet because we need to change the ack id.
let stream = self.sockets.send_many_with_ack(sids, packet, timeout);
let stream = self.sockets.send_many_with_ack(&sids, packet, timeout);
(stream, count)
}

/// Returns the sockets ids that match the [`BroadcastOptions`].
pub fn sockets(&self, opts: BroadcastOptions) -> Vec<Sid> {
self.apply_opts(opts)
self.apply_opts(opts).into_vec()
}

//TODO: make this operation O(1)
Expand Down Expand Up @@ -429,7 +429,7 @@
/// Disconnects the sockets that match the [`BroadcastOptions`].
pub fn disconnect_socket(&self, opts: BroadcastOptions) -> Result<(), Vec<SocketError>> {
let sids = self.apply_opts(opts);
self.sockets.disconnect_many(sids)
self.sockets.disconnect_many(&sids)
}

/// Returns all the rooms for this adapter.
Expand All @@ -450,33 +450,29 @@

impl<E: SocketEmitter> CoreLocalAdapter<E> {
/// Applies the given `opts` and return the sockets that match.
fn apply_opts(&self, opts: BroadcastOptions) -> Vec<Sid> {
fn apply_opts(&self, opts: BroadcastOptions) -> SmallVec<[Sid; 16]> {
let is_broadcast = opts.has_flag(BroadcastFlags::Broadcast);
let rooms = opts.rooms;

let except = self.get_except_sids(&opts.except);
let is_socket_current = |id| opts.sid.map(|s| s != id).unwrap_or(true);
if !rooms.is_empty() {
let rooms_map = self.rooms.read().unwrap();
rooms
.iter()
.filter_map(|room| rooms_map.get(room))
.flatten()
.copied()
.filter(|id| {
!except.contains(id)
&& (!is_broadcast || opts.sid.map(|s| &s != id).unwrap_or(true))
})
.filter(|id| !except.contains(id) && (!is_broadcast || is_socket_current(*id)))
.collect()
} else if is_broadcast {
self.sockets
.get_all_sids()
.into_iter()
.filter(|id| !except.contains(id) && opts.sid.map(|s| &s != id).unwrap_or(true))
.collect()
.get_all_sids(|id| !except.contains(id) && is_socket_current(*id))
.into()
} else if let Some(id) = opts.sid {
vec![id]
smallvec::smallvec![id]
} else {
vec![]
smallvec::smallvec![]
}
}

Expand Down Expand Up @@ -546,20 +542,15 @@
self.sockets.iter().copied().collect()
}

fn send_many(&self, _: Vec<Sid>, _: Value) -> Result<(), Vec<SocketError>> {
fn send_many(&self, _: &[Sid], _: Value) -> Result<(), Vec<SocketError>> {
Ok(())
}

fn send_many_with_ack(
&self,
_: Vec<Sid>,
_: Packet,
_: Option<Duration>,
) -> Self::AckStream {
fn send_many_with_ack(&self, _: &[Sid], _: Packet, _: Option<Duration>) -> Self::AckStream {
StubAckStream
}

fn disconnect_many(&self, _: Vec<Sid>) -> Result<(), Vec<SocketError>> {
fn disconnect_many(&self, _: &[Sid]) -> Result<(), Vec<SocketError>> {
Ok(())
}

Expand Down
29 changes: 12 additions & 17 deletions crates/socketioxide/src/ack.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,19 +144,13 @@ impl AckInnerStream {
///
/// The [`AckInnerStream`] will wait for the default timeout specified in the config
/// (5s by default) if no custom timeout is specified.
pub fn broadcast<A: Adapter>(
pub fn broadcast<'a, A: Adapter>(
packet: Packet,
sockets: Vec<Arc<Socket<A>>>,
duration: Option<Duration>,
sockets: impl Iterator<Item = &'a Arc<Socket<A>>>,
duration: Duration,
) -> Self {
let rxs = FuturesUnordered::new();

if sockets.is_empty() {
return AckInnerStream::Stream { rxs };
}

let duration =
duration.unwrap_or_else(|| sockets.first().unwrap().get_io().config().ack_timeout);
for socket in sockets {
let rx = socket.send_with_ack(packet.clone());
rxs.push(AckResultWithId {
Expand Down Expand Up @@ -312,16 +306,17 @@ mod test {
Self::new(val, Parser::default())
}
}
const TIMEOUT: Duration = Duration::from_secs(5);

#[tokio::test]
async fn broadcast_ack() {
let socket = create_socket();
let socket2 = create_socket();
let mut packet = get_packet();
packet.inner.set_ack_id(1);
let socks = vec![socket.clone().into(), socket2.clone().into()];
let socks = vec![&socket, &socket2];
let stream: AckStream<String, LocalAdapter> =
AckInnerStream::broadcast(packet, socks, None).into();
AckInnerStream::broadcast(packet, socks.into_iter(), TIMEOUT).into();

let res_packet = Packet::ack("test", value("test"), 1);
socket.recv(res_packet.inner.clone()).unwrap();
Expand Down Expand Up @@ -365,9 +360,9 @@ mod test {
let socket2 = create_socket();
let mut packet = get_packet();
packet.inner.set_ack_id(1);
let socks = vec![socket.clone().into(), socket2.clone().into()];
let socks = vec![&socket, &socket2];
let stream: AckStream<String, LocalAdapter> =
AckInnerStream::broadcast(packet, socks, None).into();
AckInnerStream::broadcast(packet, socks.into_iter(), TIMEOUT).into();

let res_packet = Packet::ack("test", value(132), 1);
socket.recv(res_packet.inner.clone()).unwrap();
Expand Down Expand Up @@ -422,9 +417,9 @@ mod test {
let socket2 = create_socket();
let mut packet = get_packet();
packet.inner.set_ack_id(1);
let socks = vec![socket.clone().into(), socket2.clone().into()];
let socks = vec![&socket, &socket2];
let stream: AckStream<String, LocalAdapter> =
AckInnerStream::broadcast(packet, socks, None).into();
AckInnerStream::broadcast(packet, socks.into_iter(), TIMEOUT).into();

let res_packet = Packet::ack("test", value("test"), 1);
socket.clone().recv(res_packet.inner.clone()).unwrap();
Expand Down Expand Up @@ -478,9 +473,9 @@ mod test {
let socket2 = create_socket();
let mut packet = get_packet();
packet.inner.set_ack_id(1);
let socks = vec![socket.clone().into(), socket2.clone().into()];
let socks = vec![&socket, &socket2];
let stream: AckStream<String, LocalAdapter> =
AckInnerStream::broadcast(packet, socks, Some(Duration::from_millis(10))).into();
AckInnerStream::broadcast(packet, socks.into_iter(), Duration::from_millis(10)).into();

socket
.recv(Packet::ack("test", value("test"), 1).inner)
Expand Down
16 changes: 3 additions & 13 deletions crates/socketioxide/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ impl<A: Adapter> Client<A> {
// We have to create a new `Str` otherwise, we would keep a ref to the original connect packet
// for the entire lifetime of the Namespace.
let path = Str::copy_from_slice(&ns_path);
let ns = ns_ctr.get_new_ns(path.clone(), &self.adapter_state, self.config.parser);
let ns = ns_ctr.get_new_ns(path.clone(), &self.adapter_state, &self.config);
let this = self.clone();
let esocket = esocket.clone();
tokio::spawn(async move {
Expand Down Expand Up @@ -157,12 +157,7 @@ impl<A: Adapter> Client<A> {
tracing::debug!("adding namespace {}", path);

let ns_path = Str::from(&path);
let ns = Namespace::new(
ns_path.clone(),
callback,
&self.adapter_state,
self.config.parser,
);
let ns = Namespace::new(ns_path.clone(), callback, &self.adapter_state, &self.config);
// We spawn the adapter init task and therefore it might fail but the namespace is still added.
// The best solution would be to make the fn async and returning the error to the user.
// However this would require all .ns() calls to be async.
Expand Down Expand Up @@ -472,12 +467,7 @@ mod test {
#[tokio::test]
async fn get_ns() {
let client = create_client();
let ns = Namespace::new(
Str::from("/"),
|| {},
&client.adapter_state,
client.config.parser,
);
let ns = Namespace::new(Str::from("/"), || {}, &client.adapter_state, &client.config);
client.nsps.write().unwrap().insert(Str::from("/"), ns);
assert!(client.get_ns("/").is_some());
}
Expand Down
Loading
Loading