# coding=UTF-8 from __future__ import absolute_import, division, print_function import unittest from threading import Lock import time import collections import os import monotonic from coolamqp import Cluster, ClusterNode, ConnectionUp, ConnectionDown, ConnectionUp, ConsumerCancelled def getamqp(): amqp = Cluster([ClusterNode('127.0.0.1', 'guest', 'guest')]) amqp.start() return amqp class CoolAMQPTestCase(unittest.TestCase): """ Base class for all CoolAMQP tests. Creates na AMQP connection, provides methods for easy interfacing, and other utils. """ INIT_AMQP = True # override on child classes def new_amqp_connection(self, consume_connectionup=True): obj = self class CM(object): """Context manager. Get new AMQP uplink. Consume ConnectionUp if consume_connectionup Use like: with self.new_amqp_connection() as amqp2: amqp2.consume(...) """ def __enter__(self): self.amqp = getamqp() if consume_connectionup: obj.assertIsInstance(self.amqp.drain(3), ConnectionUp) return self.amqp def __exit__(self, exc_type, exc_val, exc_tb): self.amqp.shutdown() return False return CM() def restart_rmq(self): # forcibly reset the connection class FailbowlSocket(object): def __getattr__(self, name): import socket raise socket.error() self.amqp.thread.backend.channel.connection.transport.sock = FailbowlSocket() self.drainTo([ConnectionDown, ConnectionUp], [5, 10]) def setUp(self): if self.INIT_AMQP: os.system('sudo service rabbitmq-server start') # if someone killed it self.__newam = self.new_amqp_connection() self.amqp = self.__newam.__enter__() def tearDown(self): if self.INIT_AMQP: self.__newam.__exit__(None, None, None) def drainToNone(self, timeout=4): self.assertIsNone(self.amqp.drain(4)) def drainToAny(self, types, timeout, forbidden=[]): """Assert that messages with types, in any order, are found within timeout. Fail if any type from forbidden is found""" start = monotonic.monotonic() types = set(types) while monotonic.monotonic() - start < timeout: q = self.amqp.drain(1) if type(q) in forbidden: self.fail('%s found', type(q)) if type(q) in types: types.remove(type(q)) if len(types) > 0: self.fail('Not found %s' % (''.join(map(str, types)), )) def drainTo(self, type_, timeout, forbidden=[ConsumerCancelled]): """ Return next event of type_. It has to occur within timeout, or fail. If you pass iterable (len(type_) == len(timeout), last result will be returned and I will drainTo() in order. """ if isinstance(type_, collections.Iterable): self.assertIsInstance(timeout, collections.Iterable) for tp, ti in zip(type_, timeout): p = self.drainTo(tp, ti) if type(p) in forbidden: self.fail('Found %s but forbidden', type(p)) return p start = monotonic.monotonic() while monotonic.monotonic() - start < timeout: q = self.amqp.drain(1) if isinstance(q, type_): return q self.fail('Did not find %s' % (type_, )) def takes_less_than(self, max_time): """ Tests that your code executes in less time than specified value. Use like: with self.takes_less_than(0.9): my_operation() :param max_time: in seconds """ test = self class CM(object): def __enter__(self): self.started_at = time.time() def __exit__(self, tp, v, tb): test.assertLess(time.time() - self.started_at, max_time) return False return CM()