Skip to content

Determine the life time of a variable gradient #11416

@tonyyang-svail

Description

@tonyyang-svail

A variable should be able to access its gradient. Ideally, the access should be a smart pointer.

Question: should a variable hold

  1. a shared_ptr to its gradient?
  2. a weak_ptr to its gradient?

Case 0: forward

while (true) {
  reset_global_tape();
  loss = model.Forward(data);
}

Case 1: forward, backward

while (true) {
  reset_global_tape();
  loss = model.Forward(data);
  loss.Backward();
}

Case 2: forward, backward, optimize

while (true) {
  reset_global_tape();
  loss = model.Forward(data);
  loss.Backward();
  sgd(model.Params());
}

Case 3: Release Memory on Backward

We can release memory aggressively. During backward, we can delete the OpHandle once we have finished its backward. Since all the variable is managed by smart pointers, the memory is automatically released when its ref_count goes to 0.

In the current implementation, shared_ptr can't solve case 1, while weak_prt can't solve case 3.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions