Skip to content

Commit d817a04

Browse files
DeepMind Lab Teamtkoeppe
authored andcommitted
[setting] Add 'mixerSeed' setting to enable spliting seed space between training and test runs.
Added a new setting for DMLab so that seeds for training and test runs are separate. This happens internally (using full 64 bit seeds instead of the 32 bit ones exposed by the environment API). This prevents any unnecessary folding due to training.
1 parent 9154942 commit d817a04

16 files changed

+131
-38
lines changed

RELEASE_NOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
Added further pre-built maps, which removes the need for the expensive
1313
:map_assets build step.
1414
2. Allow game to be renderered with top-left as origin instead of bottom-left.
15+
3. Add 'mixerSeed' setting to change behaviour of all random number generators.
1516

1617
## release-2018-02-07 February 2018 release
1718

deepmind/engine/context.cc

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,11 @@ static bool get_native_app(void* userdata) {
158158
return static_cast<Context*>(userdata)->NativeApp();
159159
}
160160

161+
static void set_mixer_seed(void* userdata, int v) {
162+
return static_cast<Context*>(userdata)->SetMixerSeed(
163+
static_cast<std::uint32_t>(v));
164+
}
165+
161166
static void set_actions(void* userdata, double look_down_up,
162167
double look_left_right, signed char move_back_forward,
163168
signed char strafe_left_right, signed char crouch_jump,
@@ -480,7 +485,7 @@ lua::NResultsOr MapMakerModule(lua_State* L) {
480485
LuaTextLevelMaker::CreateObject(
481486
L, ctx->ExecutableRunfiles(), ctx->TempDirectory(),
482487
ctx->UseLocalLevelCache(), ctx->UseGlobalLevelCache(),
483-
ctx->LevelCacheParams());
488+
ctx->LevelCacheParams(), ctx->MixerSeed());
484489
return 1;
485490
} else {
486491
return "Missing context!";
@@ -507,6 +512,7 @@ Context::Context(lua::Vm lua_vm, const char* executable_runfiles,
507512
: lua_vm_(std::move(lua_vm)),
508513
native_app_(false),
509514
actions_{},
515+
mixer_seed_(0),
510516
level_cache_params_{},
511517
game_(executable_runfiles, calls, file_reader_override,
512518
temp_folder != nullptr ? temp_folder : ""),
@@ -528,6 +534,7 @@ Context::Context(lua::Vm lua_vm, const char* executable_runfiles,
528534
hooks->run_lua_snippet = run_lua_snippet;
529535
hooks->set_native_app = set_native_app;
530536
hooks->get_native_app = get_native_app;
537+
hooks->set_mixer_seed = set_mixer_seed;
531538
hooks->set_actions = set_actions;
532539
hooks->get_actions = get_actions;
533540
hooks->find_model = find_model;
@@ -652,7 +659,8 @@ int Context::Init() {
652659
lua_vm_.AddCModuleToSearchers(
653660
"dmlab.system.tensor", tensor::LuaTensorConstructors);
654661
lua_vm_.AddCModuleToSearchers(
655-
"dmlab.system.maze_generation", LuaMazeGeneration::Require);
662+
"dmlab.system.maze_generation", &lua::Bind<LuaMazeGeneration::Require>,
663+
{reinterpret_cast<void*>(static_cast<std::uintptr_t>(mixer_seed_))});
656664
lua_vm_.AddCModuleToSearchers(
657665
"dmlab.system.map_maker", &lua::Bind<MapMakerModule>, {this});
658666
lua_vm_.AddCModuleToSearchers(
@@ -668,7 +676,9 @@ int Context::Init() {
668676
&lua::Bind<ContextPickups::Module>,
669677
{MutablePickups()});
670678
lua_vm_.AddCModuleToSearchers(
671-
"dmlab.system.random", &lua::Bind<LuaRandom::Require>, {UserPrbg()});
679+
"dmlab.system.random", &lua::Bind<LuaRandom::Require>,
680+
{UserPrbg(),
681+
reinterpret_cast<void*>(static_cast<std::uintptr_t>(mixer_seed_))});
672682
lua_vm_.AddCModuleToSearchers(
673683
"dmlab.system.model", &lua::Bind<ModelModule>,
674684
{const_cast<DeepmindCalls*>(Game().Calls())});
@@ -704,7 +714,8 @@ int Context::Init() {
704714
}
705715

706716
int Context::Start(int episode, int seed) {
707-
EnginePrbg()->seed(seed);
717+
EnginePrbg()->seed(static_cast<std::uint64_t>(seed) ^
718+
(static_cast<std::uint64_t>(mixer_seed_) << 32));
708719
MutableGame()->NextMap();
709720
lua_State* L = lua_vm_.get();
710721
script_table_ref_.PushMemberFunction("start");

deepmind/engine/context.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,15 @@ class Context {
227227
// generate new positive integers.
228228
int MakeRandomSeed();
229229

230+
// Specifies a mixer value to be combined with all the seeds passed to this
231+
// environment, before using them with the internal PRBGs. This is done in
232+
// a way which guarantees that the resulting seeds span disjoint subsets of
233+
// the integers in [0, 2^64) for each different mixer value. However, the
234+
// sequences produced by the environment's PRBGs are not necessarily disjoint.
235+
void SetMixerSeed(std::uint32_t s) { mixer_seed_ = s; }
236+
237+
std::uint32_t MixerSeed() const { return mixer_seed_; }
238+
230239
std::mt19937_64* UserPrbg() { return &user_prbg_; }
231240

232241
std::mt19937_64* EnginePrbg() { return &engine_prbg_; }
@@ -480,6 +489,9 @@ class Context {
480489
// A pseudo-random-bit generator for exclusive use by users.
481490
std::mt19937_64 user_prbg_;
482491

492+
// Stores the mixer seed for the PRBG.
493+
std::uint32_t mixer_seed_;
494+
483495
// A pseudo-random-bit generator for exclusive use of the engine. Seeded each
484496
// episode with the episode start seed.
485497
std::mt19937_64 engine_prbg_;

deepmind/engine/lua_maze_generation.cc

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ namespace lab {
3737
namespace {
3838

3939
std::mt19937_64* GetRandomNumberGenerator(lua::TableRef* table,
40-
std::mt19937_64* seeded_rng) {
40+
std::mt19937_64* seeded_rng,
41+
std::uint64_t mixer_seq) {
4142
std::mt19937_64* prng = nullptr;
4243
lua_State* L = table->LuaState();
4344
table->LookUpToStack("random");
@@ -51,7 +52,7 @@ std::mt19937_64* GetRandomNumberGenerator(lua::TableRef* table,
5152
if (prng == nullptr) {
5253
int seed = 0;
5354
if (table->LookUp("seed", &seed)) {
54-
seeded_rng->seed(seed);
55+
seeded_rng->seed(static_cast<std::uint64_t>(seed) ^ mixer_seq);
5556
prng = seeded_rng;
5657
}
5758
}
@@ -112,11 +113,19 @@ class LuaRoom : public lua::Class<LuaRoom> {
112113
std::vector<maze_generation::Pos> room_;
113114
};
114115

116+
// Bit toggle sequence applied to the 32 MSB of the 64bit seeds fed to the maze
117+
// generation PRBGs, with the intention of creating disjoint seed subspaces for
118+
// each different mixer_seed value as described in python_api.md
119+
std::uint64_t LuaMazeGeneration::mixer_seq_ = 0;
120+
115121
const char* LuaMazeGeneration::ClassName() {
116122
return "deepmind.lab.LuaMazeGeneration";
117123
}
118124

119-
int LuaMazeGeneration::Require(lua_State* L) {
125+
lua::NResultsOr LuaMazeGeneration::Require(lua_State* L) {
126+
std::uintptr_t mixer_seed =
127+
reinterpret_cast<std::uintptr_t>(lua_touserdata(L, lua_upvalueindex(1)));
128+
mixer_seq_ = static_cast<std::uint64_t>(mixer_seed) << 32;
120129
auto table = lua::TableRef::Create(L);
121130
table.Insert("mazeGeneration", &lua::Bind<LuaMazeGeneration::Create>);
122131
table.Insert("randomMazeGeneration",
@@ -162,7 +171,8 @@ lua::NResultsOr LuaMazeGeneration::CreateRandom(lua_State* L) {
162171
lua::Read(L, -1, &table);
163172

164173
std::mt19937_64 seeded_rng;
165-
std::mt19937_64* prng = GetRandomNumberGenerator(&table, &seeded_rng);
174+
std::mt19937_64* prng =
175+
GetRandomNumberGenerator(&table, &seeded_rng, mixer_seq_);
166176
if (prng == nullptr) {
167177
return "[randomMazeGeneration] - Must construct with 'random' a random "
168178
"number generator. ('seed' is deprecated.)";
@@ -544,7 +554,8 @@ lua::NResultsOr LuaMazeGeneration::VisitRandomPath(lua_State* L) {
544554
return "[visitRandomPath] - must supply table";
545555
}
546556
std::mt19937_64 seeded_rng;
547-
std::mt19937_64* prng = GetRandomNumberGenerator(&table, &seeded_rng);
557+
std::mt19937_64* prng =
558+
GetRandomNumberGenerator(&table, &seeded_rng, mixer_seq_);
548559
if (prng == nullptr) {
549560
return "[visitRandomPath] - must supply 'random' with random number "
550561
"generator. ('seed' is deprecated.)";

deepmind/engine/lua_maze_generation.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (C) 2016 Google Inc.
1+
// Copyright (C) 2016-2018 Google Inc.
22
//
33
// This program is free software; you can redistribute it and/or modify
44
// it under the terms of the GNU General Public License as published by
@@ -47,7 +47,7 @@ class LuaMazeGeneration : public lua::Class<LuaMazeGeneration> {
4747

4848
// Returns table of constructors and standalone functions.
4949
// [0, 1, -]
50-
static int Require(lua_State* L);
50+
static lua::NResultsOr Require(lua_State* L);
5151

5252
private:
5353
// Constructs a LuaMazeGeneration.
@@ -197,6 +197,8 @@ class LuaMazeGeneration : public lua::Class<LuaMazeGeneration> {
197197
lua::NResultsOr CountVariations(lua_State* L);
198198

199199
maze_generation::TextMaze text_maze_;
200+
201+
static std::uint64_t mixer_seq_;
200202
};
201203

202204
} // namespace lab

deepmind/engine/lua_maze_generation_test.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,13 @@ class LuaMazeGenerationTest : public lua::testing::TestWithVm {
4242
protected:
4343
LuaMazeGenerationTest() {
4444
LuaMazeGeneration::Register(L);
45-
vm()->AddCModuleToSearchers("dmlab.system.maze_generation",
46-
LuaMazeGeneration::Require);
45+
vm()->AddCModuleToSearchers(
46+
"dmlab.system.maze_generation", &lua::Bind<LuaMazeGeneration::Require>,
47+
{reinterpret_cast<void*>(static_cast<std::uintptr_t>(0))});
4748
LuaRandom::Register(L);
48-
vm()->AddCModuleToSearchers("dmlab.system.sys_random",
49-
&lua::Bind<LuaRandom::Require>, {&prbg_});
49+
vm()->AddCModuleToSearchers(
50+
"dmlab.system.sys_random", &lua::Bind<LuaRandom::Require>,
51+
{&prbg_, reinterpret_cast<void*>(static_cast<std::uintptr_t>(0))});
5052
}
5153

5254
std::mt19937_64 prbg_;

deepmind/engine/lua_random.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,9 @@ bool ReadLargeNumber(lua_State* L, int idx, RbgNumType* num) {
6060
lua::NResultsOr LuaRandom::Require(lua_State* L) {
6161
if (auto* prbg = static_cast<std::mt19937_64*>(
6262
lua_touserdata(L, lua_upvalueindex(1)))) {
63-
LuaRandom::CreateObject(L, prbg);
63+
std::uintptr_t mixer_seed = reinterpret_cast<std::uintptr_t>(
64+
lua_touserdata(L, lua_upvalueindex(2)));
65+
LuaRandom::CreateObject(L, prbg, mixer_seed);
6466
return 1;
6567
} else {
6668
return "Missing std::mt19937_64 pointer in up value!";
@@ -87,7 +89,7 @@ lua::NResultsOr LuaRandom::Seed(lua_State* L) {
8789
RbgNumType k;
8890

8991
if (ReadLargeNumber(L, -1, &k)) {
90-
prbg_->seed(k);
92+
prbg_->seed(k ^ mixer_seq_);
9193
return 0;
9294
} else if (lua::Read(L, -1, &s)) {
9395
auto& err = errno; // cache TLS-lookup
@@ -96,7 +98,7 @@ lua::NResultsOr LuaRandom::Seed(lua_State* L) {
9698
unsigned long long int n = std::strtoull(s.data(), &ep, 0);
9799
if (ep != s.data() && *ep == '\0' && err == 0 &&
98100
n <= std::numeric_limits<RbgNumType>::max()) {
99-
prbg_->seed(n);
101+
prbg_->seed(n ^ mixer_seq_);
100102
return 0;
101103
}
102104
}

deepmind/engine/lua_random.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ class LuaRandom : public lua::Class<LuaRandom> {
4343

4444
public:
4545
// Constructed with a non-owning view of a PRBG instance.
46-
explicit LuaRandom(std::mt19937_64* prbg) : prbg_(prbg) {}
46+
explicit LuaRandom(std::mt19937_64* prbg, std::uint32_t mixer_seed)
47+
: prbg_(prbg), mixer_seq_(static_cast<std::uint64_t>(mixer_seed) << 32) {}
4748

4849
// Registers the class as well as member functions:
4950
//
@@ -124,6 +125,7 @@ class LuaRandom : public lua::Class<LuaRandom> {
124125

125126
private:
126127
std::mt19937_64* prbg_;
128+
std::uint64_t mixer_seq_;
127129
};
128130

129131
} // namespace lab

deepmind/engine/lua_random_test.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,9 @@ class LuaRandomTest : public lua::testing::TestWithVm {
3939
protected:
4040
LuaRandomTest() {
4141
LuaRandom::Register(L);
42-
vm()->AddCModuleToSearchers("dmlab.system.sys_random",
43-
&lua::Bind<LuaRandom::Require>, {&prbg_});
42+
vm()->AddCModuleToSearchers(
43+
"dmlab.system.sys_random", &lua::Bind<LuaRandom::Require>,
44+
{&prbg_, reinterpret_cast<void*>(static_cast<std::uintptr_t>(0))});
4445
}
4546

4647
std::mt19937_64 prbg_;

deepmind/engine/lua_text_level_maker.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -314,8 +314,11 @@ bool NoOp(std::size_t, std::size_t, char,
314314
LuaTextLevelMaker::LuaTextLevelMaker(
315315
const std::string& self, const std::string& output_folder,
316316
bool use_local_level_cache, bool use_global_level_cache,
317-
DeepMindLabLevelCacheParams level_cache_params)
318-
: prng_(0), rundir_(self), output_folder_(output_folder) {
317+
DeepMindLabLevelCacheParams level_cache_params, std::uint32_t mixer_seed)
318+
: prng_(0),
319+
mixer_seed_(mixer_seed),
320+
rundir_(self),
321+
output_folder_(output_folder) {
319322
settings_.use_local_level_cache = use_local_level_cache;
320323
settings_.use_global_level_cache = use_global_level_cache;
321324
settings_.level_cache_params = level_cache_params;
@@ -404,7 +407,7 @@ lua::NResultsOr LuaTextLevelMaker::MapFromTextLevel(lua_State* L) {
404407

405408

406409
lua::NResultsOr LuaTextLevelMaker::ViewRandomness(lua_State* L) {
407-
LuaRandom::CreateObject(L, &prng_);
410+
LuaRandom::CreateObject(L, &prng_, mixer_seed_);
408411
return 1;
409412
}
410413

0 commit comments

Comments
 (0)