Skip to content

Commit

Permalink
Ensure presence of head and tail entity in the original sentence
Browse files Browse the repository at this point in the history
  • Loading branch information
dobbersc committed Jan 15, 2025
1 parent f798a3c commit 0cca6a8
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions flair/models/relation_classifier_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def __init__(
allow_unk_tag: bool = True,
max_allowed_tokens_between_entities: int = 20,
max_surrounding_context_length: int = 10,
**classifierargs,
**classifierargs: Any,
) -> None:
"""Initializes a `RelationClassifier`.
Expand Down Expand Up @@ -435,8 +435,8 @@ def _encode_sentence(
# since there may be multiple occurrences of the same entity mentioned in the sentence.
# Therefore, we use the span's position in the sentence.
encoded_sentence_tokens: list[str] = []
head_idx = -10000
tail_idx = 10000
head_idx: Optional[int] = None
tail_idx: Optional[int] = None
for token in original_sentence:
if token is head.span[0]:
head_idx = len(encoded_sentence_tokens)
Expand All @@ -452,11 +452,19 @@ def _encode_sentence(
):
encoded_sentence_tokens.append(token.text)

# filter cases in which the distance between the two entities is too large
msg: str
if head_idx is None:
msg = f"The head entity ({head!r}) is not located inside the original sentence ({original_sentence!r})."
raise AssertionError(msg)
if tail_idx is None:
msg = f"The tail entity ({tail!r}) is not located inside the original sentence ({original_sentence!r})."
raise AssertionError(msg)

# Filter cases in which the distance between the two entities is too large
if abs(head_idx - tail_idx) > self._max_allowed_tokens_between_entities:
return None

# remove excess tokens left and right of entity pair to make encoded sentence shorter
# Remove excess tokens left and right of entity pair to make encoded sentence shorter
encoded_sentence_tokens = self._slice_encoded_sentence_to_max_allowed_length(
encoded_sentence_tokens, head_idx, tail_idx
)
Expand Down

0 comments on commit 0cca6a8

Please sign in to comment.