-
Notifications
You must be signed in to change notification settings - Fork 528
[FEATURE] Add stratified train_valid_split similar to sklearn.model_selection.train_test_split #933
Conversation
Codecov Report
@@ Coverage Diff @@
## master #933 +/- ##
=========================================
- Coverage 89.98% 88.38% -1.6%
=========================================
Files 67 67
Lines 6372 6296 -76
=========================================
- Hits 5734 5565 -169
- Misses 638 731 +93
|
Job PR-933/1 is complete. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the contribution, @colinkyle. To move forward, let's:
- fix lint errors
- add a unittest for the new functionality
Thanks for your patience, I don't see any tests for the functions within gluonnlp.data.utils should I create a new file under the unittest folder? or is there already a file I should modify with a test for train_valid_split? |
Job PR-933/2 is complete. |
@colinkyle let's put the test in tests/test_utils.py for now. Or if you prefer to create a new file for it we can do that too. |
Job PR-933/3 is complete. |
Job PR-933/4 is complete. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. The PR will be merged once the mxnet website intersphinx is fixed.
|
||
classes, digitized = np.unique(stratify, return_inverse=True) | ||
n_classes = len(classes) | ||
num_class = np.bincount(digitized) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One problem of using bincount is that len(num_class) != n_classes
in some cases, e.g., labels = [0,1,2,4] in which 3 is missing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I lifted that logic directly from sklearn's implementation and I believe it corrects for that problem. "digitized" is numbered labels starting at zero (i.e., labels = [1, 2, 4, 4, 2, 1, 1, 1, 2, 2, 0, 0], digitized = [1, 2, 3, 3, 2, 1, 1, 1, 2, 2, 0, 0]), and len(num_class) does equal n_classes.
Is this going to move forward? I'm not sure where we are with review. |
@colinkyle. It will be merged once it passed the CI. |
@szhengac thanks for the update. |
@leezu the CI error is related to downloading a dataset. It seems that we have seen it in other pr before? |
I think we should just rebase and merge. |
Job PR-933/16 is complete. |
Job PR-933/17 is complete. |
Description
I added the ability to perform a stratified split in train_valid_split
Checklist
Essentials
Changes
Comments
Backwards compatible, the only edge case I can think of is if someone tries to use a float to stratify their data and end up getting non-sense results.
cc @dmlc/gluon-nlp-team