iot-logic/src/preprocessor/soundpreprocessor.py

117 lines
4.1 KiB
Python
Raw Normal View History

2020-08-24 20:38:38 +02:00
#!/usr/bin/env python3
2020-08-25 01:40:09 +02:00
import requests
2020-09-25 13:48:11 +02:00
from pyAudioAnalysis.audioTrainTest import load_model, load_model_knn, classifier_wrapper
2020-08-25 01:40:09 +02:00
from utils import config
2020-08-24 20:38:38 +02:00
from .abcpreprocessor import AbcPreProcessor
2020-09-25 13:48:11 +02:00
import tempfile
import os
import logging
from pyAudioAnalysis import audioBasicIO
from pyAudioAnalysis import MidTermFeatures
import numpy
2020-08-24 20:38:38 +02:00
"""
Abstract base class for Sender
"""
__author__ = "@tormakris"
__copyright__ = "Copyright 2020, Birbnetes Team"
__module_name__ = "soundpreprocessor"
__version__text__ = "1"
class SoundPreProcessor(AbcPreProcessor):
"""
SoundPreProcessor class, responsible for detecting birb chirps in sound sample.
"""
2020-08-25 01:40:09 +02:00
2020-09-25 13:48:11 +02:00
def __init__(self):
logging.info("Downloading current model...")
_, self._temp_model_name = tempfile.mkstemp()
self._temp_means_name = self._temp_model_name + "MEANS"
logging.debug("Fetching model info...")
2020-09-30 05:26:21 +02:00
r = requests.get(f"{config.API_URL}/model/svm/$default/details")
2020-09-25 13:48:11 +02:00
r.raise_for_status()
self._model_details = r.json()
logging.debug("Downloading model...")
2020-09-30 05:26:21 +02:00
r = requests.get(f"{config.API_URL}/model/svm/{self._model_details['id']}")
2020-09-25 13:48:11 +02:00
r.raise_for_status()
with open(self._temp_model_name, 'wb') as f:
f.write(r.content)
logging.debug("Downloading MEANS...")
2020-09-30 05:26:21 +02:00
r = requests.get(f"{config.API_URL}/model/svm/{self._model_details['id']}?means")
2020-09-25 13:48:11 +02:00
r.raise_for_status()
with open(self._temp_means_name, 'wb') as f:
f.write(r.content)
logging.info("Loading current model...")
if self._model_details['type'] == 'knn':
self._classifier, self._mean, self._std, self._classes, \
self._mid_window, self._mid_step, self._short_window, \
self._short_step, self._compute_beat = load_model_knn(self._temp_model_name)
else:
self._classifier, self._mean, self._std, self._classes, \
self._mid_window, self._mid_step, self._short_window, \
self._short_step, self._compute_beat = load_model(self._temp_model_name)
def preprocesssignal(self, file_path: str) -> bool:
2020-08-24 20:38:38 +02:00
"""
Classify a sound sample.
2020-09-25 13:48:11 +02:00
:param file_path: Access path of the sound sample up for processing.
2020-08-24 20:38:38 +02:00
:return:
"""
2020-09-25 13:48:11 +02:00
logging.info("Running extraction...")
2020-08-25 01:40:09 +02:00
2020-09-25 13:48:11 +02:00
sampling_rate, signal = audioBasicIO.read_audio_file(file_path)
signal = audioBasicIO.stereo_to_mono(signal)
2020-08-25 01:40:09 +02:00
2020-09-25 13:48:11 +02:00
if sampling_rate == 0:
raise Exception("Could not read the file properly: Sampling rate zero")
2020-08-25 01:40:09 +02:00
2020-09-25 13:48:11 +02:00
if signal.shape[0] / float(sampling_rate) <= self._mid_window:
raise Exception("Could not read the file properly: Signal shape is not good")
# feature extraction:
mid_features, s, _ = \
MidTermFeatures.mid_feature_extraction(signal, sampling_rate,
self._mid_window * sampling_rate,
self._mid_step * sampling_rate,
round(sampling_rate * self._short_window),
round(sampling_rate * self._short_step))
# long term averaging of mid-term statistics
mid_features = mid_features.mean(axis=1)
if self._compute_beat:
beat, beat_conf = MidTermFeatures.beat_extraction(s, self._short_step)
mid_features = numpy.append(mid_features, beat)
mid_features = numpy.append(mid_features, beat_conf)
logging.info("Running classification...")
target_id = self._classes.index('chirp') # Might raise ValueError
feature_vector = (mid_features - self._mean) / self._std
class_id, probability = classifier_wrapper(self._classifier, self._model_details['type'], feature_vector)
return bool(class_id == target_id)
2020-08-25 01:40:09 +02:00
2020-09-25 13:48:11 +02:00
def __del__(self):
try:
os.remove(self._temp_model_name)
except FileNotFoundError:
pass
2020-08-25 01:40:09 +02:00
2020-09-25 13:48:11 +02:00
try:
os.remove(self._temp_means_name)
except FileNotFoundError:
pass