123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161 |
- #!/usr/bin/python3
- """
- This script calculates the SNR, assuming that the input has, at a specific position, a single specific tone.
- So everything else is assumed to be noise.
- """
- import math
- from itertools import count
- import argparse
- from logging import getLogger, basicConfig
- from typing import Tuple
- import numpy as np
- import scipy.io.wavfile
- from matplotlib import pyplot as plt
- LOW_CUT = 100 # ignore anything below 100Hz
- HIGH_CUT = 4500
- log = getLogger("snr")
- def rms_partition(
- seq, partition_key=lambda x: False, ignore=lambda x: False
- ) -> Tuple[float, float]:
- """
- We want to separate what is signal and what is noise.
- Please provide a partition_key function, and you'll get two numbers: RMS of noise and RMS of signal
- """
- sum1 = np.ulonglong(0)
- cnt1 = 0
- sum2 = np.ulonglong(0)
- cnt2 = 0
- for freq, val in enumerate(seq):
- if ignore(freq):
- continue
- v = np.ulonglong(val)
- if partition_key(freq):
- sum2 += v ** 2
- cnt2 += 1
- else:
- sum1 += v ** 2
- cnt1 += 1
- def rm(sum_, cnt):
- return np.sqrt(sum_ / cnt)
- return (rm(sum1, cnt1), rm(sum2, cnt2))
- class ToneAnalyzer:
- """
- this class checks how "clean" some tone is.
- give it the offset, a reference frequency, and it will help you calculate a SNR
- """
- def __init__(self, frequency, offset, duration, framerate: int, wave, threshold=0):
- self.freq = frequency
- self.offset = offset
- self.duration = duration
- self.framerate = framerate
- self.wave = wave
- self.log = getLogger(self.__class__.__name__)
- def get_wave(self):
- total = len(self.wave)
- self.log.debug("tot frames = %s", total)
- self.log.debug("framerate = %s", self.framerate)
- self.log.debug("duration = %.2fs", (total / self.framerate))
- start = math.ceil(self.framerate * (self.offset / 1000.0))
- end = math.floor(start + (self.framerate * (self.duration / 1000.0)))
- seq = self.wave[start:end]
- return seq
- def get_fft(self, seq):
- r = np.fft.fft(seq, self.framerate)[:HIGH_CUT]
- # x = np.arange(HIGH_CUT)
- # plt.plot(x, r)
- # plt.show()
- def to_int(val):
- return int(np.absolute(val))
- fft = [to_int(val) for val in r]
- return fft
- def is_signal(self, f):
- return abs(f - self.freq) < 5
- def analyze(self, freq2=None) -> float:
- seq = self.get_wave()
- fft = self.get_fft(seq)
- if freq2 is not None:
- freq_val = fft[self.freq]
- second_best_val = fft[freq2]
- if second_best_val == 0:
- return np.inf
- return freq_val / second_best_val
- noise, signal = rms_partition(
- seq,
- partition_key=self.is_signal,
- ignore=lambda f: f < LOW_CUT,
- )
- self.log.debug("noise=%.2f signal=%.2f", noise, signal)
- return signal / noise
- # self.log.debug("len %d", len(rms))
- # self.log.debug("max %.2f", np.nanmax(rms))
- # self.log.debug("RMS %.2f", rms)
- class MultiToneAnalyzer(ToneAnalyzer):
- def is_signal(self, f):
- for freq_i in self.freq:
- if abs(f - freq_i) < 30:
- return True
- return False
- def get_parser():
- p = argparse.ArgumentParser()
- p.add_argument("fname")
- p.add_argument("--freq", type=int, default=[440], nargs="*")
- p.add_argument("--freq2", type=int)
- p.add_argument("--offset", type=int, default=0, help="In milliseconds")
- p.add_argument("--duration", type=int, default=100, help="In milliseconds")
- p.add_argument(
- "--log-level",
- default="WARNING",
- choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
- )
- return p
- def main():
- args = get_parser().parse_args()
- basicConfig(level=args.log_level)
- if args.log_level in ["DEBUG", "INFO"]:
- np.seterr(all="warn")
- framerate, data = scipy.io.wavfile.read(args.fname)
- if len(data.shape) > 1:
- raise ValueError("audio not mono maybe?")
- data = data[:, 0]
- analyzer = MultiToneAnalyzer(
- args.freq,
- args.offset,
- args.duration,
- framerate,
- data,
- )
- print("%.4f" % (analyzer.analyze(args.freq2)))
- if __name__ == "__main__":
- main()
|