diff --git a/fimfarchive/signals.py b/fimfarchive/signals.py index da4b512..9488233 100644 --- a/fimfarchive/signals.py +++ b/fimfarchive/signals.py @@ -30,6 +30,10 @@ __all__ = ( 'SignalBinder', 'SignalSender', 'SignalReceiver', + 'find_related', + 'find_sources', + 'find_targets', + 'find_matches', ) @@ -125,25 +129,22 @@ class SignalBinder: class SignalSender: """ - Automatically binds signals on init. + Automatically binds unbound signals on init. """ def __init__(self): """ Constructor. """ - sources = { - k: v for k, v in vars(type(self)).items() - if k.startswith('on_') and isinstance(v, Signal) - } - - for k, v in sources.items(): - setattr(self, k, SignalBinder(v, self)) + for key, source in find_sources(self): + if not isinstance(source, SignalBinder): + binding = SignalBinder(source, self) + setattr(self, key, binding) class SignalReceiver: """ - Automatically connects signals on init. + Automatically connects signals on enter. """ def __init__(self, sender): @@ -153,19 +154,55 @@ class SignalReceiver: Args: sender: Object to connect to. """ - sources = { - k for k, v in vars(type(sender)).items() - if k.startswith('on_') and isinstance(v, Signal) - } + self.sender = sender - targets = { - k for k, v in vars(type(self)).items() - if k.startswith('on_') and callable(v) - } + def __enter__(self): + for key, source, target in find_matches(self.sender, self): + source.connect(target, sender=self.sender) - connect = sources.intersection(targets) + return self - for key in connect: - method = getattr(self, key) - signal = getattr(sender, key) - signal.connect(method, sender=sender) + def __exit__(self, *args): + for key, source, target in find_matches(self.sender, self): + source.disconnect(target, sender=self.sender) + + +def find_related(obj): + """ + Yields all source or target candidates. + """ + for key in dir(obj): + if key.startswith('on_'): + yield key, getattr(obj, key) + + +def find_sources(sender): + """ + Yields all source signals in a sender. + """ + for key, value in find_related(sender): + connect = getattr(value, 'connect', None) + disconnect = getattr(value, 'disconnect', None) + + if callable(connect) and callable(disconnect): + yield key, value + + +def find_targets(receiver): + """ + Yields all target methods in a receiver. + """ + for key, value in find_related(receiver): + if callable(value): + yield key, value + + +def find_matches(sender, receiver): + """ + Yields all matching signal connections. + """ + sources = dict(find_sources(sender)) + targets = dict(find_targets(receiver)) + + for key in sources.keys() & targets.keys(): + yield key, sources[key], targets[key] diff --git a/tests/test_signals.py b/tests/test_signals.py index 10c50ff..e9fc647 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -69,12 +69,13 @@ def sender(signal): @pytest.fixture def receiver(sender): """ - Returns a signal receiver instance. + Returns a connected signal receiver intance. """ class Receiver(SignalReceiver): on_signal = Mock('on_signal') - return Receiver(sender) + with Receiver(sender) as receiver: + yield receiver @pytest.fixture @@ -182,7 +183,7 @@ class TestSignalReceiver: def test_send(self, params, sender, receiver): """ - Tests receiver recives emitted singal. + Tests receiver receives emitted signal. """ sender.on_signal(*params.values()) receiver.on_signal.assert_called_once_with(sender, **params)