From 142703b0fd6ad5d95d25d30f8054ae063c2d8caf Mon Sep 17 00:00:00 2001 From: alanakbik Date: Thu, 2 Jan 2025 21:44:28 +0100 Subject: [PATCH] Make mypy happy --- flair/models/relation_classifier_model.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/flair/models/relation_classifier_model.py b/flair/models/relation_classifier_model.py index dadc17c053..fe52791479 100644 --- a/flair/models/relation_classifier_model.py +++ b/flair/models/relation_classifier_model.py @@ -435,8 +435,8 @@ def _encode_sentence( # since there may be multiple occurrences of the same entity mentioned in the sentence. # Therefore, we use the span's position in the sentence. encoded_sentence_tokens: list[str] = [] - head_idx = None - tail_idx = None + head_idx = -10000 + tail_idx = 10000 for token in original_sentence: if token is head.span[0]: head_idx = len(encoded_sentence_tokens) @@ -456,8 +456,6 @@ def _encode_sentence( if abs(head_idx - tail_idx) > self._max_allowed_tokens_between_entities: return None - print(head_idx, tail_idx) - # remove excess tokens left and right of entity pair to make encoded sentence shorter encoded_sentence_tokens = self._slice_encoded_sentence_to_max_allowed_length( encoded_sentence_tokens, head_idx, tail_idx @@ -511,13 +509,15 @@ def _encode_sentence_for_inference( Returns: Encoded sentences annotated with their gold relation and the corresponding relation in the original sentence """ for head, tail, gold_label in self._entity_pair_permutations(sentence): - masked_sentence: EncodedSentence = self._encode_sentence( + masked_sentence = self._encode_sentence( head=head, tail=tail, gold_label=gold_label if gold_label is not None else self.zero_tag_value, ) original_relation: Relation = Relation(first=head.span, second=tail.span) - yield masked_sentence, original_relation + + if masked_sentence is not None: + yield masked_sentence, original_relation def _encode_sentence_for_training(self, sentence: Sentence) -> Iterator[EncodedSentence]: """Create Encoded Sentences and Relation pairs for Training. @@ -534,13 +534,14 @@ def _encode_sentence_for_training(self, sentence: Sentence) -> Iterator[EncodedS else: continue # Skip generated data points that do not express an originally annotated relation - masked_sentence: EncodedSentence = self._encode_sentence( + masked_sentence = self._encode_sentence( head=head, tail=tail, gold_label=gold_label, ) - yield masked_sentence + if masked_sentence is not None: + yield masked_sentence def transform_sentence(self, sentences: Union[Sentence, list[Sentence]]) -> list[EncodedSentence]: """Transforms sentences into encoded sentences specific to the `RelationClassifier`. @@ -562,7 +563,6 @@ def transform_sentence(self, sentences: Union[Sentence, list[Sentence]]) -> list encoded_sentence for sentence in sentences for encoded_sentence in self._encode_sentence_for_training(sentence) - if encoded_sentence is not None ] def transform_dataset(self, dataset: Dataset[Sentence]) -> FlairDatapointDataset[EncodedSentence]: @@ -691,10 +691,6 @@ def predict( ) ) - sentences_with_relation_reference = [ - item for item in sentences_with_relation_reference if item[0] is not None - ] - encoded_sentences = [x[0] for x in sentences_with_relation_reference] loss = super().predict( encoded_sentences,