snr.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. #!/usr/bin/python3
  2. """
  3. This script calculates the SNR, assuming that the input has, at a specific position, a single specific tone.
  4. So everything else is assumed to be noise.
  5. """
  6. import math
  7. from itertools import count
  8. import argparse
  9. from logging import getLogger, basicConfig
  10. from typing import Tuple
  11. import numpy as np
  12. import scipy.io.wavfile
  13. from matplotlib import pyplot as plt
  14. LOW_CUT = 100 # ignore anything below 100Hz
  15. HIGH_CUT = 4500
  16. def rms_partition(
  17. seq, partition_key=lambda x: False, ignore=lambda x: False
  18. ) -> Tuple[float, float]:
  19. """
  20. We want to separate what is signal and what is noise.
  21. Please provide a partition_key function, and you'll get two numbers: RMS of noise and RMS of signal
  22. """
  23. sum1 = np.ulonglong(0)
  24. cnt1 = 0
  25. sum2 = np.ulonglong(0)
  26. cnt2 = 0
  27. for freq, val in enumerate(seq):
  28. if ignore(freq):
  29. continue
  30. v = np.ulonglong(val)
  31. if partition_key(freq):
  32. sum2 += v ** 2
  33. cnt2 += 1
  34. else:
  35. sum1 += v ** 2
  36. cnt1 += 1
  37. return (np.sqrt(sum1 / cnt1), np.sqrt(sum2 / cnt2))
  38. class ToneAnalyzer:
  39. """
  40. this class checks how "clean" some tone is.
  41. give it the offset, a reference frequency, and it will help you calculate a SNR
  42. """
  43. def __init__(self, frequency, offset, duration, framerate: int, wave, threshold=0):
  44. self.freq = frequency
  45. self.offset = offset
  46. self.duration = duration
  47. self.framerate = framerate
  48. self.wave = wave
  49. self.log = getLogger(self.__class__.__name__)
  50. def get_wave(self):
  51. total = len(self.wave)
  52. self.log.debug("tot frames = %s", total)
  53. self.log.debug("framerate = %s", self.framerate)
  54. self.log.debug("duration = %.2fs", (total / self.framerate))
  55. start = math.ceil(self.framerate * (self.offset / 1000.0))
  56. end = math.floor(start + (self.framerate * (self.duration / 1000.0)))
  57. seq = self.wave[start:end]
  58. return seq
  59. def get_fft(self, seq):
  60. r = np.fft.fft(seq, self.framerate)[:HIGH_CUT]
  61. # x = np.arange(HIGH_CUT)
  62. # plt.plot(x, r)
  63. # plt.show()
  64. return r
  65. def analyze(self, freq2=None) -> float:
  66. seq = self.get_wave()
  67. r = self.get_fft(seq)
  68. def to_int(val):
  69. return int(np.absolute(val))
  70. fft = [to_int(val) for val in r]
  71. if freq2 is not None:
  72. freq_val = fft[self.freq]
  73. second_best_val = fft[freq2]
  74. if second_best_val == 0:
  75. return np.inf
  76. return freq_val / second_best_val
  77. noise, signal = rms_partition(
  78. seq,
  79. partition_key=lambda f: abs(f - self.freq) < 5,
  80. ignore=lambda f: f < LOW_CUT,
  81. )
  82. self.log.debug("noise=%.2f signal=%.2f", noise, signal)
  83. return signal / noise
  84. # self.log.debug("len %d", len(rms))
  85. # self.log.debug("max %.2f", np.nanmax(rms))
  86. # self.log.debug("RMS %.2f", rms)
  87. def get_parser():
  88. p = argparse.ArgumentParser()
  89. p.add_argument("fname")
  90. p.add_argument("--freq", type=int, default=440)
  91. p.add_argument("--freq2", type=int)
  92. p.add_argument("--offset", type=int, default=0, help="In milliseconds")
  93. p.add_argument("--duration", type=int, default=100, help="In milliseconds")
  94. p.add_argument(
  95. "--log-level",
  96. default="WARNING",
  97. choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
  98. )
  99. return p
  100. def main():
  101. args = get_parser().parse_args()
  102. basicConfig(level=args.log_level)
  103. if args.log_level in ["DEBUG", "INFO"]:
  104. np.seterr(all="warn")
  105. framerate, data = scipy.io.wavfile.read(args.fname)
  106. if len(data.shape) > 1:
  107. raise ValueError("audio not mono maybe?")
  108. data = data[:, 0]
  109. analyzer = ToneAnalyzer(
  110. args.freq,
  111. args.offset,
  112. args.duration,
  113. framerate,
  114. data,
  115. )
  116. print("%.4f" % (analyzer.analyze(args.freq2)))
  117. if __name__ == "__main__":
  118. main()