@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313limitations under the License. */
1414
1515#include " fast_tokenizer/core/base.h"
16+
1617#include < thread>
1718
1819namespace paddlenlp {
@@ -28,16 +29,22 @@ int GetThreadNum() { return fast_tokenizer_thread_num; }
2829void RunMultiThread (std::function<void (size_t , size_t )> func,
2930 size_t batch_size) {
3031 int thread_num = GetThreadNum ();
31- std::vector<std::thread> vectorOfThread;
32- size_t start_index = 0 ;
33- size_t step_index = ceil (batch_size / float (thread_num));
34-
35- for (size_t thread_index = 0 ; thread_index < thread_num; thread_index++) {
36- vectorOfThread.emplace_back (std::thread (func, start_index, step_index));
37- start_index = start_index + step_index;
38- }
39- for (size_t thread_index = 0 ; thread_index < thread_num; thread_index++) {
40- vectorOfThread[thread_index].join ();
32+ if (thread_num == 1 ) {
33+ // Note(zhoushunjie): No need to create threads when
34+ // thread_num equals to 1.
35+ func (0 , batch_size);
36+ } else {
37+ std::vector<std::thread> vectorOfThread;
38+ size_t start_index = 0 ;
39+ size_t step_index = ceil (batch_size / float (thread_num));
40+
41+ for (size_t thread_index = 0 ; thread_index < thread_num; thread_index++) {
42+ vectorOfThread.emplace_back (std::thread (func, start_index, step_index));
43+ start_index = start_index + step_index;
44+ }
45+ for (size_t thread_index = 0 ; thread_index < thread_num; thread_index++) {
46+ vectorOfThread[thread_index].join ();
47+ }
4148 }
4249}
4350
0 commit comments