Skip to content

Commit f69bb23

Browse files
committed
fix consolidated locks
1 parent d8dde2e commit f69bb23

File tree

3 files changed

+6
-10
lines changed

3 files changed

+6
-10
lines changed

test/test_rb.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1791,6 +1791,10 @@ def test_batch_errors():
17911791

17921792

17931793
def test_add_warning():
1794+
from torchrl._utils import RL_WARNINGS
1795+
1796+
if not RL_WARNINGS:
1797+
return
17941798
rb = ReplayBuffer(storage=ListStorage(10), batch_size=3)
17951799
with pytest.warns(
17961800
UserWarning,

torchrl/data/llm/history.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1219,15 +1219,6 @@ def append(
12191219
f"The new history to append must have one less dimension than self. Got self.ndim={self.ndim} and history.ndim={history.ndim}."
12201220
)
12211221
dim = _maybe_correct_neg_dim(dim, self.batch_size)
1222-
# if self.ndim > 1 and dim >= self.ndim - 1:
1223-
# # then we need to append each element independently
1224-
# result = []
1225-
# for hist, new_hist in zip(self.unbind(0), history.unbind(0)):
1226-
# hist_c = hist.append(new_hist, inplace=inplace, dim=dim - 1)
1227-
# result.append(hist_c)
1228-
# if inplace:
1229-
# return self
1230-
# return lazy_stack(result)
12311222
if inplace:
12321223
if (
12331224
isinstance(self._tensordict, LazyStackedTensorDict)

torchrl/envs/batched_envs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1843,7 +1843,7 @@ def _step_no_buffers(
18431843
if self.consolidate:
18441844
try:
18451845
data = tensordict.consolidate(
1846-
share_memory=True, inplace=True, num_threads=1
1846+
share_memory=True, inplace=False, num_threads=1
18471847
)
18481848
except Exception as err:
18491849
raise RuntimeError(_CONSOLIDATE_ERR_CAPTURE) from err
@@ -2677,6 +2677,7 @@ def _run_worker_pipe_direct(
26772677
# data = data[idx]
26782678
data, reset_kwargs = data
26792679
if data is not None:
2680+
data.unlock_()
26802681
data._fast_apply(
26812682
lambda x: x.clone() if x.device.type == "cuda" else x, out=data
26822683
)

0 commit comments

Comments
 (0)