snr.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. #!/usr/bin/python3
  2. """
  3. This script calculates the SNR, assuming that the input has, at a specific position, a single specific tone.
  4. So everything else is assumed to be noise.
  5. """
  6. import math
  7. from itertools import count
  8. import argparse
  9. from logging import getLogger, basicConfig
  10. from typing import Tuple
  11. import numpy as np
  12. import scipy.io.wavfile
  13. from matplotlib import pyplot as plt
  14. LOW_CUT = 100 # ignore anything below 100Hz
  15. HIGH_CUT = 4500
  16. log = getLogger("snr")
  17. def rms_partition(
  18. seq, partition_key=lambda x: False, ignore=lambda x: False
  19. ) -> Tuple[float, float]:
  20. """
  21. We want to separate what is signal and what is noise.
  22. Please provide a partition_key function, and you'll get two numbers: RMS of noise and RMS of signal
  23. """
  24. sum1 = np.ulonglong(0)
  25. cnt1 = 0
  26. sum2 = np.ulonglong(0)
  27. cnt2 = 0
  28. for freq, val in enumerate(seq):
  29. if ignore(freq):
  30. continue
  31. v = np.ulonglong(val)
  32. if partition_key(freq):
  33. sum2 += v ** 2
  34. cnt2 += 1
  35. else:
  36. sum1 += v ** 2
  37. cnt1 += 1
  38. def rm(sum_, cnt):
  39. return np.sqrt(sum_ / cnt)
  40. return (rm(sum1, cnt1), rm(sum2, cnt2))
  41. class ToneAnalyzer:
  42. """
  43. this class checks how "clean" some tone is.
  44. give it the offset, a reference frequency, and it will help you calculate a SNR
  45. """
  46. def __init__(self, frequency, offset, duration, framerate: int, wave, threshold=0):
  47. self.freq = frequency
  48. self.offset = offset
  49. self.duration = duration
  50. self.framerate = framerate
  51. self.wave = wave
  52. self.log = getLogger(self.__class__.__name__)
  53. def get_wave(self):
  54. total = len(self.wave)
  55. self.log.debug("tot frames = %s", total)
  56. self.log.debug("framerate = %s", self.framerate)
  57. self.log.debug("duration = %.2fs", (total / self.framerate))
  58. start = math.ceil(self.framerate * (self.offset / 1000.0))
  59. end = math.floor(start + (self.framerate * (self.duration / 1000.0)))
  60. seq = self.wave[start:end]
  61. return seq
  62. def get_fft(self, seq):
  63. r = np.fft.fft(seq, self.framerate)[:HIGH_CUT]
  64. # x = np.arange(HIGH_CUT)
  65. # plt.plot(x, r)
  66. # plt.show()
  67. def to_int(val):
  68. return int(np.absolute(val))
  69. fft = [to_int(val) for val in r]
  70. return fft
  71. def is_signal(self, f):
  72. return abs(f - self.freq) < 5
  73. def analyze(self, freq2=None) -> float:
  74. seq = self.get_wave()
  75. fft = self.get_fft(seq)
  76. if freq2 is not None:
  77. freq_val = fft[self.freq]
  78. second_best_val = fft[freq2]
  79. if second_best_val == 0:
  80. return np.inf
  81. return freq_val / second_best_val
  82. noise, signal = rms_partition(
  83. seq,
  84. partition_key=self.is_signal,
  85. ignore=lambda f: f < LOW_CUT,
  86. )
  87. self.log.debug("noise=%.2f signal=%.2f", noise, signal)
  88. return signal / noise
  89. # self.log.debug("len %d", len(rms))
  90. # self.log.debug("max %.2f", np.nanmax(rms))
  91. # self.log.debug("RMS %.2f", rms)
  92. class MultiToneAnalyzer(ToneAnalyzer):
  93. def is_signal(self, f):
  94. for freq_i in self.freq:
  95. if abs(f - freq_i) < 30:
  96. return True
  97. return False
  98. def get_parser():
  99. p = argparse.ArgumentParser()
  100. p.add_argument("fname")
  101. p.add_argument("--freq", type=int, default=[440], nargs="*")
  102. p.add_argument("--freq2", type=int)
  103. p.add_argument("--offset", type=int, default=0, help="In milliseconds")
  104. p.add_argument("--duration", type=int, default=100, help="In milliseconds")
  105. p.add_argument(
  106. "--log-level",
  107. default="WARNING",
  108. choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
  109. )
  110. return p
  111. def main():
  112. args = get_parser().parse_args()
  113. basicConfig(level=args.log_level)
  114. if args.log_level in ["DEBUG", "INFO"]:
  115. np.seterr(all="warn")
  116. framerate, data = scipy.io.wavfile.read(args.fname)
  117. if len(data.shape) > 1:
  118. raise ValueError("audio not mono maybe?")
  119. data = data[:, 0]
  120. analyzer = MultiToneAnalyzer(
  121. args.freq,
  122. args.offset,
  123. args.duration,
  124. framerate,
  125. data,
  126. )
  127. print("%.4f" % (analyzer.analyze(args.freq2)))
  128. if __name__ == "__main__":
  129. main()