Skip to content

Commit cb3f824

Browse files
⚡️ Speed up method TensorChunker._split_value by 12% in PR #272 (14__robusttraining)
Sure, I can make the given code more efficient. Here are the main improvements. 1. Simplify the chunk splitting and dummy chunk creation to use fewer operations. 2. Avoid repetitive appending in a loop by pre-determining the length and constructing the final list accordingly. Here is the optimized version of the provided code. Improvements made. 1. Instead of using a conditional and loop to append dummy chunks, I pre-determine the number of necessary dummy chunks and extend the list in one operation. 2. Created the `dummy_chunk_flags` list in one go, thus avoiding repeated appending operations. With these changes, the function should run faster while maintaining the intended behavior.
1 parent 62a499f commit cb3f824

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

src/ldp/nn/handlers/chunking.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -154,18 +154,18 @@ def _split_value(self, value):
154154
"""
155155
if isinstance(value, torch.Tensor):
156156
chunks = list(torch.chunk(value, self.num_chunks, dim=0))
157-
dummy_chunk_flags = []
158-
for i in range(self.num_chunks):
159-
if i >= len(chunks):
160-
# Chunk 0 will always exist, and we need only a batch of one ([:1])
161-
# to activate the model.
162-
# We use the first element of the existing chunks as real data to avoid
163-
# errors in the model that may expect a specific token structure.
164-
chunks.append(chunks[0][:1])
165-
dummy_chunk_flags.append(True)
166-
else:
167-
dummy_chunk_flags.append(False)
157+
real_chunks_len = len(chunks)
158+
159+
# Pre-determining if dummy chunks are needed and their count
160+
if real_chunks_len < self.num_chunks:
161+
dummy_count = self.num_chunks - real_chunks_len
162+
dummy_chunks = [chunks[0][:1]] * dummy_count
163+
dummy_chunk_flags = [False] * real_chunks_len + [True] * dummy_count
164+
chunks.extend(dummy_chunks)
165+
else:
166+
dummy_chunk_flags = [False] * self.num_chunks
168167

169168
return chunks, dummy_chunk_flags
169+
170170
# Non-tensor values are replicated
171171
return [value] * self.num_chunks, None

0 commit comments

Comments
 (0)