sparse.go 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. package multibayes
  2. type sparseMatrix struct {
  3. Tokens map[string]*sparseColumn `json:"tokens"` // []map[tokenindex]occurence
  4. Classes map[string]*sparseColumn `json:"classes"` // map[classname]classindex
  5. N int `json:"n"` // number of rows currently in the matrix
  6. }
  7. type sparseColumn struct {
  8. Data []int `json:"data"`
  9. }
  10. func newSparseColumn() *sparseColumn {
  11. return &sparseColumn{
  12. Data: make([]int, 0, 1000),
  13. }
  14. }
  15. func (s *sparseColumn) Add(index int) {
  16. s.Data = append(s.Data, index)
  17. }
  18. // return the number of rows that contain the column
  19. func (s *sparseColumn) Count() int {
  20. return len(s.Data)
  21. }
  22. // sparse to dense
  23. func (s *sparseColumn) Expand(n int) []float64 {
  24. expanded := make([]float64, n)
  25. for _, index := range s.Data {
  26. expanded[index] = 1.0
  27. }
  28. return expanded
  29. }
  30. func newSparseMatrix() *sparseMatrix {
  31. return &sparseMatrix{
  32. Tokens: make(map[string]*sparseColumn),
  33. Classes: make(map[string]*sparseColumn),
  34. N: 0,
  35. }
  36. }
  37. func (s *sparseMatrix) Add(ngrams []ngram, classes []string) {
  38. if len(ngrams) == 0 || len(classes) == 0 {
  39. return
  40. }
  41. for _, class := range classes {
  42. if _, ok := s.Classes[class]; !ok {
  43. s.Classes[class] = newSparseColumn()
  44. }
  45. s.Classes[class].Add(s.N)
  46. }
  47. // add ngrams uniquely
  48. added := make(map[string]int)
  49. for _, ngram := range ngrams {
  50. gramString := ngram.String()
  51. if _, ok := s.Tokens[gramString]; !ok {
  52. s.Tokens[gramString] = newSparseColumn()
  53. }
  54. // only add the document index once for the ngram
  55. if _, ok := added[gramString]; !ok {
  56. added[gramString] = 1
  57. s.Tokens[gramString].Add(s.N)
  58. }
  59. }
  60. // increment the row counter
  61. s.N++
  62. }