-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimize RelationClassifier by adding the option to filter long sentences and truncate context #3593
base: master
Are you sure you want to change the base?
Optimize RelationClassifier by adding the option to filter long sentences and truncate context #3593
Changes from 9 commits
fc786b3
594d858
8fc8a58
1fd1851
7f89bb0
70148da
f50c3b3
142703b
3ad499b
f798a3c
0cca6a8
4ed2e49
306412f
4fc4878
de8b7f4
6789a6a
a06fa30
3ce22f1
6f8d168
cbd8be3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -256,6 +256,8 @@ def __init__( | |
encoding_strategy: EncodingStrategy = TypedEntityMarker(), | ||
zero_tag_value: str = "O", | ||
allow_unk_tag: bool = True, | ||
max_allowed_tokens_between_entities: int = 20, | ||
max_surrounding_context_length: int = 10, | ||
**classifierargs, | ||
) -> None: | ||
"""Initializes a `RelationClassifier`. | ||
|
@@ -271,13 +273,18 @@ def __init__( | |
encoding_strategy: An instance of a class conforming the :class:`EncodingStrategy` protocol | ||
zero_tag_value: The label to use for out-of-class relations | ||
allow_unk_tag: If `False`, removes `<unk>` from the passed label dictionary, otherwise do nothing. | ||
max_allowed_tokens_between_entities: The maximum allowed number of allowed tokens between entities. All other entity pairs are filtered from consideration. | ||
max_surrounding_context_length: The maximum length of context around entity pairs that will be considered. | ||
classifierargs: The remaining parameters passed to the underlying :class:`flair.models.DefaultClassifier` | ||
""" | ||
# Set label type and prepare label dictionary | ||
self._label_type = label_type | ||
self._zero_tag_value = zero_tag_value | ||
self._allow_unk_tag = allow_unk_tag | ||
|
||
self._max_allowed_tokens_between_entities = max_allowed_tokens_between_entities | ||
self._max_surrounding_context_length = max_surrounding_context_length | ||
|
||
modified_label_dictionary: Dictionary = Dictionary(add_unk=self._allow_unk_tag) | ||
modified_label_dictionary.add_item(self._zero_tag_value) | ||
for label in label_dictionary.get_items(): | ||
|
@@ -398,7 +405,7 @@ def _encode_sentence( | |
head: _Entity, | ||
tail: _Entity, | ||
gold_label: Optional[str] = None, | ||
) -> EncodedSentence: | ||
) -> Optional[EncodedSentence]: | ||
"""Returns a new Sentence object with masked/marked head and tail spans according to the encoding strategy. | ||
|
||
If provided, the encoded sentence also has the corresponding gold label annotation from :attr:`~label_type`. | ||
|
@@ -414,6 +421,12 @@ def _encode_sentence( | |
original_sentence: Sentence = head.span.sentence | ||
assert original_sentence is tail.span.sentence, "The head and tail need to come from the same sentence." | ||
|
||
# Sanity check: Do not create a labeled span if one entity contains the other | ||
if head.span[0].idx <= tail.span[0].idx and head.span[-1].idx >= tail.span[-1].idx: | ||
return None | ||
if head.span[0].idx >= tail.span[0].idx and head.span[-1].idx <= tail.span[-1].idx: | ||
return None | ||
|
||
# Pre-compute non-leading head and tail tokens for entity masking | ||
non_leading_head_tokens: list[Token] = head.span.tokens[1:] | ||
non_leading_tail_tokens: list[Token] = tail.span.tokens[1:] | ||
|
@@ -422,11 +435,15 @@ def _encode_sentence( | |
# since there may be multiple occurrences of the same entity mentioned in the sentence. | ||
# Therefore, we use the span's position in the sentence. | ||
encoded_sentence_tokens: list[str] = [] | ||
head_idx = -10000 | ||
tail_idx = 10000 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For safety and as a sanity check, we could also initialize these values as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I had that before, but that gave me mypy errors. But I agree that this is better. |
||
for token in original_sentence: | ||
if token is head.span[0]: | ||
head_idx = len(encoded_sentence_tokens) | ||
encoded_sentence_tokens.append(self.encoding_strategy.encode_head(head.span, head.label)) | ||
|
||
elif token is tail.span[0]: | ||
tail_idx = len(encoded_sentence_tokens) | ||
encoded_sentence_tokens.append(self.encoding_strategy.encode_tail(tail.span, tail.label)) | ||
|
||
elif all( | ||
|
@@ -435,6 +452,15 @@ def _encode_sentence( | |
): | ||
encoded_sentence_tokens.append(token.text) | ||
|
||
# filter cases in which the distance between the two entities is too large | ||
if abs(head_idx - tail_idx) > self._max_allowed_tokens_between_entities: | ||
return None | ||
|
||
# remove excess tokens left and right of entity pair to make encoded sentence shorter | ||
encoded_sentence_tokens = self._slice_encoded_sentence_to_max_allowed_length( | ||
encoded_sentence_tokens, head_idx, tail_idx | ||
) | ||
|
||
# Create masked sentence | ||
encoded_sentence: EncodedSentence = EncodedSentence( | ||
" ".join(encoded_sentence_tokens), use_tokenizer=SpaceTokenizer() | ||
|
@@ -448,6 +474,20 @@ def _encode_sentence( | |
encoded_sentence.copy_context_from_sentence(original_sentence) | ||
return encoded_sentence | ||
|
||
def _slice_encoded_sentence_to_max_allowed_length(self, encoded_sentence_tokens, head_idx, tail_idx): | ||
begin_slice = head_idx if head_idx < tail_idx else tail_idx | ||
end_slice = tail_idx if head_idx < tail_idx else head_idx | ||
padding_amount = self._max_surrounding_context_length | ||
begin_slice = begin_slice - padding_amount if begin_slice - padding_amount > 0 else 0 | ||
end_slice = ( | ||
end_slice + padding_amount + 1 | ||
if end_slice + padding_amount + 1 < len(encoded_sentence_tokens) | ||
else len(encoded_sentence_tokens) | ||
) | ||
|
||
encoded_sentence_tokens = encoded_sentence_tokens[begin_slice:end_slice] | ||
return encoded_sentence_tokens | ||
|
||
def _encode_sentence_for_inference( | ||
self, | ||
sentence: Sentence, | ||
|
@@ -469,13 +509,15 @@ def _encode_sentence_for_inference( | |
Returns: Encoded sentences annotated with their gold relation and the corresponding relation in the original sentence | ||
""" | ||
for head, tail, gold_label in self._entity_pair_permutations(sentence): | ||
masked_sentence: EncodedSentence = self._encode_sentence( | ||
masked_sentence = self._encode_sentence( | ||
head=head, | ||
tail=tail, | ||
gold_label=gold_label if gold_label is not None else self.zero_tag_value, | ||
) | ||
original_relation: Relation = Relation(first=head.span, second=tail.span) | ||
yield masked_sentence, original_relation | ||
|
||
if masked_sentence is not None: | ||
yield masked_sentence, original_relation | ||
|
||
def _encode_sentence_for_training(self, sentence: Sentence) -> Iterator[EncodedSentence]: | ||
"""Create Encoded Sentences and Relation pairs for Training. | ||
|
@@ -492,13 +534,14 @@ def _encode_sentence_for_training(self, sentence: Sentence) -> Iterator[EncodedS | |
else: | ||
continue # Skip generated data points that do not express an originally annotated relation | ||
|
||
masked_sentence: EncodedSentence = self._encode_sentence( | ||
masked_sentence = self._encode_sentence( | ||
head=head, | ||
tail=tail, | ||
gold_label=gold_label, | ||
) | ||
|
||
yield masked_sentence | ||
if masked_sentence is not None: | ||
yield masked_sentence | ||
|
||
def transform_sentence(self, sentences: Union[Sentence, list[Sentence]]) -> list[EncodedSentence]: | ||
"""Transforms sentences into encoded sentences specific to the `RelationClassifier`. | ||
|
@@ -706,6 +749,8 @@ def _get_state_dict(self) -> dict[str, Any]: | |
"encoding_strategy": self.encoding_strategy, | ||
"zero_tag_value": self.zero_tag_value, | ||
"allow_unk_tag": self.allow_unk_tag, | ||
"max_allowed_tokens_between_entities": self._max_allowed_tokens_between_entities, | ||
"max_surrounding_context_length": self._max_surrounding_context_length, | ||
} | ||
return model_state | ||
|
||
|
@@ -723,6 +768,8 @@ def _init_model_with_state_dict(cls, state: dict[str, Any], **kwargs): | |
encoding_strategy=state["encoding_strategy"], | ||
zero_tag_value=state["zero_tag_value"], | ||
allow_unk_tag=state["allow_unk_tag"], | ||
max_allowed_tokens_between_entities=state.get("max_allowed_tokens_between_entities", 25), | ||
max_surrounding_context_length=state.get("max_surrounding_context_length", 50), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The default parameters for backwards compatibility are different to the ones in the ̀ init` method. Are these a better fit? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. backwards compatibility is tricky, since older models will have no limitations on max allowed tokens or surrounding context. It's probably best to set really high numbers here (e.g., even higher). |
||
**kwargs, | ||
) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe if would make sense to also allow for
None
values to disable the measures. ̀This could then also be used as default value for_init_model_with_state_dict
to not change the behaviour of existing models.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea!