2020-08-24 20:38:38 +02:00
|
|
|
#!/usr/bin/env python3
|
2020-08-25 01:40:09 +02:00
|
|
|
from utils import config
|
2020-08-24 20:38:38 +02:00
|
|
|
from .abcpreprocessor import AbcPreProcessor
|
2021-11-18 21:51:50 +01:00
|
|
|
|
2020-09-25 13:48:11 +02:00
|
|
|
import logging
|
|
|
|
|
2021-11-18 21:51:50 +01:00
|
|
|
if not config.DISABLE_AI:
|
|
|
|
import tempfile
|
|
|
|
import requests
|
|
|
|
from urllib.parse import urljoin
|
|
|
|
import os
|
|
|
|
|
|
|
|
from pyAudioAnalysis.audioTrainTest import load_model, load_model_knn, classifier_wrapper
|
|
|
|
from pyAudioAnalysis import audioBasicIO
|
|
|
|
from pyAudioAnalysis import MidTermFeatures
|
|
|
|
import numpy
|
2020-08-24 20:38:38 +02:00
|
|
|
|
2020-10-01 20:15:24 +02:00
|
|
|
from birbnetes_iot_platform_raspberry import BirbnetesIoTPlatformStatusDriver
|
|
|
|
|
2020-08-24 20:38:38 +02:00
|
|
|
"""
|
|
|
|
Abstract base class for Sender
|
|
|
|
"""
|
|
|
|
|
|
|
|
__author__ = "@tormakris"
|
|
|
|
__copyright__ = "Copyright 2020, Birbnetes Team"
|
|
|
|
__module_name__ = "soundpreprocessor"
|
|
|
|
__version__text__ = "1"
|
|
|
|
|
|
|
|
|
2021-11-18 21:51:50 +01:00
|
|
|
class SoundPreProcessorLegit(AbcPreProcessor):
|
2020-08-24 20:38:38 +02:00
|
|
|
"""
|
|
|
|
SoundPreProcessor class, responsible for detecting birb chirps in sound sample.
|
|
|
|
"""
|
2020-08-25 01:40:09 +02:00
|
|
|
|
2020-09-25 13:48:11 +02:00
|
|
|
def __init__(self):
|
|
|
|
logging.info("Downloading current model...")
|
2020-10-01 20:15:24 +02:00
|
|
|
temp_model_handle, self._temp_model_name = tempfile.mkstemp()
|
2020-09-25 13:48:11 +02:00
|
|
|
self._temp_means_name = self._temp_model_name + "MEANS"
|
|
|
|
|
|
|
|
logging.debug("Fetching model info...")
|
2020-10-01 20:15:24 +02:00
|
|
|
BirbnetesIoTPlatformStatusDriver.enqueue_pattern('green', [1])
|
2020-09-30 06:01:38 +02:00
|
|
|
if config.SVM_MODEL_ID:
|
|
|
|
model_id_to_get = config.SVM_MODEL_ID
|
|
|
|
else:
|
|
|
|
model_id_to_get = '$default'
|
|
|
|
|
2020-10-02 03:49:37 +02:00
|
|
|
model_root_url = urljoin(config.API_URL, f"/model/svm/{model_id_to_get}")
|
|
|
|
|
|
|
|
r = requests.get(model_root_url)
|
2020-09-25 13:48:11 +02:00
|
|
|
r.raise_for_status()
|
|
|
|
|
|
|
|
self._model_details = r.json()
|
|
|
|
|
|
|
|
logging.debug("Downloading model...")
|
2020-10-01 20:15:24 +02:00
|
|
|
BirbnetesIoTPlatformStatusDriver.enqueue_pattern('green', [1])
|
2020-10-02 03:49:37 +02:00
|
|
|
r = requests.get(urljoin(model_root_url, self._model_details['files']['model']))
|
2020-09-25 13:48:11 +02:00
|
|
|
r.raise_for_status()
|
|
|
|
|
2020-10-01 20:15:24 +02:00
|
|
|
with open(temp_model_handle, 'wb') as f: # bruhtastic
|
2020-09-25 13:48:11 +02:00
|
|
|
f.write(r.content)
|
|
|
|
|
|
|
|
logging.debug("Downloading MEANS...")
|
2020-10-01 20:15:24 +02:00
|
|
|
BirbnetesIoTPlatformStatusDriver.enqueue_pattern('green', [1])
|
2020-10-02 03:49:37 +02:00
|
|
|
r = requests.get(urljoin(model_root_url, self._model_details['files']['means']))
|
2020-09-25 13:48:11 +02:00
|
|
|
r.raise_for_status()
|
|
|
|
|
|
|
|
with open(self._temp_means_name, 'wb') as f:
|
|
|
|
f.write(r.content)
|
|
|
|
|
|
|
|
logging.info("Loading current model...")
|
|
|
|
|
|
|
|
if self._model_details['type'] == 'knn':
|
|
|
|
self._classifier, self._mean, self._std, self._classes, \
|
|
|
|
self._mid_window, self._mid_step, self._short_window, \
|
|
|
|
self._short_step, self._compute_beat = load_model_knn(self._temp_model_name)
|
|
|
|
|
|
|
|
else:
|
|
|
|
self._classifier, self._mean, self._std, self._classes, \
|
|
|
|
self._mid_window, self._mid_step, self._short_window, \
|
|
|
|
self._short_step, self._compute_beat = load_model(self._temp_model_name)
|
|
|
|
|
2020-10-01 20:15:24 +02:00
|
|
|
target_class_name = self._model_details['target_class_name']
|
|
|
|
|
2020-09-30 06:12:36 +02:00
|
|
|
logging.info("The loaded model contains the following classes: " + ", ".join(self._classes))
|
2020-10-01 20:15:24 +02:00
|
|
|
if target_class_name not in self._classes:
|
2020-10-03 14:33:00 +02:00
|
|
|
raise ValueError(
|
|
|
|
f"The specified target class {target_class_name} is not in the possible classes (Wrong model info?)"
|
|
|
|
)
|
2020-10-01 20:15:24 +02:00
|
|
|
|
|
|
|
self._target_id = self._classes.index(target_class_name)
|
2020-09-30 06:12:36 +02:00
|
|
|
|
2020-09-25 13:48:11 +02:00
|
|
|
def preprocesssignal(self, file_path: str) -> bool:
|
2020-08-24 20:38:38 +02:00
|
|
|
"""
|
|
|
|
Classify a sound sample.
|
2020-09-25 13:48:11 +02:00
|
|
|
:param file_path: Access path of the sound sample up for processing.
|
2020-08-24 20:38:38 +02:00
|
|
|
:return:
|
|
|
|
"""
|
2020-09-30 06:12:36 +02:00
|
|
|
logging.debug("Running extraction...")
|
2020-08-25 01:40:09 +02:00
|
|
|
|
2020-09-25 13:48:11 +02:00
|
|
|
sampling_rate, signal = audioBasicIO.read_audio_file(file_path)
|
|
|
|
signal = audioBasicIO.stereo_to_mono(signal)
|
2020-08-25 01:40:09 +02:00
|
|
|
|
2020-09-25 13:48:11 +02:00
|
|
|
if sampling_rate == 0:
|
2020-10-03 14:33:00 +02:00
|
|
|
raise AssertionError("Could not read the file properly: Sampling rate zero")
|
2020-08-25 01:40:09 +02:00
|
|
|
|
2020-09-25 13:48:11 +02:00
|
|
|
if signal.shape[0] / float(sampling_rate) <= self._mid_window:
|
2020-10-03 14:33:00 +02:00
|
|
|
raise AssertionError("Could not read the file properly: Signal shape is not good")
|
2020-09-25 13:48:11 +02:00
|
|
|
|
|
|
|
# feature extraction:
|
|
|
|
mid_features, s, _ = \
|
|
|
|
MidTermFeatures.mid_feature_extraction(signal, sampling_rate,
|
|
|
|
self._mid_window * sampling_rate,
|
|
|
|
self._mid_step * sampling_rate,
|
|
|
|
round(sampling_rate * self._short_window),
|
|
|
|
round(sampling_rate * self._short_step))
|
|
|
|
|
|
|
|
# long term averaging of mid-term statistics
|
|
|
|
mid_features = mid_features.mean(axis=1)
|
|
|
|
if self._compute_beat:
|
|
|
|
beat, beat_conf = MidTermFeatures.beat_extraction(s, self._short_step)
|
|
|
|
mid_features = numpy.append(mid_features, beat)
|
|
|
|
mid_features = numpy.append(mid_features, beat_conf)
|
|
|
|
|
2020-09-30 06:12:36 +02:00
|
|
|
logging.debug("Running classification...")
|
2020-09-25 13:48:11 +02:00
|
|
|
|
|
|
|
feature_vector = (mid_features - self._mean) / self._std
|
2020-10-01 20:15:24 +02:00
|
|
|
class_id, probability = classifier_wrapper(
|
|
|
|
self._classifier, self._model_details['type'].lower(), feature_vector
|
|
|
|
)
|
2020-09-30 06:19:53 +02:00
|
|
|
class_id = int(class_id) # faszom
|
2020-09-25 13:48:11 +02:00
|
|
|
|
2020-10-01 20:15:24 +02:00
|
|
|
logging.debug(
|
|
|
|
f"Sample {file_path} identified as {self._classes[class_id]} with the probablility of {probability[class_id]}"
|
|
|
|
)
|
2020-09-30 06:12:36 +02:00
|
|
|
|
2020-11-15 04:20:39 +01:00
|
|
|
return bool((class_id == self._target_id) and (probability[class_id] > 0.5))
|
2020-08-25 01:40:09 +02:00
|
|
|
|
2020-09-25 13:48:11 +02:00
|
|
|
def __del__(self):
|
|
|
|
try:
|
|
|
|
os.remove(self._temp_model_name)
|
|
|
|
except FileNotFoundError:
|
|
|
|
pass
|
2020-08-25 01:40:09 +02:00
|
|
|
|
2020-09-25 13:48:11 +02:00
|
|
|
try:
|
|
|
|
os.remove(self._temp_means_name)
|
|
|
|
except FileNotFoundError:
|
|
|
|
pass
|
2021-11-18 21:51:50 +01:00
|
|
|
|
|
|
|
|
|
|
|
class SoundPreProcessorDummy(AbcPreProcessor):
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
print("AI is disabled! Initializing dummy sound pre-processor...")
|
|
|
|
|
|
|
|
def preprocesssignal(self, file_path) -> bool:
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
if config.DISABLE_AI:
|
|
|
|
SoundPreProcessor = SoundPreProcessorDummy
|
|
|
|
else:
|
|
|
|
SoundPreProcessor = SoundPreProcessorLegit
|