Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ name = "random_ant"
[[example]]
name = "sac_ant"

[[example]]
name = "sac_ant_gpu"

[[example]]
name = "dqn_atari_vec"
# test = true
Expand Down
71 changes: 38 additions & 33 deletions border-tch-agent/src/dqn/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,9 @@ impl DQNBuilder {

/// Constructs DQN agent.
///
/// This is used with non-vectorized environments.
/// This method is used with non-vectorized environments.
///
/// * `device` - The device where the model is put.
pub fn build<E, Q, O, A>(self, qnet: DQNModel<Q>, device: tch::Device) -> DQN<E, Q, O, A>
where
E: Env,
Expand All @@ -149,7 +151,7 @@ impl DQNBuilder {
A: TchBuffer<Item = E::Act, SubBatch = Tensor>, // Todo: consider replacing Tensor with M::Output
{
let qnet_tgt = qnet.clone();
let replay_buffer = ReplayBuffer::new(self.replay_burffer_capacity);
let replay_buffer = ReplayBuffer::new(self.replay_burffer_capacity, device);

DQN {
qnet,
Expand Down Expand Up @@ -209,34 +211,37 @@ impl DQNBuilder {
}
}

#[cfg(test)]
mod test {
use super::*;
use tempdir::TempDir;
use crate::{dqn::{EpsilonGreedy, DQNBuilder}, util::OptInterval};

#[test]
fn test_serde_dqn_builder() -> Result<()> {
let builder = DQNBuilder::default()
.opt_interval(OptInterval::Steps(50))
.n_updates_per_opt(1)
.min_transitions_warmup(100)
.batch_size(32)
.discount_factor(0.99)
.tau(0.005)
.explorer(EpsilonGreedy::with_final_step(1000));

let dir = TempDir::new("dqn_builder")?;
let path = dir.path().join("dqn_builder.yaml");
println!("{:?}", path);

builder.save(&path)?;
let builder_ = DQNBuilder::load(&path)?;
assert_eq!(builder, builder_);

let yaml = serde_yaml::to_string(&builder)?;
println!("{}", yaml);

Ok(())
}
}
// Commented out as tempdir crate causes not resolved error, not sure why
// (tempdir is added in Cargo.toml)
//
// #[cfg(test)]
// mod test {
// use super::*;
// use tempdir::TempDir;
// use crate::{dqn::{EpsilonGreedy, DQNBuilder}, util::OptInterval};

// #[test]
// fn test_serde_dqn_builder() -> Result<()> {
// let builder = DQNBuilder::default()
// .opt_interval(OptInterval::Steps(50))
// .n_updates_per_opt(1)
// .min_transitions_warmup(100)
// .batch_size(32)
// .discount_factor(0.99)
// .tau(0.005)
// .explorer(EpsilonGreedy::with_final_step(1000));

// let dir = TempDir::new("dqn_builder")?;
// let path = dir.path().join("dqn_builder.yaml");
// println!("{:?}", path);

// builder.save(&path)?;
// let builder_ = DQNBuilder::load(&path)?;
// assert_eq!(builder, builder_);

// let yaml = serde_yaml::to_string(&builder)?;
// println!("{}", yaml);

// Ok(())
// }
// }
2 changes: 1 addition & 1 deletion border-tch-agent/src/iqn/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ impl IQNBuilder {
{
let iqn = iqn_model;
let iqn_tgt = iqn.clone();
let replay_buffer = ReplayBuffer::new(self.replay_buffer_capacity);
let replay_buffer = ReplayBuffer::new(self.replay_buffer_capacity, device);

IQN {
iqn,
Expand Down
42 changes: 33 additions & 9 deletions border-tch-agent/src/replay_buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
use std::marker::PhantomData;
use tch::{Device, Tensor};
mod base;
pub use base::{ReplayBuffer, TchBatch, TchBuffer};
pub use base::{ReplayBuffer, TchBatch, TchBuffer, TchBufferOnDevice};
use border_core::Shape;

/// Adds capability of constructing [Tensor] with a static method.
Expand Down Expand Up @@ -48,6 +48,7 @@ where
{
buf: Tensor,
capacity: i64,
model_device: tch::Device,
phantom: PhantomData<(D, S, T)>,
}

Expand All @@ -60,23 +61,25 @@ where
type Item = T;
type SubBatch = Tensor;

/// Creates a buffer.
///
/// Input argument `_n_proc` is not used.
/// TODO: remove n_procs
fn new(capacity: usize) -> Self {
/// Creates a buffer for observation or action.
fn new(capacity: usize, model_device: Device) -> Self {
let capacity = capacity as i64;
let mut shape: Vec<_> = S::shape().to_vec().iter().map(|e| *e as i64).collect();
shape.insert(0, capacity);
let buf = D::zeros(shape.as_slice());
let buf = D::zeros(shape.as_slice()).to(tch::Device::Cpu);

Self {
buf,
capacity,
model_device,
phantom: PhantomData,
}
}

/// Push data (`Into<Tensor>`) to the buffer.
///
/// The first dimension of the tensor is the number of samples,
/// which can be two or more in vectorized environments.
fn push(&mut self, index: i64, item: &Self::Item) {
let val: Tensor = item.clone().into();
let batch_size = val.size()[0];
Expand All @@ -85,12 +88,33 @@ where
for i_ in 0..batch_size {
let i = (i_ + index) % self.capacity;
self.buf.get(i).copy_(&val.get(i_));

}
}

/// Creates minibatch.
fn batch(&self, batch_indexes: &Tensor) -> Tensor {
self.buf.index_select(0, &batch_indexes)
let batch_indexes = batch_indexes.to(self.buf.device());
self.buf.index_select(0, &batch_indexes).to(self.model_device)
}
}

impl<D, S, T> TchBufferOnDevice for TchTensorBuffer<D, S, T>
where
D: 'static + Copy + tch::kind::Element + ZeroTensor,
S: Shape,
T: Clone + Into<Tensor>,
{
fn new_on_device(capacity: usize, device: tch::Device, model_device: tch::Device) -> Self {
let capacity = capacity as i64;
let mut shape: Vec<_> = S::shape().to_vec().iter().map(|e| *e as i64).collect();
shape.insert(0, capacity);
let buf = D::zeros(shape.as_slice()).to(device);

Self {
buf,
capacity,
model_device,
phantom: PhantomData,
}
}
}
78 changes: 56 additions & 22 deletions border-tch-agent/src/replay_buffer/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@
use border_core::Env;
use log::{info, trace};
use std::marker::PhantomData;
use tch::{
kind::{FLOAT_CPU, INT64_CPU},
Tensor,
};
use tch::Tensor;

// /// Return binary tensor, one where reward is not zero.
// ///
Expand All @@ -26,7 +23,7 @@ pub trait TchBuffer {
type SubBatch;

/// Constructs a [TchBuffer].
fn new(capacity: usize) -> Self;
fn new(capacity: usize, model_device: tch::Device) -> Self;

/// Push a sample of an item (observations or actions).
/// Note that each item may consists of values from multiple environments.
Expand All @@ -36,6 +33,12 @@ pub trait TchBuffer {
fn batch(&self, batch_indexes: &Tensor) -> Self::SubBatch;
}

/// Generic buffer on a device.
pub trait TchBufferOnDevice: TchBuffer {
/// Constructs a [TchBuffer] on a device.
fn new_on_device(capacity: usize, device: tch::Device, model_device: tch::Device) -> Self;
}

/// Batch object, generic wrt observation and action.
pub struct TchBatch<E: Env, O, A>
where
Expand Down Expand Up @@ -85,16 +88,18 @@ where
A: TchBuffer<Item = E::Act>,
{
/// Constructs a replay buffer.
pub fn new(capacity: usize) -> Self {
pub fn new(capacity: usize, model_device: tch::Device) -> Self {
info!("Construct replay buffer with capacity = {}", capacity);
let capacity = capacity;

let float_type = (tch::Kind::Float, tch::Device::Cpu);

Self {
obs: O::new(capacity),
next_obs: O::new(capacity),
actions: A::new(capacity),
rewards: Tensor::zeros(&[capacity as _], FLOAT_CPU),
not_dones: Tensor::zeros(&[capacity as _], FLOAT_CPU),
obs: O::new(capacity, model_device),
next_obs: O::new(capacity, model_device),
actions: A::new(capacity, model_device),
rewards: Tensor::zeros(&[capacity as _], float_type),
not_dones: Tensor::zeros(&[capacity as _], float_type),
returns: None,
capacity,
len: 0,
Expand All @@ -104,14 +109,6 @@ where
}
}

// /// If set to `True`, non-zero reward is considered as the end of episodes.
// #[deprecated]
// pub fn nonzero_reward_as_done(mut self, _v: bool) -> Self {
// unimplemented!();
// // self.nonzero_reward_as_done = v;
// // self
// }

/// Clears the buffer.
pub fn clear(&mut self) {
self.len = 0;
Expand Down Expand Up @@ -142,6 +139,7 @@ where
for j in 0..batch_size {
self.rewards.get(self.i as _).copy_(&reward.get(j));

// TODO: Consider removing this block
if !self.nonzero_reward_as_done {
self.not_dones.get(self.i as _).copy_(&not_done.get(j));
} else {
Expand All @@ -162,12 +160,17 @@ where
/// Constructs random samples.
pub fn random_batch(&self, batch_size: usize) -> Option<TchBatch<E, O, A>> {
let batch_size = batch_size.min(self.len - 1);
let batch_indexes = Tensor::randint((self.len - 2) as _, &[batch_size as _], INT64_CPU);
let _no_grad = tch::no_grad_guard();
let batch_indexes = Tensor::randint(
(self.len - 2) as _,
&[batch_size as _],
(tch::Kind::Int64, self.rewards.device()),
);
let obs = self.obs.batch(&batch_indexes);
let next_obs = self.next_obs.batch(&batch_indexes);
let actions = self.actions.batch(&batch_indexes);
let rewards = self.rewards.index_select(0, &batch_indexes).unsqueeze(-1); //.flatten(0, 1);
let not_dones = self.not_dones.index_select(0, &batch_indexes).unsqueeze(-1); //.flatten(0, 1);
let rewards = self.rewards.index_select(0, &batch_indexes).unsqueeze(-1);
let not_dones = self.not_dones.index_select(0, &batch_indexes).unsqueeze(-1);
let returns = match self.returns.as_ref() {
Some(r) => Some(r.index_select(0, &batch_indexes).flatten(0, 1)),
None => None,
Expand Down Expand Up @@ -235,3 +238,34 @@ where
self.len
}
}

impl<E, O, A> ReplayBuffer<E, O, A>
where
E: Env,
O: TchBufferOnDevice<Item = E::Obs>,
A: TchBufferOnDevice<Item = E::Act>,
{
/// Constructs a replay buffer on a `device`.
///
/// When generating a minibatch, the data is transferred to `model_device`.
pub fn new_on_device(capacity: usize, device: tch::Device, model_device: tch::Device) -> Self {
info!("Construct replay buffer with capacity = {}", capacity);
let capacity = capacity;

let float_type = (tch::Kind::Float, device);

Self {
obs: O::new_on_device(capacity, device, model_device),
next_obs: O::new_on_device(capacity, device, model_device),
actions: A::new_on_device(capacity, device, model_device),
rewards: Tensor::zeros(&[capacity as _], float_type).to(device),
not_dones: Tensor::zeros(&[capacity as _], float_type).to(device),
returns: None,
capacity,
len: 0,
i: 0,
nonzero_reward_as_done: false,
phandom: PhantomData,
}
}
}
4 changes: 2 additions & 2 deletions border-tch-agent/src/sac/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ where

let losses = {
let o = &batch.obs;
let a = &batch.actions.to(self.device);
let a = &batch.actions;
let next_o = &batch.next_obs;
let r = &batch.rewards.to(self.device).squeeze();
let not_done = &batch.not_dones.to(self.device).squeeze();
Expand Down Expand Up @@ -224,7 +224,7 @@ where
P: SubModel<Input = O::SubBatch, Output = (ActMean, ActStd)>,
E::Obs: Into<O::SubBatch>,
E::Act: From<Tensor>,
O: TchBuffer<Item = E::Obs>,
O: TchBuffer<Item = E::Obs, SubBatch = Tensor>,
A: TchBuffer<Item = E::Act, SubBatch = Tensor>,
{
fn train(&mut self) {
Expand Down
Loading