-
Notifications
You must be signed in to change notification settings - Fork 3.9k
Closed
Labels
Description
If I instantiate an LGBMClassifier
object and fit it to non-binary data, the object will error out when refitting to binary data:
import numpy as np
from lightgbm import LGBMClassifier
rng = np.random.default_rng(seed=123)
nrows = 1000
ncols = 20
X = rng.standard_normal(size=(nrows,ncols))
y_bin = (rng.random(size=nrows) <= .3).astype(np.float64)
y_multi = rng.integers(4, size=nrows)
model = LGBMClassifier()
model.fit(X, y_multi)
model.fit(X, y_bin)
---------------------------------------------------------------------------
LightGBMError Traceback (most recent call last)
Untitled-1 in <module>
----> <a href='untitled:Untitled-1?line=13'>14</a> model.fit(X, y_bin)
~/anaconda3/lib/python3.9/site-packages/lightgbm/sklearn.py in fit(self, X, y, sample_weight, init_score, eval_set, eval_names, eval_sample_weight, eval_class_weight, eval_init_score, eval_metric, early_stopping_rounds, verbose, feature_name, categorical_feature, callbacks, init_model)
965 valid_sets[i] = (valid_x, self._le.transform(valid_y))
966
--> 967 super().fit(X, _y, sample_weight=sample_weight, init_score=init_score, eval_set=valid_sets,
968 eval_names=eval_names, eval_sample_weight=eval_sample_weight,
969 eval_class_weight=eval_class_weight, eval_init_score=eval_init_score,
~/anaconda3/lib/python3.9/site-packages/lightgbm/sklearn.py in fit(self, X, y, sample_weight, init_score, group, eval_set, eval_names, eval_sample_weight, eval_class_weight, eval_init_score, eval_group, eval_metric, early_stopping_rounds, verbose, feature_name, categorical_feature, callbacks, init_model)
746 callbacks.append(record_evaluation(evals_result))
747
--> 748 self._Booster = train(
749 params=params,
750 train_set=train_set,
~/anaconda3/lib/python3.9/site-packages/lightgbm/engine.py in train(params, train_set, num_boost_round, valid_sets, valid_names, fobj, feval, init_model, feature_name, categorical_feature, early_stopping_rounds, evals_result, verbose_eval, learning_rates, keep_training_booster, callbacks)
269 # construct booster
270 try:
--> 271 booster = Booster(params=params, train_set=train_set)
272 if is_valid_contain_train:
273 booster.set_train_data_name(train_data_name)
...
--> 125 raise LightGBMError(_LIB.LGBM_GetLastError().decode('utf-8'))
126
127
LightGBMError: Number of classes should be specified and greater than 1 for multiclass training