diff --git a/fimfarchive/utils.py b/fimfarchive/utils.py index 9aac2e1..8f5d8e7 100644 --- a/fimfarchive/utils.py +++ b/fimfarchive/utils.py @@ -22,8 +22,15 @@ Various utilities. # +import json +import os +import shutil +from collections import UserDict + + __all__ = ( 'Empty', + 'PersistedDict', ) @@ -43,3 +50,57 @@ class Empty(metaclass=EmptyMeta): def __bool__(self): return False + + +class PersistedDict(UserDict): + """ + Dictionary for simple persistance. + """ + + def __init__(self, path, default=dict()): + """ + Constructor. + + Args: + path: Location of the persistence file. + default: Initial values for entries. + """ + super().__init__() + self.path = path + self.temp = path + '~' + self.default = default + self.load() + + def load(self): + """ + Loads data from file as JSON. + """ + if os.path.exists(self.path): + with open(self.path, 'rt') as fobj: + self.data = json.load(fobj) + else: + self.data = dict() + + for k, v in self.default.items(): + if not k in self.data: + self.data[k] = v + + def save(self): + """ + Saves data to file as JSON. + """ + content = json.dumps( + self.data, + indent=4, + ensure_ascii=False, + sort_keys=True, + ) + + if os.path.exists(self.path): + shutil.copy(self.path, self.temp) + + with open(self.path, 'wt') as fobj: + fobj.write(content) + + if os.path.exists(self.temp): + os.remove(self.temp) diff --git a/tests/test_utils.py b/tests/test_utils.py index 00609d5..58bf327 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -22,7 +22,12 @@ Utility tests. # -from fimfarchive.utils import Empty +import json +import os + +import pytest + +from fimfarchive.utils import Empty, PersistedDict class TestEmpty: @@ -69,3 +74,118 @@ class TestEmpty: assert empty != Empty() assert empty is empty assert empty == empty + + +class TestPersistedDict: + """ + PersistedDict tests. + """ + + @pytest.fixture + def sample(self): + """ + Returns a sample dictionary. + """ + return {'key': 'value'} + + @pytest.fixture + def tmppath(self, tmpdir): + """ + Returns a temporary file path to nothing. + """ + return str(tmpdir.join('sample.json')) + + @pytest.fixture + def tmpfile(self, tmppath, sample): + """ + Returns a temporary file path to sample data. + """ + with open(tmppath, 'wt') as fobj: + json.dump(sample, fobj) + + return tmppath + + def test_saves_data(self, tmppath, sample): + """ + Tests data is saved to file. + """ + data = PersistedDict(tmppath) + data.update(sample) + + assert not os.path.exists(tmppath) + + data.save() + + with open(tmppath, 'rt') as fobj: + saved = json.load(fobj) + + assert dict(data) == saved + + def test_loads_values(self, tmpfile, sample): + """ + Tests data is loaded from file. + """ + data = PersistedDict(tmpfile) + + assert dict(data) == sample + + def test_load_replaces_data(self, tmpfile, sample): + """ + Tests data is replaced on load. + """ + extra = {object(): object()} + data = PersistedDict(tmpfile) + data.update(extra) + data.load() + + assert dict(data) == sample + + def test_load_empty_replaces_data(self, tmppath, sample): + """ + Tests data is replaced on load if file does not exist. + """ + data = PersistedDict(tmppath) + data.update(sample) + data.load() + + assert dict(data) == dict() + + def test_load_restores_defaults(self, tmpfile, sample): + """ + Tests defaults are restored on load. + """ + extra = {object(): object()} + data = PersistedDict(tmpfile, default=extra) + data.clear() + + assert dict(data) == dict() + + data.load() + + assert dict(data) == {**sample, **extra} + + def test_default_in_empty(self, tmppath, sample): + """ + Tests defaults are inserted when data is empty. + """ + data = PersistedDict(tmppath, default=sample) + + assert dict(data) == sample + + def test_default_in_mixed(self, tmpfile, sample): + """ + Tests defaults are inserted alongside loaded data. + """ + extra = {object(): object()} + data = PersistedDict(tmpfile, default=extra) + + assert dict(data) == {**sample, **extra} + + def test_default_does_not_override(self, tmpfile, sample): + """ + Tests defaults do not override loaded data. + """ + extra = {k: object() for k in sample.keys()} + data = PersistedDict(tmpfile, default=extra) + + assert dict(data) == sample