Skip to content

Commit 162654f

Browse files
committed
Flink:clear globalStatisticsState in init to avoid duplication
1 parent a473b1c commit 162654f

File tree

2 files changed

+63
-0
lines changed

2 files changed

+63
-0
lines changed

flink/v2.1/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOperator.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@ public void initializeState(StateInitializationContext context) throws Exception
122122
this.globalStatistics = restoredStatistics;
123123
}
124124

125+
// Perform a cleanup first to ensure that the state is empty.
126+
globalStatisticsState.clear();
125127
// Always request for new statistics from coordinator upon task initialization.
126128
// There are a few scenarios this is needed
127129
// 1. downstream writer parallelism changed due to rescale.
@@ -266,4 +268,9 @@ DataStatistics localStatistics() {
266268
GlobalStatistics globalStatistics() {
267269
return globalStatistics;
268270
}
271+
272+
@VisibleForTesting
273+
ListState<GlobalStatistics> globalStatisticsState() {
274+
return globalStatisticsState;
275+
}
269276
}

flink/v2.1/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsOperator.java

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,62 @@ public void testRestoreState(StatisticsType type, int parallelismAdjustment) thr
261261
}
262262
}
263263

264+
@ParameterizedTest
265+
@MethodSource("provideRestoreStateParameters")
266+
public void testRestoreStateClearGlobalStatistic(StatisticsType type, int parallelismAdjustment)
267+
throws Exception {
268+
Map<SortKey, Long> keyFrequency =
269+
ImmutableMap.of(CHAR_KEYS.get("a"), 2L, CHAR_KEYS.get("b"), 1L, CHAR_KEYS.get("c"), 1L);
270+
SortKey[] rangeBounds = new SortKey[] {CHAR_KEYS.get("a")};
271+
MapAssignment mapAssignment =
272+
MapAssignment.fromKeyFrequency(2, keyFrequency, 0.0d, SORT_ORDER_COMPARTOR);
273+
DataStatisticsOperator operator = createOperator(type, Fixtures.NUM_SUBTASKS);
274+
OperatorSubtaskState snapshot;
275+
try (OneInputStreamOperatorTestHarness<RowData, StatisticsOrRecord> testHarness1 =
276+
createHarness(operator)) {
277+
GlobalStatistics statistics;
278+
if (StatisticsUtil.collectType(type) == StatisticsType.Map) {
279+
statistics = GlobalStatistics.fromMapAssignment(1L, mapAssignment);
280+
} else {
281+
statistics = GlobalStatistics.fromRangeBounds(1L, rangeBounds);
282+
}
283+
284+
StatisticsEvent event =
285+
StatisticsEvent.createGlobalStatisticsEvent(
286+
statistics, Fixtures.GLOBAL_STATISTICS_SERIALIZER, false);
287+
operator.handleOperatorEvent(event);
288+
289+
GlobalStatistics globalStatistics = operator.globalStatistics();
290+
assertThat(globalStatistics.type()).isEqualTo(StatisticsUtil.collectType(type));
291+
if (StatisticsUtil.collectType(type) == StatisticsType.Map) {
292+
assertThat(globalStatistics.mapAssignment()).isEqualTo(mapAssignment);
293+
assertThat(globalStatistics.rangeBounds()).isNull();
294+
} else {
295+
assertThat(globalStatistics.mapAssignment()).isNull();
296+
assertThat(globalStatistics.rangeBounds()).isEqualTo(rangeBounds);
297+
}
298+
299+
snapshot = testHarness1.snapshot(1L, 0);
300+
}
301+
302+
// Use the snapshot to initialize state for another new operator and then verify that the global
303+
// statistics for the new operator is same as before
304+
MockOperatorEventGateway spyGateway = Mockito.spy(new MockOperatorEventGateway());
305+
DataStatisticsOperator restoredOperator =
306+
createOperator(type, Fixtures.NUM_SUBTASKS + parallelismAdjustment, spyGateway);
307+
try (OneInputStreamOperatorTestHarness<RowData, StatisticsOrRecord> testHarness2 =
308+
new OneInputStreamOperatorTestHarness<>(restoredOperator, 2, 2, 1)) {
309+
testHarness2.setup();
310+
testHarness2.initializeState(snapshot);
311+
312+
// When we restore from the savepoint, we should ensure that `globalStatisticsState` has been
313+
// completely cleaned up
314+
Iterable<GlobalStatistics> globalStatisticsIterable =
315+
restoredOperator.globalStatisticsState().get();
316+
assertThat(globalStatisticsIterable).isEmpty();
317+
}
318+
}
319+
264320
@SuppressWarnings("unchecked")
265321
@Test
266322
public void testMigrationWithLocalStatsOverThreshold() throws Exception {

0 commit comments

Comments
 (0)