1212// See the License for the specific language governing permissions and
1313// limitations under the License.
1414#include " paddle/fluid/framework/new_executor/interpretercore.h"
15+ #include " paddle/fluid/framework/executor_gc_helper.h"
16+ #include " paddle/fluid/framework/new_executor/interpretercore_gc_helper.h"
17+
18+ #if defined(PADDLE_WITH_CUDA)
19+ using ::paddle::platform::kCUDA ;
20+ USE_EVENT (kCUDA );
21+ #endif
1522
1623#include < unordered_set>
1724
@@ -145,6 +152,12 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
145152 d2h_ctx_pool_({place}),
146153 h2d_ctx_pool_({place}) {
147154 is_build_ = false ;
155+
156+ garbages_.reset (new GarbageQueue ());
157+ max_memory_size_ = static_cast <size_t >(GetEagerDeletionThreshold ());
158+ cur_memory_size_ = 0 ;
159+ gc_queue_ = CreateSingleThreadedWorkQueue ();
160+
148161 feed_names_ = feed_names;
149162
150163 // Step1: add feedop and fetchop to main_program
@@ -215,11 +228,24 @@ void InterpreterCore::Convert() {
215228 temp_inst.input_index_ = vec_func_list_[i].input_index ;
216229 temp_inst.output_index_ = vec_func_list_[i].output_index ;
217230
231+ OpInOutInfo info;
232+
218233 std::vector<size_t > gc_check_input_list;
219234 for (auto & item : vec_func_list_[i].input_index ) {
220235 for (auto id : item.second ) {
221236 input_var2op_info_[id].push_back (i);
222- gc_check_input_list.push_back (id);
237+ // var can be gc-ed
238+ if (!info.IsBuilt ()) {
239+ info.Build (op_list_[i]);
240+ }
241+ if (global_scope_->vec_meta_info_ [id].vardesc_ ) {
242+ if (info.IsInArgBufferNeeded (
243+ global_scope_->vec_meta_info_ [id].vardesc_ ->Name ())) {
244+ gc_check_input_list.push_back (id);
245+ }
246+ } else {
247+ gc_check_input_list.push_back (id);
248+ }
223249 }
224250 }
225251 std::sort (gc_check_input_list.begin (), gc_check_input_list.end ());
@@ -236,6 +262,13 @@ void InterpreterCore::Convert() {
236262 }
237263
238264 for (size_t i = 0 ; i < vec_instruction_.size (); ++i) {
265+ #if defined(PADDLE_WITH_CUDA)
266+ int device_type = static_cast <int >(paddle::platform::DeviceType::CUDA);
267+ paddle::platform::DeviceOption dev_opt (
268+ device_type, BOOST_GET_CONST (platform::CUDAPlace, place_).device );
269+ gc_event_.emplace_back (dev_opt);
270+ #endif
271+
239272 std::vector<size_t > vec_temp;
240273 for (auto & item : vec_instruction_[i].output_index_ ) {
241274 for (auto id : item.second ) {
@@ -365,11 +398,8 @@ void InterpreterCore::ExecuteInstructionList(
365398 }
366399
367400 // GC infomation
368-
369- auto & gc_check_list = instr_node.gc_check_var_list ;
370- for (auto var_id : gc_check_list) {
371- --working_var_ref[var_id].var_ref_count_ ;
372- }
401+ CheckGC (instr_id, instr_node.gc_check_var_list , var_scope, place,
402+ working_var_ref);
373403 }
374404
375405 for (size_t i = 0 ; i < working_var_ref.size (); ++i) {
@@ -379,6 +409,87 @@ void InterpreterCore::ExecuteInstructionList(
379409 }
380410}
381411
412+ void InterpreterCore::CheckGC (size_t instr_id,
413+ const std::vector<size_t >& gc_check_list,
414+ const VariableScope& var_scope,
415+ const platform::Place& place,
416+ std::vector<VariableMetaInfo>& working_var_ref) {
417+ for (auto var_id : gc_check_list) {
418+ --working_var_ref[var_id].var_ref_count_ ;
419+ if (var_scope.vec_meta_info_ [var_id].vardesc_ &&
420+ !var_scope.vec_meta_info_ [var_id].vardesc_ ->Persistable () &&
421+ working_var_ref[var_id].var_ref_count_ == 0 ) {
422+ Variable* var = var_scope.var_list [var_id];
423+ if (var->IsType <LoDTensor>()) {
424+ garbages_->emplace_back (
425+ var->GetMutable <LoDTensor>()->MoveMemoryHolder ());
426+ if (garbages_->back ()) {
427+ cur_memory_size_ += garbages_->back ()->size ();
428+ }
429+ } else if (var->IsType <SelectedRows>()) {
430+ garbages_->emplace_back (var->GetMutable <SelectedRows>()
431+ ->mutable_value ()
432+ ->MoveMemoryHolder ());
433+ if (garbages_->back ()) {
434+ cur_memory_size_ += garbages_->back ()->size ();
435+ }
436+ } else if (var->IsType <LoDTensorArray>()) {
437+ auto * tensor_arr = var->GetMutable <LoDTensorArray>();
438+ for (auto & t : *tensor_arr) {
439+ garbages_->emplace_back (t.MoveMemoryHolder ());
440+ if (garbages_->back ()) {
441+ cur_memory_size_ += garbages_->back ()->size ();
442+ }
443+ }
444+ } else {
445+ PADDLE_THROW (platform::errors::Unimplemented (
446+ " The variable(%s) is not supported in eager deletion." ,
447+ framework::ToTypeName (var->Type ())));
448+ }
449+ }
450+ }
451+
452+ if (!garbages_->empty ()) {
453+ if (max_memory_size_ <= 1 ) {
454+ #if defined(PADDLE_WITH_CUDA)
455+ auto * dev_ctx = reinterpret_cast <platform::CUDADeviceContext*>(
456+ platform::DeviceContextPool::Instance ().Get (place));
457+ gc_event_[instr_id].Record (place, dev_ctx);
458+ gc_queue_->AddTask (
459+ [ container = garbages_.release (), event = &gc_event_[instr_id] ]() {
460+ while (!event->Query ()) {
461+ continue ;
462+ }
463+ delete container;
464+ });
465+ garbages_.reset (new GarbageQueue ());
466+ #else
467+ delete garbages_.release ();
468+ garbages_.reset (new GarbageQueue ());
469+ #endif
470+ } else if (cur_memory_size_ >= max_memory_size_) {
471+ #if defined(PADDLE_WITH_CUDA)
472+ auto * dev_ctx = reinterpret_cast <platform::CUDADeviceContext*>(
473+ platform::DeviceContextPool::Instance ().Get (place));
474+ gc_event_[instr_id].Record (place, dev_ctx);
475+ gc_queue_->AddTask (
476+ [ container = garbages_.release (), event = &gc_event_[instr_id] ]() {
477+ while (!event->Query ()) {
478+ continue ;
479+ }
480+ delete container;
481+ });
482+ garbages_.reset (new GarbageQueue ());
483+ cur_memory_size_ = 0 ;
484+ #else
485+ delete garbages_.release ();
486+ garbages_.reset (new GarbageQueue ());
487+ cur_memory_size_ = 0 ;
488+ #endif
489+ }
490+ }
491+ }
492+
382493std::vector<size_t > InterpreterCore::MergeVector (
383494 const std::vector<size_t >& first, const std::vector<size_t >& second) {
384495 std::vector<size_t > out (first.size () + second.size ());
@@ -407,6 +518,11 @@ void InterpreterCore::BuildVariableScope(const framework::ProgramDesc& pdesc,
407518 auto v = new Variable ();
408519 InitializeVariable (v, var->GetType ());
409520 var_scope->var_list .push_back (v);
521+
522+ VariableMetaInfo info;
523+ info.var_ref_count_ = 0 ;
524+ info.vardesc_ = var;
525+ var_scope->vec_meta_info_ .push_back (info);
410526 }
411527 }
412528}
@@ -419,6 +535,7 @@ void InterpreterCore::BuildOpFuncList(const platform::Place& place,
419535 auto & global_block = pdesc.Block (0 );
420536 auto & all_op_kernels = OperatorWithKernel::AllOpKernels ();
421537
538+ std::vector<OperatorBase*> ops;
422539 for (auto & op : global_block.AllOps ()) {
423540 VLOG (3 ) << " Build OpFuncNode from : " << op->Type ();
424541
@@ -434,6 +551,20 @@ void InterpreterCore::BuildOpFuncList(const platform::Place& place,
434551 // step 1. Prepare VariableValueMap of input/output
435552 auto op_base =
436553 info.Creator ()(op->Type (), inputs_names, outputs_names, op_attr_map);
554+ ops.push_back (op_base);
555+ }
556+
557+ auto unused_var_map = get_unused_vars (global_block, ops);
558+
559+ size_t ops_index = 0 ;
560+ for (auto & op : global_block.AllOps ()) {
561+ VLOG (3 ) << op->Type ();
562+ // << op->Type() << endl;
563+
564+ auto op_base = ops[ops_index++];
565+
566+ auto inputs_names = op->Inputs ();
567+ auto outputs_names = op->Outputs ();
437568
438569 VariableValueMap ins_map;
439570 std::map<std::string, std::vector<int >> ins_name2id;
@@ -542,6 +673,11 @@ void InterpreterCore::BuildOpFuncList(const platform::Place& place,
542673 var_scope->name2id [new_var_name] = var_scope->var_list .size ();
543674 var_scope->var_list .push_back (v);
544675
676+ VariableMetaInfo info;
677+ info.var_ref_count_ = 0 ;
678+ info.vardesc_ = nullptr ;
679+ var_scope->vec_meta_info_ .push_back (info);
680+
545681 VariableNameMap copy_in_map;
546682 auto x_iter = inputs_names.find (var_name_item.first );
547683 copy_in_map[" X" ] = {x_iter->second [i]};
@@ -647,6 +783,47 @@ void InterpreterCore::BuildOpFuncList(const platform::Place& place,
647783 op_func_node.kernel_func_ = OpKernelComputeFunc (kernel_iter->second );
648784 op_func_node.kernel_func_ (exec_ctx);
649785 vec_func_list->push_back (op_func_node);
786+
787+ // gc---------------------------------------------------------------------------
788+ auto iter = unused_var_map.find (op_base);
789+ if (iter == unused_var_map.end ()) {
790+ continue ;
791+ }
792+
793+ auto & delete_vars = iter->second ;
794+ std::deque<std::shared_ptr<memory::Allocation>>* garbages =
795+ new std::deque<std::shared_ptr<memory::Allocation>>();
796+
797+ for (auto & var_name : delete_vars) {
798+ auto it = var_scope->name2id .find (var_name);
799+ assert (it != var_scope->name2id .end ());
800+ auto * var = var_scope->var_list [it->second ];
801+ if (var == nullptr ) {
802+ continue ;
803+ }
804+
805+ VLOG (2 ) << " Erase variable " << var_name;
806+ if (var->IsType <LoDTensor>()) {
807+ garbages->emplace_back (
808+ var->GetMutable <LoDTensor>()->MoveMemoryHolder ());
809+ } else if (var->IsType <SelectedRows>()) {
810+ garbages->emplace_back (var->GetMutable <SelectedRows>()
811+ ->mutable_value ()
812+ ->MoveMemoryHolder ());
813+ } else if (var->IsType <LoDTensorArray>()) {
814+ auto * lod_tensor_arr = var->GetMutable <LoDTensorArray>();
815+ for (auto & t : *lod_tensor_arr) {
816+ garbages->emplace_back (t.MoveMemoryHolder ());
817+ }
818+ } else {
819+ PADDLE_THROW (platform::errors::Unimplemented (
820+ " Type %s of variable %s is not supported eager deletion." ,
821+ framework::ToTypeName (var->Type ()), var_name));
822+ }
823+ }
824+
825+ delete garbages; // free mem
826+
650827 VLOG (3 ) << " run " << op_base->Type () << " done." ;
651828 }
652829}
0 commit comments