#!/usr/bin/python3 """ This script calculates the SNR, assuming that the input has, at a specific position, a single specific tone. So everything else is assumed to be noise. """ import math from itertools import count import argparse from logging import getLogger, basicConfig from typing import Tuple import numpy as np import scipy.io.wavfile from matplotlib import pyplot as plt LOW_CUT = 100 # ignore anything below 100Hz HIGH_CUT = 4500 log = getLogger("snr") def rms_partition( seq, partition_key=lambda x: False, ignore=lambda x: False ) -> Tuple[float, float]: """ We want to separate what is signal and what is noise. Please provide a partition_key function, and you'll get two numbers: RMS of noise and RMS of signal """ sum1 = np.ulonglong(0) cnt1 = 0 sum2 = np.ulonglong(0) cnt2 = 0 for freq, val in enumerate(seq): if ignore(freq): continue v = np.ulonglong(val) if partition_key(freq): sum2 += v ** 2 cnt2 += 1 else: sum1 += v ** 2 cnt1 += 1 def rm(sum_, cnt): return np.sqrt(sum_ / cnt) return (rm(sum1, cnt1), rm(sum2, cnt2)) class ToneAnalyzer: """ this class checks how "clean" some tone is. give it the offset, a reference frequency, and it will help you calculate a SNR """ def __init__(self, frequency, offset, duration, framerate: int, wave, threshold=0): self.freq = frequency self.offset = offset self.duration = duration self.framerate = framerate self.wave = wave self.log = getLogger(self.__class__.__name__) def get_wave(self): total = len(self.wave) self.log.debug("tot frames = %s", total) self.log.debug("framerate = %s", self.framerate) self.log.debug("duration = %.2fs", (total / self.framerate)) start = math.ceil(self.framerate * (self.offset / 1000.0)) end = math.floor(start + (self.framerate * (self.duration / 1000.0))) seq = self.wave[start:end] return seq def get_fft(self, seq): r = np.fft.fft(seq, self.framerate)[:HIGH_CUT] # x = np.arange(HIGH_CUT) # plt.plot(x, r) # plt.show() def to_int(val): return int(np.absolute(val)) fft = [to_int(val) for val in r] return fft def is_signal(self, f): return abs(f - self.freq) < 5 def analyze(self, freq2=None) -> float: seq = self.get_wave() fft = self.get_fft(seq) if freq2 is not None: freq_val = fft[self.freq] second_best_val = fft[freq2] if second_best_val == 0: return np.inf return freq_val / second_best_val noise, signal = rms_partition( seq, partition_key=self.is_signal, ignore=lambda f: f < LOW_CUT, ) self.log.debug("noise=%.2f signal=%.2f", noise, signal) return signal / noise # self.log.debug("len %d", len(rms)) # self.log.debug("max %.2f", np.nanmax(rms)) # self.log.debug("RMS %.2f", rms) class MultiToneAnalyzer(ToneAnalyzer): def is_signal(self, f): for freq_i in self.freq: if abs(f - freq_i) < 30: return True return False def get_parser(): p = argparse.ArgumentParser() p.add_argument("fname") p.add_argument("--freq", type=int, default=[440], nargs="*") p.add_argument("--freq2", type=int) p.add_argument("--offset", type=int, default=0, help="In milliseconds") p.add_argument("--duration", type=int, default=100, help="In milliseconds") p.add_argument( "--log-level", default="WARNING", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], ) return p def main(): args = get_parser().parse_args() basicConfig(level=args.log_level) if args.log_level in ["DEBUG", "INFO"]: np.seterr(all="warn") framerate, data = scipy.io.wavfile.read(args.fname) if len(data.shape) > 1: raise ValueError("audio not mono maybe?") data = data[:, 0] analyzer = MultiToneAnalyzer( args.freq, args.offset, args.duration, framerate, data, ) print("%.4f" % (analyzer.analyze(args.freq2))) if __name__ == "__main__": main()