Skip to content

Commit 16d8436

Browse files
authored
fix Tensor.numpy()[0] to float(Tensor) to adapt 0D (#2884)
1 parent 089c060 commit 16d8436

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

examples/tess/cls0/local/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def _collate_features(batch):
121121
optimizer.clear_grad()
122122

123123
# Calculate loss
124-
avg_loss += loss.numpy()[0]
124+
avg_loss += float(loss)
125125

126126
# Calculate metrics
127127
preds = paddle.argmax(logits, axis=1)

0 commit comments

Comments
 (0)