Skip to content
This repository was archived by the owner on Jul 16, 2021. It is now read-only.

Commit c3937ac

Browse files
committed
add criterion splitter
1 parent dca892d commit c3937ac

File tree

2 files changed

+253
-209
lines changed

2 files changed

+253
-209
lines changed

src/learning/tree/criterion.rs

Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
use std::collections::BTreeMap;
2+
3+
use linalg::Vector;
4+
5+
fn xlogy(x: f64, y: f64) -> f64 {
6+
if x == 0. {
7+
0.
8+
} else {
9+
x * y.ln()
10+
}
11+
}
12+
13+
/// Count target label frequencies
14+
fn freq(labels: &Vector<usize>) -> (Vector<usize>, Vector<usize>) {
15+
let mut map: BTreeMap<usize, usize> = BTreeMap::new();
16+
for l in labels {
17+
let e = map.entry(*l).or_insert(0);
18+
*e += 1;
19+
}
20+
21+
let mut uniques: Vec<usize> = Vec::with_capacity(map.len());
22+
let mut counts: Vec<usize> = Vec::with_capacity(map.len());
23+
for (&k, &v) in map.iter() {
24+
uniques.push(k);
25+
counts.push(v);
26+
}
27+
(Vector::new(uniques), Vector::new(counts))
28+
}
29+
30+
pub fn label_counts(labels: &Vector<usize>, n_classes: usize) -> Vector<f64> {
31+
// ToDo: make this private
32+
debug_assert!(n_classes >= 1);
33+
debug_assert!(*labels.iter().max().unwrap() <= n_classes - 1);
34+
35+
let mut counts: Vec<f64> = vec![0.0f64; n_classes];
36+
37+
unsafe {
38+
for &label in labels.iter() {
39+
*counts.get_unchecked_mut(label) += 1.;
40+
}
41+
}
42+
Vector::new(counts)
43+
}
44+
45+
/// Split criterias
46+
#[derive(Debug, Clone)]
47+
pub enum Metrics {
48+
// ToDo: remove clone
49+
50+
/// Gini impurity
51+
Gini,
52+
/// Information gain
53+
Entropy
54+
}
55+
56+
impl Metrics {
57+
58+
/// calculate metrics from target labels
59+
pub fn from_labels(&self, labels: &Vector<usize>, n_classes: usize) -> f64 {
60+
let counts = label_counts(labels, n_classes);
61+
let sum: f64 = labels.size() as f64;
62+
let probas: Vector<f64> = counts / sum;
63+
self.from_probas(&probas.data())
64+
}
65+
66+
/// calculate metrics from label probabilities
67+
pub fn from_probas(&self, probas: &[f64]) -> f64 {
68+
match self {
69+
&Metrics::Entropy => {
70+
let res: f64 = probas.iter().map(|&x| xlogy(x, x)).sum();
71+
- res
72+
},
73+
&Metrics::Gini => {
74+
let res: f64 = probas.iter().map(|&x| x * x).sum();
75+
1.0 - res
76+
}
77+
}
78+
}
79+
}
80+
81+
pub struct Splitter {
82+
total_counts: Vec<f64>,
83+
sorter: Vec<(f64, usize)>
84+
}
85+
86+
impl Splitter {
87+
pub fn new(features: &Vec<f64>, target: &Vector<usize>,
88+
total_counts: &Vec<f64>) -> Self {
89+
90+
debug_assert!(features.len() == target.size());
91+
debug_assert!(features.len() > 0);
92+
93+
let mut sorter: Vec<(f64, usize)> = Vec::with_capacity(features.len());
94+
for (&f, &t) in features.iter().zip(target.iter()) {
95+
sorter.push((f, t));
96+
}
97+
sorter.sort_by(|x, y| x.0.partial_cmp(&y.0).unwrap());
98+
99+
Splitter {
100+
total_counts: total_counts.clone(),
101+
sorter: sorter
102+
}
103+
}
104+
105+
pub fn get_max_splits(&self, metric: &Metrics) -> Vec<(f64, f64)> {
106+
let (mut prev_val, prev_label) = unsafe { *self.sorter.get_unchecked(0) };
107+
let mut left_counts = vec![0.0f64; self.total_counts.len()];
108+
unsafe {
109+
*left_counts.get_unchecked_mut(prev_label) += 1.;
110+
}
111+
112+
// ToDo: compare perf whether to store total as f64
113+
let mut left_total: f64 = 1.0f64;
114+
let mut right_counts: Vec<f64> = self.total_counts.iter()
115+
.zip(left_counts.iter())
116+
.map(|(&t, &c)| t - c)
117+
.collect();
118+
let mut right_total: f64 = (self.sorter.len() - 1) as f64;
119+
120+
// stores tuple of split value and criterion
121+
let mut res: Vec<(f64, f64)> = Vec::with_capacity(self.sorter.len());
122+
123+
for &(current_val, current_label) in self.sorter.iter().skip(1) {
124+
if prev_val != current_val {
125+
let split = (prev_val + current_val) / 2.0f64;
126+
let lp: Vec<f64> = left_counts.iter().map(|&x| x / left_total).collect();
127+
let rp: Vec<f64> = right_counts.iter().map(|&x| x / right_total).collect();
128+
let lc = metric.from_probas(&lp) * left_total;
129+
let rc = metric.from_probas(&rp) * right_total;
130+
res.push((split, lc + rc));
131+
}
132+
133+
unsafe {
134+
*left_counts.get_unchecked_mut(current_label) += 1.0f64;
135+
*right_counts.get_unchecked_mut(current_label) -= 1.0f64;
136+
}
137+
left_total += 1.0f64;
138+
right_total -= 1.0f64;
139+
140+
prev_val = current_val;
141+
}
142+
res
143+
}
144+
}
145+
146+
#[cfg(test)]
147+
mod tests {
148+
149+
use linalg::Vector;
150+
151+
use super::{xlogy, freq, Metrics, Splitter};
152+
153+
#[test]
154+
fn test_xlogy() {
155+
assert_eq!(xlogy(3., 8.), 6.2383246250395068);
156+
assert_eq!(xlogy(0., 100.), 0.);
157+
}
158+
159+
#[test]
160+
fn test_freq() {
161+
let (uniques, counts) = freq(&Vector::new(vec![1, 2, 3, 1, 2, 4]));
162+
assert_eq!(uniques, Vector::new(vec![1, 2, 3, 4]));
163+
assert_eq!(counts, Vector::new(vec![2, 2, 1, 1]));
164+
165+
let (uniques, counts) = freq(&Vector::new(vec![1, 2, 2, 2, 2]));
166+
assert_eq!(uniques, Vector::new(vec![1, 2]));
167+
assert_eq!(counts, Vector::new(vec![1, 4]));
168+
}
169+
170+
#[test]
171+
fn test_entropy() {
172+
assert_eq!(Metrics::Entropy.from_probas(&vec![1.]), 0.);
173+
assert_eq!(Metrics::Entropy.from_probas(&vec![1., 0., 0.]), 0.);
174+
assert_eq!(Metrics::Entropy.from_probas(&vec![0.5, 0.5]), 0.69314718055994529);
175+
assert_eq!(Metrics::Entropy.from_probas(&vec![1. / 3., 1. / 3., 1. / 3.]), 1.0986122886681096);
176+
assert_eq!(Metrics::Entropy.from_probas(&vec![0.4, 0.3, 0.3]), 1.0888999753452238);
177+
}
178+
179+
#[test]
180+
fn test_gini_from_probas() {
181+
assert_eq!(Metrics::Gini.from_probas(&vec![1., 0., 0.]), 0.);
182+
assert_eq!(Metrics::Gini.from_probas(&vec![1. / 3., 1. / 3., 1. / 3.]), 0.6666666666666667);
183+
assert_eq!(Metrics::Gini.from_probas(&vec![0., 1. / 46., 45. / 46.]), 0.04253308128544431);
184+
assert_eq!(Metrics::Gini.from_probas(&vec![0., 49. / 54., 5. / 54.]), 0.16803840877914955);
185+
}
186+
187+
#[test]
188+
fn test_entropy_from_labels() {
189+
assert_eq!(Metrics::Entropy.from_labels(&Vector::new(vec![0, 1, 2]), 3), 1.0986122886681096);
190+
assert_eq!(Metrics::Entropy.from_labels(&Vector::new(vec![0, 0, 1, 1]), 2), 0.69314718055994529);
191+
}
192+
193+
#[test]
194+
fn test_gini_from_labels() {
195+
assert_eq!(Metrics::Gini.from_labels(&Vector::new(vec![1, 1, 1]), 2), 0.);
196+
assert_eq!(Metrics::Gini.from_labels(&Vector::new(vec![0, 0, 0]), 2), 0.);
197+
assert_eq!(Metrics::Gini.from_labels(&Vector::new(vec![0, 0, 1, 1, 2, 2]), 3), 0.6666666666666667);
198+
}
199+
200+
#[test]
201+
fn test_splitter() {
202+
let features: Vec<f64> = vec![1.0, 2.0, 1.0, 2.0, 3.0, 4.0];
203+
let labels: Vector<usize> = Vector::new(vec![0, 1, 1, 1, 0, 0]);
204+
205+
let s = Splitter::new(&features, &labels, &vec![3., 3.]);
206+
let res = s.get_max_splits(&Metrics::Gini);
207+
assert_eq!(res.len(), 3);
208+
209+
let exp = Metrics::Gini.from_labels(&Vector::new(vec![0, 1]), 2) * 2. +
210+
Metrics::Gini.from_labels(&Vector::new(vec![0, 0, 1, 1]), 2) * 4.;
211+
assert_eq!(res[0], (1.5, exp));
212+
213+
let exp = Metrics::Gini.from_labels(&Vector::new(vec![0, 1, 1, 1]), 2) * 4. +
214+
Metrics::Gini.from_labels(&Vector::new(vec![0, 0]), 2) * 2.;
215+
assert_eq!(res[1], (2.5, exp));
216+
217+
let exp = Metrics::Gini.from_labels(&Vector::new(vec![0, 0, 1, 1, 1]), 2) * 5. +
218+
Metrics::Gini.from_labels(&Vector::new(vec![0]), 2) * 1.;
219+
assert_eq!(res[2], (3.5, exp));
220+
}
221+
}

0 commit comments

Comments
 (0)