Skip to content

Commit ee4b9dd

Browse files
authored
[Pretrain] Fix eval during pretrain (#7827)
* add unified checkpoint training args doc * fix eval during pretrain * fix
1 parent f039d09 commit ee4b9dd

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

llm/run_pretrain.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import copy
1415
import math
1516
import os
1617
import sys
@@ -261,12 +262,12 @@ def print_dataset(data, mode="train"):
261262
def _collate_data(data, stack_fn=Stack()):
262263
tokens_ = stack_fn([x["text"] for x in data])
263264

264-
labels = tokens_[:, 1:]
265+
labels = copy.deepcopy(tokens_)[:, 1:]
265266
tokens = tokens_[:, :-1]
266267

267268
return {
268-
"input_ids": paddle.to_tensor(tokens),
269-
"labels": paddle.to_tensor(labels),
269+
"input_ids": tokens,
270+
"labels": labels,
270271
}
271272

272273
if need_data:

0 commit comments

Comments
 (0)