Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/trainingdata/reader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,10 @@ InputPlanes PlanesFromTrainingData(const V6TrainingData& data) {
return result;
}

InputPlanes PlanesFromTrainingData(const V7TrainingData& data) {
return PlanesFromTrainingData(static_cast<const V6TrainingData&>(data));
}

TrainingDataReader::TrainingDataReader(std::string filename)
: filename_(filename) {
fin_ = gzopen(filename_.c_str(), "rb");
Expand Down Expand Up @@ -205,7 +209,7 @@ bool TrainingDataReader::ReadChunk(V6TrainingData* data) {
data->played_idx = 0;
data->best_idx = 0;
data->policy_kld = 0.0f;
data->reserved = 0;
data->q_st = 0.0f;
return true;
}
case 6: {
Expand Down
2 changes: 2 additions & 0 deletions src/trainingdata/reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#pragma once

#include "trainingdata/trainingdata.h"
#include "trainingdata/trainingdata_v7.h"

namespace lczero {

Expand All @@ -38,6 +39,7 @@ namespace lczero {
// will be used with DecodeMoveFromInput or PopulateBoard which assume the
// InputPlanes are not transformed.
InputPlanes PlanesFromTrainingData(const V6TrainingData& data);
InputPlanes PlanesFromTrainingData(const V7TrainingData& data);

class TrainingDataReader {
public:
Expand Down
83 changes: 53 additions & 30 deletions src/trainingdata/rescorer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ void DataAssert(bool check_result,
}
}

void Validate(std::span<const V6TrainingData> fileContents) {
template <typename FrameType>
void Validate(std::span<const FrameType> fileContents) {
if (fileContents.empty()) throw Exception("Empty File");

for (size_t i = 0; i < fileContents.size(); i++) {
Expand Down Expand Up @@ -217,8 +218,8 @@ void Validate(std::span<const V6TrainingData> fileContents) {
}
}

void Validate(std::span<const V6TrainingData> fileContents,
const MoveList& moves) {
template <typename FrameType>
void Validate(std::span<const FrameType> fileContents, const MoveList& moves) {
PositionHistory history;
int rule50ply;
int gameply;
Expand Down Expand Up @@ -460,13 +461,15 @@ struct ProcessFileFlags {
bool nnue_best_move : 1;
};

template <typename FrameType>
struct FileData {
std::vector<V6TrainingData> fileContents;
std::vector<FrameType> fileContents;
MoveList moves;
pblczero::NetworkFormat::InputFormat input_format;
};

bool IsAllDraws(const FileData& data) {
template <typename FrameType>
bool IsAllDraws(const FileData<FrameType>& data) {
for (const auto& chunk : data.fileContents) {
if (ResultForData(chunk) != 0) {
return false;
Expand All @@ -487,11 +490,12 @@ std::vector<V6TrainingData> ReadFile(const std::string& file) {
return fileContents;
}

FileData ProcessAndValidateFileData(std::vector<V6TrainingData> fileContents) {
FileData data;
template <typename FrameType>
FileData<FrameType> ProcessAndValidateFileData(std::vector<FrameType> fileContents) {
FileData<FrameType> data;
data.fileContents = std::move(fileContents);

Validate(data.fileContents);
Validate(std::span<const FrameType>(data.fileContents));
games += 1;
positions += data.fileContents.size();
// Decode moves from input data
Expand All @@ -504,15 +508,16 @@ FileData ProcessAndValidateFileData(std::vector<V6TrainingData> fileContents) {
// position before.
data.moves.back().Flip();
}
Validate(data.fileContents, data.moves);
Validate(std::span<const FrameType>(data.fileContents), data.moves);

data.input_format = static_cast<pblczero::NetworkFormat::InputFormat>(
data.fileContents[0].input_format);

return data;
}

void ApplyPolicySubstitutions(FileData& data) {
template <typename FrameType>
void ApplyPolicySubstitutions(FileData<FrameType>& data) {
if (policy_subs.empty()) return;
PositionHistory history;
int rule50ply;
Expand Down Expand Up @@ -545,7 +550,8 @@ void ApplyPolicySubstitutions(FileData& data) {
}
}

void ApplySyzygyRescoring(FileData& data, SyzygyTablebase* tablebase) {
template <typename FrameType>
void ApplySyzygyRescoring(FileData<FrameType>& data, SyzygyTablebase* tablebase) {
PositionHistory history;
int rule50ply;
int gameply;
Expand Down Expand Up @@ -704,7 +710,8 @@ void ApplySyzygyRescoring(FileData& data, SyzygyTablebase* tablebase) {
}
}

void ApplyPolicyAdjustments(FileData& data, SyzygyTablebase* tablebase,
template <typename FrameType>
void ApplyPolicyAdjustments(FileData<FrameType>& data, SyzygyTablebase* tablebase,
float distTemp, float distOffset, float dtzBoost) {
if (distTemp == 1.0f && distOffset == 0.0f && dtzBoost == 0.0f) {
return; // No adjustments needed
Expand Down Expand Up @@ -812,7 +819,8 @@ void ApplyPolicyAdjustments(FileData& data, SyzygyTablebase* tablebase,
}
}

void EstimateAndCorrectPliesLeft(FileData& data) {
template <typename FrameType>
void EstimateAndCorrectPliesLeft(FileData<FrameType>& data) {
// Make move_count field plies_left for moves left head.
int offset = 0;
for (auto& chunk : data.fileContents) {
Expand All @@ -826,7 +834,8 @@ void EstimateAndCorrectPliesLeft(FileData& data) {
}
}

void ApplyGaviotaCorrections(FileData& data) {
template <typename FrameType>
void ApplyGaviotaCorrections(FileData<FrameType>& data) {
if (!gaviotaEnabled) return;

if (IsAllDraws(data)) return;
Expand Down Expand Up @@ -899,7 +908,8 @@ void ApplyGaviotaCorrections(FileData& data) {
}
}

void ApplyDTZCorrections(FileData& data, SyzygyTablebase* tablebase) {
template <typename FrameType>
void ApplyDTZCorrections(FileData<FrameType>& data, SyzygyTablebase* tablebase) {
// Correct move_count using DTZ for 3 piece no-pawn positions only.
// If Gaviota TBs are enabled no need to use syzygy.
if (gaviotaEnabled) return;
Expand Down Expand Up @@ -976,7 +986,8 @@ void ApplyDTZCorrections(FileData& data, SyzygyTablebase* tablebase) {
}
}

void ApplyDeblunder(FileData& data, SyzygyTablebase* tablebase) {
template <typename FrameType>
void ApplyDeblunder(FileData<FrameType>& data, SyzygyTablebase* tablebase) {
// Deblunder only works from v6 data onwards. We therefore check
// the visits field which is 0 if we're dealing with upgraded data.
if (!deblunderEnabled || data.fileContents.back().visits == 0) {
Expand Down Expand Up @@ -1052,7 +1063,8 @@ void ApplyDeblunder(FileData& data, SyzygyTablebase* tablebase) {
}
}

void ConvertInputFormat(FileData& data, int newInputFormat) {
template <typename FrameType>
void ConvertInputFormat(FileData<FrameType>& data, int newInputFormat) {
if (newInputFormat == -1) return;

PositionHistory history;
Expand All @@ -1071,7 +1083,8 @@ void ConvertInputFormat(FileData& data, int newInputFormat) {
}
}

void WriteNnueOutput(const FileData& data, const std::string& nnue_plain_file,
template <typename FrameType>
void WriteNnueOutput(const FileData<FrameType>& data, const std::string& nnue_plain_file,
ProcessFileFlags flags) {
// Output data in Stockfish plain format.
if (!nnue_plain_file.empty()) {
Expand Down Expand Up @@ -1116,7 +1129,8 @@ void WriteNnueOutput(const FileData& data, const std::string& nnue_plain_file,
}
}

void WriteOutputs(const FileData& data, const std::string& file,
template <typename FrameType>
void WriteOutputs(const FileData<FrameType>& data, const std::string& file,
const std::string& outputDir) {
// Write processed training data
if (!outputDir.empty()) {
Comment on lines +1132 to 1136
Copy link

Copilot AI Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The WriteOutputs function calls writer.WriteChunk(chunk) with FrameType objects, but TrainingDataWriter::WriteChunk only accepts const V6TrainingData&. When FrameType is V7TrainingData, this will cause implicit slicing, losing V7-specific fields (d_st, opp_played_idx, next_played_idx, reserved). Consider adding a V7TrainingData overload to TrainingDataWriter or using a constrained template to prevent misuse.

Copilot uses AI. Check for mistakes.
Expand All @@ -1131,12 +1145,13 @@ void WriteOutputs(const FileData& data, const std::string& file,
}
}

FileData ProcessFileInternal(std::vector<V6TrainingData> fileContents,
SyzygyTablebase* tablebase, float distTemp,
float distOffset, float dtzBoost,
int newInputFormat) {
template <typename FrameType>
FileData<FrameType> ProcessFileInternal(std::vector<FrameType> fileContents,
SyzygyTablebase* tablebase, float distTemp,
float distOffset, float dtzBoost,
int newInputFormat) {
// Process and validate file data
FileData data = ProcessAndValidateFileData(std::move(fileContents));
FileData<FrameType> data = ProcessAndValidateFileData(std::move(fileContents));

// Apply policy substitutions if available
ApplyPolicySubstitutions(data);
Expand Down Expand Up @@ -1219,7 +1234,7 @@ void BuildSubs(const std::vector<std::string>& files) {
while (reader.ReadChunk(&data)) {
fileContents.push_back(data);
}
Validate(fileContents);
Validate(std::span<const V6TrainingData>(fileContents));
MoveList moves;
for (size_t i = 1; i < fileContents.size(); i++) {
moves.push_back(
Expand All @@ -1230,7 +1245,7 @@ void BuildSubs(const std::vector<std::string>& files) {
// position before.
moves.back().Flip();
}
Validate(fileContents, moves);
Validate(std::span<const V6TrainingData>(fileContents), moves);

// Subs are 'valid'.
PositionHistory history;
Expand Down Expand Up @@ -1417,10 +1432,12 @@ void RunRescorer() {
<< std::endl;
}

std::vector<V6TrainingData> RescoreTrainingData(
std::vector<V6TrainingData> fileContents, SyzygyTablebase* tablebase,
float distTemp, float distOffset, float dtzBoost, int newInputFormat) {
FileData data =
template <typename FrameType>
std::vector<FrameType> RescoreTrainingData(std::vector<FrameType> fileContents,
SyzygyTablebase* tablebase, float distTemp,
float distOffset, float dtzBoost,
int newInputFormat) {
FileData<FrameType> data =
ProcessFileInternal(std::move(fileContents), tablebase, distTemp,
distOffset, dtzBoost, newInputFormat);
return data.fileContents;
Expand Down Expand Up @@ -1467,4 +1484,10 @@ bool RescorerPolicySubstitutionSetup(std::string policySubsDir) {
return !policy_subs.empty();
}

// Explicit template instantiations.
template std::vector<V6TrainingData> RescoreTrainingData(
std::vector<V6TrainingData>, SyzygyTablebase*, float, float, float, int);
template std::vector<V7TrainingData> RescoreTrainingData(
std::vector<V7TrainingData>, SyzygyTablebase*, float, float, float, int);

} // namespace lczero
20 changes: 16 additions & 4 deletions src/trainingdata/rescorer.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@

#pragma once

#include <type_traits>
#include <vector>

#include "syzygy/syzygy.h"
#include "trainingdata/trainingdata_v6.h"
#include "trainingdata/trainingdata_v7.h"

namespace lczero {

Expand All @@ -40,9 +42,19 @@ void RunRescorer();
bool RescorerDeblunderSetup(float threshold, float width);
bool RescorerGaviotaSetup(std::string dtmPaths);
bool RescorerPolicySubstitutionSetup(std::string policySubsDir);
std::vector<V6TrainingData> RescoreTrainingData(
std::vector<V6TrainingData> fileContents, SyzygyTablebase* tablebase,
float distTemp = 1.0f, float distOffset = 0.0f, float dtzBoost = 0.0f,
int newInputFormat = -1);

template <typename T>
std::vector<T> RescoreTrainingData(std::vector<T> fileContents,
SyzygyTablebase* tablebase,
float distTemp = 1.0f,
float distOffset = 0.0f,
float dtzBoost = 0.0f,
int newInputFormat = -1);

// Explicit instantiation declarations.
extern template std::vector<V6TrainingData> RescoreTrainingData(
std::vector<V6TrainingData>, SyzygyTablebase*, float, float, float, int);
extern template std::vector<V7TrainingData> RescoreTrainingData(
std::vector<V7TrainingData>, SyzygyTablebase*, float, float, float, int);

} // namespace lczero
15 changes: 6 additions & 9 deletions src/trainingdata/trainingdata.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,11 @@ void V6TrainingDataArray::Write(TrainingDataWriter* writer, GameResult result,
}
}

void V6TrainingDataArray::Add(const classic::Node* node,
const PositionHistory& history,
classic::Eval best_eval,
classic::Eval played_eval, bool best_is_proven,
Move best_move, Move played_move,
std::span<Move> legal_moves,
const std::optional<EvalResult>& nneval,
float policy_softmax_temp) {
void V6TrainingDataArray::Add(
const classic::Node* node, const PositionHistory& history,
classic::Eval best_eval, classic::Eval played_eval, bool best_is_proven,
Move best_move, Move played_move, std::span<Move> legal_moves,
const std::optional<EvalResult>& nneval, float policy_softmax_temp) {
V6TrainingData result;
const auto& position = history.Last();

Expand Down Expand Up @@ -262,7 +259,7 @@ void V6TrainingDataArray::Add(const classic::Node* node,
}
result.best_idx = MoveToNNIndex(best_move, transform);
result.played_idx = MoveToNNIndex(played_move, transform);
result.reserved = 0;
result.q_st = 0.0f;

// Unknown here - will be filled in once the full data has been collected.
result.plies_left = 0;
Expand Down
3 changes: 2 additions & 1 deletion src/trainingdata/trainingdata_v6.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ struct V6TrainingData {
uint16_t best_idx;
// Kullback-Leibler divergence between visits and policy (denominator)
float policy_kld;
uint32_t reserved;
// Q standard deviation. Was uint32_t reserved in original v6.
float q_st;
} PACKED_STRUCT;
static_assert(sizeof(V6TrainingData) == 8356, "Wrong struct size");

Expand Down
50 changes: 50 additions & 0 deletions src/trainingdata/trainingdata_v7.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
This file is part of Leela Chess Zero.
Copyright (C) 2021 The LCZero Authors

Leela Chess is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

Leela Chess is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with Leela Chess. If not, see <http://www.gnu.org/licenses/>.

Additional permission under GNU GPL version 3 section 7

If you modify this Program, or any covered work, by linking or
combining it with NVIDIA Corporation's libraries from the NVIDIA CUDA
Toolkit and the NVIDIA CUDA Deep Neural Network library (or a
modified version of those libraries), containing parts covered by the
terms of the respective license agreement, the licensors of this
Program grant you additional permission to convey the resulting work.
*/

#pragma once

#include <cstdint>

#include "trainingdata/trainingdata_v6.h"
#include "utils/cppattributes.h"

namespace lczero {

#pragma pack(push, 1)

struct V7TrainingData : V6TrainingData {
// D standard deviation (q_st is already defined in V6TrainingData).
float d_st;
uint16_t opp_played_idx;
uint16_t next_played_idx;
float reserved[8];
} PACKED_STRUCT;
static_assert(sizeof(V7TrainingData) == 8396, "Wrong struct size");

#pragma pack(pop)

} // namespace lczero