|
@@ -0,0 +1,131 @@
|
|
|
+#!/usr/bin/python3
|
|
|
+"""
|
|
|
+This script calculates the SNR, assuming that the input has, at a specific position, a single specific tone.
|
|
|
+
|
|
|
+So everything else is assumed to be noise.
|
|
|
+"""
|
|
|
+
|
|
|
+import math
|
|
|
+from itertools import count
|
|
|
+import argparse
|
|
|
+from logging import getLogger, basicConfig
|
|
|
+from typing import Tuple
|
|
|
+
|
|
|
+import numpy as np
|
|
|
+import scipy.io.wavfile
|
|
|
+from matplotlib import pyplot as plt
|
|
|
+
|
|
|
+LOW_CUT = 100 # ignore anything below 100Hz
|
|
|
+HIGH_CUT = 4500
|
|
|
+
|
|
|
+def window_rms(a, window_size=2):
|
|
|
+ return np.sqrt(
|
|
|
+ sum([a[window_size - i - 1 : len(a) - i] ** 2 for i in range(window_size - 1)])
|
|
|
+ / window_size
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+def rms_except(seq, except_key=lambda x: False) -> Tuple[float, float]:
|
|
|
+ sum1 = np.ulonglong(0)
|
|
|
+ cnt1 = 0
|
|
|
+ sum2 = np.ulonglong(0)
|
|
|
+ cnt2 = 0
|
|
|
+ for freq, val in enumerate(seq):
|
|
|
+ if freq < LOW_CUT:
|
|
|
+ continue
|
|
|
+ v = np.ulonglong(val)
|
|
|
+ if except_key(freq):
|
|
|
+ sum2 += v**2
|
|
|
+ cnt2 += 1
|
|
|
+ else:
|
|
|
+ sum1 += v**2
|
|
|
+ cnt1 += 1
|
|
|
+ return (np.sqrt(sum1 / cnt1), np.sqrt(sum2 / cnt2))
|
|
|
+
|
|
|
+
|
|
|
+class ToneAnalyzer:
|
|
|
+ """
|
|
|
+ this class checks how "clean" some tone is.
|
|
|
+ give it the offset, a reference frequency, and it will help you calculate a SNR
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(self, frequency, offset, duration, framerate: int, wave, threshold=0):
|
|
|
+ self.freq = frequency
|
|
|
+ self.offset = offset
|
|
|
+ self.duration = duration
|
|
|
+ self.framerate = framerate
|
|
|
+ self.wave = wave
|
|
|
+ self.log = getLogger(self.__class__.__name__)
|
|
|
+
|
|
|
+ def analyze(self, freq2=None):
|
|
|
+ total = len(self.wave)
|
|
|
+ self.log.debug("tot frames = %s", total)
|
|
|
+ self.log.debug("framerate = %s", self.framerate)
|
|
|
+ self.log.debug("duration = %.2fs", (total / self.framerate))
|
|
|
+
|
|
|
+ start = math.ceil(self.framerate * (self.offset / 1000.0))
|
|
|
+ end = math.floor(start + (self.framerate * (self.duration / 1000.0)))
|
|
|
+ seq = self.wave[start:end]
|
|
|
+ r = np.fft.fft(seq, self.framerate)[:HIGH_CUT]
|
|
|
+
|
|
|
+ x = np.arange(HIGH_CUT)
|
|
|
+ plt.plot(x, r)
|
|
|
+ plt.show()
|
|
|
+
|
|
|
+ def to_int(val):
|
|
|
+ return int(np.absolute(val))
|
|
|
+
|
|
|
+ fft = [to_int(val) for val in r]
|
|
|
+ if freq2 is not None:
|
|
|
+ freq_val = fft[self.freq]
|
|
|
+ second_best_val = fft[freq2]
|
|
|
+ if second_best_val == 0:
|
|
|
+ return np.inf
|
|
|
+ return freq_val / second_best_val
|
|
|
+
|
|
|
+ noise, signal = rms_except(seq, except_key=lambda f: abs(f - self.freq) < 5)
|
|
|
+ self.log.debug("noise=%.2f signal=%.2f", noise, signal)
|
|
|
+ return signal / noise
|
|
|
+ # self.log.debug("len %d", len(rms))
|
|
|
+ # self.log.debug("max %.2f", np.nanmax(rms))
|
|
|
+ # self.log.debug("RMS %.2f", rms)
|
|
|
+
|
|
|
+
|
|
|
+def get_parser():
|
|
|
+ p = argparse.ArgumentParser()
|
|
|
+ p.add_argument("fname")
|
|
|
+ p.add_argument("--freq", type=int, default=440)
|
|
|
+ p.add_argument("--freq2", type=int)
|
|
|
+ p.add_argument("--offset", type=int, default=0, help="In milliseconds")
|
|
|
+ p.add_argument("--duration", type=int, default=100, help="In milliseconds")
|
|
|
+
|
|
|
+ p.add_argument(
|
|
|
+ "--log-level",
|
|
|
+ default="WARNING",
|
|
|
+ choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
|
|
+ )
|
|
|
+ return p
|
|
|
+
|
|
|
+
|
|
|
+def main():
|
|
|
+ args = get_parser().parse_args()
|
|
|
+ basicConfig(level=args.log_level)
|
|
|
+ if args.log_level in ['DEBUG', 'INFO']:
|
|
|
+ np.seterr(all='warn')
|
|
|
+ framerate, data = scipy.io.wavfile.read(args.fname)
|
|
|
+ if len(data.shape) > 1:
|
|
|
+ raise ValueError("audio not mono maybe?")
|
|
|
+ data = data[:, 0]
|
|
|
+
|
|
|
+ analyzer = ToneAnalyzer(
|
|
|
+ args.freq,
|
|
|
+ args.offset,
|
|
|
+ args.duration,
|
|
|
+ framerate,
|
|
|
+ data,
|
|
|
+ )
|
|
|
+ print('%.4f' % (analyzer.analyze(args.freq2)))
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ main()
|