Skip to content

Commit

Permalink
Add a new filter to remove/replace a whole category
Browse files Browse the repository at this point in the history
  • Loading branch information
benmwebb committed Jan 15, 2025
1 parent 61d81f0 commit f45e4cb
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 6 deletions.
2 changes: 2 additions & 0 deletions docs/format.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,7 @@ The :mod:`ihm.format` Python module

.. autoclass:: ChangeKeywordFilter

.. autoclass:: ReplaceCategoryFilter

.. autoexception:: CifParserError
:members:
108 changes: 102 additions & 6 deletions ihm/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
import textwrap
import operator
import ihm
if sys.version_info[0] >= 3:
from io import StringIO
else:
from io import BytesIO as StringIO
# getargspec is deprecated in Python 3, but getfullargspec has a very
# similar interface
try:
Expand Down Expand Up @@ -666,6 +670,13 @@ def __init__(self, target):
self.category = '_' + ts[0]
self.keyword = ts[-1]

def _set_category_from_target(self, target):
if target.startswith('_'):
self.category = target
else:
self.category = '_' + target
self.keyword = None

def match_token_category(self, tok):
"""Return true iff the given token matches the target's category"""
return self.category is None or tok.category == self.category
Expand All @@ -684,6 +695,16 @@ def filter_category(self, tok):
"""
raise NotImplementedError

def filter_loop_header(self, tok):
"""Filter the given loop header token.
:return: the original token (which must not have been modified),
a replacement token, or None if the token should be
deleted. If the header token is replaced or deleted,
all of the original loop rows will also be deleted.
"""
return tok

def get_loop_filter(self, tok):
"""Given a loop header token, potentially return a handler for each
loop row token. This function is also permitted to alter the
Expand Down Expand Up @@ -791,6 +812,67 @@ def get_loop_filter(self, tok):
tok.keywords[keyword_index].token.keyword = self.new


class ReplaceCategoryFilter(Filter):
"""Replace any token from the file that sets the given category.
This can also be used to completely remove a category if no
replacement is given.
:param str target: the mmCIF category name this filter should act on,
such as ``_entity``.
:param str raw_cif: if given, text in mmCIF format which should replace
the first instance of the category.
:param dumper: if given, a dumper object that should generate mmCIF
output to replace the first instance of the category.
:type dumper: :class:`ihm.dumper.Dumper`
:param system: the System that the given dumper will work on.
:type system: :class:`ihm.System`
"""

class _RawCifToken(_Token):
__slots__ = ['txt']
category = keyword = None

def __init__(self, txt):
self.txt = txt

def as_mmcif(self):
return self.txt

def __init__(self, target, raw_cif=None, dumper=None, system=None):
self._set_category_from_target(target)
self.raw_cif = raw_cif
self.dumper = dumper
self.system = system
#: The number of times the category was found in the mmCIF file
self.num_matches = 0

def _get_replacement_token(self):
if self.num_matches > 1:
return None
if self.raw_cif:
return self._RawCifToken(self.raw_cif)
elif self.dumper and self.system:
fh = StringIO()
writer = CifWriter(fh)
self.dumper.finalize(self.system)
self.dumper.dump(self.system, writer)
return self._RawCifToken(fh.getvalue())

def filter_category(self, tok):
if self.match_token_category(tok):
self.num_matches += 1
return self._get_replacement_token()
else:
return tok

def filter_loop_header(self, tok):
return self.filter_category(tok)

def get_loop_filter(self, tok):
return None


class CifTokenReader(_PreservingCifTokenizer):
"""Read an mmCIF file and break it into tokens.
Expand Down Expand Up @@ -836,13 +918,19 @@ def _read_file_with_filters(self, filters):
if isinstance(tok, _CategoryTokenGroup):
tok = self._filter_category(tok, filters)
elif isinstance(tok, ihm.format._LoopHeaderTokenGroup):
remove_all_loop_rows = False
loop_filters = [f.get_loop_filter(tok) for f in filters]
loop_filters = [f for f in loop_filters if f is not None]
# Did filters remove all keywords from the loop?
if all(isinstance(k.token, _NullToken) for k in tok.keywords):
tok = None
new_tok = self._filter_loop_header(tok, filters)
if new_tok is not tok:
tok = new_tok
remove_all_loop_rows = True
else:
remove_all_loop_rows = False
loop_filters = [f.get_loop_filter(tok) for f in filters]
loop_filters = [f for f in loop_filters if f is not None]
# Did filters remove all keywords from the loop?
if all(isinstance(k.token, _NullToken)
for k in tok.keywords):
tok = None
remove_all_loop_rows = True
elif isinstance(tok, ihm.format._LoopRowTokenGroup):
if remove_all_loop_rows:
tok = None
Expand All @@ -858,6 +946,14 @@ def _filter_category(self, tok, filters):
return
return tok

def _filter_loop_header(self, tok, filters):
orig_tok = tok
for f in filters:
tok = f.filter_loop_header(tok)
if tok is not orig_tok:
break
return tok

def _filter_loop(self, tok, filters):
for f in filters:
tok = f(tok)
Expand Down
41 changes: 41 additions & 0 deletions test/test_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
TOPDIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
utils.set_search_paths(TOPDIR)
import ihm.format
import ihm.dumper

try:
from ihm import _format
Expand Down Expand Up @@ -956,6 +957,46 @@ def test_cif_token_reader_filter(self):
_foo.baz
newa b c d
x y
""")

def test_cif_token_reader_replace_category_filter(self):
"""Test CifTokenReader class with ReplaceCategoryFilter"""
cif = """
data_foo_bar
#
_cat1.bar old
#
loop_
_cat2.bar
_cat2.baz
a b c d
x y
#
_cat3.x 1
_cat3.y 2
#
_cat4.z 1
"""
d = ihm.dumper._CommentDumper()
s = ihm.System()
s.comments.extend(['comment1', 'comment2'])
r = ihm.format.CifTokenReader(StringIO(cif))
filters = [ihm.format.ReplaceCategoryFilter("cat1"),
ihm.format.ReplaceCategoryFilter("cat2", raw_cif='FOO'),
ihm.format.ReplaceCategoryFilter("cat3", dumper=d,
system=s)]
tokens = list(r.read_file(filters))
new_cif = "".join(x.as_mmcif() for x in tokens)
self.assertEqual(new_cif, """
data_foo_bar
#
#
FOO
#
# comment1
# comment2
#
_cat4.z 1
""")

def test_category_token_group(self):
Expand Down

0 comments on commit f45e4cb

Please sign in to comment.