Skip to content

Commit cf87d13

Browse files
Swap to using median value with outliers removed for deciding on best solution to run (#3515)
* Swap to using modified Z when deciding on best solution to run * Update comments * Code review comments swap to using mean, and removing only positive-z scores * Code review comments and warnings clean-up * Code review comments change names * Code review comments add static_assert * Update copyright date
1 parent 1bc4707 commit cf87d13

File tree

4 files changed

+377
-25
lines changed

4 files changed

+377
-25
lines changed

src/conv/solver_finders.cpp

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include <miopen/perf_field.hpp>
3333
#include <miopen/conv/problem_description.hpp>
3434
#include <miopen/solution.hpp>
35+
#include <miopen/utility/modified_z.hpp>
3536

3637
MIOPEN_DECLARE_ENV_VAR_BOOL(MIOPEN_DEBUG_CONV_GEMM)
3738
MIOPEN_DECLARE_ENV_VAR_BOOL(MIOPEN_DEBUG_CONV_DIRECT)
@@ -230,6 +231,7 @@ static std::vector<Solution> EvaluateInvokers(const Handle& handle,
230231
auto best = std::numeric_limits<float>::max();
231232
auto best_invoker = Invoker{};
232233
auto ret = std::vector<Solution>{};
234+
std::vector<float> samples;
233235

234236
for(const auto& sol : solutions)
235237
{
@@ -261,31 +263,40 @@ static std::vector<Solution> EvaluateInvokers(const Handle& handle,
261263

262264
try
263265
{
264-
// Run invoker max 6 times, with ~5 sec time limit.
266+
// Run invoker max 8 times, with ~5 sec time limit.
265267
using elapsed_t = decltype(handle.GetKernelTime());
266268
constexpr elapsed_t TIME_MS_MAX = 5000.0;
267269
constexpr int N_RUNS_MAX = 8;
268-
constexpr int N_RUNS_DISCARD = 3;
269270
auto elapsed = static_cast<elapsed_t>(0);
270271
auto first_elapsed = static_cast<elapsed_t>(0);
271272
int i = 0;
273+
samples.clear();
272274
while(i < N_RUNS_MAX && elapsed < TIME_MS_MAX)
273275
{
274276
invoker(handle, invoke_ctx);
275-
elapsed += handle.GetKernelTime();
276-
if(i == (N_RUNS_DISCARD - 1))
277-
first_elapsed = elapsed;
277+
278+
// don't include warm-up run in our samples.
279+
if(i > 0)
280+
{
281+
samples.push_back(handle.GetKernelTime());
282+
}
283+
else
284+
{
285+
// Keep first run just in case we go over the limit, and have no samples.
286+
first_elapsed = handle.GetKernelTime();
287+
}
278288
++i;
279289
}
280-
// If the execution time was not too long,
281-
// then the 1st run is not counted (assume it's warm-up):
282-
if(i > N_RUNS_DISCARD)
290+
291+
if(samples.size() > 0)
283292
{
284-
elapsed = (elapsed - first_elapsed) / static_cast<elapsed_t>(i - N_RUNS_DISCARD);
293+
// Remove outliers that are more than 2 positive modified z-score's away, and get
294+
// the mean.
295+
elapsed = miopen::removeHighOutliersAndGetMean(samples, 2.0f);
285296
}
286-
else if(i > 0)
297+
else
287298
{
288-
elapsed /= i;
299+
elapsed = first_elapsed;
289300
}
290301

291302
MIOPEN_THROW_IF(elapsed <= 0, "Invalid elapsed time detected in EvaluateInvokers");

src/include/miopen/generic_search.hpp

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
#include <miopen/timer.hpp>
3939
#include <miopen/mt_queue.hpp>
4040
#include <miopen/generic_search_controls.hpp>
41+
#include <miopen/utility/modified_z.hpp>
4142

4243
#include <algorithm>
4344
#include <vector>
@@ -400,10 +401,11 @@ auto GenericSearch(const Solver s,
400401
}
401402
}
402403

403-
bool is_passed = false; // left false only if all iterations failed.
404-
float best_time = std::numeric_limits<float>::max();
405-
size_t n_failed = 0;
406-
size_t n_best = 0;
404+
bool is_passed = false; // left false only if all iterations failed.
405+
float best_time = std::numeric_limits<float>::max();
406+
float worst_time = std::numeric_limits<float>::max();
407+
size_t n_failed = 0;
408+
size_t n_best = 0;
407409
HeartBeat<PerformanceConfig> heartbeat;
408410
heartbeat.Start();
409411

@@ -429,6 +431,7 @@ auto GenericSearch(const Solver s,
429431
size_t n_current = 0;
430432
size_t last_imprv = 0;
431433
auto threads_remaining = total_threads;
434+
std::vector<float> samples;
432435
while(true)
433436
{
434437
if(n_current >= n_runs_total)
@@ -461,6 +464,7 @@ auto GenericSearch(const Solver s,
461464
}
462465
}
463466

467+
samples.clear();
464468
float elapsed_time = 0.0f;
465469
int ret = 0;
466470
MIOPEN_LOG_I2('#' << n_current << '/' << n_failed << '/' << n_runs_total << ' '
@@ -481,8 +485,18 @@ auto GenericSearch(const Solver s,
481485

482486
invoker = profile_h.PrepareInvoker(*current_solution.invoker_factory,
483487
current_solution.construction_params);
488+
489+
// Warm-up run for first time invoker is used
490+
if(n_current == 0)
491+
{
492+
invoker(profile_h, invoke_ctx);
493+
profile_h.ResetKernelTime();
494+
}
495+
484496
invoker(profile_h, invoke_ctx);
485497
elapsed_time = profile_h.GetKernelTime();
498+
samples.push_back(elapsed_time);
499+
profile_h.ResetKernelTime();
486500
}
487501
catch(const std::exception& e)
488502
{
@@ -503,11 +517,11 @@ auto GenericSearch(const Solver s,
503517
if(ret == 0)
504518
{
505519
// Smooth the jitter of measurements:
506-
// If the 1st probe is NOT too bad (measured time <= 1.10 * best known time),
507-
// then re-run it 9 times more and compute average time,
508-
// and decide using average of all 10 attempts vs. the best.
520+
// If the 1st probe is NOT too bad (measured time <= 1.10 * worst sample of the best
521+
// config), then gather 9 more samples, and remove positive z-score outliers. Use
522+
// the mean value with outliers removed for calculating best config.
509523
constexpr int N_RUNS = 10;
510-
if(elapsed_time / best_time < 1.10f)
524+
if(elapsed_time / worst_time < 1.10f)
511525
{
512526
MIOPEN_LOG_I2("Finding average for: " << elapsed_time << " / " << best_time
513527
<< " = " << (elapsed_time / best_time));
@@ -517,7 +531,8 @@ auto GenericSearch(const Solver s,
517531
for(int i = 1; i < N_RUNS; ++i)
518532
{
519533
invoker(profile_h, invoke_ctx);
520-
elapsed_time += profile_h.GetKernelTime();
534+
samples.push_back(profile_h.GetKernelTime());
535+
profile_h.ResetKernelTime();
521536
}
522537
}
523538
catch(...)
@@ -528,21 +543,28 @@ auto GenericSearch(const Solver s,
528543
if(ret == 0)
529544
{
530545
is_passed = true;
531-
elapsed_time /= N_RUNS;
546+
547+
// Remove outliers that are more than 2 positive modified z-score's away,
548+
// and get the mean.
549+
elapsed_time = miopen::removeHighOutliersAndGetMean(samples, 2.0f);
532550
if(elapsed_time < best_time)
533551
{
534552
MIOPEN_LOG_I('#' << n_current << '/' << n_failed << '/' << n_runs_total
535553
<< ' ' << elapsed_time << " < " << best_time << ' '
536554
<< current_config);
537555
best_config = current_config;
538556
best_time = elapsed_time;
539-
n_best = n_current;
540-
last_imprv = 0;
557+
558+
// Samples gets sorted by the RemoveOutliers call so the last element
559+
// will be the slowest.
560+
worst_time = samples.back();
561+
n_best = n_current;
562+
last_imprv = 0;
541563
}
542564
else
543565
{
544-
MIOPEN_LOG_I2("Average is not better: " << elapsed_time
545-
<< " >= " << best_time);
566+
MIOPEN_LOG_I2("Mean is not better: " << elapsed_time
567+
<< " >= " << best_time);
546568
}
547569
}
548570
}
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
/*******************************************************************************
2+
*
3+
* MIT License
4+
*
5+
* Copyright (c) 2025 Advanced Micro Devices, Inc.
6+
*
7+
* Permission is hereby granted, free of charge, to any person obtaining a copy
8+
* of this software and associated documentation files (the "Software"), to deal
9+
* in the Software without restriction, including without limitation the rights
10+
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11+
* copies of the Software, and to permit persons to whom the Software is
12+
* furnished to do so, subject to the following conditions:
13+
*
14+
* The above copyright notice and this permission notice shall be included in all
15+
* copies or substantial portions of the Software.
16+
*
17+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18+
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19+
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20+
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21+
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22+
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23+
* SOFTWARE.
24+
*
25+
*******************************************************************************/
26+
27+
#pragma once
28+
29+
#include <vector>
30+
#include <algorithm>
31+
#include <numeric>
32+
#include <miopen/errors.hpp>
33+
34+
namespace miopen {
35+
36+
template <typename T>
37+
T mean(const std::vector<T>& data)
38+
{
39+
static_assert(std::is_floating_point_v<T>);
40+
MIOPEN_THROW_IF(data.size() == 0, "Cannot find Mean of 0 length data");
41+
42+
T sumOfValues = std::accumulate(data.begin(), data.end(), 0.0);
43+
return sumOfValues / data.size();
44+
}
45+
46+
template <typename T>
47+
T medianOfSortedData(const std::vector<T>& sortedData)
48+
{
49+
static_assert(std::is_floating_point_v<T>);
50+
MIOPEN_THROW_IF(sortedData.size() == 0, "Cannot find Median of 0 length data");
51+
52+
size_t size = sortedData.size();
53+
54+
T median = (size % 2 == 0) ? (sortedData[size / 2 - 1] + sortedData[size / 2]) / 2.0
55+
: sortedData[size / 2];
56+
57+
return median;
58+
}
59+
60+
template <typename T>
61+
T median(std::vector<T>& data)
62+
{
63+
static_assert(std::is_floating_point_v<T>);
64+
// Note: The data needs to be sorted for other parts of the algorthim
65+
std::sort(data.begin(), data.end());
66+
67+
return medianOfSortedData(data);
68+
}
69+
70+
template <typename T>
71+
std::vector<T> medianAbsoluteDeviation(const std::vector<T>& sortedData)
72+
{
73+
static_assert(std::is_floating_point_v<T>);
74+
T median = medianOfSortedData(sortedData);
75+
76+
std::vector<T> absDeviation;
77+
absDeviation.reserve(sortedData.size());
78+
79+
std::transform(sortedData.begin(),
80+
sortedData.end(),
81+
std::back_inserter(absDeviation),
82+
[&](auto& value) { return std::abs(value - median); });
83+
84+
return absDeviation;
85+
}
86+
87+
template <typename T>
88+
std::vector<T> modifiedZScores(const std::vector<T>& sortedData)
89+
{
90+
static_assert(std::is_floating_point_v<T>);
91+
T medianValue = medianOfSortedData(sortedData);
92+
93+
std::vector<T> absolute_deviation = medianAbsoluteDeviation(sortedData);
94+
T mad = median(absolute_deviation);
95+
96+
// If MAD is 0, then we cannot calcualte the ModifiedZScore
97+
if(mad == T{0})
98+
{
99+
return std::vector<T>(sortedData.size(), 0);
100+
}
101+
else
102+
{
103+
std::vector<T> modZScores;
104+
modZScores.reserve(sortedData.size());
105+
106+
std::transform(sortedData.begin(),
107+
sortedData.end(),
108+
std::back_inserter(modZScores),
109+
[&](auto& value) { return 0.6745 * (value - medianValue) / mad; });
110+
111+
return modZScores;
112+
}
113+
}
114+
115+
template <typename T>
116+
T removeHighOutliersAndGetMean(std::vector<T>& data, T z_threshold)
117+
{
118+
static_assert(std::is_floating_point_v<T>);
119+
std::sort(data.begin(), data.end());
120+
121+
std::vector<T> modZScores = modifiedZScores(data);
122+
std::vector<T> filteredData;
123+
124+
for(size_t i = 0; i < data.size(); ++i)
125+
{
126+
if(modZScores[i] <= z_threshold)
127+
{
128+
filteredData.push_back(data[i]);
129+
}
130+
}
131+
132+
return mean(filteredData);
133+
}
134+
} // namespace miopen

0 commit comments

Comments
 (0)