# coding=UTF-8
from __future__ import absolute_import, division, print_function

import collections
import os
import socket
import time
import unittest

import monotonic
from coolamqp.backends.base import AMQPBackend, ConnectionFailedError

from coolamqp import Cluster, ClusterNode, ConnectionDown, \
    ConnectionUp, ConsumerCancelled


def getamqp():
    amqp = Cluster([ClusterNode(os.environ.get('AMQP_HOST', '127.0.0.1'), 'guest', 'guest')], extra_properties=[
        (b'mode', (b'Testing', 'S')),
    ])
    amqp.start()
    return amqp


class CoolAMQPTestCase(unittest.TestCase):
    """
    Base class for all CoolAMQP tests. Creates na AMQP connection, provides methods
    for easy interfacing, and other utils.
    """
    INIT_AMQP = True  # override on child classes

    def setUp(self):
        if self.INIT_AMQP:
            self.__newam = self.new_amqp_connection()
            self.amqp = self.__newam.__enter__()

    def tearDown(self):
        # if you didn't unfail AMQP, that means you don't know what you doing
        self.assertRaises(AttributeError, lambda: self.old_backend)

        if self.INIT_AMQP:
            self.__newam.__exit__(None, None, None)

    def drainToNone(self, timeout=4):
        self.assertIsNone(self.amqp.drain(4))

    def drainToAny(self, types, timeout, forbidden=[]):
        """Assert that messages with types, in any order, are found within timeout.
        Fail if any type from forbidden is found"""
        start = monotonic.monotonic()
        types = set(types)
        while monotonic.monotonic() - start < timeout:
            q = self.amqp.drain(1)
            if type(q) in forbidden:
                self.fail('%s found', type(q))
            if type(q) in types:
                types.remove(type(q))
        if len(types) > 0:
            self.fail('Not found %s' % (''.join(map(str, types)),))

    def drainTo(self, type_, timeout, forbidden=[ConsumerCancelled]):
        """
        Return next event of type_. It has to occur within timeout, or fail.

        If you pass iterable (len(type_) == len(timeout), last result will be returned
        and I will drainTo() in order.
        """
        if isinstance(type_, collections.Iterable):
            self.assertIsInstance(timeout, collections.Iterable)
            for tp, ti in zip(type_, timeout):
                p = self.drainTo(tp, ti)
                if type(p) in forbidden:
                    self.fail('Found %s but forbidden', type(p))
            return p

        start = monotonic.monotonic()
        while monotonic.monotonic() - start < timeout:
            q = self.amqp.drain(1)
            if isinstance(q, type_):
                return q
        self.fail('Did not find %s' % (type_,))

    def takes_less_than(self, max_time):
        """
        Tests that your code executes in less time than specified value.
        Use like:

            with self.takes_less_than(0.9):
                my_operation()

        :param max_time: in seconds
        """
        return TakesLessThanCM(self, max_time)

    # ======failures
    def single_fail_amqp(self):  # insert single failure
        sock = self.amqp.thread.backend.channel.connection.transport.sock
        self.amqp.thread.backend.channel.connection.transport.sock = FailbowlSocket()
        self.amqp.thread.backend.channel.connection = None  # 'connection already closed' or sth like that

        sock.close()

    def fail_amqp(self):  # BROKER DEAD: SWITCH ON

        self.old_backend = self.amqp.backend
        self.amqp.backend = FailbowlBackend

    def unfail_amqp(self):  # BROKER DEAD: SWITCH OFF
        self.amqp.backend = self.old_backend
        del self.old_backend

    def restart_rmq(self):  # simulate a broker restart
        self.fail_amqp()
        self.single_fail_amqp()
        time.sleep(3)
        self.unfail_amqp()

        self.drainTo([ConnectionDown, ConnectionUp], [5, 20])

    def new_amqp_connection(self, consume_connectionup=True):
        return AMQPConnectionCM(self, consume_connectionup=consume_connectionup)


class TakesLessThanCM(object):
    def __init__(self, testCase, max_time):
        self.test = testCase
        self.max_time = max_time

    def __enter__(self, testCase, max_time):
        self.started_at = time.time()
        return lambda: time.time() - self.started_at > self.max_time  # is_late

    def __exit__(self, tp, v, tb):
        self.test.assertLess(time.time() - self.started_at, self.max_time)
        return False


class AMQPConnectionCM(object):
    """Context manager. Get new AMQP uplink. Consume ConnectionUp if consume_connectionup

    Use like:

        with self.new_amqp_connection() as amqp2:
            amqp2.consume(...)

    """

    def __init__(self, testCase, consume_connectionup):
        self.test = testCase
        self.consume_connectionup = consume_connectionup

    def __enter__(self):
        self.amqp = getamqp()
        if self.consume_connectionup:
            self.test.assertIsInstance(self.amqp.drain(3), ConnectionUp)
        return self.amqp

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.amqp.shutdown()
        return False


class FailbowlBackend(AMQPBackend):
    def __init__(self, node, thread):
        AMQPBackend.__init__(self, node, thread)
        raise ConnectionFailedError('Failbowl')


class FailbowlSocket(object):
    def __getattr__(self, item):
        def failbowl(*args, **kwargs):
            time.sleep(1)  # hang and fail
            raise socket.error

        def sleeper(*args, **kwargs):
            time.sleep(1)  # hang and fail

        if item in ('close', 'shutdown'):
            return sleeper
        else:
            return failbowl