iot-logic/src/preprocessor/soundpreprocessor.py

130 lines
4.7 KiB
Python

#!/usr/bin/env python3
import requests
from pyAudioAnalysis.audioTrainTest import load_model, load_model_knn, classifier_wrapper
from utils import config
from .abcpreprocessor import AbcPreProcessor
import tempfile
import os
import logging
from pyAudioAnalysis import audioBasicIO
from pyAudioAnalysis import MidTermFeatures
import numpy
"""
Abstract base class for Sender
"""
__author__ = "@tormakris"
__copyright__ = "Copyright 2020, Birbnetes Team"
__module_name__ = "soundpreprocessor"
__version__text__ = "1"
class SoundPreProcessor(AbcPreProcessor):
"""
SoundPreProcessor class, responsible for detecting birb chirps in sound sample.
"""
def __init__(self):
logging.info("Downloading current model...")
_, self._temp_model_name = tempfile.mkstemp()
self._temp_means_name = self._temp_model_name + "MEANS"
logging.debug("Fetching model info...")
if config.SVM_MODEL_ID:
model_id_to_get = config.SVM_MODEL_ID
else:
model_id_to_get = '$default'
r = requests.get(f"{config.API_URL}/model/svm/{model_id_to_get}/details")
r.raise_for_status()
self._model_details = r.json()
logging.debug("Downloading model...")
r = requests.get(f"{config.API_URL}/model/svm/{self._model_details['id']}")
r.raise_for_status()
with open(self._temp_model_name, 'wb') as f:
f.write(r.content)
logging.debug("Downloading MEANS...")
r = requests.get(f"{config.API_URL}/model/svm/{self._model_details['id']}?means")
r.raise_for_status()
with open(self._temp_means_name, 'wb') as f:
f.write(r.content)
logging.info("Loading current model...")
if self._model_details['type'] == 'knn':
self._classifier, self._mean, self._std, self._classes, \
self._mid_window, self._mid_step, self._short_window, \
self._short_step, self._compute_beat = load_model_knn(self._temp_model_name)
else:
self._classifier, self._mean, self._std, self._classes, \
self._mid_window, self._mid_step, self._short_window, \
self._short_step, self._compute_beat = load_model(self._temp_model_name)
logging.info("The loaded model contains the following classes: " + ", ".join(self._classes))
if config.SVM_TARGET_CLASS_NAME not in self._classes:
raise Exception(f"The specified target class {config.SVM_TARGET_CLASS_NAME} is not in the possible classes")
def preprocesssignal(self, file_path: str) -> bool:
"""
Classify a sound sample.
:param file_path: Access path of the sound sample up for processing.
:return:
"""
logging.debug("Running extraction...")
sampling_rate, signal = audioBasicIO.read_audio_file(file_path)
signal = audioBasicIO.stereo_to_mono(signal)
if sampling_rate == 0:
raise Exception("Could not read the file properly: Sampling rate zero")
if signal.shape[0] / float(sampling_rate) <= self._mid_window:
raise Exception("Could not read the file properly: Signal shape is not good")
# feature extraction:
mid_features, s, _ = \
MidTermFeatures.mid_feature_extraction(signal, sampling_rate,
self._mid_window * sampling_rate,
self._mid_step * sampling_rate,
round(sampling_rate * self._short_window),
round(sampling_rate * self._short_step))
# long term averaging of mid-term statistics
mid_features = mid_features.mean(axis=1)
if self._compute_beat:
beat, beat_conf = MidTermFeatures.beat_extraction(s, self._short_step)
mid_features = numpy.append(mid_features, beat)
mid_features = numpy.append(mid_features, beat_conf)
logging.debug("Running classification...")
target_id = self._classes.index(config.SVM_TARGET_CLASS_NAME) # Might raise ValueError
feature_vector = (mid_features - self._mean) / self._std
class_id, probability = classifier_wrapper(self._classifier, self._model_details['type'], feature_vector)
logging.debug(f"Sample {file_path} identified as {self._classes[class_id]} with the probablility of {probability}")
return bool(class_id == target_id)
def __del__(self):
try:
os.remove(self._temp_model_name)
except FileNotFoundError:
pass
try:
os.remove(self._temp_means_name)
except FileNotFoundError:
pass