Skip to content

Commit 5ec6f94

Browse files
committed
feat(socketio/ns): improve emitter trait
1 parent f44d031 commit 5ec6f94

File tree

4 files changed

+82
-95
lines changed

4 files changed

+82
-95
lines changed

crates/socketioxide-core/src/adapter.rs

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -171,16 +171,16 @@ pub trait SocketEmitter: Send + Sync + 'static {
171171
/// Get all the socket ids in the namespace.
172172
fn get_all_sids(&self) -> Vec<Sid>;
173173
/// Send data to the list of socket ids.
174-
fn send_many(&self, sids: Vec<Sid>, data: Value) -> Result<(), Vec<SocketError>>;
174+
fn send_many(&self, sids: &[Sid], data: Value) -> Result<(), Vec<SocketError>>;
175175
/// Send data to the list of socket ids and get a stream of acks.
176176
fn send_many_with_ack(
177177
&self,
178-
sids: Vec<Sid>,
178+
sids: &[Sid],
179179
packet: Packet,
180180
timeout: Option<Duration>,
181181
) -> Self::AckStream;
182182
/// Disconnect all the sockets in the list.
183-
fn disconnect_many(&self, sid: Vec<Sid>) -> Result<(), Vec<SocketError>>;
183+
fn disconnect_many(&self, sid: &[Sid]) -> Result<(), Vec<SocketError>>;
184184
/// Get the path of the namespace.
185185
fn path(&self) -> &Str;
186186
/// Get the parser of the namespace.
@@ -373,7 +373,7 @@ impl<E: SocketEmitter> CoreLocalAdapter<E> {
373373
}
374374

375375
let data = self.sockets.parser().encode(packet);
376-
self.sockets.send_many(sids, data)
376+
self.sockets.send_many(&sids, data)
377377
}
378378

379379
/// Broadcasts the packet to the sockets that match the [`BroadcastOptions`] and return a stream of ack responses.
@@ -390,13 +390,13 @@ impl<E: SocketEmitter> CoreLocalAdapter<E> {
390390

391391
let count = sids.len() as u32;
392392
// We cannot pre-serialize the packet because we need to change the ack id.
393-
let stream = self.sockets.send_many_with_ack(sids, packet, timeout);
393+
let stream = self.sockets.send_many_with_ack(&sids, packet, timeout);
394394
(stream, count)
395395
}
396396

397397
/// Returns the sockets ids that match the [`BroadcastOptions`].
398398
pub fn sockets(&self, opts: BroadcastOptions) -> Vec<Sid> {
399-
self.apply_opts(opts)
399+
self.apply_opts(opts).into_vec()
400400
}
401401

402402
//TODO: make this operation O(1)
@@ -429,7 +429,7 @@ impl<E: SocketEmitter> CoreLocalAdapter<E> {
429429
/// Disconnects the sockets that match the [`BroadcastOptions`].
430430
pub fn disconnect_socket(&self, opts: BroadcastOptions) -> Result<(), Vec<SocketError>> {
431431
let sids = self.apply_opts(opts);
432-
self.sockets.disconnect_many(sids)
432+
self.sockets.disconnect_many(&sids)
433433
}
434434

435435
/// Returns all the rooms for this adapter.
@@ -450,33 +450,31 @@ impl<E: SocketEmitter> CoreLocalAdapter<E> {
450450

451451
impl<E: SocketEmitter> CoreLocalAdapter<E> {
452452
/// Applies the given `opts` and return the sockets that match.
453-
fn apply_opts(&self, opts: BroadcastOptions) -> Vec<Sid> {
453+
fn apply_opts(&self, opts: BroadcastOptions) -> SmallVec<[Sid; 16]> {
454454
let is_broadcast = opts.has_flag(BroadcastFlags::Broadcast);
455455
let rooms = opts.rooms;
456456

457457
let except = self.get_except_sids(&opts.except);
458+
let is_socket_current = |id| opts.sid.map(|s| s != id).unwrap_or(true);
458459
if !rooms.is_empty() {
459460
let rooms_map = self.rooms.read().unwrap();
460461
rooms
461462
.iter()
462463
.filter_map(|room| rooms_map.get(room))
463464
.flatten()
464465
.copied()
465-
.filter(|id| {
466-
!except.contains(id)
467-
&& (!is_broadcast || opts.sid.map(|s| &s != id).unwrap_or(true))
468-
})
466+
.filter(|id| !except.contains(id) && (!is_broadcast || is_socket_current(*id)))
469467
.collect()
470468
} else if is_broadcast {
471469
self.sockets
472470
.get_all_sids()
473471
.into_iter()
474-
.filter(|id| !except.contains(id) && opts.sid.map(|s| &s != id).unwrap_or(true))
472+
.filter(|id| !except.contains(id) && is_socket_current(*id))
475473
.collect()
476474
} else if let Some(id) = opts.sid {
477-
vec![id]
475+
smallvec::smallvec![id]
478476
} else {
479-
vec![]
477+
smallvec::smallvec![]
480478
}
481479
}
482480

@@ -546,20 +544,15 @@ mod test {
546544
self.sockets.iter().copied().collect()
547545
}
548546

549-
fn send_many(&self, _: Vec<Sid>, _: Value) -> Result<(), Vec<SocketError>> {
547+
fn send_many(&self, _: &[Sid], _: Value) -> Result<(), Vec<SocketError>> {
550548
Ok(())
551549
}
552550

553-
fn send_many_with_ack(
554-
&self,
555-
_: Vec<Sid>,
556-
_: Packet,
557-
_: Option<Duration>,
558-
) -> Self::AckStream {
551+
fn send_many_with_ack(&self, _: &[Sid], _: Packet, _: Option<Duration>) -> Self::AckStream {
559552
StubAckStream
560553
}
561554

562-
fn disconnect_many(&self, _: Vec<Sid>) -> Result<(), Vec<SocketError>> {
555+
fn disconnect_many(&self, _: &[Sid]) -> Result<(), Vec<SocketError>> {
563556
Ok(())
564557
}
565558

crates/socketioxide/src/ack.rs

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -144,19 +144,13 @@ impl AckInnerStream {
144144
///
145145
/// The [`AckInnerStream`] will wait for the default timeout specified in the config
146146
/// (5s by default) if no custom timeout is specified.
147-
pub fn broadcast<A: Adapter>(
147+
pub fn broadcast<'a, A: Adapter>(
148148
packet: Packet,
149-
sockets: Vec<Arc<Socket<A>>>,
150-
duration: Option<Duration>,
149+
sockets: impl Iterator<Item = &'a Arc<Socket<A>>>,
150+
duration: Duration,
151151
) -> Self {
152152
let rxs = FuturesUnordered::new();
153153

154-
if sockets.is_empty() {
155-
return AckInnerStream::Stream { rxs };
156-
}
157-
158-
let duration =
159-
duration.unwrap_or_else(|| sockets.first().unwrap().get_io().config().ack_timeout);
160154
for socket in sockets {
161155
let rx = socket.send_with_ack(packet.clone());
162156
rxs.push(AckResultWithId {
@@ -312,16 +306,17 @@ mod test {
312306
Self::new(val, Parser::default())
313307
}
314308
}
309+
const TIMEOUT: Duration = Duration::from_secs(5);
315310

316311
#[tokio::test]
317312
async fn broadcast_ack() {
318313
let socket = create_socket();
319314
let socket2 = create_socket();
320315
let mut packet = get_packet();
321316
packet.inner.set_ack_id(1);
322-
let socks = vec![socket.clone().into(), socket2.clone().into()];
317+
let socks = vec![&socket, &socket2];
323318
let stream: AckStream<String, LocalAdapter> =
324-
AckInnerStream::broadcast(packet, socks, None).into();
319+
AckInnerStream::broadcast(packet, socks.into_iter(), TIMEOUT).into();
325320

326321
let res_packet = Packet::ack("test", value("test"), 1);
327322
socket.recv(res_packet.inner.clone()).unwrap();
@@ -365,9 +360,9 @@ mod test {
365360
let socket2 = create_socket();
366361
let mut packet = get_packet();
367362
packet.inner.set_ack_id(1);
368-
let socks = vec![socket.clone().into(), socket2.clone().into()];
363+
let socks = vec![&socket, &socket2];
369364
let stream: AckStream<String, LocalAdapter> =
370-
AckInnerStream::broadcast(packet, socks, None).into();
365+
AckInnerStream::broadcast(packet, socks.into_iter(), TIMEOUT).into();
371366

372367
let res_packet = Packet::ack("test", value(132), 1);
373368
socket.recv(res_packet.inner.clone()).unwrap();
@@ -422,9 +417,9 @@ mod test {
422417
let socket2 = create_socket();
423418
let mut packet = get_packet();
424419
packet.inner.set_ack_id(1);
425-
let socks = vec![socket.clone().into(), socket2.clone().into()];
420+
let socks = vec![&socket, &socket2];
426421
let stream: AckStream<String, LocalAdapter> =
427-
AckInnerStream::broadcast(packet, socks, None).into();
422+
AckInnerStream::broadcast(packet, socks.into_iter(), TIMEOUT).into();
428423

429424
let res_packet = Packet::ack("test", value("test"), 1);
430425
socket.clone().recv(res_packet.inner.clone()).unwrap();
@@ -478,9 +473,9 @@ mod test {
478473
let socket2 = create_socket();
479474
let mut packet = get_packet();
480475
packet.inner.set_ack_id(1);
481-
let socks = vec![socket.clone().into(), socket2.clone().into()];
476+
let socks = vec![&socket, &socket2];
482477
let stream: AckStream<String, LocalAdapter> =
483-
AckInnerStream::broadcast(packet, socks, Some(Duration::from_millis(10))).into();
478+
AckInnerStream::broadcast(packet, socks.into_iter(), Duration::from_millis(10)).into();
484479

485480
socket
486481
.recv(Packet::ack("test", value("test"), 1).inner)

crates/socketioxide/src/client.rs

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ impl<A: Adapter> Client<A> {
8282
// We have to create a new `Str` otherwise, we would keep a ref to the original connect packet
8383
// for the entire lifetime of the Namespace.
8484
let path = Str::copy_from_slice(&ns_path);
85-
let ns = ns_ctr.get_new_ns(path.clone(), &self.adapter_state, self.config.parser);
85+
let ns = ns_ctr.get_new_ns(path.clone(), &self.adapter_state, &self.config);
8686
let this = self.clone();
8787
let esocket = esocket.clone();
8888
tokio::spawn(async move {
@@ -157,12 +157,7 @@ impl<A: Adapter> Client<A> {
157157
tracing::debug!("adding namespace {}", path);
158158

159159
let ns_path = Str::from(&path);
160-
let ns = Namespace::new(
161-
ns_path.clone(),
162-
callback,
163-
&self.adapter_state,
164-
self.config.parser,
165-
);
160+
let ns = Namespace::new(ns_path.clone(), callback, &self.adapter_state, &self.config);
166161
// We spawn the adapter init task and therefore it might fail but the namespace is still added.
167162
// The best solution would be to make the fn async and returning the error to the user.
168163
// However this would require all .ns() calls to be async.
@@ -472,12 +467,7 @@ mod test {
472467
#[tokio::test]
473468
async fn get_ns() {
474469
let client = create_client();
475-
let ns = Namespace::new(
476-
Str::from("/"),
477-
|| {},
478-
&client.adapter_state,
479-
client.config.parser,
480-
);
470+
let ns = Namespace::new(Str::from("/"), || {}, &client.adapter_state, &client.config);
481471
client.nsps.write().unwrap().insert(Str::from("/"), ns);
482472
assert!(client.get_ns("/").is_some());
483473
}

0 commit comments

Comments
 (0)