diff --git a/codearchiver/core.py b/codearchiver/core.py index 945b991..bb64fed 100644 --- a/codearchiver/core.py +++ b/codearchiver/core.py @@ -84,31 +84,45 @@ class IndexField: class Index(list[tuple[str, str]]): '''An index (key-value mapping, possibly with repeated keys) of a file produced by a module''' - fields: list[IndexField] = [] + fields: tuple[IndexField] = () '''The fields for this index''' + _allFieldsCache: typing.Optional[tuple[IndexField]] = None + def append(self, *args): if len(args) == 1: args = args[0] return super().append(args) + # This should be a @classmethod, too, but that's deprecated since Python 3.11. + @property + def _allFields(self): + '''All fields known by this index, own ones and all from superclasses''' + + if type(self)._allFieldsCache is None: + fields = [] + for cls in reversed(type(self).mro()): + fields.extend(getattr(cls, 'fields', [])) + type(self)._allFieldsCache = tuple(fields) + return type(self)._allFieldsCache + def validate(self): '''Check that all keys and values in the index conform to the specification''' keyCounts = collections.Counter(key for key, _ in self) keys = set(keyCounts) - permittedKeys = set(field.key for field in type(self).fields) + permittedKeys = set(field.key for field in self._allFields) unrecognisedKeys = keys - permittedKeys if unrecognisedKeys: raise IndexValidationError(f'Unrecognised key(s): {", ".join(sorted(unrecognisedKeys))}') - requiredKeys = set(field.key for field in type(self).fields if field.required) + requiredKeys = set(field.key for field in self._allFields if field.required) missingRequiredKeys = requiredKeys - keys if missingRequiredKeys: raise IndexValidationError(f'Missing required key(s): {", ".join(sorted(missingRequiredKeys))}') - repeatableKeys = set(field.key for field in type(self).fields if field.repeatable) + repeatableKeys = set(field.key for field in self._allFields if field.repeatable) repeatedKeys = set(key for key, count in keyCounts.items() if count > 1) repeatedUnrepeatableKeys = repeatedKeys - repeatableKeys if repeatedUnrepeatableKeys: diff --git a/codearchiver/modules/git.py b/codearchiver/modules/git.py index fda28e5..a3e5e98 100644 --- a/codearchiver/modules/git.py +++ b/codearchiver/modules/git.py @@ -13,12 +13,12 @@ _logger = logging.getLogger(__name__) class GitIndex(codearchiver.core.Index): - fields = [ + fields = ( codearchiver.core.IndexField(key = 'Based on bundle', required = False, repeatable = True), codearchiver.core.IndexField(key = 'Ref', required = True, repeatable = True), codearchiver.core.IndexField(key = 'Root commit', required = True, repeatable = True), codearchiver.core.IndexField(key = 'Commit', required = False, repeatable = True), - ] + ) class Git(codearchiver.core.Module):