Browse Source

snr.py measure signal and noise

it assumes there is a single tone going on
boyska 2 years ago
parent
commit
455987a403
1 changed files with 131 additions and 0 deletions
  1. 131 0
      barker/snr.py

+ 131 - 0
barker/snr.py

@@ -0,0 +1,131 @@
+#!/usr/bin/python3
+"""
+This script calculates the SNR, assuming that the input has, at a specific position, a single specific tone.
+
+So everything else is assumed to be noise.
+"""
+
+import math
+from itertools import count
+import argparse
+from logging import getLogger, basicConfig
+from typing import Tuple
+
+import numpy as np
+import scipy.io.wavfile
+from matplotlib import pyplot as plt
+
+LOW_CUT = 100  # ignore anything below 100Hz
+HIGH_CUT = 4500
+
+def window_rms(a, window_size=2):
+    return np.sqrt(
+        sum([a[window_size - i - 1 : len(a) - i] ** 2 for i in range(window_size - 1)])
+        / window_size
+    )
+
+
+def rms_except(seq, except_key=lambda x: False) -> Tuple[float, float]:
+    sum1 = np.ulonglong(0)
+    cnt1 = 0
+    sum2 = np.ulonglong(0)
+    cnt2 = 0
+    for freq, val in enumerate(seq):
+        if freq < LOW_CUT:
+            continue
+        v = np.ulonglong(val)
+        if except_key(freq):
+            sum2 += v**2
+            cnt2 += 1
+        else:
+            sum1 += v**2
+            cnt1 += 1
+    return (np.sqrt(sum1 / cnt1), np.sqrt(sum2 / cnt2))
+
+
+class ToneAnalyzer:
+    """
+    this class checks how "clean" some tone is.
+    give it the offset, a reference frequency, and it will help you calculate a SNR
+    """
+
+    def __init__(self, frequency, offset, duration, framerate: int, wave, threshold=0):
+        self.freq = frequency
+        self.offset = offset
+        self.duration = duration
+        self.framerate = framerate
+        self.wave = wave
+        self.log = getLogger(self.__class__.__name__)
+
+    def analyze(self, freq2=None):
+        total = len(self.wave)
+        self.log.debug("tot frames = %s", total)
+        self.log.debug("framerate = %s", self.framerate)
+        self.log.debug("duration = %.2fs", (total / self.framerate))
+
+        start = math.ceil(self.framerate * (self.offset / 1000.0))
+        end = math.floor(start + (self.framerate * (self.duration / 1000.0)))
+        seq = self.wave[start:end]
+        r = np.fft.fft(seq, self.framerate)[:HIGH_CUT]
+
+        x = np.arange(HIGH_CUT)
+        plt.plot(x, r)
+        plt.show()
+
+        def to_int(val):
+            return int(np.absolute(val))
+
+        fft = [to_int(val) for val in r]
+        if freq2 is not None:
+            freq_val = fft[self.freq]
+            second_best_val = fft[freq2]
+            if second_best_val == 0:
+                return np.inf
+            return freq_val / second_best_val
+
+        noise, signal = rms_except(seq, except_key=lambda f: abs(f - self.freq) < 5)
+        self.log.debug("noise=%.2f signal=%.2f", noise, signal)
+        return signal / noise
+        # self.log.debug("len %d", len(rms))
+        # self.log.debug("max %.2f", np.nanmax(rms))
+        # self.log.debug("RMS %.2f", rms)
+
+
+def get_parser():
+    p = argparse.ArgumentParser()
+    p.add_argument("fname")
+    p.add_argument("--freq", type=int, default=440)
+    p.add_argument("--freq2", type=int)
+    p.add_argument("--offset", type=int, default=0, help="In milliseconds")
+    p.add_argument("--duration", type=int, default=100, help="In milliseconds")
+
+    p.add_argument(
+        "--log-level",
+        default="WARNING",
+        choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
+    )
+    return p
+
+
+def main():
+    args = get_parser().parse_args()
+    basicConfig(level=args.log_level)
+    if args.log_level in ['DEBUG', 'INFO']:
+        np.seterr(all='warn')
+    framerate, data = scipy.io.wavfile.read(args.fname)
+    if len(data.shape) > 1:
+        raise ValueError("audio not mono maybe?")
+        data = data[:, 0]
+
+    analyzer = ToneAnalyzer(
+        args.freq,
+        args.offset,
+        args.duration,
+        framerate,
+        data,
+    )
+    print('%.4f' % (analyzer.analyze(args.freq2)))
+
+
+if __name__ == "__main__":
+    main()