bayesian.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547
  1. package main
  2. import (
  3. "encoding/gob"
  4. "errors"
  5. "io"
  6. "math"
  7. "os"
  8. "path/filepath"
  9. "sync/atomic"
  10. )
  11. // defaultProb is the tiny non-zero probability that a word
  12. // we have not seen before appears in the class.
  13. const defaultProb = 0.00000000001
  14. // ErrUnderflow is returned when an underflow is detected.
  15. var ErrUnderflow = errors.New("possible underflow detected")
  16. // Class defines a class that the classifier will filter:
  17. // C = {C_1, ..., C_n}. You should define your classes as a
  18. // set of constants, for example as follows:
  19. //
  20. // const (
  21. // Good Class = "Good"
  22. // Bad Class = "Bad
  23. // )
  24. //
  25. // Class values should be unique.
  26. type Class string
  27. // Classifier implements the Naive Bayesian Classifier.
  28. type Classifier struct {
  29. Classes []Class
  30. learned int // docs learned
  31. seen int32 // docs seen
  32. datas map[Class]*classData
  33. tfIdf bool
  34. DidConvertTfIdf bool // we can't classify a TF-IDF classifier if we haven't yet
  35. // called ConverTermsFreqToTfIdf
  36. }
  37. // serializableClassifier represents a container for
  38. // Classifier objects whose fields are modifiable by
  39. // reflection and are therefore writeable by gob.
  40. type serializableClassifier struct {
  41. Classes []Class
  42. Learned int
  43. Seen int
  44. Datas map[Class]*classData
  45. TfIdf bool
  46. DidConvertTfIdf bool
  47. }
  48. // classData holds the frequency data for words in a
  49. // particular class. In the future, we may replace this
  50. // structure with a trie-like structure for more
  51. // efficient storage.
  52. type classData struct {
  53. Freqs map[string]float64
  54. FreqTfs map[string][]float64
  55. Total int
  56. }
  57. // newClassData creates a new empty classData node.
  58. func newClassData() *classData {
  59. return &classData{
  60. Freqs: make(map[string]float64),
  61. FreqTfs: make(map[string][]float64),
  62. }
  63. }
  64. // getWordProb returns P(W|C_j) -- the probability of seeing
  65. // a particular word W in a document of this class.
  66. func (d *classData) getWordProb(word string) float64 {
  67. value, ok := d.Freqs[word]
  68. if !ok {
  69. return defaultProb
  70. }
  71. return float64(value) / float64(d.Total)
  72. }
  73. // getWordsProb returns P(D|C_j) -- the probability of seeing
  74. // this set of words in a document of this class.
  75. //
  76. // Note that words should not be empty, and this method of
  77. // calulation is prone to underflow if there are many words
  78. // and their individual probabilties are small.
  79. func (d *classData) getWordsProb(words []string) (prob float64) {
  80. prob = 1
  81. for _, word := range words {
  82. prob *= d.getWordProb(word)
  83. }
  84. return
  85. }
  86. // NewClassifierTfIdf returns a new classifier. The classes provided
  87. // should be at least 2 in number and unique, or this method will
  88. // panic.
  89. func NewClassifierTfIdf(classes ...Class) (c *Classifier) {
  90. n := len(classes)
  91. // check size
  92. if n < 2 {
  93. panic("provide at least two classes")
  94. }
  95. // check uniqueness
  96. check := make(map[Class]bool, n)
  97. for _, class := range classes {
  98. check[class] = true
  99. }
  100. if len(check) != n {
  101. panic("classes must be unique")
  102. }
  103. // create the classifier
  104. c = &Classifier{
  105. Classes: classes,
  106. datas: make(map[Class]*classData, n),
  107. tfIdf: true,
  108. }
  109. for _, class := range classes {
  110. c.datas[class] = newClassData()
  111. }
  112. return
  113. }
  114. // NewClassifier returns a new classifier. The classes provided
  115. // should be at least 2 in number and unique, or this method will
  116. // panic.
  117. func NewClassifier(classes ...Class) (c *Classifier) {
  118. n := len(classes)
  119. // check size
  120. if n < 2 {
  121. panic("provide at least two classes")
  122. }
  123. // check uniqueness
  124. check := make(map[Class]bool, n)
  125. for _, class := range classes {
  126. check[class] = true
  127. }
  128. if len(check) != n {
  129. panic("classes must be unique")
  130. }
  131. // create the classifier
  132. c = &Classifier{
  133. Classes: classes,
  134. datas: make(map[Class]*classData, n),
  135. tfIdf: false,
  136. DidConvertTfIdf: false,
  137. }
  138. for _, class := range classes {
  139. c.datas[class] = newClassData()
  140. }
  141. return
  142. }
  143. // NewClassifierFromFile loads an existing classifier from
  144. // file. The classifier was previously saved with a call
  145. // to c.WriteToFile(string).
  146. func NewClassifierFromFile(name string) (c *Classifier, err error) {
  147. file, err := os.Open(name)
  148. if err != nil {
  149. return nil, err
  150. }
  151. return NewClassifierFromReader(file)
  152. }
  153. // NewClassifierFromReader: This actually does the deserializing of a Gob encoded classifier
  154. func NewClassifierFromReader(r io.Reader) (c *Classifier, err error) {
  155. dec := gob.NewDecoder(r)
  156. w := new(serializableClassifier)
  157. err = dec.Decode(w)
  158. return &Classifier{w.Classes, w.Learned, int32(w.Seen), w.Datas, w.TfIdf, w.DidConvertTfIdf}, err
  159. }
  160. // getPriors returns the prior probabilities for the
  161. // classes provided -- P(C_j).
  162. //
  163. // TODO: There is a way to smooth priors, currently
  164. // not implemented here.
  165. func (c *Classifier) getPriors() (priors []float64) {
  166. n := len(c.Classes)
  167. priors = make([]float64, n, n)
  168. sum := 0
  169. for index, class := range c.Classes {
  170. total := c.datas[class].Total
  171. priors[index] = float64(total)
  172. sum += total
  173. }
  174. if sum != 0 {
  175. for i := 0; i < n; i++ {
  176. priors[i] /= float64(sum)
  177. }
  178. }
  179. return
  180. }
  181. // Learned returns the number of documents ever learned
  182. // in the lifetime of this classifier.
  183. func (c *Classifier) Learned() int {
  184. return c.learned
  185. }
  186. // Seen returns the number of documents ever classified
  187. // in the lifetime of this classifier.
  188. func (c *Classifier) Seen() int {
  189. return int(atomic.LoadInt32(&c.seen))
  190. }
  191. // IsTfIdf returns true if we are a classifier of type TfIdf
  192. func (c *Classifier) IsTfIdf() bool {
  193. return c.tfIdf
  194. }
  195. // WordCount returns the number of words counted for
  196. // each class in the lifetime of the classifier.
  197. func (c *Classifier) WordCount() (result []int) {
  198. result = make([]int, len(c.Classes))
  199. for inx, class := range c.Classes {
  200. data := c.datas[class]
  201. result[inx] = data.Total
  202. }
  203. return
  204. }
  205. // Observe should be used when word-frequencies have been already been learned
  206. // externally (e.g., hadoop)
  207. func (c *Classifier) Observe(word string, count int, which Class) {
  208. data := c.datas[which]
  209. data.Freqs[word] += float64(count)
  210. data.Total += count
  211. }
  212. // Learn will accept new training documents for
  213. // supervised learning.
  214. func (c *Classifier) Learn(document []string, which Class) {
  215. // If we are a tfidf classifier we first need to get terms as
  216. // terms frequency and store that to work out the idf part later
  217. // in ConvertToIDF().
  218. if c.tfIdf {
  219. if c.DidConvertTfIdf {
  220. panic("Cannot call ConvertTermsFreqToTfIdf more than once. Reset and relearn to reconvert.")
  221. }
  222. // Term Frequency: word count in document / document length
  223. docTf := make(map[string]float64)
  224. for _, word := range document {
  225. docTf[word]++
  226. }
  227. docLen := float64(len(document))
  228. for wIndex, wCount := range docTf {
  229. docTf[wIndex] = wCount / docLen
  230. // add the TF sample, after training we can get IDF values.
  231. c.datas[which].FreqTfs[wIndex] = append(c.datas[which].FreqTfs[wIndex], docTf[wIndex])
  232. }
  233. }
  234. data := c.datas[which]
  235. for _, word := range document {
  236. data.Freqs[word]++
  237. data.Total++
  238. }
  239. c.learned++
  240. }
  241. // ConvertTermsFreqToTfIdf uses all the TF samples for the class and converts
  242. // them to TF-IDF https://en.wikipedia.org/wiki/Tf%E2%80%93idf
  243. // once we have finished learning all the classes and have the totals.
  244. func (c *Classifier) ConvertTermsFreqToTfIdf() {
  245. if c.DidConvertTfIdf {
  246. panic("Cannot call ConvertTermsFreqToTfIdf more than once. Reset and relearn to reconvert.")
  247. }
  248. for className := range c.datas {
  249. for wIndex := range c.datas[className].FreqTfs {
  250. tfIdfAdder := float64(0)
  251. for tfSampleIndex := range c.datas[className].FreqTfs[wIndex] {
  252. // we always want a possitive TF-IDF score.
  253. tf := c.datas[className].FreqTfs[wIndex][tfSampleIndex]
  254. c.datas[className].FreqTfs[wIndex][tfSampleIndex] = math.Log1p(tf) * math.Log1p(float64(c.learned)/float64(c.datas[className].Total))
  255. tfIdfAdder += c.datas[className].FreqTfs[wIndex][tfSampleIndex]
  256. }
  257. // convert the 'counts' to TF-IDF's
  258. c.datas[className].Freqs[wIndex] = tfIdfAdder
  259. }
  260. }
  261. // sanity check
  262. c.DidConvertTfIdf = true
  263. }
  264. // LogScores produces "log-likelihood"-like scores that can
  265. // be used to classify documents into classes.
  266. //
  267. // The value of the score is proportional to the likelihood,
  268. // as determined by the classifier, that the given document
  269. // belongs to the given class. This is true even when scores
  270. // returned are negative, which they will be (since we are
  271. // taking logs of probabilities).
  272. //
  273. // The index j of the score corresponds to the class given
  274. // by c.Classes[j].
  275. //
  276. // Additionally returned are "inx" and "strict" values. The
  277. // inx corresponds to the maximum score in the array. If more
  278. // than one of the scores holds the maximum values, then
  279. // strict is false.
  280. //
  281. // Unlike c.Probabilities(), this function is not prone to
  282. // floating point underflow and is relatively safe to use.
  283. func (c *Classifier) LogScores(document []string) (scores []float64, inx int, strict bool) {
  284. if c.tfIdf && !c.DidConvertTfIdf {
  285. panic("Using a TF-IDF classifier. Please call ConvertTermsFreqToTfIdf before calling LogScores.")
  286. }
  287. n := len(c.Classes)
  288. scores = make([]float64, n, n)
  289. priors := c.getPriors()
  290. // calculate the score for each class
  291. for index, class := range c.Classes {
  292. data := c.datas[class]
  293. // c is the sum of the logarithms
  294. // as outlined in the refresher
  295. score := math.Log(priors[index])
  296. for _, word := range document {
  297. score += math.Log(data.getWordProb(word))
  298. }
  299. scores[index] = score
  300. }
  301. inx, strict = findMax(scores)
  302. atomic.AddInt32(&c.seen, 1)
  303. return scores, inx, strict
  304. }
  305. // ProbScores works the same as LogScores, but delivers
  306. // actual probabilities as discussed above. Note that float64
  307. // underflow is possible if the word list contains too
  308. // many words that have probabilities very close to 0.
  309. //
  310. // Notes on underflow: underflow is going to occur when you're
  311. // trying to assess large numbers of words that you have
  312. // never seen before. Depending on the application, this
  313. // may or may not be a concern. Consider using SafeProbScores()
  314. // instead.
  315. func (c *Classifier) ProbScores(doc []string) (scores []float64, inx int, strict bool) {
  316. if c.tfIdf && !c.DidConvertTfIdf {
  317. panic("Using a TF-IDF classifier. Please call ConvertTermsFreqToTfIdf before calling ProbScores.")
  318. }
  319. n := len(c.Classes)
  320. scores = make([]float64, n, n)
  321. priors := c.getPriors()
  322. sum := float64(0)
  323. // calculate the score for each class
  324. for index, class := range c.Classes {
  325. data := c.datas[class]
  326. // c is the sum of the logarithms
  327. // as outlined in the refresher
  328. score := priors[index]
  329. for _, word := range doc {
  330. score *= data.getWordProb(word)
  331. }
  332. scores[index] = score
  333. sum += score
  334. }
  335. for i := 0; i < n; i++ {
  336. scores[i] /= sum
  337. }
  338. inx, strict = findMax(scores)
  339. atomic.AddInt32(&c.seen, 1)
  340. return scores, inx, strict
  341. }
  342. // SafeProbScores works the same as ProbScores, but is
  343. // able to detect underflow in those cases where underflow
  344. // results in the reverse classification. If an underflow is detected,
  345. // this method returns an ErrUnderflow, allowing the user to deal with it as
  346. // necessary. Note that underflow, under certain rare circumstances,
  347. // may still result in incorrect probabilities being returned,
  348. // but this method guarantees that all error-less invokations
  349. // are properly classified.
  350. //
  351. // Underflow detection is more costly because it also
  352. // has to make additional log score calculations.
  353. func (c *Classifier) SafeProbScores(doc []string) (scores []float64, inx int, strict bool, err error) {
  354. if c.tfIdf && !c.DidConvertTfIdf {
  355. panic("Using a TF-IDF classifier. Please call ConvertTermsFreqToTfIdf before calling SafeProbScores.")
  356. }
  357. n := len(c.Classes)
  358. scores = make([]float64, n, n)
  359. logScores := make([]float64, n, n)
  360. priors := c.getPriors()
  361. sum := float64(0)
  362. // calculate the score for each class
  363. for index, class := range c.Classes {
  364. data := c.datas[class]
  365. // c is the sum of the logarithms
  366. // as outlined in the refresher
  367. score := priors[index]
  368. logScore := math.Log(priors[index])
  369. for _, word := range doc {
  370. p := data.getWordProb(word)
  371. score *= p
  372. logScore += math.Log(p)
  373. }
  374. scores[index] = score
  375. logScores[index] = logScore
  376. sum += score
  377. }
  378. for i := 0; i < n; i++ {
  379. scores[i] /= sum
  380. }
  381. inx, strict = findMax(scores)
  382. logInx, logStrict := findMax(logScores)
  383. // detect underflow -- the size
  384. // relation between scores and logScores
  385. // must be preserved or something is wrong
  386. if inx != logInx || strict != logStrict {
  387. err = ErrUnderflow
  388. }
  389. atomic.AddInt32(&c.seen, 1)
  390. return scores, inx, strict, err
  391. }
  392. // WordFrequencies returns a matrix of word frequencies that currently
  393. // exist in the classifier for each class state for the given input
  394. // words. In other words, if you obtain the frequencies
  395. //
  396. // freqs := c.WordFrequencies(/* [j]string */)
  397. //
  398. // then the expression freq[i][j] represents the frequency of the j-th
  399. // word within the i-th class.
  400. func (c *Classifier) WordFrequencies(words []string) (freqMatrix [][]float64) {
  401. n, l := len(c.Classes), len(words)
  402. freqMatrix = make([][]float64, n)
  403. for i := range freqMatrix {
  404. arr := make([]float64, l)
  405. data := c.datas[c.Classes[i]]
  406. for j := range arr {
  407. arr[j] = data.getWordProb(words[j])
  408. }
  409. freqMatrix[i] = arr
  410. }
  411. return
  412. }
  413. // WordsByClass returns a map of words and their probability of
  414. // appearing in the given class.
  415. func (c *Classifier) WordsByClass(class Class) (freqMap map[string]float64) {
  416. freqMap = make(map[string]float64)
  417. for word, cnt := range c.datas[class].Freqs {
  418. freqMap[word] = float64(cnt) / float64(c.datas[class].Total)
  419. }
  420. return freqMap
  421. }
  422. // WriteToFile serializes this classifier to a file.
  423. func (c *Classifier) WriteToFile(name string) (err error) {
  424. file, err := os.OpenFile(name, os.O_WRONLY|os.O_CREATE, 0644)
  425. if err != nil {
  426. return err
  427. }
  428. return c.WriteTo(file)
  429. }
  430. // WriteClassesToFile writes all classes to files.
  431. func (c *Classifier) WriteClassesToFile(rootPath string) (err error) {
  432. for name := range c.datas {
  433. c.WriteClassToFile(name, rootPath)
  434. }
  435. return
  436. }
  437. // WriteClassToFile writes a single class to file.
  438. func (c *Classifier) WriteClassToFile(name Class, rootPath string) (err error) {
  439. data := c.datas[name]
  440. fileName := filepath.Join(rootPath, string(name))
  441. file, err := os.OpenFile(fileName, os.O_WRONLY|os.O_CREATE, 0644)
  442. if err != nil {
  443. return err
  444. }
  445. enc := gob.NewEncoder(file)
  446. err = enc.Encode(data)
  447. return
  448. }
  449. // WriteTo serializes this classifier to GOB and write to Writer.
  450. func (c *Classifier) WriteTo(w io.Writer) (err error) {
  451. enc := gob.NewEncoder(w)
  452. err = enc.Encode(&serializableClassifier{c.Classes, c.learned, int(c.seen), c.datas, c.tfIdf, c.DidConvertTfIdf})
  453. return
  454. }
  455. // ReadClassFromFile loads existing class data from a
  456. // file.
  457. func (c *Classifier) ReadClassFromFile(class Class, location string) (err error) {
  458. fileName := filepath.Join(location, string(class))
  459. file, err := os.Open(fileName)
  460. if err != nil {
  461. return err
  462. }
  463. dec := gob.NewDecoder(file)
  464. w := new(classData)
  465. err = dec.Decode(w)
  466. c.learned++
  467. c.datas[class] = w
  468. return
  469. }
  470. // findMax finds the maximum of a set of scores; if the
  471. // maximum is strict -- that is, it is the single unique
  472. // maximum from the set -- then strict has return value
  473. // true. Otherwise it is false.
  474. func findMax(scores []float64) (inx int, strict bool) {
  475. inx = 0
  476. strict = true
  477. for i := 1; i < len(scores); i++ {
  478. if scores[inx] < scores[i] {
  479. inx = i
  480. strict = true
  481. } else if scores[inx] == scores[i] {
  482. strict = false
  483. }
  484. }
  485. return
  486. }