Browse Source

snr.py refactored

boyska 2 years ago
parent
commit
101a79d129
1 changed files with 31 additions and 18 deletions
  1. 31 18
      barker/snr.py

+ 31 - 18
barker/snr.py

@@ -18,27 +18,28 @@ from matplotlib import pyplot as plt
 LOW_CUT = 100  # ignore anything below 100Hz
 HIGH_CUT = 4500
 
-def window_rms(a, window_size=2):
-    return np.sqrt(
-        sum([a[window_size - i - 1 : len(a) - i] ** 2 for i in range(window_size - 1)])
-        / window_size
-    )
 
+def rms_partition(
+    seq, partition_key=lambda x: False, ignore=lambda x: False
+) -> Tuple[float, float]:
+    """
+    We want to separate what is signal and what is noise.
 
-def rms_except(seq, except_key=lambda x: False) -> Tuple[float, float]:
+    Please provide a partition_key function, and you'll get two numbers: RMS of noise and RMS of signal
+    """
     sum1 = np.ulonglong(0)
     cnt1 = 0
     sum2 = np.ulonglong(0)
     cnt2 = 0
     for freq, val in enumerate(seq):
-        if freq < LOW_CUT:
+        if ignore(freq):
             continue
         v = np.ulonglong(val)
-        if except_key(freq):
-            sum2 += v**2
+        if partition_key(freq):
+            sum2 += v ** 2
             cnt2 += 1
         else:
-            sum1 += v**2
+            sum1 += v ** 2
             cnt1 += 1
     return (np.sqrt(sum1 / cnt1), np.sqrt(sum2 / cnt2))
 
@@ -57,7 +58,7 @@ class ToneAnalyzer:
         self.wave = wave
         self.log = getLogger(self.__class__.__name__)
 
-    def analyze(self, freq2=None):
+    def get_wave(self):
         total = len(self.wave)
         self.log.debug("tot frames = %s", total)
         self.log.debug("framerate = %s", self.framerate)
@@ -66,11 +67,19 @@ class ToneAnalyzer:
         start = math.ceil(self.framerate * (self.offset / 1000.0))
         end = math.floor(start + (self.framerate * (self.duration / 1000.0)))
         seq = self.wave[start:end]
+        return seq
+
+    def get_fft(self, seq):
         r = np.fft.fft(seq, self.framerate)[:HIGH_CUT]
+        # x = np.arange(HIGH_CUT)
+        # plt.plot(x, r)
+        # plt.show()
+        return r
+
+    def analyze(self, freq2=None) -> float:
+        seq = self.get_wave()
+        r = self.get_fft(seq)
 
-        x = np.arange(HIGH_CUT)
-        plt.plot(x, r)
-        plt.show()
 
         def to_int(val):
             return int(np.absolute(val))
@@ -83,7 +92,11 @@ class ToneAnalyzer:
                 return np.inf
             return freq_val / second_best_val
 
-        noise, signal = rms_except(seq, except_key=lambda f: abs(f - self.freq) < 5)
+        noise, signal = rms_partition(
+            seq,
+            partition_key=lambda f: abs(f - self.freq) < 5,
+            ignore=lambda f: f < LOW_CUT,
+        )
         self.log.debug("noise=%.2f signal=%.2f", noise, signal)
         return signal / noise
         # self.log.debug("len %d", len(rms))
@@ -110,8 +123,8 @@ def get_parser():
 def main():
     args = get_parser().parse_args()
     basicConfig(level=args.log_level)
-    if args.log_level in ['DEBUG', 'INFO']:
-        np.seterr(all='warn')
+    if args.log_level in ["DEBUG", "INFO"]:
+        np.seterr(all="warn")
     framerate, data = scipy.io.wavfile.read(args.fname)
     if len(data.shape) > 1:
         raise ValueError("audio not mono maybe?")
@@ -124,7 +137,7 @@ def main():
         framerate,
         data,
     )
-    print('%.4f' % (analyzer.analyze(args.freq2)))
+    print("%.4f" % (analyzer.analyze(args.freq2)))
 
 
 if __name__ == "__main__":