@@ -150,22 +150,22 @@ func NewClassifierFromReader(r io.Reader) (c *Classifier, err error) {
150150
151151// getPriors returns the prior probabilities for the
152152// classes provided -- P(C_j).
153- //
154- // TODO: There is a way to smooth priors, currently
155- // not implemented here.
153+ // Uses Laplace smoothing to ensure no prior is zero:
154+ // P(C_j) = (count_j + 1) / (total + num_classes)
156155func (c * Classifier ) getPriors () (priors []float64 ) {
157156 n := len (c .Classes )
158- priors = make ([]float64 , n , n )
157+ priors = make ([]float64 , n )
159158 sum := 0
160159 for index , class := range c .Classes {
161160 total := c .datas [class ].Total
162161 priors [index ] = float64 (total )
163162 sum += total
164163 }
165- if sum != 0 {
166- for i := 0 ; i < n ; i ++ {
167- priors [i ] /= float64 (sum )
168- }
164+ // Apply Laplace smoothing to priors to avoid log(0)
165+ floatN := float64 (n )
166+ floatSum := float64 (sum )
167+ for i := 0 ; i < n ; i ++ {
168+ priors [i ] = (priors [i ] + 1 ) / (floatSum + floatN )
169169 }
170170 return
171171}
@@ -339,30 +339,41 @@ func (c *Classifier) Classify(document []string) (class Class, scores []float64,
339339// never seen before. Depending on the application, this
340340// may or may not be a concern. Consider using SafeProbScores()
341341// instead.
342+ //
343+ // If all scores underflow to zero, returns equal probabilities
344+ // for all classes (1/n each).
342345func (c * Classifier ) ProbScores (doc []string ) (scores []float64 , inx int , strict bool ) {
343346 if c .tfIdf && ! c .DidConvertTfIdf {
344347 panic ("Using a TF-IDF classifier. Please call ConvertTermsFreqToTfIdf before calling ProbScores." )
345348 }
346349 n := len (c .Classes )
347- scores = make ([]float64 , n , n )
350+ scores = make ([]float64 , n )
348351 priors := c .getPriors ()
349352 sum := float64 (0 )
350353 // calculate the score for each class
351354 for index , class := range c .Classes {
352355 data := c .datas [class ]
353- // c is the sum of the logarithms
354- // as outlined in the refresher
355356 score := priors [index ]
356357 for _ , word := range doc {
357358 score *= data .getWordProb (word )
358359 }
359360 scores [index ] = score
360361 sum += score
361362 }
362- for i := 0 ; i < n ; i ++ {
363- scores [i ] /= sum
363+ // Handle underflow: if sum is 0, all scores underflowed
364+ // Return equal probabilities to avoid NaN
365+ if sum == 0 {
366+ equal := 1.0 / float64 (n )
367+ for i := 0 ; i < n ; i ++ {
368+ scores [i ] = equal
369+ }
370+ strict = false
371+ } else {
372+ for i := 0 ; i < n ; i ++ {
373+ scores [i ] /= sum
374+ }
375+ inx , strict = findMax (scores )
364376 }
365- inx , strict = findMax (scores )
366377 atomic .AddInt32 (& c .seen , 1 )
367378 return scores , inx , strict
368379}
@@ -383,26 +394,28 @@ func (c *Classifier) ClassifyProb(document []string) (class Class, scores []floa
383394// this method returns an ErrUnderflow, allowing the user to deal with it as
384395// necessary. Note that underflow, under certain rare circumstances,
385396// may still result in incorrect probabilities being returned,
386- // but this method guarantees that all error-less invokations
397+ // but this method guarantees that all error-less invocations
387398// are properly classified.
388399//
389400// Underflow detection is more costly because it also
390401// has to make additional log score calculations.
402+ //
403+ // When underflow is detected, the returned scores are computed from
404+ // log-domain scores using the log-sum-exp trick for numerical stability.
391405func (c * Classifier ) SafeProbScores (doc []string ) (scores []float64 , inx int , strict bool , err error ) {
392406 if c .tfIdf && ! c .DidConvertTfIdf {
393407 panic ("Using a TF-IDF classifier. Please call ConvertTermsFreqToTfIdf before calling SafeProbScores." )
394408 }
395409
396410 n := len (c .Classes )
397- scores = make ([]float64 , n , n )
398- logScores := make ([]float64 , n , n )
411+ scores = make ([]float64 , n )
412+ logScores := make ([]float64 , n )
399413 priors := c .getPriors ()
400414 sum := float64 (0 )
415+
401416 // calculate the score for each class
402417 for index , class := range c .Classes {
403418 data := c .datas [class ]
404- // c is the sum of the logarithms
405- // as outlined in the refresher
406419 score := priors [index ]
407420 logScore := math .Log (priors [index ])
408421 for _ , word := range doc {
@@ -414,22 +427,64 @@ func (c *Classifier) SafeProbScores(doc []string) (scores []float64, inx int, st
414427 logScores [index ] = logScore
415428 sum += score
416429 }
417- for i := 0 ; i < n ; i ++ {
418- scores [i ] /= sum
419- }
420- inx , strict = findMax (scores )
430+
431+ // Get the winner from log-domain (always reliable)
421432 logInx , logStrict := findMax (logScores )
422433
423- // detect underflow -- the size
424- // relation between scores and logScores
425- // must be preserved or something is wrong
426- if inx != logInx || strict != logStrict {
434+ // Check for underflow: if sum is 0 or prob-domain disagrees with log-domain
435+ if sum == 0 {
436+ // Complete underflow - use log-sum-exp to recover probabilities
427437 err = ErrUnderflow
438+ scores = logScoresToProbs (logScores )
439+ inx , strict = logInx , logStrict
440+ } else {
441+ for i := 0 ; i < n ; i ++ {
442+ scores [i ] /= sum
443+ }
444+ inx , strict = findMax (scores )
445+
446+ // Detect partial underflow - when prob and log domains disagree
447+ if inx != logInx || strict != logStrict {
448+ err = ErrUnderflow
449+ // Use log-domain results as they're more reliable
450+ scores = logScoresToProbs (logScores )
451+ inx , strict = logInx , logStrict
452+ }
428453 }
454+
429455 atomic .AddInt32 (& c .seen , 1 )
430456 return scores , inx , strict , err
431457}
432458
459+ // logScoresToProbs converts log-domain scores to probabilities
460+ // using the log-sum-exp trick for numerical stability.
461+ func logScoresToProbs (logScores []float64 ) []float64 {
462+ n := len (logScores )
463+ probs := make ([]float64 , n )
464+
465+ // Find max for numerical stability
466+ maxLog := logScores [0 ]
467+ for i := 1 ; i < n ; i ++ {
468+ if logScores [i ] > maxLog {
469+ maxLog = logScores [i ]
470+ }
471+ }
472+
473+ // Compute exp(log - max) and sum
474+ sum := 0.0
475+ for i := 0 ; i < n ; i ++ {
476+ probs [i ] = math .Exp (logScores [i ] - maxLog )
477+ sum += probs [i ]
478+ }
479+
480+ // Normalize
481+ for i := 0 ; i < n ; i ++ {
482+ probs [i ] /= sum
483+ }
484+
485+ return probs
486+ }
487+
433488// ClassifySafe returns the most likely class for the given document
434489// along with the probability scores, whether the classification is strict,
435490// and an error if underflow is detected.
0 commit comments