@@ -261,6 +261,62 @@ public void testRestoreState(StatisticsType type, int parallelismAdjustment) thr
261261 }
262262 }
263263
264+ @ ParameterizedTest
265+ @ MethodSource ("provideRestoreStateParameters" )
266+ public void testRestoreStateClearGlobalStatistic (StatisticsType type , int parallelismAdjustment )
267+ throws Exception {
268+ Map <SortKey , Long > keyFrequency =
269+ ImmutableMap .of (CHAR_KEYS .get ("a" ), 2L , CHAR_KEYS .get ("b" ), 1L , CHAR_KEYS .get ("c" ), 1L );
270+ SortKey [] rangeBounds = new SortKey [] {CHAR_KEYS .get ("a" )};
271+ MapAssignment mapAssignment =
272+ MapAssignment .fromKeyFrequency (2 , keyFrequency , 0.0d , SORT_ORDER_COMPARTOR );
273+ DataStatisticsOperator operator = createOperator (type , Fixtures .NUM_SUBTASKS );
274+ OperatorSubtaskState snapshot ;
275+ try (OneInputStreamOperatorTestHarness <RowData , StatisticsOrRecord > testHarness1 =
276+ createHarness (operator )) {
277+ GlobalStatistics statistics ;
278+ if (StatisticsUtil .collectType (type ) == StatisticsType .Map ) {
279+ statistics = GlobalStatistics .fromMapAssignment (1L , mapAssignment );
280+ } else {
281+ statistics = GlobalStatistics .fromRangeBounds (1L , rangeBounds );
282+ }
283+
284+ StatisticsEvent event =
285+ StatisticsEvent .createGlobalStatisticsEvent (
286+ statistics , Fixtures .GLOBAL_STATISTICS_SERIALIZER , false );
287+ operator .handleOperatorEvent (event );
288+
289+ GlobalStatistics globalStatistics = operator .globalStatistics ();
290+ assertThat (globalStatistics .type ()).isEqualTo (StatisticsUtil .collectType (type ));
291+ if (StatisticsUtil .collectType (type ) == StatisticsType .Map ) {
292+ assertThat (globalStatistics .mapAssignment ()).isEqualTo (mapAssignment );
293+ assertThat (globalStatistics .rangeBounds ()).isNull ();
294+ } else {
295+ assertThat (globalStatistics .mapAssignment ()).isNull ();
296+ assertThat (globalStatistics .rangeBounds ()).isEqualTo (rangeBounds );
297+ }
298+
299+ snapshot = testHarness1 .snapshot (1L , 0 );
300+ }
301+
302+ // Use the snapshot to initialize state for another new operator and then verify that the global
303+ // statistics for the new operator is same as before
304+ MockOperatorEventGateway spyGateway = Mockito .spy (new MockOperatorEventGateway ());
305+ DataStatisticsOperator restoredOperator =
306+ createOperator (type , Fixtures .NUM_SUBTASKS + parallelismAdjustment , spyGateway );
307+ try (OneInputStreamOperatorTestHarness <RowData , StatisticsOrRecord > testHarness2 =
308+ new OneInputStreamOperatorTestHarness <>(restoredOperator , 2 , 2 , 1 )) {
309+ testHarness2 .setup ();
310+ testHarness2 .initializeState (snapshot );
311+
312+ // When we restore from the savepoint, we should ensure that `globalStatisticsState` has been
313+ // completely cleaned up
314+ Iterable <GlobalStatistics > globalStatisticsIterable =
315+ restoredOperator .globalStatisticsState ().get ();
316+ assertThat (globalStatisticsIterable ).isEmpty ();
317+ }
318+ }
319+
264320 @ SuppressWarnings ("unchecked" )
265321 @ Test
266322 public void testMigrationWithLocalStatsOverThreshold () throws Exception {
0 commit comments