from flask import Flask from threading import Lock import pika import pika.exceptions import json import time import opentracing from opentracing.ext import tags from opentracing.propagation import Format class MagicAMQP: """ This is my pathetic attempt to make RabbitMQ connection in a Flask app reliable and performant. """ def __init__(self, app: Flask = None): self.app = app if app: self.init_app(app) self._lock = Lock() self._credentials = None def init_app(self, app: Flask): self.app = app self.app.config.setdefault('FLASK_PIKA_PARAMS', {}) self.app.config.setdefault('EXCHANGE_NAME_META', None) self.app.config.setdefault('EXCHANGE_NAME_CACHE', None) self._credentials = pika.PlainCredentials( app.config['FLASK_PIKA_PARAMS']['username'], app.config['FLASK_PIKA_PARAMS']['password'] ) self._reconnect_ampq() def _reconnect_ampq(self): self._pika_connection = pika.BlockingConnection( pika.ConnectionParameters( host=self.app.config['FLASK_PIKA_PARAMS']['host'], credentials=self._credentials, heartbeat=10, socket_timeout=5) ) self._pika_channel = self._pika_connection.channel() self._pika_channel.exchange_declare( exchange=self.app.config['EXCHANGE_NAME_META'], exchange_type='direct' ) self._pika_channel.exchange_declare( exchange=self.app.config['EXCHANGE_NAME_CACHE'], exchange_type='direct' ) def loop(self): """ This method should be called periodically to keep up the connection """ lock_start = time.time() with self._lock: lock_acquire_time = time.time() - lock_start if lock_acquire_time >= 0.5: self.app.logger.warning(f"Loop: Lock acquire took {lock_acquire_time:5f} sec") try: self._pika_connection.process_data_events(0) # We won't attempt retry if this fail except pika.exceptions.AMQPConnectionError as e: self.app.logger.warning(f"Connection error during process loop: {e} (attempting reconnect)") self._reconnect_ampq() total_time = time.time() - lock_start if total_time > 1: self.app.logger.warning(f"Loop: Total loop took {total_time:5f} sec") def _publish(self, exchange: str, routing_key: str, payload=None): """ Publish a simple json serialized message to the configured queue. If the connection is broken, then this call will block until the connection is restored """ span_tags = {tags.SPAN_KIND: tags.SPAN_KIND_PRODUCER} with opentracing.tracer.start_active_span('magic_amqp.publish', tags=span_tags) as scope: opentracing.tracer.inject(scope.span.context, Format.TEXT_MAP, payload) lock_start = time.time() with self._lock: scope.span.log_kv({'event': 'lockAcquired'}) lock_acquire_time = time.time() - lock_start if lock_acquire_time >= 0.2: self.app.logger.warning(f"Publish: Lock acquire took {lock_acquire_time:5f} sec") tries = 0 while True: try: self._pika_channel.basic_publish( exchange=exchange, routing_key=routing_key, body=json.dumps(payload).encode('UTF-8') ) self.app.logger.debug(f"Published: {payload}") break # message sent successfully except pika.exceptions.AMQPConnectionError as e: scope.span.log_kv({'event': 'connectionError', 'error': str(e)}) self.app.logger.warning(f"Connection error during publish: {e} (attempting reconnect)") if tries > 30: raise # just give up while True: try: self._reconnect_ampq() break except pika.exceptions.AMQPConnectionError as e: self.app.logger.warning( f"Connection error during reconnection: {e} (attempting reconnect)") tries += 1 if tries > 30: raise # just give up if tries > 10: time.sleep(2) total_time = time.time() - lock_start if total_time > 0.4: self.app.logger.warning(f"Publish: Total publish took {total_time:5f} sec") def publish_cache(self, payload=None): return self._publish(self.app.config['EXCHANGE_NAME_CACHE'], "cache", payload) def publish_meta(self, payload=None): return self._publish(self.app.config['EXCHANGE_NAME_META'], "meta", payload) def is_healthy(self) -> bool: with self._lock: if not self._pika_channel: return False return self._pika_channel.is_open and self._pika_connection.is_open # instance to be used in the flask app magic_amqp = MagicAMQP()