snr.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. #!/usr/bin/python3
  2. """
  3. This script calculates the SNR, assuming that the input has, at a specific position, a single specific tone.
  4. So everything else is assumed to be noise.
  5. """
  6. import math
  7. from itertools import count
  8. import argparse
  9. from logging import getLogger, basicConfig
  10. from typing import Tuple
  11. import numpy as np
  12. import scipy.io.wavfile
  13. from matplotlib import pyplot as plt
  14. LOW_CUT = 100 # ignore anything below 100Hz
  15. HIGH_CUT = 4500
  16. def window_rms(a, window_size=2):
  17. return np.sqrt(
  18. sum([a[window_size - i - 1 : len(a) - i] ** 2 for i in range(window_size - 1)])
  19. / window_size
  20. )
  21. def rms_except(seq, except_key=lambda x: False) -> Tuple[float, float]:
  22. sum1 = np.ulonglong(0)
  23. cnt1 = 0
  24. sum2 = np.ulonglong(0)
  25. cnt2 = 0
  26. for freq, val in enumerate(seq):
  27. if freq < LOW_CUT:
  28. continue
  29. v = np.ulonglong(val)
  30. if except_key(freq):
  31. sum2 += v**2
  32. cnt2 += 1
  33. else:
  34. sum1 += v**2
  35. cnt1 += 1
  36. return (np.sqrt(sum1 / cnt1), np.sqrt(sum2 / cnt2))
  37. class ToneAnalyzer:
  38. """
  39. this class checks how "clean" some tone is.
  40. give it the offset, a reference frequency, and it will help you calculate a SNR
  41. """
  42. def __init__(self, frequency, offset, duration, framerate: int, wave, threshold=0):
  43. self.freq = frequency
  44. self.offset = offset
  45. self.duration = duration
  46. self.framerate = framerate
  47. self.wave = wave
  48. self.log = getLogger(self.__class__.__name__)
  49. def analyze(self, freq2=None):
  50. total = len(self.wave)
  51. self.log.debug("tot frames = %s", total)
  52. self.log.debug("framerate = %s", self.framerate)
  53. self.log.debug("duration = %.2fs", (total / self.framerate))
  54. start = math.ceil(self.framerate * (self.offset / 1000.0))
  55. end = math.floor(start + (self.framerate * (self.duration / 1000.0)))
  56. seq = self.wave[start:end]
  57. r = np.fft.fft(seq, self.framerate)[:HIGH_CUT]
  58. x = np.arange(HIGH_CUT)
  59. plt.plot(x, r)
  60. plt.show()
  61. def to_int(val):
  62. return int(np.absolute(val))
  63. fft = [to_int(val) for val in r]
  64. if freq2 is not None:
  65. freq_val = fft[self.freq]
  66. second_best_val = fft[freq2]
  67. if second_best_val == 0:
  68. return np.inf
  69. return freq_val / second_best_val
  70. noise, signal = rms_except(seq, except_key=lambda f: abs(f - self.freq) < 5)
  71. self.log.debug("noise=%.2f signal=%.2f", noise, signal)
  72. return signal / noise
  73. # self.log.debug("len %d", len(rms))
  74. # self.log.debug("max %.2f", np.nanmax(rms))
  75. # self.log.debug("RMS %.2f", rms)
  76. def get_parser():
  77. p = argparse.ArgumentParser()
  78. p.add_argument("fname")
  79. p.add_argument("--freq", type=int, default=440)
  80. p.add_argument("--freq2", type=int)
  81. p.add_argument("--offset", type=int, default=0, help="In milliseconds")
  82. p.add_argument("--duration", type=int, default=100, help="In milliseconds")
  83. p.add_argument(
  84. "--log-level",
  85. default="WARNING",
  86. choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
  87. )
  88. return p
  89. def main():
  90. args = get_parser().parse_args()
  91. basicConfig(level=args.log_level)
  92. if args.log_level in ['DEBUG', 'INFO']:
  93. np.seterr(all='warn')
  94. framerate, data = scipy.io.wavfile.read(args.fname)
  95. if len(data.shape) > 1:
  96. raise ValueError("audio not mono maybe?")
  97. data = data[:, 0]
  98. analyzer = ToneAnalyzer(
  99. args.freq,
  100. args.offset,
  101. args.duration,
  102. framerate,
  103. data,
  104. )
  105. print('%.4f' % (analyzer.analyze(args.freq2)))
  106. if __name__ == "__main__":
  107. main()