Skip to content

Commit 903c0ad

Browse files
committed
update multiclass tests
1 parent 8abb932 commit 903c0ad

File tree

1 file changed

+30
-17
lines changed

1 file changed

+30
-17
lines changed

tests/python_package_test/test_dask.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -493,15 +493,21 @@ def test_classifier_custom_objective(output, task, cluster):
493493
with Client(cluster) as client:
494494
X, y, w, _, dX, dy, dw, _ = _create_data(
495495
objective=task,
496-
output=output
496+
output=output,
497497
)
498498

499+
# + + + + + + + + + +
500+
# + + + + + + + + + +
501+
# + + + + + + + + + +
502+
# + + + + + + + + + +
503+
# + _ _ _ _ _ _ _ _ _
499504
params = {
500505
"n_estimators": 50,
501506
"num_leaves": 31,
502-
"min_data": 1,
503507
"verbose": -1,
504-
"learning_rate": 0.01,
508+
"seed": 708,
509+
"deterministic": True,
510+
"force_col_wise": True
505511
}
506512

507513
if task == 'binary-classification':
@@ -522,25 +528,26 @@ def test_classifier_custom_objective(output, task, cluster):
522528
)
523529
dask_classifier = dask_classifier.fit(dX, dy, sample_weight=dw)
524530
dask_classifier_local = dask_classifier.to_local()
525-
p1_proba = dask_classifier.predict_proba(dX).compute()
526-
p1_proba_local = dask_classifier_local.predict_proba(X)
531+
p1_raw = dask_classifier.predict(dX, raw_score=True).compute()
532+
p1_raw_local = dask_classifier_local.predict(X, raw_score=True)
527533

528534
# with a custom objective, prediction result is a raw score instead of predicted class
529-
p1_class = (1.0 / (1.0 + np.exp(-p1_proba))) > 0.5
530-
p1_class = p1_class.astype(np.int64)
531-
p1_class_local = (1.0 / (1.0 + np.exp(-p1_proba_local))) > 0.5
532-
p1_class_local = p1_class_local.astype(np.int64)
535+
p1_proba = 1.0 / (1.0 + np.exp(-p1_raw))
536+
p1_proba_local = 1.0 / (1.0 + np.exp(-p1_raw_local))
533537

534538
local_classifier = lgb.LGBMClassifier(**params)
535539
local_classifier.fit(X, y, sample_weight=w)
536-
p2_proba = local_classifier.predict_proba(X)
537-
p2_class = (1.0 / (1.0 + np.exp(-p1_proba))) > 0.5
538-
p2_class = p2_class.astype(np.int64)
540+
p2_raw = local_classifier.predict(X, raw_score=True)
541+
p2_proba = 1.0 / (1.0 + np.exp(-p2_raw))
539542

540-
if task == 'multiclass-classification':
541-
p1_class = p1_class.argmax(axis=1)
542-
p1_class_local = p1_class_local.argmax(axis=1)
543-
p2_class = p2_class.argmax(axis=1)
543+
if task == 'binary-classification':
544+
p1_class = (p1_proba > 0.5).astype(np.int64)
545+
p1_class_local = (p1_proba_local > 0.5).astype(np.int64)
546+
p2_class = (p2_proba > 0.5).astype(np.int64)
547+
elif task == 'multiclass-classification':
548+
p1_class = p1_proba.argmax(axis=1)
549+
p1_class_local = p1_proba_local.argmax(axis=1)
550+
p2_class = p2_proba.argmax(axis=1)
544551

545552
# function should have been preserved
546553
assert callable(dask_classifier.objective_)
@@ -552,7 +559,13 @@ def test_classifier_custom_objective(output, task, cluster):
552559
assert_eq(p2_class, y)
553560

554561
# probability estimates should be similar
555-
assert_eq(p1_proba, p2_proba, atol=0.03)
562+
assert_eq(p1_proba, p2_proba, atol=0.04)
563+
# try:
564+
# assert_eq(p1_proba, p2_proba, atol=0.04)
565+
# except Exception as err:
566+
# max_diff = np.max(np.abs(p1_proba - p2_proba))
567+
# num_samples = np.sum(np.abs(p1_proba - p2_proba) > 0.04)
568+
# raise RuntimeError(f"max diff: {max_diff} | n: {num_samples}")
556569

557570

558571
def test_group_workers_by_host():

0 commit comments

Comments
 (0)