Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 408959c

Browse files
committed
Merge pull request #5788 from matrix-org/rav/metaredactions
2 parents c24b899 + fb86217 commit 408959c

File tree

3 files changed

+183
-101
lines changed

3 files changed

+183
-101
lines changed

changelog.d/5788.bugfix

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Correctly handle redactions of redactions.

synapse/storage/events_worker.py

Lines changed: 112 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,7 @@
2929
from synapse.events import FrozenEvent, event_type_from_format_version # noqa: F401
3030
from synapse.events.snapshot import EventContext # noqa: F401
3131
from synapse.events.utils import prune_event
32-
from synapse.logging.context import (
33-
LoggingContext,
34-
PreserveLoggingContext,
35-
make_deferred_yieldable,
36-
run_in_background,
37-
)
32+
from synapse.logging.context import LoggingContext, PreserveLoggingContext
3833
from synapse.metrics.background_process_metrics import run_as_background_process
3934
from synapse.types import get_domain_from_id
4035
from synapse.util import batch_iter
@@ -342,13 +337,12 @@ def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
342337
log_ctx = LoggingContext.current_context()
343338
log_ctx.record_event_fetch(len(missing_events_ids))
344339

345-
# Note that _enqueue_events is also responsible for turning db rows
340+
# Note that _get_events_from_db is also responsible for turning db rows
346341
# into FrozenEvents (via _get_event_from_row), which involves seeing if
347342
# the events have been redacted, and if so pulling the redaction event out
348343
# of the database to check it.
349344
#
350-
# _enqueue_events is a bit of a rubbish name but naming is hard.
351-
missing_events = yield self._enqueue_events(
345+
missing_events = yield self._get_events_from_db(
352346
missing_events_ids, allow_rejected=allow_rejected
353347
)
354348

@@ -421,28 +415,28 @@ def _fetch_event_list(self, conn, event_list):
421415
The fetch requests. Each entry consists of a list of event
422416
ids to be fetched, and a deferred to be completed once the
423417
events have been fetched.
418+
419+
The deferreds are callbacked with a dictionary mapping from event id
420+
to event row. Note that it may well contain additional events that
421+
were not part of this request.
424422
"""
425423
with Measure(self._clock, "_fetch_event_list"):
426424
try:
427-
event_id_lists = list(zip(*event_list))[0]
428-
event_ids = [item for sublist in event_id_lists for item in sublist]
425+
events_to_fetch = set(
426+
event_id for events, _ in event_list for event_id in events
427+
)
429428

430429
row_dict = self._new_transaction(
431-
conn, "do_fetch", [], [], self._fetch_event_rows, event_ids
430+
conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch
432431
)
433432

434433
# We only want to resolve deferreds from the main thread
435-
def fire(lst, res):
436-
for ids, d in lst:
437-
if not d.called:
438-
try:
439-
with PreserveLoggingContext():
440-
d.callback([res[i] for i in ids if i in res])
441-
except Exception:
442-
logger.exception("Failed to callback")
434+
def fire():
435+
for _, d in event_list:
436+
d.callback(row_dict)
443437

444438
with PreserveLoggingContext():
445-
self.hs.get_reactor().callFromThread(fire, event_list, row_dict)
439+
self.hs.get_reactor().callFromThread(fire)
446440
except Exception as e:
447441
logger.exception("do_fetch")
448442

@@ -457,13 +451,98 @@ def fire(evs, exc):
457451
self.hs.get_reactor().callFromThread(fire, event_list, e)
458452

459453
@defer.inlineCallbacks
460-
def _enqueue_events(self, events, allow_rejected=False):
454+
def _get_events_from_db(self, event_ids, allow_rejected=False):
455+
"""Fetch a bunch of events from the database.
456+
457+
Returned events will be added to the cache for future lookups.
458+
459+
Args:
460+
event_ids (Iterable[str]): The event_ids of the events to fetch
461+
allow_rejected (bool): Whether to include rejected events
462+
463+
Returns:
464+
Deferred[Dict[str, _EventCacheEntry]]:
465+
map from event id to result. May return extra events which
466+
weren't asked for.
467+
"""
468+
fetched_events = {}
469+
events_to_fetch = event_ids
470+
471+
while events_to_fetch:
472+
row_map = yield self._enqueue_events(events_to_fetch)
473+
474+
# we need to recursively fetch any redactions of those events
475+
redaction_ids = set()
476+
for event_id in events_to_fetch:
477+
row = row_map.get(event_id)
478+
fetched_events[event_id] = row
479+
if row:
480+
redaction_ids.update(row["redactions"])
481+
482+
events_to_fetch = redaction_ids.difference(fetched_events.keys())
483+
if events_to_fetch:
484+
logger.debug("Also fetching redaction events %s", events_to_fetch)
485+
486+
# build a map from event_id to EventBase
487+
event_map = {}
488+
for event_id, row in fetched_events.items():
489+
if not row:
490+
continue
491+
assert row["event_id"] == event_id
492+
493+
rejected_reason = row["rejected_reason"]
494+
495+
if not allow_rejected and rejected_reason:
496+
continue
497+
498+
d = json.loads(row["json"])
499+
internal_metadata = json.loads(row["internal_metadata"])
500+
501+
format_version = row["format_version"]
502+
if format_version is None:
503+
# This means that we stored the event before we had the concept
504+
# of a event format version, so it must be a V1 event.
505+
format_version = EventFormatVersions.V1
506+
507+
original_ev = event_type_from_format_version(format_version)(
508+
event_dict=d,
509+
internal_metadata_dict=internal_metadata,
510+
rejected_reason=rejected_reason,
511+
)
512+
513+
event_map[event_id] = original_ev
514+
515+
# finally, we can decide whether each one nededs redacting, and build
516+
# the cache entries.
517+
result_map = {}
518+
for event_id, original_ev in event_map.items():
519+
redactions = fetched_events[event_id]["redactions"]
520+
redacted_event = self._maybe_redact_event_row(
521+
original_ev, redactions, event_map
522+
)
523+
524+
cache_entry = _EventCacheEntry(
525+
event=original_ev, redacted_event=redacted_event
526+
)
527+
528+
self._get_event_cache.prefill((event_id,), cache_entry)
529+
result_map[event_id] = cache_entry
530+
531+
return result_map
532+
533+
@defer.inlineCallbacks
534+
def _enqueue_events(self, events):
461535
"""Fetches events from the database using the _event_fetch_list. This
462536
allows batch and bulk fetching of events - it allows us to fetch events
463537
without having to create a new transaction for each request for events.
538+
539+
Args:
540+
events (Iterable[str]): events to be fetched.
541+
542+
Returns:
543+
Deferred[Dict[str, Dict]]: map from event id to row data from the database.
544+
May contain events that weren't requested.
464545
"""
465-
if not events:
466-
return {}
467546

468547
events_d = defer.Deferred()
469548
with self._event_fetch_lock:
@@ -482,32 +561,12 @@ def _enqueue_events(self, events, allow_rejected=False):
482561
"fetch_events", self.runWithConnection, self._do_fetch
483562
)
484563

485-
logger.debug("Loading %d events", len(events))
564+
logger.debug("Loading %d events: %s", len(events), events)
486565
with PreserveLoggingContext():
487-
rows = yield events_d
488-
logger.debug("Loaded %d events (%d rows)", len(events), len(rows))
489-
490-
if not allow_rejected:
491-
rows[:] = [r for r in rows if r["rejected_reason"] is None]
492-
493-
res = yield make_deferred_yieldable(
494-
defer.gatherResults(
495-
[
496-
run_in_background(
497-
self._get_event_from_row,
498-
row["internal_metadata"],
499-
row["json"],
500-
row["redactions"],
501-
rejected_reason=row["rejected_reason"],
502-
format_version=row["format_version"],
503-
)
504-
for row in rows
505-
],
506-
consumeErrors=True,
507-
)
508-
)
566+
row_map = yield events_d
567+
logger.debug("Loaded %d events (%d rows)", len(events), len(row_map))
509568

510-
return {e.event.event_id: e for e in res if e}
569+
return row_map
511570

512571
def _fetch_event_rows(self, txn, event_ids):
513572
"""Fetch event rows from the database
@@ -580,57 +639,16 @@ def _fetch_event_rows(self, txn, event_ids):
580639

581640
return event_dict
582641

583-
@defer.inlineCallbacks
584-
def _get_event_from_row(
585-
self, internal_metadata, js, redactions, format_version, rejected_reason=None
586-
):
587-
"""Parse an event row which has been read from the database
588-
589-
Args:
590-
internal_metadata (str): json-encoded internal_metadata column
591-
js (str): json-encoded event body from event_json
592-
redactions (list[str]): a list of the events which claim to have redacted
593-
this event, from the redactions table
594-
format_version: (str): the 'format_version' column
595-
rejected_reason (str|None): the reason this event was rejected, if any
596-
597-
Returns:
598-
_EventCacheEntry
599-
"""
600-
with Measure(self._clock, "_get_event_from_row"):
601-
d = json.loads(js)
602-
internal_metadata = json.loads(internal_metadata)
603-
604-
if format_version is None:
605-
# This means that we stored the event before we had the concept
606-
# of a event format version, so it must be a V1 event.
607-
format_version = EventFormatVersions.V1
608-
609-
original_ev = event_type_from_format_version(format_version)(
610-
event_dict=d,
611-
internal_metadata_dict=internal_metadata,
612-
rejected_reason=rejected_reason,
613-
)
614-
615-
redacted_event = yield self._maybe_redact_event_row(original_ev, redactions)
616-
617-
cache_entry = _EventCacheEntry(
618-
event=original_ev, redacted_event=redacted_event
619-
)
620-
621-
self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
622-
623-
return cache_entry
624-
625-
@defer.inlineCallbacks
626-
def _maybe_redact_event_row(self, original_ev, redactions):
642+
def _maybe_redact_event_row(self, original_ev, redactions, event_map):
627643
"""Given an event object and a list of possible redacting event ids,
628644
determine whether to honour any of those redactions and if so return a redacted
629645
event.
630646
631647
Args:
632648
original_ev (EventBase):
633649
redactions (iterable[str]): list of event ids of potential redaction events
650+
event_map (dict[str, EventBase]): other events which have been fetched, in
651+
which we can look up the redaaction events. Map from event id to event.
634652
635653
Returns:
636654
Deferred[EventBase|None]: if the event should be redacted, a pruned
@@ -640,15 +658,9 @@ def _maybe_redact_event_row(self, original_ev, redactions):
640658
# we choose to ignore redactions of m.room.create events.
641659
return None
642660

643-
if original_ev.type == "m.room.redaction":
644-
# ... and redaction events
645-
return None
646-
647-
redaction_map = yield self._get_events_from_cache_or_db(redactions)
648-
649661
for redaction_id in redactions:
650-
redaction_entry = redaction_map.get(redaction_id)
651-
if not redaction_entry:
662+
redaction_event = event_map.get(redaction_id)
663+
if not redaction_event or redaction_event.rejected_reason:
652664
# we don't have the redaction event, or the redaction event was not
653665
# authorized.
654666
logger.debug(
@@ -658,7 +670,6 @@ def _maybe_redact_event_row(self, original_ev, redactions):
658670
)
659671
continue
660672

661-
redaction_event = redaction_entry.event
662673
if redaction_event.room_id != original_ev.room_id:
663674
logger.debug(
664675
"%s was redacted by %s but redaction was in a different room!",

tests/storage/test_redaction.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
from mock import Mock
1919

20+
from twisted.internet import defer
21+
2022
from synapse.api.constants import EventTypes, Membership
2123
from synapse.api.room_versions import RoomVersions
2224
from synapse.types import RoomID, UserID
@@ -216,3 +218,71 @@ def test_redact_join(self):
216218
},
217219
event.unsigned["redacted_because"],
218220
)
221+
222+
def test_circular_redaction(self):
223+
redaction_event_id1 = "$redaction1_id:test"
224+
redaction_event_id2 = "$redaction2_id:test"
225+
226+
class EventIdManglingBuilder:
227+
def __init__(self, base_builder, event_id):
228+
self._base_builder = base_builder
229+
self._event_id = event_id
230+
231+
@defer.inlineCallbacks
232+
def build(self, prev_event_ids):
233+
built_event = yield self._base_builder.build(prev_event_ids)
234+
built_event.event_id = self._event_id
235+
built_event._event_dict["event_id"] = self._event_id
236+
return built_event
237+
238+
@property
239+
def room_id(self):
240+
return self._base_builder.room_id
241+
242+
event_1, context_1 = self.get_success(
243+
self.event_creation_handler.create_new_client_event(
244+
EventIdManglingBuilder(
245+
self.event_builder_factory.for_room_version(
246+
RoomVersions.V1,
247+
{
248+
"type": EventTypes.Redaction,
249+
"sender": self.u_alice.to_string(),
250+
"room_id": self.room1.to_string(),
251+
"content": {"reason": "test"},
252+
"redacts": redaction_event_id2,
253+
},
254+
),
255+
redaction_event_id1,
256+
)
257+
)
258+
)
259+
260+
self.get_success(self.store.persist_event(event_1, context_1))
261+
262+
event_2, context_2 = self.get_success(
263+
self.event_creation_handler.create_new_client_event(
264+
EventIdManglingBuilder(
265+
self.event_builder_factory.for_room_version(
266+
RoomVersions.V1,
267+
{
268+
"type": EventTypes.Redaction,
269+
"sender": self.u_alice.to_string(),
270+
"room_id": self.room1.to_string(),
271+
"content": {"reason": "test"},
272+
"redacts": redaction_event_id1,
273+
},
274+
),
275+
redaction_event_id2,
276+
)
277+
)
278+
)
279+
self.get_success(self.store.persist_event(event_2, context_2))
280+
281+
# fetch one of the redactions
282+
fetched = self.get_success(self.store.get_event(redaction_event_id1))
283+
284+
# it should have been redacted
285+
self.assertEqual(fetched.unsigned["redacted_by"], redaction_event_id2)
286+
self.assertEqual(
287+
fetched.unsigned["redacted_because"].event_id, redaction_event_id2
288+
)

0 commit comments

Comments
 (0)