We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent f039d09 commit ee4b9ddCopy full SHA for ee4b9dd
llm/run_pretrain.py
@@ -11,6 +11,7 @@
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
+import copy
15
import math
16
import os
17
import sys
@@ -261,12 +262,12 @@ def print_dataset(data, mode="train"):
261
262
def _collate_data(data, stack_fn=Stack()):
263
tokens_ = stack_fn([x["text"] for x in data])
264
- labels = tokens_[:, 1:]
265
+ labels = copy.deepcopy(tokens_)[:, 1:]
266
tokens = tokens_[:, :-1]
267
268
return {
- "input_ids": paddle.to_tensor(tokens),
269
- "labels": paddle.to_tensor(labels),
+ "input_ids": tokens,
270
+ "labels": labels,
271
}
272
273
if need_data:
0 commit comments