Skip to content

Commit

Permalink
Make mypy happy
Browse files Browse the repository at this point in the history
  • Loading branch information
alanakbik committed Jan 2, 2025
1 parent f50c3b3 commit 142703b
Showing 1 changed file with 9 additions and 13 deletions.
22 changes: 9 additions & 13 deletions flair/models/relation_classifier_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,8 +435,8 @@ def _encode_sentence(
# since there may be multiple occurrences of the same entity mentioned in the sentence.
# Therefore, we use the span's position in the sentence.
encoded_sentence_tokens: list[str] = []
head_idx = None
tail_idx = None
head_idx = -10000
tail_idx = 10000
for token in original_sentence:
if token is head.span[0]:
head_idx = len(encoded_sentence_tokens)
Expand All @@ -456,8 +456,6 @@ def _encode_sentence(
if abs(head_idx - tail_idx) > self._max_allowed_tokens_between_entities:
return None

print(head_idx, tail_idx)

# remove excess tokens left and right of entity pair to make encoded sentence shorter
encoded_sentence_tokens = self._slice_encoded_sentence_to_max_allowed_length(
encoded_sentence_tokens, head_idx, tail_idx
Expand Down Expand Up @@ -511,13 +509,15 @@ def _encode_sentence_for_inference(
Returns: Encoded sentences annotated with their gold relation and the corresponding relation in the original sentence
"""
for head, tail, gold_label in self._entity_pair_permutations(sentence):
masked_sentence: EncodedSentence = self._encode_sentence(
masked_sentence = self._encode_sentence(
head=head,
tail=tail,
gold_label=gold_label if gold_label is not None else self.zero_tag_value,
)
original_relation: Relation = Relation(first=head.span, second=tail.span)
yield masked_sentence, original_relation

if masked_sentence is not None:
yield masked_sentence, original_relation

def _encode_sentence_for_training(self, sentence: Sentence) -> Iterator[EncodedSentence]:
"""Create Encoded Sentences and Relation pairs for Training.
Expand All @@ -534,13 +534,14 @@ def _encode_sentence_for_training(self, sentence: Sentence) -> Iterator[EncodedS
else:
continue # Skip generated data points that do not express an originally annotated relation

masked_sentence: EncodedSentence = self._encode_sentence(
masked_sentence = self._encode_sentence(
head=head,
tail=tail,
gold_label=gold_label,
)

yield masked_sentence
if masked_sentence is not None:
yield masked_sentence

def transform_sentence(self, sentences: Union[Sentence, list[Sentence]]) -> list[EncodedSentence]:
"""Transforms sentences into encoded sentences specific to the `RelationClassifier`.
Expand All @@ -562,7 +563,6 @@ def transform_sentence(self, sentences: Union[Sentence, list[Sentence]]) -> list
encoded_sentence
for sentence in sentences
for encoded_sentence in self._encode_sentence_for_training(sentence)
if encoded_sentence is not None
]

def transform_dataset(self, dataset: Dataset[Sentence]) -> FlairDatapointDataset[EncodedSentence]:
Expand Down Expand Up @@ -691,10 +691,6 @@ def predict(
)
)

sentences_with_relation_reference = [
item for item in sentences_with_relation_reference if item[0] is not None
]

encoded_sentences = [x[0] for x in sentences_with_relation_reference]
loss = super().predict(
encoded_sentences,
Expand Down

0 comments on commit 142703b

Please sign in to comment.