diff --git a/docs/format.rst b/docs/format.rst index 90129fb..984a7e9 100644 --- a/docs/format.rst +++ b/docs/format.rst @@ -25,5 +25,7 @@ The :mod:`ihm.format` Python module .. autoclass:: ChangeKeywordFilter +.. autoclass:: ReplaceCategoryFilter + .. autoexception:: CifParserError :members: diff --git a/ihm/format.py b/ihm/format.py index 172b361..fadacc2 100644 --- a/ihm/format.py +++ b/ihm/format.py @@ -14,6 +14,10 @@ import textwrap import operator import ihm +if sys.version_info[0] >= 3: + from io import StringIO +else: + from io import BytesIO as StringIO # getargspec is deprecated in Python 3, but getfullargspec has a very # similar interface try: @@ -666,6 +670,13 @@ def __init__(self, target): self.category = '_' + ts[0] self.keyword = ts[-1] + def _set_category_from_target(self, target): + if target.startswith('_'): + self.category = target + else: + self.category = '_' + target + self.keyword = None + def match_token_category(self, tok): """Return true iff the given token matches the target's category""" return self.category is None or tok.category == self.category @@ -684,6 +695,16 @@ def filter_category(self, tok): """ raise NotImplementedError + def filter_loop_header(self, tok): + """Filter the given loop header token. + + :return: the original token (which must not have been modified), + a replacement token, or None if the token should be + deleted. If the header token is replaced or deleted, + all of the original loop rows will also be deleted. + """ + return tok + def get_loop_filter(self, tok): """Given a loop header token, potentially return a handler for each loop row token. This function is also permitted to alter the @@ -791,6 +812,67 @@ def get_loop_filter(self, tok): tok.keywords[keyword_index].token.keyword = self.new +class ReplaceCategoryFilter(Filter): + """Replace any token from the file that sets the given category. + + This can also be used to completely remove a category if no + replacement is given. + + :param str target: the mmCIF category name this filter should act on, + such as ``_entity``. + :param str raw_cif: if given, text in mmCIF format which should replace + the first instance of the category. + :param dumper: if given, a dumper object that should generate mmCIF + output to replace the first instance of the category. + :type dumper: :class:`ihm.dumper.Dumper` + :param system: the System that the given dumper will work on. + :type system: :class:`ihm.System` + """ + + class _RawCifToken(_Token): + __slots__ = ['txt'] + category = keyword = None + + def __init__(self, txt): + self.txt = txt + + def as_mmcif(self): + return self.txt + + def __init__(self, target, raw_cif=None, dumper=None, system=None): + self._set_category_from_target(target) + self.raw_cif = raw_cif + self.dumper = dumper + self.system = system + #: The number of times the category was found in the mmCIF file + self.num_matches = 0 + + def _get_replacement_token(self): + if self.num_matches > 1: + return None + if self.raw_cif: + return self._RawCifToken(self.raw_cif) + elif self.dumper and self.system: + fh = StringIO() + writer = CifWriter(fh) + self.dumper.finalize(self.system) + self.dumper.dump(self.system, writer) + return self._RawCifToken(fh.getvalue()) + + def filter_category(self, tok): + if self.match_token_category(tok): + self.num_matches += 1 + return self._get_replacement_token() + else: + return tok + + def filter_loop_header(self, tok): + return self.filter_category(tok) + + def get_loop_filter(self, tok): + return None + + class CifTokenReader(_PreservingCifTokenizer): """Read an mmCIF file and break it into tokens. @@ -836,13 +918,19 @@ def _read_file_with_filters(self, filters): if isinstance(tok, _CategoryTokenGroup): tok = self._filter_category(tok, filters) elif isinstance(tok, ihm.format._LoopHeaderTokenGroup): - remove_all_loop_rows = False - loop_filters = [f.get_loop_filter(tok) for f in filters] - loop_filters = [f for f in loop_filters if f is not None] - # Did filters remove all keywords from the loop? - if all(isinstance(k.token, _NullToken) for k in tok.keywords): - tok = None + new_tok = self._filter_loop_header(tok, filters) + if new_tok is not tok: + tok = new_tok remove_all_loop_rows = True + else: + remove_all_loop_rows = False + loop_filters = [f.get_loop_filter(tok) for f in filters] + loop_filters = [f for f in loop_filters if f is not None] + # Did filters remove all keywords from the loop? + if all(isinstance(k.token, _NullToken) + for k in tok.keywords): + tok = None + remove_all_loop_rows = True elif isinstance(tok, ihm.format._LoopRowTokenGroup): if remove_all_loop_rows: tok = None @@ -858,6 +946,14 @@ def _filter_category(self, tok, filters): return return tok + def _filter_loop_header(self, tok, filters): + orig_tok = tok + for f in filters: + tok = f.filter_loop_header(tok) + if tok is not orig_tok: + break + return tok + def _filter_loop(self, tok, filters): for f in filters: tok = f(tok) diff --git a/test/test_format.py b/test/test_format.py index 4425da1..7faf961 100644 --- a/test/test_format.py +++ b/test/test_format.py @@ -15,6 +15,7 @@ TOPDIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) utils.set_search_paths(TOPDIR) import ihm.format +import ihm.dumper try: from ihm import _format @@ -956,6 +957,46 @@ def test_cif_token_reader_filter(self): _foo.baz newa b c d x y +""") + + def test_cif_token_reader_replace_category_filter(self): + """Test CifTokenReader class with ReplaceCategoryFilter""" + cif = """ +data_foo_bar +# +_cat1.bar old +# +loop_ +_cat2.bar +_cat2.baz +a b c d +x y +# +_cat3.x 1 +_cat3.y 2 +# +_cat4.z 1 +""" + d = ihm.dumper._CommentDumper() + s = ihm.System() + s.comments.extend(['comment1', 'comment2']) + r = ihm.format.CifTokenReader(StringIO(cif)) + filters = [ihm.format.ReplaceCategoryFilter("cat1"), + ihm.format.ReplaceCategoryFilter("cat2", raw_cif='FOO'), + ihm.format.ReplaceCategoryFilter("cat3", dumper=d, + system=s)] + tokens = list(r.read_file(filters)) + new_cif = "".join(x.as_mmcif() for x in tokens) + self.assertEqual(new_cif, """ +data_foo_bar +# +# +FOO +# +# comment1 +# comment2 +# +_cat4.z 1 """) def test_category_token_group(self):