Skip to content

Commit 5f517cb

Browse files
committed
Merge pull request #23 from dmitry/naive_bayes_refactor
Reformatted naive bayes to more idiomatic ruby
2 parents 9d4278d + d45bdc8 commit 5f517cb

File tree

2 files changed

+40
-39
lines changed

2 files changed

+40
-39
lines changed

lib/ai4r/classifiers/naive_bayes.rb

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ module Classifiers
5757

5858
class NaiveBayes < Classifier
5959

60-
parameters_info :m => "Default value is set to 0. It may be set to a value greater than " +
61-
"0 when the size of the dataset is relatively small"
60+
parameters_info :m => 'Default value is set to 0. It may be set to a value greater than ' +
61+
'0 when the size of the dataset is relatively small'
6262

6363
def initialize
6464
@m = 0
@@ -75,7 +75,7 @@ def initialize
7575
# b.eval(["Red", "SUV", "Domestic"])
7676
# => 'No'
7777
def eval(data)
78-
prob = @class_prob.map {|cp| cp}
78+
prob = @class_prob.dup
7979
prob = calculate_class_probabilities_for_entry(data, prob)
8080
index_to_klass(prob.index(prob.max))
8181
end
@@ -90,27 +90,28 @@ def eval(data)
9090
# b.get_probability_map(["Red", "SUV", "Domestic"])
9191
# => {"Yes"=>0.4166666666666667, "No"=>0.5833333333333334}
9292
def get_probability_map(data)
93-
prob = @class_prob.map {|cp| cp}
93+
prob = @class_prob.dup
9494
prob = calculate_class_probabilities_for_entry(data, prob)
9595
prob = normalize_class_probability prob
9696
probability_map = {}
9797
prob.each_with_index { |p, i| probability_map[index_to_klass(i)] = p }
98-
return probability_map
98+
99+
probability_map
99100
end
100101

101102
# counts values of the attribute instances and calculates the probability of the classes
102103
# and the conditional probabilities
103104
# Parameter data has to be an instance of CsvDataSet
104105
def build(data)
105-
raise "Error instance must be passed" unless data.is_a?(Ai4r::Data::DataSet)
106-
raise "Data should not be empty" if data.data_items.length == 0
106+
raise 'Error instance must be passed' unless data.is_a?(Ai4r::Data::DataSet)
107+
raise 'Data should not be empty' if data.data_items.length == 0
107108

108109
initialize_domain_data(data)
109110
initialize_klass_index
110111
initialize_pc
111112
calculate_probabilities
112113

113-
return self
114+
self
114115
end
115116

116117
private
@@ -128,7 +129,7 @@ def initialize_domain_data(data)
128129
# probability of every attribute in condition to a specific class
129130
# this is repeated for every class
130131
def calculate_class_probabilities_for_entry(data, prob)
131-
prob.each_with_index do |prob_entry, prob_index|
132+
0.upto(prob.length - 1) do |prob_index|
132133
data.each_with_index do |att, index|
133134
next if value_index(att, index).nil?
134135
prob[prob_index] *= @pcp[index][value_index(att, index)][prob_index]
@@ -140,13 +141,13 @@ def calculate_class_probabilities_for_entry(data, prob)
140141
def normalize_class_probability(prob)
141142
prob_sum = sum(prob)
142143
prob_sum > 0 ?
143-
prob.map {|prob_entry| prob_entry / prob_sum } :
144+
prob.map { |prob_entry| prob_entry / prob_sum } :
144145
prob
145146
end
146147

147148
# sums an array up; returns a number of type Float
148149
def sum(array)
149-
array.inject(0.0){|b, i| b+i}
150+
array.inject(0.0) { |b, i| b + i }
150151
end
151152

152153
# returns the name of the class when the index is found
@@ -160,7 +161,7 @@ def initialize_klass_index
160161
@klass_index[dl] = index
161162
end
162163

163-
@data_labels.each_with_index do |dl, index|
164+
0.upto(@data_labels.length - 1) do |index|
164165
@values[index] = {}
165166
@domains[index].each_with_index do |d, d_index|
166167
@values[index][d] = d_index
@@ -180,27 +181,27 @@ def value_index(value, dl_index)
180181

181182
# builds an array of the form:
182183
# array[attributes][values][classes]
183-
def build_array(dl, index)
184+
def build_array(index)
184185
domains = Array.new(@domains[index].length)
185-
domains.map do |p1|
186-
pl = Array.new @klasses.length, 0
186+
domains.map do
187+
Array.new @klasses.length, 0
187188
end
188189
end
189190

190191
# initializes the two array for storing the count and conditional probabilities of
191192
# the attributes
192193
def initialize_pc
193-
@data_labels.each_with_index do |dl, index|
194-
@pcc << build_array(dl, index)
195-
@pcp << build_array(dl, index)
194+
0.upto(@data_labels.length - 1) do |index|
195+
@pcc << build_array(index)
196+
@pcp << build_array(index)
196197
end
197198
end
198199

199200
# calculates the occurrences of a class and the instances of a certain value of a
200201
# certain attribute and the assigned class.
201202
# In addition to that, it also calculates the conditional probabilities and values
202203
def calculate_probabilities
203-
@klasses.each {|dl| @class_counts[klass_index(dl)] = 0}
204+
@klasses.each { |dl| @class_counts[klass_index(dl)] = 0 }
204205

205206
calculate_class_probabilities
206207
count_instances
@@ -220,7 +221,7 @@ def calculate_class_probabilities
220221
# counts the instances of a certain value of a certain attribute and the assigned class
221222
def count_instances
222223
@data_items.each do |item|
223-
@data_labels.each_with_index do |dl, dl_index|
224+
0.upto(@data_labels.length - 1) do |dl_index|
224225
@pcc[dl_index][value_index(item[dl_index], dl_index)][klass_index(item.klass)] += 1
225226
end
226227
end
@@ -231,7 +232,7 @@ def calculate_conditional_probabilities
231232
@pcc.each_with_index do |attributes, a_index|
232233
attributes.each_with_index do |values, v_index|
233234
values.each_with_index do |klass, k_index|
234-
@pcp[a_index][v_index][k_index] = (klass.to_f + @m * @class_prob[k_index]) / (@class_counts[k_index] + @m).to_f
235+
@pcp[a_index][v_index][k_index] = (klass.to_f + @m * @class_prob[k_index]) / (@class_counts[k_index] + @m)
235236
end
236237
end
237238
end

test/classifiers/naive_bayes_test.rb

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,37 +7,37 @@
77

88
class NaiveBayesTest < Test::Unit::TestCase
99

10-
@@data_labels = [ "Color","Type","Origin","Stolen?" ]
10+
@@data_labels = %w(Color Type Origin Stolen?)
1111

1212
@@data_items = [
13-
["Red", "Sports", "Domestic", "Yes"],
14-
["Red", "Sports", "Domestic", "No"],
15-
["Red", "Sports", "Domestic", "Yes"],
16-
["Yellow","Sports", "Domestic", "No"],
17-
["Yellow","Sports", "Imported", "Yes"],
18-
["Yellow","SUV", "Imported", "No"],
19-
["Yellow","SUV", "Imported", "Yes"],
20-
["Yellow","Sports", "Domestic", "No"],
21-
["Red", "SUV", "Imported", "No"],
22-
["Red", "Sports", "Imported", "Yes"]
23-
]
13+
%w(Red Sports Domestic Yes),
14+
%w(Red Sports Domestic No),
15+
%w(Red Sports Domestic Yes),
16+
%w(Yellow Sports Domestic No),
17+
%w(Yellow Sports Imported Yes),
18+
%w(Yellow SUV Imported No),
19+
%w(Yellow SUV Imported Yes),
20+
%w(Yellow Sports Domestic No),
21+
%w(Red SUV Imported No),
22+
%w(Red Sports Imported Yes)
23+
]
2424

2525
def setup
2626
@data_set = DataSet.new
2727
@data_set = DataSet.new(:data_items => @@data_items, :data_labels => @@data_labels)
28-
@b = NaiveBayes.new.set_parameters({:m=>3}).build @data_set
28+
@b = NaiveBayes.new.set_parameters({:m => 3}).build @data_set
2929
end
3030

3131
def test_eval
32-
result = @b.eval(["Red", "SUV", "Domestic"])
33-
assert_equal "No", result
32+
result = @b.eval(%w(Red SUV Domestic))
33+
assert_equal 'No', result
3434
end
3535

3636
def test_get_probability_map
37-
map = @b.get_probability_map(["Red", "SUV", "Domestic"])
37+
map = @b.get_probability_map(%w(Red SUV Domestic))
3838
assert_equal 2, map.keys.length
39-
assert_in_delta 0.42, map["Yes"], 0.1
40-
assert_in_delta 0.58, map["No"], 0.1
39+
assert_in_delta 0.42, map['Yes'], 0.1
40+
assert_in_delta 0.58, map['No'], 0.1
4141
end
4242

4343
end

0 commit comments

Comments
 (0)