@@ -88,6 +88,7 @@ struct ReduceTreePattern {
8888 const FusionTrackerPtr& tracker)
8989 : childs_(childs), root_(root), tracker_(tracker) {
9090 id_ = UniqueId ();
91+ cur_id_ = id_;
9192 }
9293 const ReducePattern& GetRootPattern () const { return root_; }
9394 std::vector<pir::Operation*> ops () const {
@@ -118,44 +119,48 @@ struct ReduceTreePattern {
118119 std::string id () const { return id_; }
119120 std::string id_;
120121
122+ mutable std::string cur_id_;
123+ std::string cur_id () const { return cur_id_; }
124+ std::string new_tmp_id () const {
125+ if (cur_id_ == id_) {
126+ cur_id_ = id_ + " _tmp_0" ;
127+ } else {
128+ int ith = std::stoi (cur_id_.substr (cur_id_.size () - 1 ));
129+ cur_id_ = id_ + " _tmp_" + std::to_string (ith + 1 );
130+ }
131+ return cur_id_;
132+ }
133+
121134 FusionTrackerPtr tracker_;
122135
123136 void update_tracker () const {
124- int counter = 0 ;
125- std::function<std::string ()> gen_name = [&counter]() {
126- return " tmp_" + std::to_string (counter++);
127- };
128- const std::string& root_name = id ();
137+ const std::string& root_name = GetRootPattern ().id ();
129138 std::vector<std::string> names;
130- UpdateTrackerImpl (root_name,
131- *this ,
132- std::vector<size_t >(),
133- gen_name,
134- this ->tracker_ ,
135- &names);
136- tracker_->append (std::make_shared<CombineInstr>(names, root_name));
139+ UpdateTrackerImpl (
140+ root_name, *this , std::vector<size_t >(), this ->tracker_ , &names);
141+ tracker_->append (std::make_shared<CombineInstr>(names, cur_id ()));
137142 }
138143
139144 void UpdateTrackerImpl (const std::string root_name,
140145 const ReduceTreePattern& root,
141146 const std::vector<size_t >& fake_reduce_iter_idx,
142- const std::function<std::string()>& unique_tmp_name_fn,
143147 FusionTrackerPtr tracker,
144148 std::vector<std::string>* names) const {
145- // Apply a bunch of tracker to get a output_name of ReduceTreePattern.
149+ // Apply a brunch of tracker to get a output_name of ReduceTreePattern.
146150 // names and trackers collect all the needed fusion nodes.
147- for (const auto & child : childs_) {
148- const std::string& tmp_name = unique_tmp_name_fn ();
149- tracker->append (std::make_shared<TmpTransformInstr>(
150- tmp_name, root_name, tmp_name, root_name, fake_reduce_iter_idx));
151- UpdateTrackerImpl (tmp_name,
152- child,
153- fake_reduce_iter_idx,
154- unique_tmp_name_fn,
155- tracker,
156- names);
151+ for (const auto & child : root.childs ()) {
152+ auto origin_child_id = child.cur_id ();
153+ auto new_child_id = child.new_tmp_id ();
154+ tracker->append (
155+ std::make_shared<TmpTransformInstr>(origin_child_id,
156+ root_name,
157+ new_child_id,
158+ root.cur_id (),
159+ fake_reduce_iter_idx));
160+ UpdateTrackerImpl (
161+ new_child_id, child, fake_reduce_iter_idx, tracker, names);
157162 }
158- names->push_back (root_name );
163+ names->push_back (root. cur_id () );
159164 }
160165
161166 private:
@@ -190,29 +195,21 @@ struct ReduceTreePlusTrivialPattern {
190195 FusionTrackerPtr tracker_;
191196
192197 void update_tracker () const {
193- int counter = 0 ;
194- std::function<std::string ()> gen_name = [&counter]() {
195- return " tmp_" + std::to_string (counter++);
196- };
197198 const std::string& root_name = id ();
198- const std::string& tmp_name_for_tree = gen_name ();
199+ const std::string& origin_tree_id = tree.cur_id ();
200+ const std::string& new_tree_id = tree.new_tmp_id ();
199201 std::vector<std::string> names;
200- tracker_->append (
201- std::make_shared<TmpTransformInstr>(tree.GetRootPattern ().id (),
202- sink_trivial.id (),
203- tmp_name_for_tree,
204- root_name,
205- fake_reduce_iter_idx));
206- tree.UpdateTrackerImpl (tmp_name_for_tree,
207- tree,
208- fake_reduce_iter_idx,
209- gen_name,
210- this ->tracker_ ,
211- &names);
202+ tracker_->append (std::make_shared<TmpTransformInstr>(origin_tree_id,
203+ sink_trivial.id (),
204+ new_tree_id,
205+ root_name,
206+ fake_reduce_iter_idx));
207+ tree.UpdateTrackerImpl (
208+ new_tree_id, tree, fake_reduce_iter_idx, this ->tracker_ , &names);
212209 names.push_back (root_name);
213210 // optimize the loop range of R + T for speed up.
214211 tracker_->append (std::make_shared<TrivialLoopAlignInstr>(
215- tmp_name_for_tree , root_name, root_name, fake_reduce_iter_idx));
212+ new_tree_id , root_name, root_name, fake_reduce_iter_idx));
216213 // collect all the Expr and represent the root_name.
217214 tracker_->append (std::make_shared<CombineInstr>(names, root_name));
218215 }
0 commit comments