Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize RelationClassifier by adding the option to filter long sentences and truncate context #3593

Open
wants to merge 20 commits into
base: master
Choose a base branch
from
Open
Changes from 9 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
fc786b3
Optimize RelationClassifier by filtering long sentences
alanakbik Jan 2, 2025
594d858
Optimize RelationClassifier by filtering long sentences
alanakbik Jan 2, 2025
8fc8a58
Fix serialization
alanakbik Jan 2, 2025
1fd1851
Change context window calculation
alanakbik Jan 2, 2025
7f89bb0
Change context window calculation
alanakbik Jan 2, 2025
70148da
Add sanity check to ensure entities are not contained in one another
alanakbik Jan 2, 2025
f50c3b3
Fix slicing such that left and right context are of equal length
alanakbik Jan 2, 2025
142703b
Make mypy happy
alanakbik Jan 2, 2025
3ad499b
Remove unnecessary if statement
alanakbik Jan 2, 2025
f798a3c
Merge branch 'master' into filter_relations
alanakbik Jan 11, 2025
0cca6a8
Ensure presence of head and tail entity in the original sentence
dobbersc Jan 15, 2025
4ed2e49
Refactor `_slice_encoded_sentence_to_max_allowed_length` to use min/m…
dobbersc Jan 21, 2025
306412f
Allow to disable the `max_allowed_tokens_between_entities` and `max_s…
dobbersc Jan 21, 2025
4fc4878
Rearrange parameters and make sentence filters public
dobbersc Jan 21, 2025
de8b7f4
Add test cases for `max_allowed_tokens_between_entities` and `max_sur…
dobbersc Jan 22, 2025
6789a6a
Merge branch 'master' into filter_relations
dobbersc Jan 22, 2025
a06fa30
Fix tests due to additional training data point in `train.conllup`
dobbersc Jan 22, 2025
3ce22f1
Merge remote-tracking branch 'origin/filter_relations' into filter_re…
dobbersc Jan 22, 2025
6f8d168
Fix serialization issue in ModelTrainer
alanakbik Jan 27, 2025
cbd8be3
Merge branch 'master' into filter_relations
alanakbik Feb 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 52 additions & 5 deletions flair/models/relation_classifier_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,8 @@ def __init__(
encoding_strategy: EncodingStrategy = TypedEntityMarker(),
zero_tag_value: str = "O",
allow_unk_tag: bool = True,
max_allowed_tokens_between_entities: int = 20,
max_surrounding_context_length: int = 10,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe if would make sense to also allow for None values to disable the measures. ̀This could then also be used as default value for _init_model_with_state_dict to not change the behaviour of existing models.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea!

**classifierargs,
) -> None:
"""Initializes a `RelationClassifier`.
Expand All @@ -271,13 +273,18 @@ def __init__(
encoding_strategy: An instance of a class conforming the :class:`EncodingStrategy` protocol
zero_tag_value: The label to use for out-of-class relations
allow_unk_tag: If `False`, removes `<unk>` from the passed label dictionary, otherwise do nothing.
max_allowed_tokens_between_entities: The maximum allowed number of allowed tokens between entities. All other entity pairs are filtered from consideration.
max_surrounding_context_length: The maximum length of context around entity pairs that will be considered.
classifierargs: The remaining parameters passed to the underlying :class:`flair.models.DefaultClassifier`
"""
# Set label type and prepare label dictionary
self._label_type = label_type
self._zero_tag_value = zero_tag_value
self._allow_unk_tag = allow_unk_tag

self._max_allowed_tokens_between_entities = max_allowed_tokens_between_entities
self._max_surrounding_context_length = max_surrounding_context_length

modified_label_dictionary: Dictionary = Dictionary(add_unk=self._allow_unk_tag)
modified_label_dictionary.add_item(self._zero_tag_value)
for label in label_dictionary.get_items():
Expand Down Expand Up @@ -398,7 +405,7 @@ def _encode_sentence(
head: _Entity,
tail: _Entity,
gold_label: Optional[str] = None,
) -> EncodedSentence:
) -> Optional[EncodedSentence]:
"""Returns a new Sentence object with masked/marked head and tail spans according to the encoding strategy.

If provided, the encoded sentence also has the corresponding gold label annotation from :attr:`~label_type`.
Expand All @@ -414,6 +421,12 @@ def _encode_sentence(
original_sentence: Sentence = head.span.sentence
assert original_sentence is tail.span.sentence, "The head and tail need to come from the same sentence."

# Sanity check: Do not create a labeled span if one entity contains the other
if head.span[0].idx <= tail.span[0].idx and head.span[-1].idx >= tail.span[-1].idx:
return None
if head.span[0].idx >= tail.span[0].idx and head.span[-1].idx <= tail.span[-1].idx:
return None

# Pre-compute non-leading head and tail tokens for entity masking
non_leading_head_tokens: list[Token] = head.span.tokens[1:]
non_leading_tail_tokens: list[Token] = tail.span.tokens[1:]
Expand All @@ -422,11 +435,15 @@ def _encode_sentence(
# since there may be multiple occurrences of the same entity mentioned in the sentence.
# Therefore, we use the span's position in the sentence.
encoded_sentence_tokens: list[str] = []
head_idx = -10000
tail_idx = 10000
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For safety and as a sanity check, we could also initialize these values as None and have an assertion after the for token in original_sentence loop that these variables should not be None.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had that before, but that gave me mypy errors. But I agree that this is better.

for token in original_sentence:
if token is head.span[0]:
head_idx = len(encoded_sentence_tokens)
encoded_sentence_tokens.append(self.encoding_strategy.encode_head(head.span, head.label))

elif token is tail.span[0]:
tail_idx = len(encoded_sentence_tokens)
encoded_sentence_tokens.append(self.encoding_strategy.encode_tail(tail.span, tail.label))

elif all(
Expand All @@ -435,6 +452,15 @@ def _encode_sentence(
):
encoded_sentence_tokens.append(token.text)

# filter cases in which the distance between the two entities is too large
if abs(head_idx - tail_idx) > self._max_allowed_tokens_between_entities:
return None

# remove excess tokens left and right of entity pair to make encoded sentence shorter
encoded_sentence_tokens = self._slice_encoded_sentence_to_max_allowed_length(
encoded_sentence_tokens, head_idx, tail_idx
)

# Create masked sentence
encoded_sentence: EncodedSentence = EncodedSentence(
" ".join(encoded_sentence_tokens), use_tokenizer=SpaceTokenizer()
Expand All @@ -448,6 +474,20 @@ def _encode_sentence(
encoded_sentence.copy_context_from_sentence(original_sentence)
return encoded_sentence

def _slice_encoded_sentence_to_max_allowed_length(self, encoded_sentence_tokens, head_idx, tail_idx):
begin_slice = head_idx if head_idx < tail_idx else tail_idx
end_slice = tail_idx if head_idx < tail_idx else head_idx
padding_amount = self._max_surrounding_context_length
begin_slice = begin_slice - padding_amount if begin_slice - padding_amount > 0 else 0
end_slice = (
end_slice + padding_amount + 1
if end_slice + padding_amount + 1 < len(encoded_sentence_tokens)
else len(encoded_sentence_tokens)
)

encoded_sentence_tokens = encoded_sentence_tokens[begin_slice:end_slice]
return encoded_sentence_tokens

def _encode_sentence_for_inference(
self,
sentence: Sentence,
Expand All @@ -469,13 +509,15 @@ def _encode_sentence_for_inference(
Returns: Encoded sentences annotated with their gold relation and the corresponding relation in the original sentence
"""
for head, tail, gold_label in self._entity_pair_permutations(sentence):
masked_sentence: EncodedSentence = self._encode_sentence(
masked_sentence = self._encode_sentence(
head=head,
tail=tail,
gold_label=gold_label if gold_label is not None else self.zero_tag_value,
)
original_relation: Relation = Relation(first=head.span, second=tail.span)
yield masked_sentence, original_relation

if masked_sentence is not None:
yield masked_sentence, original_relation

def _encode_sentence_for_training(self, sentence: Sentence) -> Iterator[EncodedSentence]:
"""Create Encoded Sentences and Relation pairs for Training.
Expand All @@ -492,13 +534,14 @@ def _encode_sentence_for_training(self, sentence: Sentence) -> Iterator[EncodedS
else:
continue # Skip generated data points that do not express an originally annotated relation

masked_sentence: EncodedSentence = self._encode_sentence(
masked_sentence = self._encode_sentence(
head=head,
tail=tail,
gold_label=gold_label,
)

yield masked_sentence
if masked_sentence is not None:
yield masked_sentence

def transform_sentence(self, sentences: Union[Sentence, list[Sentence]]) -> list[EncodedSentence]:
"""Transforms sentences into encoded sentences specific to the `RelationClassifier`.
Expand Down Expand Up @@ -706,6 +749,8 @@ def _get_state_dict(self) -> dict[str, Any]:
"encoding_strategy": self.encoding_strategy,
"zero_tag_value": self.zero_tag_value,
"allow_unk_tag": self.allow_unk_tag,
"max_allowed_tokens_between_entities": self._max_allowed_tokens_between_entities,
"max_surrounding_context_length": self._max_surrounding_context_length,
}
return model_state

Expand All @@ -723,6 +768,8 @@ def _init_model_with_state_dict(cls, state: dict[str, Any], **kwargs):
encoding_strategy=state["encoding_strategy"],
zero_tag_value=state["zero_tag_value"],
allow_unk_tag=state["allow_unk_tag"],
max_allowed_tokens_between_entities=state.get("max_allowed_tokens_between_entities", 25),
max_surrounding_context_length=state.get("max_surrounding_context_length", 50),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default parameters for backwards compatibility are different to the ones in the ̀ init` method. Are these a better fit?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

backwards compatibility is tricky, since older models will have no limitations on max allowed tokens or surrounding context. It's probably best to set really high numbers here (e.g., even higher).

**kwargs,
)

Expand Down
Loading