Skip to content

Commit 45eb062

Browse files
Merge pull request #466 from sathvikbhagavan/sb/moe
refactor: use train data if test data is empty for a cluster
2 parents 4f069df + 12a0529 commit 45eb062

File tree

1 file changed

+13
-8
lines changed

1 file changed

+13
-8
lines changed

lib/SurrogatesMOE/src/SurrogatesMOE.jl

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,9 @@ function _find_best_model(clustered_train_values, clustered_test_values, dim,
241241
xtest_mat = _vector_of_tuples_to_matrix(x_test_vec)
242242
end
243243

244-
X = vcat(xtrain_mat, xtest_mat)
244+
X = !isnothing(xtest_mat) ? vcat(xtrain_mat, xtest_mat) : xtrain_mat
245+
x_test_vec = !isnothing(xtest_mat) ? x_test_vec : x_vec
246+
y_test_vec = !isnothing(xtest_mat) ? y_test_vec : y_vec
245247
lb, ub = _find_upper_lower_bounds(X)
246248

247249
# call on _surrogate_builder with clustered_train_vals, enabled expert types, lb, ub
@@ -385,15 +387,18 @@ end
385387
takes in a vector of tuples or vector of vectors and converts it into a matrix
386388
"""
387389
function _vector_of_tuples_to_matrix(v)
388-
num_rows = length(v)
389-
num_cols = length(first(v))
390-
K = zeros(num_rows, num_cols)
391-
for row in 1:num_rows
392-
for col in 1:num_cols
393-
K[row, col] = v[row][col]
390+
if !isempty(v)
391+
num_rows = length(v)
392+
num_cols = length(first(v))
393+
K = zeros(num_rows, num_cols)
394+
for row in 1:num_rows
395+
for col in 1:num_cols
396+
K[row, col] = v[row][col]
397+
end
394398
end
399+
return K
395400
end
396-
return K
401+
return nothing
397402
end
398403

399404
end #module

0 commit comments

Comments
 (0)