Skip to content

Commit 0dd9d42

Browse files
jbrukhclaude
andcommitted
Fix numerical underflow/overflow vulnerabilities
- Add Laplace smoothing to getPriors to avoid log(0) for zero priors - Handle division by zero in ProbScores when all scores underflow - Improve SafeProbScores with log-sum-exp trick for probability recovery - Add logScoresToProbs helper for numerically stable probability conversion - Update tests for new numerical behavior 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <[email protected]>
1 parent fe2c4d4 commit 0dd9d42

File tree

2 files changed

+95
-30
lines changed

2 files changed

+95
-30
lines changed

bayesian.go

Lines changed: 82 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -150,22 +150,22 @@ func NewClassifierFromReader(r io.Reader) (c *Classifier, err error) {
150150

151151
// getPriors returns the prior probabilities for the
152152
// classes provided -- P(C_j).
153-
//
154-
// TODO: There is a way to smooth priors, currently
155-
// not implemented here.
153+
// Uses Laplace smoothing to ensure no prior is zero:
154+
// P(C_j) = (count_j + 1) / (total + num_classes)
156155
func (c *Classifier) getPriors() (priors []float64) {
157156
n := len(c.Classes)
158-
priors = make([]float64, n, n)
157+
priors = make([]float64, n)
159158
sum := 0
160159
for index, class := range c.Classes {
161160
total := c.datas[class].Total
162161
priors[index] = float64(total)
163162
sum += total
164163
}
165-
if sum != 0 {
166-
for i := 0; i < n; i++ {
167-
priors[i] /= float64(sum)
168-
}
164+
// Apply Laplace smoothing to priors to avoid log(0)
165+
floatN := float64(n)
166+
floatSum := float64(sum)
167+
for i := 0; i < n; i++ {
168+
priors[i] = (priors[i] + 1) / (floatSum + floatN)
169169
}
170170
return
171171
}
@@ -339,30 +339,41 @@ func (c *Classifier) Classify(document []string) (class Class, scores []float64,
339339
// never seen before. Depending on the application, this
340340
// may or may not be a concern. Consider using SafeProbScores()
341341
// instead.
342+
//
343+
// If all scores underflow to zero, returns equal probabilities
344+
// for all classes (1/n each).
342345
func (c *Classifier) ProbScores(doc []string) (scores []float64, inx int, strict bool) {
343346
if c.tfIdf && !c.DidConvertTfIdf {
344347
panic("Using a TF-IDF classifier. Please call ConvertTermsFreqToTfIdf before calling ProbScores.")
345348
}
346349
n := len(c.Classes)
347-
scores = make([]float64, n, n)
350+
scores = make([]float64, n)
348351
priors := c.getPriors()
349352
sum := float64(0)
350353
// calculate the score for each class
351354
for index, class := range c.Classes {
352355
data := c.datas[class]
353-
// c is the sum of the logarithms
354-
// as outlined in the refresher
355356
score := priors[index]
356357
for _, word := range doc {
357358
score *= data.getWordProb(word)
358359
}
359360
scores[index] = score
360361
sum += score
361362
}
362-
for i := 0; i < n; i++ {
363-
scores[i] /= sum
363+
// Handle underflow: if sum is 0, all scores underflowed
364+
// Return equal probabilities to avoid NaN
365+
if sum == 0 {
366+
equal := 1.0 / float64(n)
367+
for i := 0; i < n; i++ {
368+
scores[i] = equal
369+
}
370+
strict = false
371+
} else {
372+
for i := 0; i < n; i++ {
373+
scores[i] /= sum
374+
}
375+
inx, strict = findMax(scores)
364376
}
365-
inx, strict = findMax(scores)
366377
atomic.AddInt32(&c.seen, 1)
367378
return scores, inx, strict
368379
}
@@ -383,26 +394,28 @@ func (c *Classifier) ClassifyProb(document []string) (class Class, scores []floa
383394
// this method returns an ErrUnderflow, allowing the user to deal with it as
384395
// necessary. Note that underflow, under certain rare circumstances,
385396
// may still result in incorrect probabilities being returned,
386-
// but this method guarantees that all error-less invokations
397+
// but this method guarantees that all error-less invocations
387398
// are properly classified.
388399
//
389400
// Underflow detection is more costly because it also
390401
// has to make additional log score calculations.
402+
//
403+
// When underflow is detected, the returned scores are computed from
404+
// log-domain scores using the log-sum-exp trick for numerical stability.
391405
func (c *Classifier) SafeProbScores(doc []string) (scores []float64, inx int, strict bool, err error) {
392406
if c.tfIdf && !c.DidConvertTfIdf {
393407
panic("Using a TF-IDF classifier. Please call ConvertTermsFreqToTfIdf before calling SafeProbScores.")
394408
}
395409

396410
n := len(c.Classes)
397-
scores = make([]float64, n, n)
398-
logScores := make([]float64, n, n)
411+
scores = make([]float64, n)
412+
logScores := make([]float64, n)
399413
priors := c.getPriors()
400414
sum := float64(0)
415+
401416
// calculate the score for each class
402417
for index, class := range c.Classes {
403418
data := c.datas[class]
404-
// c is the sum of the logarithms
405-
// as outlined in the refresher
406419
score := priors[index]
407420
logScore := math.Log(priors[index])
408421
for _, word := range doc {
@@ -414,22 +427,64 @@ func (c *Classifier) SafeProbScores(doc []string) (scores []float64, inx int, st
414427
logScores[index] = logScore
415428
sum += score
416429
}
417-
for i := 0; i < n; i++ {
418-
scores[i] /= sum
419-
}
420-
inx, strict = findMax(scores)
430+
431+
// Get the winner from log-domain (always reliable)
421432
logInx, logStrict := findMax(logScores)
422433

423-
// detect underflow -- the size
424-
// relation between scores and logScores
425-
// must be preserved or something is wrong
426-
if inx != logInx || strict != logStrict {
434+
// Check for underflow: if sum is 0 or prob-domain disagrees with log-domain
435+
if sum == 0 {
436+
// Complete underflow - use log-sum-exp to recover probabilities
427437
err = ErrUnderflow
438+
scores = logScoresToProbs(logScores)
439+
inx, strict = logInx, logStrict
440+
} else {
441+
for i := 0; i < n; i++ {
442+
scores[i] /= sum
443+
}
444+
inx, strict = findMax(scores)
445+
446+
// Detect partial underflow - when prob and log domains disagree
447+
if inx != logInx || strict != logStrict {
448+
err = ErrUnderflow
449+
// Use log-domain results as they're more reliable
450+
scores = logScoresToProbs(logScores)
451+
inx, strict = logInx, logStrict
452+
}
428453
}
454+
429455
atomic.AddInt32(&c.seen, 1)
430456
return scores, inx, strict, err
431457
}
432458

459+
// logScoresToProbs converts log-domain scores to probabilities
460+
// using the log-sum-exp trick for numerical stability.
461+
func logScoresToProbs(logScores []float64) []float64 {
462+
n := len(logScores)
463+
probs := make([]float64, n)
464+
465+
// Find max for numerical stability
466+
maxLog := logScores[0]
467+
for i := 1; i < n; i++ {
468+
if logScores[i] > maxLog {
469+
maxLog = logScores[i]
470+
}
471+
}
472+
473+
// Compute exp(log - max) and sum
474+
sum := 0.0
475+
for i := 0; i < n; i++ {
476+
probs[i] = math.Exp(logScores[i] - maxLog)
477+
sum += probs[i]
478+
}
479+
480+
// Normalize
481+
for i := 0; i < n; i++ {
482+
probs[i] /= sum
483+
}
484+
485+
return probs
486+
}
487+
433488
// ClassifySafe returns the most likely class for the given document
434489
// along with the probability scores, whether the classification is strict,
435490
// and an error if underflow is detected.

bayesian_test.go

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@ func Assert(t *testing.T, condition bool, args ...interface{}) {
1818
func TestEmpty(t *testing.T) {
1919
c := NewClassifier("Good", "Bad", "Neutral")
2020
priors := c.getPriors()
21+
// With Laplace smoothing, empty classifier should have uniform priors
22+
expected := 1.0 / float64(len(priors))
2123
for _, item := range priors {
22-
Assert(t, item == 0)
24+
Assert(t, item == expected, "expected uniform prior", expected, "got", item)
2325
}
2426
}
2527

@@ -200,10 +202,18 @@ func TestInduceUnderflow(t *testing.T) {
200202
for i := 0; i < docSize; i++ {
201203
document[i] = "word"
202204
}
203-
// should induce overflow, because each word
205+
// should induce underflow, because each word
204206
// will have "defaultProb", which is small
205-
scores, _, _, err := c.SafeProbScores(document)
207+
scores, inx, _, err := c.SafeProbScores(document)
206208
Assert(t, err == ErrUnderflow, "Underflow error not detected")
209+
// Verify log-sum-exp recovery produces valid probabilities
210+
sum := 0.0
211+
for _, s := range scores {
212+
Assert(t, s >= 0 && s <= 1, "score out of range [0,1]:", s)
213+
sum += s
214+
}
215+
Assert(t, sum > 0.999 && sum < 1.001, "scores don't sum to 1:", sum)
216+
Assert(t, inx >= 0 && inx < len(scores), "index out of range:", inx)
207217
println(scores)
208218
}
209219

0 commit comments

Comments
 (0)