Skip to content

Commit d53de53

Browse files
Dylan Wongdylanwong250
authored andcommitted
[SPARK-52989][SS] Add explicit close() API to State Store iterators
### What changes were proposed in this pull request? Add explicit ```close()``` API to State Store iterators. This PR changes the ```ReadStateStore``` trait's ```prefixScan``` and ```iterator``` methods to return ```StateStoreIterator[UnsafeRowPair]``` instead of ```Iterator[UnsafeRowPair]```. This new type has the ```close()``` method. The ```exists()``` method of MapStateImpl is also changed to close the iterator explicitly when it is no longer needed. Additionally ```close()``` calls are added to in TimerStateImpl, MapStateImplWithTTL in their iterators that consume the state store iterators. ### Why are the changes needed? These changes expose the close() method on state store iterators. This allows users of the StateStoreIterator to explicitly close it and its underlying resources when it's no longer needed. This change prevents the issue of having to hold on to the iterators until all rows are consumed and close() is called, or until the task completion/failure listener calls close() on the iterators. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? Existing unit tests, tests for the wrapper ```StateStoreIterator``` class and new test to verify that ```close()``` closes the underlying RocksDB iterator. ### Was this patch authored or co-authored using generative AI tooling? No Closes #51701 from dylanwong250/SPARK-52989. Lead-authored-by: Dylan Wong <[email protected]> Co-authored-by: dylanwong250 <[email protected]> Signed-off-by: Jungtaek Lim <[email protected]>
1 parent a895d55 commit d53de53

File tree

11 files changed

+202
-36
lines changed

11 files changed

+202
-36
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/statevariables/MapStateImpl.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,10 @@ class MapStateImpl[K, V](
5656

5757
/** Whether state exists or not. */
5858
override def exists(): Boolean = {
59-
store.prefixScan(stateTypesEncoder.encodeGroupingKey(), stateName).nonEmpty
59+
val iter = store.prefixScan(stateTypesEncoder.encodeGroupingKey(), stateName)
60+
val result = iter.nonEmpty
61+
iter.close()
62+
result
6063
}
6164

6265
/** Get the state value if it exists */

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/timers/TimerStateImpl.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,9 @@ class TimerStateImpl(
199199
}
200200
}
201201

202-
override protected def close(): Unit = { }
202+
override protected def close(): Unit = {
203+
iter.close()
204+
}
203205
}
204206
}
205207
}

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/ttl/MapStateImplWithTTL.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,9 @@ metrics: Map[String, SQLMetric])
128128
}
129129
}
130130

131-
override protected def close(): Unit = {}
131+
override protected def close(): Unit = {
132+
unsafeRowPairIterator.close()
133+
}
132134
}
133135
}
134136

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,9 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
8282

8383
override def get(key: UnsafeRow, colFamilyName: String): UnsafeRow = map.get(key)
8484

85-
override def iterator(colFamilyName: String): Iterator[UnsafeRowPair] = {
86-
map.iterator()
85+
override def iterator(colFamilyName: String): StateStoreIterator[UnsafeRowPair] = {
86+
val iter = map.iterator()
87+
new StateStoreIterator(iter)
8788
}
8889

8990
override def abort(): Unit = {}
@@ -94,9 +95,11 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
9495
s"HDFSReadStateStore[stateStoreId=$stateStoreId_, version=$version]"
9596
}
9697

97-
override def prefixScan(prefixKey: UnsafeRow, colFamilyName: String):
98-
Iterator[UnsafeRowPair] = {
99-
map.prefixScan(prefixKey)
98+
override def prefixScan(
99+
prefixKey: UnsafeRow,
100+
colFamilyName: String): StateStoreIterator[UnsafeRowPair] = {
101+
val iter = map.prefixScan(prefixKey)
102+
new StateStoreIterator(iter)
100103
}
101104

102105
override def valuesIterator(key: UnsafeRow, colFamilyName: String): Iterator[UnsafeRow] = {
@@ -214,15 +217,18 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
214217
* Get an iterator of all the store data.
215218
* This can be called only after committing all the updates made in the current thread.
216219
*/
217-
override def iterator(colFamilyName: String): Iterator[UnsafeRowPair] = {
220+
override def iterator(colFamilyName: String): StateStoreIterator[UnsafeRowPair] = {
218221
assertUseOfDefaultColFamily(colFamilyName)
219-
mapToUpdate.iterator()
222+
val iter = mapToUpdate.iterator()
223+
new StateStoreIterator(iter)
220224
}
221225

222-
override def prefixScan(prefixKey: UnsafeRow, colFamilyName: String):
223-
Iterator[UnsafeRowPair] = {
226+
override def prefixScan(
227+
prefixKey: UnsafeRow,
228+
colFamilyName: String): StateStoreIterator[UnsafeRowPair] = {
224229
assertUseOfDefaultColFamily(colFamilyName)
225-
mapToUpdate.prefixScan(prefixKey)
230+
val iter = mapToUpdate.prefixScan(prefixKey)
231+
new StateStoreIterator(iter)
226232
}
227233

228234
override def metrics: StateStoreMetrics = {

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -964,7 +964,7 @@ class RocksDB(
964964
/**
965965
* Get an iterator of all committed and uncommitted key-value pairs.
966966
*/
967-
def iterator(): Iterator[ByteArrayPair] = {
967+
def iterator(): NextIterator[ByteArrayPair] = {
968968
updateMemoryUsageIfNeeded()
969969
val iter = db.newIterator()
970970
logInfo(log"Getting iterator from version ${MDC(LogKeys.LOADED_VERSION, loadedVersion)}")
@@ -1001,7 +1001,7 @@ class RocksDB(
10011001
/**
10021002
* Get an iterator of all committed and uncommitted key-value pairs for the given column family.
10031003
*/
1004-
def iterator(cfName: String): Iterator[ByteArrayPair] = {
1004+
def iterator(cfName: String): NextIterator[ByteArrayPair] = {
10051005
updateMemoryUsageIfNeeded()
10061006
if (!useColumnFamilies) {
10071007
iterator()
@@ -1051,7 +1051,7 @@ class RocksDB(
10511051

10521052
def prefixScan(
10531053
prefix: Array[Byte],
1054-
cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Iterator[ByteArrayPair] = {
1054+
cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME): NextIterator[ByteArrayPair] = {
10551055
updateMemoryUsageIfNeeded()
10561056
val iter = db.newIterator()
10571057
val updatedPrefix = if (useColumnFamilies) {

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -315,17 +315,18 @@ private[sql] class RocksDBStateStoreProvider
315315
rocksDB.remove(kvEncoder._1.encodeKey(key), colFamilyName)
316316
}
317317

318-
override def iterator(colFamilyName: String): Iterator[UnsafeRowPair] = {
318+
override def iterator(colFamilyName: String): StateStoreIterator[UnsafeRowPair] = {
319319
validateAndTransitionState(UPDATE)
320320
// Note this verify function only verify on the colFamilyName being valid,
321321
// we are actually doing prefix when useColumnFamilies,
322322
// but pass "iterator" to throw correct error message
323323
verifyColFamilyOperations("iterator", colFamilyName)
324324
val kvEncoder = keyValueEncoderMap.get(colFamilyName)
325325
val rowPair = new UnsafeRowPair()
326-
327326
if (useColumnFamilies) {
328-
rocksDB.iterator(colFamilyName).map { kv =>
327+
val rocksDbIter = rocksDB.iterator(colFamilyName)
328+
329+
val iter = rocksDbIter.map { kv =>
329330
rowPair.withRows(kvEncoder._1.decodeKey(kv.key),
330331
kvEncoder._2.decodeValue(kv.value))
331332
if (!isValidated && rowPair.value != null && !useColumnFamilies) {
@@ -335,8 +336,12 @@ private[sql] class RocksDBStateStoreProvider
335336
}
336337
rowPair
337338
}
339+
340+
new StateStoreIterator(iter, rocksDbIter.closeIfNeeded)
338341
} else {
339-
rocksDB.iterator().map { kv =>
342+
val rocksDbIter = rocksDB.iterator()
343+
344+
val iter = rocksDbIter.map { kv =>
340345
rowPair.withRows(kvEncoder._1.decodeKey(kv.key),
341346
kvEncoder._2.decodeValue(kv.value))
342347
if (!isValidated && rowPair.value != null && !useColumnFamilies) {
@@ -346,11 +351,14 @@ private[sql] class RocksDBStateStoreProvider
346351
}
347352
rowPair
348353
}
354+
355+
new StateStoreIterator(iter, rocksDbIter.closeIfNeeded)
349356
}
350357
}
351358

352-
override def prefixScan(prefixKey: UnsafeRow, colFamilyName: String):
353-
Iterator[UnsafeRowPair] = {
359+
override def prefixScan(
360+
prefixKey: UnsafeRow,
361+
colFamilyName: String): StateStoreIterator[UnsafeRowPair] = {
354362
validateAndTransitionState(UPDATE)
355363
verifyColFamilyOperations("prefixScan", colFamilyName)
356364

@@ -360,11 +368,15 @@ private[sql] class RocksDBStateStoreProvider
360368

361369
val rowPair = new UnsafeRowPair()
362370
val prefix = kvEncoder._1.encodePrefixKey(prefixKey)
363-
rocksDB.prefixScan(prefix, colFamilyName).map { kv =>
371+
372+
val rocksDbIter = rocksDB.prefixScan(prefix, colFamilyName)
373+
val iter = rocksDbIter.map { kv =>
364374
rowPair.withRows(kvEncoder._1.decodeKey(kv.key),
365375
kvEncoder._2.decodeValue(kv.value))
366376
rowPair
367377
}
378+
379+
new StateStoreIterator(iter, rocksDbIter.closeIfNeeded)
368380
}
369381

370382
var checkpointInfo: Option[StateStoreCheckpointInfo] = None

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.sql.execution.streaming.state
1919

20+
import java.io.Closeable
2021
import java.util.UUID
2122
import java.util.concurrent.{ConcurrentLinkedQueue, ScheduledFuture, TimeUnit}
2223
import javax.annotation.concurrent.GuardedBy
@@ -44,6 +45,25 @@ import org.apache.spark.sql.execution.streaming.state.MaintenanceTaskType._
4445
import org.apache.spark.sql.types.StructType
4546
import org.apache.spark.util.{NextIterator, ThreadUtils, Utils}
4647

48+
/**
49+
* Represents an iterator that provides additional functionalities for state store use cases.
50+
*
51+
* `close()` is useful for freeing underlying iterator resources when the iterator is no longer
52+
* needed.
53+
*
54+
* The caller MUST call `close()` on the iterator if it was not fully consumed, and it is no
55+
* longer needed.
56+
*/
57+
class StateStoreIterator[A](
58+
val iter: Iterator[A],
59+
val onClose: () => Unit = () => {}) extends Iterator[A] with Closeable {
60+
override def hasNext: Boolean = iter.hasNext
61+
62+
override def next(): A = iter.next()
63+
64+
override def close(): Unit = onClose()
65+
}
66+
4767
sealed trait StateStoreEncoding {
4868
override def toString: String = this match {
4969
case StateStoreEncoding.UnsafeRow => "unsaferow"
@@ -117,10 +137,11 @@ trait ReadStateStore {
117137
*/
118138
def prefixScan(
119139
prefixKey: UnsafeRow,
120-
colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Iterator[UnsafeRowPair]
140+
colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): StateStoreIterator[UnsafeRowPair]
121141

122142
/** Return an iterator containing all the key-value pairs in the StateStore. */
123-
def iterator(colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Iterator[UnsafeRowPair]
143+
def iterator(
144+
colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): StateStoreIterator[UnsafeRowPair]
124145

125146
/**
126147
* Clean up the resource.
@@ -227,8 +248,8 @@ trait StateStore extends ReadStateStore {
227248
* performed after initialization of the iterator. Callers should perform all updates before
228249
* calling this method if all updates should be visible in the returned iterator.
229250
*/
230-
override def iterator(colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME):
231-
Iterator[UnsafeRowPair]
251+
override def iterator(colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME)
252+
: StateStoreIterator[UnsafeRowPair]
232253

233254
/** Current metrics of the state store */
234255
def metrics: StateStoreMetrics
@@ -260,16 +281,16 @@ class WrappedReadStateStore(store: StateStore) extends ReadStateStore {
260281
colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): UnsafeRow = store.get(key,
261282
colFamilyName)
262283

263-
override def iterator(colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME):
264-
Iterator[UnsafeRowPair] = store.iterator(colFamilyName)
284+
override def iterator(colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME)
285+
: StateStoreIterator[UnsafeRowPair] = store.iterator(colFamilyName)
265286

266287
override def abort(): Unit = store.abort()
267288

268289
override def release(): Unit = store.release()
269290

270291
override def prefixScan(prefixKey: UnsafeRow,
271-
colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Iterator[UnsafeRowPair] =
272-
store.prefixScan(prefixKey, colFamilyName)
292+
colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME)
293+
: StateStoreIterator[UnsafeRowPair] = store.prefixScan(prefixKey, colFamilyName)
273294

274295
override def valuesIterator(key: UnsafeRow, colFamilyName: String): Iterator[UnsafeRow] = {
275296
store.valuesIterator(key, colFamilyName)

sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,10 @@ class MemoryStateStore extends StateStore() {
2626
import scala.jdk.CollectionConverters._
2727
private val map = new ConcurrentHashMap[UnsafeRow, UnsafeRow]
2828

29-
override def iterator(colFamilyName: String): Iterator[UnsafeRowPair] = {
30-
map.entrySet.iterator.asScala.map { case e => new UnsafeRowPair(e.getKey, e.getValue) }
29+
override def iterator(colFamilyName: String): StateStoreIterator[UnsafeRowPair] = {
30+
val iter =
31+
map.entrySet.iterator.asScala.map { case e => new UnsafeRowPair(e.getKey, e.getValue) }
32+
new StateStoreIterator(iter)
3133
}
3234

3335
override def createColFamilyIfAbsent(
@@ -66,7 +68,9 @@ class MemoryStateStore extends StateStore() {
6668

6769
override def hasCommitted: Boolean = true
6870

69-
override def prefixScan(prefixKey: UnsafeRow, colFamilyName: String): Iterator[UnsafeRowPair] = {
71+
override def prefixScan(
72+
prefixKey: UnsafeRow,
73+
colFamilyName: String): StateStoreIterator[UnsafeRowPair] = {
7074
throw new UnsupportedOperationException("Doesn't support prefix scan!")
7175
}
7276

sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,14 @@ case class CkptIdCollectingStateStoreWrapper(innerStore: StateStore) extends Sta
7777

7878
override def prefixScan(
7979
prefixKey: UnsafeRow,
80-
colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Iterator[UnsafeRowPair] = {
80+
colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME)
81+
: StateStoreIterator[UnsafeRowPair] = {
8182
innerStore.prefixScan(prefixKey, colFamilyName)
8283
}
8384

8485
override def iterator(
85-
colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Iterator[UnsafeRowPair] = {
86+
colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME)
87+
: StateStoreIterator[UnsafeRowPair] = {
8688
innerStore.iterator(colFamilyName)
8789
}
8890

sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1650,6 +1650,80 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid
16501650
}
16511651
}
16521652

1653+
testWithColumnFamiliesAndEncodingTypes(
1654+
"closing the iterator also closes the underlying rocksdb iterator",
1655+
TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled =>
1656+
1657+
// use the same schema as value schema for single col key schema
1658+
tryWithProviderResource(newStoreProvider(valueSchema,
1659+
RangeKeyScanStateEncoderSpec(valueSchema, Seq(0)), colFamiliesEnabled)) { provider =>
1660+
val store = provider.getStore(0)
1661+
try {
1662+
val cfName = if (colFamiliesEnabled) "testColFamily" else "default"
1663+
if (colFamiliesEnabled) {
1664+
store.createColFamilyIfAbsent(cfName,
1665+
valueSchema, valueSchema,
1666+
RangeKeyScanStateEncoderSpec(valueSchema, Seq(0)))
1667+
}
1668+
1669+
val timerTimestamps = Seq(1, 2, 3, 22)
1670+
timerTimestamps.foreach { ts =>
1671+
val keyRow = dataToValueRow(ts)
1672+
val valueRow = dataToValueRow(1)
1673+
store.put(keyRow, valueRow, cfName)
1674+
assert(valueRowToData(store.get(keyRow, cfName)) === 1)
1675+
}
1676+
1677+
val iter1 = store.iterator(cfName)
1678+
for (i <- 1 to 4) {
1679+
assert(iter1.hasNext)
1680+
iter1.next()
1681+
}
1682+
// We were fully able to process the 4 elements
1683+
assert(!iter1.hasNext)
1684+
1685+
val iter2 = store.iterator(cfName)
1686+
for (i <- 1 to 2) {
1687+
assert(iter2.hasNext)
1688+
iter2.next()
1689+
}
1690+
// Close the iterator
1691+
iter2.close()
1692+
// After closing, this will call AbstractRocksIterator.isValid which should throw and
1693+
// exception since it no longer owns the underlying rocksdb iterator
1694+
val exception1 = intercept[AssertionError] {
1695+
iter2.next()
1696+
}
1697+
// Check that the exception is thrown from AbstractRocksIterator.isValid
1698+
assert(exception1.getStackTrace()(0).getClassName.contains("AbstractRocksIterator"))
1699+
assert(exception1.getStackTrace()(0).getMethodName.contains("isValid"))
1700+
1701+
// also check for prefix scan
1702+
val prefix = dataToValueRow(2)
1703+
val iter3 = store.prefixScan(prefix, cfName)
1704+
1705+
iter3.next()
1706+
assert(!iter3.hasNext)
1707+
1708+
val iter4 = store.prefixScan(prefix, cfName)
1709+
// Immediately close the iterator without calling next
1710+
iter4.close()
1711+
1712+
// Since we closed the iterator, this will throw an exception when we try to call next
1713+
val exception2 = intercept[AssertionError] {
1714+
iter4.next()
1715+
}
1716+
// Check that the exception is thrown from AbstractRocksIterator.isValid
1717+
assert(exception2.getStackTrace()(0).getClassName.contains("AbstractRocksIterator"))
1718+
assert(exception2.getStackTrace()(0).getMethodName.contains("isValid"))
1719+
1720+
store.commit()
1721+
} finally {
1722+
if (!store.hasCommitted) store.abort()
1723+
}
1724+
}
1725+
}
1726+
16531727
test("validate rocksdb values iterator correctness") {
16541728
withSQLConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1") {
16551729
tryWithProviderResource(newStoreProvider(useColumnFamilies = true,

0 commit comments

Comments
 (0)