diff --git a/fimfarchive/mappers.py b/fimfarchive/mappers.py index 515d05a..341afbc 100644 --- a/fimfarchive/mappers.py +++ b/fimfarchive/mappers.py @@ -23,10 +23,13 @@ Mappers for Fimfarchive. import os +from abc import abstractmethod +from typing import Generic, Optional, TypeVar -import arrow +from arrow import api as arrow, Arrow from fimfarchive.exceptions import InvalidStoryError +from fimfarchive.stories import Story __all__ = ( @@ -37,33 +40,45 @@ __all__ = ( ) -class Mapper: +T = TypeVar('T') + + +class Mapper(Generic[T]): """ - Callable which maps something to something else. + Callable which maps stories to something else. """ - def __call__(self, *args, **kwargs): - raise NotImplementedError() + @abstractmethod + def __call__(self, story: Story) -> T: + """ + Applies the mapper. + + Args: + story: The story to map. + + Returns: + A mapped object. + """ -class StaticMapper(Mapper): +class StaticMapper(Mapper[T]): """ Returns the supplied value for any call. """ - def __init__(self, value=None): + def __init__(self, value: T) -> None: self.value = value - def __call__(self, *args, **kwargs): + def __call__(self, story: Story) -> T: return self.value -class StoryDateMapper(Mapper): +class StoryDateMapper(Mapper[Optional[Arrow]]): """ Returns the latest timestamp in a story, or None. """ - def __call__(self, story): + def __call__(self, story: Story) -> Optional[Arrow]: try: meta = getattr(story, 'meta', None) except InvalidStoryError: @@ -87,15 +102,15 @@ class StoryDateMapper(Mapper): return None -class StoryPathMapper(Mapper): +class StoryPathMapper(Mapper[str]): """ Returns a key-based file path for a story. """ - def __init__(self, directory): + def __init__(self, directory: str) -> None: self.directory = directory - def __call__(self, story): + def __call__(self, story: Story) -> str: directory = str(self.directory) key = str(story.key) diff --git a/tests/test_mappers.py b/tests/test_mappers.py index 47d4974..9c5db5d 100644 --- a/tests/test_mappers.py +++ b/tests/test_mappers.py @@ -48,33 +48,12 @@ class TestStaticMapper: """ return object() - def test_value(self, value): + def test_value(self, story, value): """ Tests returns the supplied value. """ mapper = StaticMapper(value) - assert mapper() is value - - def test_default_value(self): - """ - Tests `None` is returned by default. - """ - mapper = StaticMapper() - assert mapper() is None - - def test_args(self, value): - """ - Tests callable ignores args. - """ - mapper = StaticMapper(value) - assert mapper(1, 2, 3) is value - - def test_kwargs(self, value): - """ - Tests callable ignores kwargs. - """ - mapper = StaticMapper(value) - assert mapper(a=1, b=2) is value + assert mapper(story) is value class TestStoryDateMapper: