From 14337e4479c281d699bcad08460fcc8ad65f99cf Mon Sep 17 00:00:00 2001 From: Andrew Beveridge Date: Mon, 20 Jan 2025 01:30:28 -0500 Subject: [PATCH] Made actually usable syllable count based corrector, passed through unformatted reference words through to GapSequence objects to use in correctors --- lyrics_transcriber/core/controller.py | 134 ++++++-------- .../correction/anchor_sequence.py | 161 ++++++++-------- lyrics_transcriber/correction/corrector.py | 156 ++++++++-------- .../correction/handlers/levenshtein.py | 10 +- .../handlers/no_space_punct_match.py | 31 ++-- .../handlers/relaxed_word_count_match.py | 6 +- .../correction/handlers/sound_alike.py | 11 +- .../correction/handlers/syllables_match.py | 175 ++++++++++++++++++ .../correction/handlers/word_count_match.py | 9 +- lyrics_transcriber/types.py | 6 + poetry.lock | 125 ++++++++++++- pyproject.toml | 2 + 12 files changed, 557 insertions(+), 269 deletions(-) create mode 100644 lyrics_transcriber/correction/handlers/syllables_match.py diff --git a/lyrics_transcriber/core/controller.py b/lyrics_transcriber/core/controller.py index 34d461c..4ea93b2 100644 --- a/lyrics_transcriber/core/controller.py +++ b/lyrics_transcriber/core/controller.py @@ -184,49 +184,39 @@ def process(self) -> LyricsControllerResult: Raises: Exception: If a critical error occurs during processing. """ - try: - # Step 1: Fetch lyrics if artist and title are provided - if self.artist and self.title: - self.fetch_lyrics() + # Step 1: Fetch lyrics if artist and title are provided + if self.artist and self.title: + self.fetch_lyrics() - # Step 2: Run transcription - self.transcribe() + # Step 2: Run transcription + self.transcribe() - # Step 3: Process and correct lyrics - self.correct_lyrics() + # Step 3: Process and correct lyrics + self.correct_lyrics() - # Step 4: Generate outputs - self.generate_outputs() + # Step 4: Generate outputs + self.generate_outputs() - self.logger.info("Processing completed successfully") - return self.results - - except Exception as e: - self.logger.error(f"Error during processing: {str(e)}") - raise + self.logger.info("Processing completed successfully") + return self.results def fetch_lyrics(self) -> None: """Fetch lyrics from available providers.""" self.logger.info(f"Fetching lyrics for {self.artist} - {self.title}") - try: - for name, provider in self.lyrics_providers.items(): - try: - result = provider.fetch_lyrics(self.artist, self.title) - if result: - self.results.lyrics_results.append(result) - self.logger.info(f"Successfully fetched lyrics from {name}") - - except Exception as e: - self.logger.error(f"Failed to fetch lyrics from {name}: {str(e)}") - continue + for name, provider in self.lyrics_providers.items(): + try: + result = provider.fetch_lyrics(self.artist, self.title) + if result: + self.results.lyrics_results.append(result) + self.logger.info(f"Successfully fetched lyrics from {name}") - if not self.results.lyrics_results: - self.logger.warning("No lyrics found from any source") + except Exception as e: + self.logger.error(f"Failed to fetch lyrics from {name}: {str(e)}") + continue - except Exception as e: - self.logger.error(f"Failed to fetch lyrics: {str(e)}") - # Don't raise - we can continue without lyrics + if not self.results.lyrics_results: + self.logger.warning("No lyrics found from any source") def transcribe(self) -> None: """Run transcription using all available transcribers.""" @@ -234,18 +224,13 @@ def transcribe(self) -> None: for name, transcriber_info in self.transcribers.items(): self.logger.info(f"Running transcription with {name}") - try: - result = transcriber_info["instance"].transcribe(self.audio_filepath) - if result: - # Add the transcriber name and priority to the result - self.results.transcription_results.append( - TranscriptionResult(name=name, priority=transcriber_info["priority"], result=result) - ) - self.logger.debug(f"Transcription completed for {name}") - - except Exception as e: - self.logger.error(f"Transcription failed for {name}: {str(e)}", exc_info=True) - continue + result = transcriber_info["instance"].transcribe(self.audio_filepath) + if result: + # Add the transcriber name and priority to the result + self.results.transcription_results.append( + TranscriptionResult(name=name, priority=transcriber_info["priority"], result=result) + ) + self.logger.debug(f"Transcription completed for {name}") if not self.results.transcription_results: self.logger.warning("No successful transcriptions from any provider") @@ -254,44 +239,35 @@ def correct_lyrics(self) -> None: """Run lyrics correction using transcription and internet lyrics.""" self.logger.info("Starting lyrics correction process") - try: - # Run correction - corrected_data = self.corrector.run( - transcription_results=self.results.transcription_results, lyrics_results=self.results.lyrics_results - ) - - # Store corrected results - self.results.transcription_corrected = corrected_data - self.logger.info("Lyrics correction completed") + # Run correction + corrected_data = self.corrector.run( + transcription_results=self.results.transcription_results, lyrics_results=self.results.lyrics_results + ) - except Exception as e: - self.logger.error(f"Failed to correct lyrics: {str(e)}", exc_info=True) + # Store corrected results + self.results.transcription_corrected = corrected_data + self.logger.info("Lyrics correction completed") def generate_outputs(self) -> None: """Generate output files.""" self.logger.info("Generating output files") - try: - output_files = self.output_generator.generate_outputs( - transcription_corrected=self.results.transcription_corrected, - lyrics_results=self.results.lyrics_results, - output_prefix=self.output_prefix, - audio_filepath=self.audio_filepath, - artist=self.artist, - title=self.title, - ) - - # Store all output paths in results - self.results.lrc_filepath = output_files.lrc - self.results.ass_filepath = output_files.ass - self.results.video_filepath = output_files.video - self.results.original_txt = output_files.original_txt - self.results.corrected_txt = output_files.corrected_txt - self.results.corrections_json = output_files.corrections_json - self.results.cdg_filepath = output_files.cdg - self.results.mp3_filepath = output_files.mp3 - self.results.cdg_zip_filepath = output_files.cdg_zip - - except Exception as e: - self.logger.error(f"Failed to generate outputs: {str(e)}") - raise + output_files = self.output_generator.generate_outputs( + transcription_corrected=self.results.transcription_corrected, + lyrics_results=self.results.lyrics_results, + output_prefix=self.output_prefix, + audio_filepath=self.audio_filepath, + artist=self.artist, + title=self.title, + ) + + # Store all output paths in results + self.results.lrc_filepath = output_files.lrc + self.results.ass_filepath = output_files.ass + self.results.video_filepath = output_files.video + self.results.original_txt = output_files.original_txt + self.results.corrected_txt = output_files.corrected_txt + self.results.corrections_json = output_files.corrections_json + self.results.cdg_filepath = output_files.cdg + self.results.mp3_filepath = output_files.mp3 + self.results.cdg_zip_filepath = output_files.cdg_zip diff --git a/lyrics_transcriber/correction/anchor_sequence.py b/lyrics_transcriber/correction/anchor_sequence.py index e785915..cff5f72 100644 --- a/lyrics_transcriber/correction/anchor_sequence.py +++ b/lyrics_transcriber/correction/anchor_sequence.py @@ -347,51 +347,88 @@ def _get_reference_words(self, source: str, ref_words: List[str], start_pos: Opt end_pos = len(ref_words) return ref_words[start_pos:end_pos] - def _create_initial_gap( - self, words: List[str], first_anchor: Optional[ScoredAnchor], ref_texts_clean: Dict[str, List[str]] - ) -> Optional[GapSequence]: - """Create gap sequence before the first anchor. + def find_gaps(self, transcribed: str, anchors: List[ScoredAnchor], references: Dict[str, str]) -> List[GapSequence]: + """Find gaps between anchor sequences in the transcribed text.""" + cache_key = self._get_cache_key(transcribed, references) + cache_path = self.cache_dir / f"gaps_{cache_key}.json" - Args: - words: Transcribed words - first_anchor: First anchor sequence (or None if no anchors) - ref_texts_clean: Cleaned reference texts + # Try to load from cache + if cached_data := self._load_from_cache(cache_path): + self.logger.info("Loading gaps from cache") + return [GapSequence.from_dict(gap) for gap in cached_data] - Returns: - GapSequence if there are words before first anchor, None otherwise - """ + # If not in cache, perform the computation + self.logger.info("Cache miss - computing gaps") + words = self._clean_text(transcribed).split() + ref_texts_clean = {source: self._clean_text(text).split() for source, text in references.items()} + # Store original reference texts split into words + ref_texts_original = {source: text.split() for source, text in references.items()} + + gaps = [] + sorted_anchors = sorted(anchors, key=lambda x: x.anchor.transcription_position) + + # Handle initial gap + if initial_gap := self._create_initial_gap( + words, sorted_anchors[0] if sorted_anchors else None, ref_texts_clean, ref_texts_original + ): + gaps.append(initial_gap) + + # Handle gaps between anchors + for i in range(len(sorted_anchors) - 1): + if between_gap := self._create_between_gap( + words, sorted_anchors[i], sorted_anchors[i + 1], ref_texts_clean, ref_texts_original + ): + gaps.append(between_gap) + + # Handle final gap + if sorted_anchors and (final_gap := self._create_final_gap(words, sorted_anchors[-1], ref_texts_clean, ref_texts_original)): + gaps.append(final_gap) + + # Save to cache + self._save_to_cache(cache_path, [gap.to_dict() for gap in gaps]) + return gaps + + def _create_initial_gap( + self, + words: List[str], + first_anchor: Optional[ScoredAnchor], + ref_texts_clean: Dict[str, List[str]], + ref_texts_original: Dict[str, List[str]], + ) -> Optional[GapSequence]: + """Create gap sequence before the first anchor.""" if not first_anchor: ref_words = {source: words for source, words in ref_texts_clean.items()} - return GapSequence(words, 0, None, None, ref_words) + ref_words_original = {source: words for source, words in ref_texts_original.items()} + return GapSequence(words, 0, None, None, ref_words, ref_words_original) if first_anchor.anchor.transcription_position > 0: ref_words = {} - for source, ref_words_list in ref_texts_clean.items(): + ref_words_original = {} + for source in ref_texts_clean: end_pos = first_anchor.anchor.reference_positions.get(source) - ref_words[source] = self._get_reference_words(source, ref_words_list, None, end_pos) + ref_words[source] = self._get_reference_words(source, ref_texts_clean[source], None, end_pos) + ref_words_original[source] = self._get_reference_words(source, ref_texts_original[source], None, end_pos) - return GapSequence(words[: first_anchor.anchor.transcription_position], 0, None, first_anchor.anchor, ref_words) + return GapSequence( + words[: first_anchor.anchor.transcription_position], 0, None, first_anchor.anchor, ref_words, ref_words_original + ) return None def _create_between_gap( - self, words: List[str], current_anchor: ScoredAnchor, next_anchor: ScoredAnchor, ref_texts_clean: Dict[str, List[str]] + self, + words: List[str], + current_anchor: ScoredAnchor, + next_anchor: ScoredAnchor, + ref_texts_clean: Dict[str, List[str]], + ref_texts_original: Dict[str, List[str]], ) -> Optional[GapSequence]: - """Create gap sequence between two anchors. - - Args: - words: Transcribed words - current_anchor: Preceding anchor - next_anchor: Following anchor - ref_texts_clean: Cleaned reference texts - - Returns: - GapSequence if there are words between anchors, None otherwise - """ + """Create gap sequence between two anchors.""" gap_start = current_anchor.anchor.transcription_position + current_anchor.anchor.length gap_end = next_anchor.anchor.transcription_position if gap_end > gap_start: ref_words = {} + ref_words_original = {} shared_sources = set(current_anchor.anchor.reference_positions.keys()) & set(next_anchor.anchor.reference_positions.keys()) # Check for large position differences in next_anchor @@ -399,80 +436,36 @@ def _create_between_gap( positions = list(next_anchor.anchor.reference_positions.values()) max_diff = max(positions) - min(positions) if max_diff > 20: - # Find source with earliest position earliest_source = min(next_anchor.anchor.reference_positions.items(), key=lambda x: x[1])[0] self.logger.warning( - f"Large position difference ({max_diff} words) in next anchor '{' '.join(next_anchor.anchor.words)}'. " - f"Using only earliest source: {earliest_source} at position {next_anchor.anchor.reference_positions[earliest_source]}" + f"Large position difference ({max_diff} words) in next anchor. Using only earliest source: {earliest_source}" ) - # Only consider the earliest source for the gap shared_sources &= {earliest_source} for source in shared_sources: start_pos = current_anchor.anchor.reference_positions[source] + current_anchor.anchor.length end_pos = next_anchor.anchor.reference_positions[source] - words_list = self._get_reference_words(source, ref_texts_clean[source], start_pos, end_pos) - if words_list: # Only add source if it has words - ref_words[source] = words_list + ref_words[source] = self._get_reference_words(source, ref_texts_clean[source], start_pos, end_pos) + ref_words_original[source] = self._get_reference_words(source, ref_texts_original[source], start_pos, end_pos) - return GapSequence(words[gap_start:gap_end], gap_start, current_anchor.anchor, next_anchor.anchor, ref_words) + return GapSequence( + words[gap_start:gap_end], gap_start, current_anchor.anchor, next_anchor.anchor, ref_words, ref_words_original + ) return None def _create_final_gap( - self, words: List[str], last_anchor: ScoredAnchor, ref_texts_clean: Dict[str, List[str]] + self, words: List[str], last_anchor: ScoredAnchor, ref_texts_clean: Dict[str, List[str]], ref_texts_original: Dict[str, List[str]] ) -> Optional[GapSequence]: - """Create gap sequence after the last anchor. - - Args: - words: Transcribed words - last_anchor: Last anchor sequence - ref_texts_clean: Cleaned reference texts - - Returns: - GapSequence if there are words after last anchor, None otherwise - """ + """Create gap sequence after the last anchor.""" last_pos = last_anchor.anchor.transcription_position + last_anchor.anchor.length if last_pos < len(words): ref_words = {} - for source, ref_words_list in ref_texts_clean.items(): + ref_words_original = {} + for source in ref_texts_clean: if source in last_anchor.anchor.reference_positions: start_pos = last_anchor.anchor.reference_positions[source] + last_anchor.anchor.length - ref_words[source] = self._get_reference_words(source, ref_words_list, start_pos, None) + ref_words[source] = self._get_reference_words(source, ref_texts_clean[source], start_pos, None) + ref_words_original[source] = self._get_reference_words(source, ref_texts_original[source], start_pos, None) - return GapSequence(words[last_pos:], last_pos, last_anchor.anchor, None, ref_words) + return GapSequence(words[last_pos:], last_pos, last_anchor.anchor, None, ref_words, ref_words_original) return None - - def find_gaps(self, transcribed: str, anchors: List[ScoredAnchor], references: Dict[str, str]) -> List[GapSequence]: - """Find gaps between anchor sequences in the transcribed text.""" - cache_key = self._get_cache_key(transcribed, references) - cache_path = self.cache_dir / f"gaps_{cache_key}.json" - - # Try to load from cache - if cached_data := self._load_from_cache(cache_path): - self.logger.info("Loading gaps from cache") - return [GapSequence.from_dict(gap) for gap in cached_data] - - # If not in cache, perform the computation - self.logger.info("Cache miss - computing gaps") - words = self._clean_text(transcribed).split() - ref_texts_clean = {source: self._clean_text(text).split() for source, text in references.items()} - - gaps = [] - sorted_anchors = sorted(anchors, key=lambda x: x.anchor.transcription_position) - - # Handle initial gap - if initial_gap := self._create_initial_gap(words, sorted_anchors[0] if sorted_anchors else None, ref_texts_clean): - gaps.append(initial_gap) - - # Handle gaps between anchors - for i in range(len(sorted_anchors) - 1): - if between_gap := self._create_between_gap(words, sorted_anchors[i], sorted_anchors[i + 1], ref_texts_clean): - gaps.append(between_gap) - - # Handle final gap - if sorted_anchors and (final_gap := self._create_final_gap(words, sorted_anchors[-1], ref_texts_clean)): - gaps.append(final_gap) - - # Save to cache - self._save_to_cache(cache_path, [gap.to_dict() for gap in gaps]) - return gaps diff --git a/lyrics_transcriber/correction/corrector.py b/lyrics_transcriber/correction/corrector.py index 4b2f793..bff197b 100644 --- a/lyrics_transcriber/correction/corrector.py +++ b/lyrics_transcriber/correction/corrector.py @@ -4,6 +4,7 @@ from lyrics_transcriber.correction.handlers.no_space_punct_match import NoSpacePunctuationMatchHandler from lyrics_transcriber.correction.handlers.relaxed_word_count_match import RelaxedWordCountMatchHandler +from lyrics_transcriber.correction.handlers.syllables_match import SyllablesMatchHandler from lyrics_transcriber.types import GapSequence, LyricsData, TranscriptionResult, CorrectionResult, LyricsSegment, WordCorrection, Word from lyrics_transcriber.correction.anchor_sequence import AnchorSequenceFinder from lyrics_transcriber.correction.handlers.base import GapCorrectionHandler @@ -34,6 +35,7 @@ def __init__( WordCountMatchHandler(), RelaxedWordCountMatchHandler(), NoSpacePunctuationMatchHandler(), + SyllablesMatchHandler(), ExtraWordsHandler(), RepeatCorrectionHandler(), SoundAlikeHandler(), @@ -47,49 +49,43 @@ def run(self, transcription_results: List[TranscriptionResult], lyrics_results: self.logger.error("No transcription results available") raise ValueError("No primary transcription data available") - try: - # Get primary transcription - primary_transcription = sorted(transcription_results, key=lambda x: x.priority)[0].result - transcribed_text = " ".join(" ".join(w.text for w in segment.words) for segment in primary_transcription.segments) - reference_texts = {lyrics.source: lyrics.lyrics for lyrics in lyrics_results} - - # Find anchor sequences and gaps - self.logger.debug("Finding anchor sequences and gaps") - anchor_sequences = self.anchor_finder.find_anchors(transcribed_text, reference_texts) - gap_sequences = self.anchor_finder.find_gaps(transcribed_text, anchor_sequences, reference_texts) - - # Process corrections - corrections, corrected_segments = self._process_corrections(primary_transcription.segments, gap_sequences) - - # Calculate correction ratio - total_words = sum(len(segment.words) for segment in corrected_segments) - corrections_made = len(corrections) - correction_ratio = 1 - (corrections_made / total_words if total_words > 0 else 0) - - return CorrectionResult( - original_segments=primary_transcription.segments, - corrected_segments=corrected_segments, - corrected_text="\n".join(segment.text for segment in corrected_segments) + "\n", - corrections=corrections, - corrections_made=corrections_made, - confidence=correction_ratio, - transcribed_text=transcribed_text, - reference_texts=reference_texts, - anchor_sequences=anchor_sequences, - resized_segments=[], - gap_sequences=gap_sequences, - metadata={ - "anchor_sequences_count": len(anchor_sequences), - "gap_sequences_count": len(gap_sequences), - "total_words": total_words, - "correction_ratio": correction_ratio, - }, - ) - - except Exception as e: - self.logger.error(f"Correction failed: {str(e)}", exc_info=True) - # Return uncorrected transcription as fallback - return self._create_fallback_result(primary_transcription) + # Get primary transcription + primary_transcription = sorted(transcription_results, key=lambda x: x.priority)[0].result + transcribed_text = " ".join(" ".join(w.text for w in segment.words) for segment in primary_transcription.segments) + reference_texts = {lyrics.source: lyrics.lyrics for lyrics in lyrics_results} + + # Find anchor sequences and gaps + self.logger.debug("Finding anchor sequences and gaps") + anchor_sequences = self.anchor_finder.find_anchors(transcribed_text, reference_texts) + gap_sequences = self.anchor_finder.find_gaps(transcribed_text, anchor_sequences, reference_texts) + + # Process corrections + corrections, corrected_segments = self._process_corrections(primary_transcription.segments, gap_sequences) + + # Calculate correction ratio + total_words = sum(len(segment.words) for segment in corrected_segments) + corrections_made = len(corrections) + correction_ratio = 1 - (corrections_made / total_words if total_words > 0 else 0) + + return CorrectionResult( + original_segments=primary_transcription.segments, + corrected_segments=corrected_segments, + corrected_text="\n".join(segment.text for segment in corrected_segments) + "\n", + corrections=corrections, + corrections_made=corrections_made, + confidence=correction_ratio, + transcribed_text=transcribed_text, + reference_texts=reference_texts, + anchor_sequences=anchor_sequences, + resized_segments=[], + gap_sequences=gap_sequences, + metadata={ + "anchor_sequences_count": len(anchor_sequences), + "gap_sequences_count": len(gap_sequences), + "total_words": total_words, + "correction_ratio": correction_ratio, + }, + ) def _preserve_formatting(self, original: str, new_word: str) -> str: """Preserve original word's formatting when applying correction.""" @@ -172,29 +168,56 @@ def _process_gaps(self, gap_sequences: List[GapSequence]) -> List[WordCorrection def _apply_corrections_to_segments(self, segments: List[LyricsSegment], corrections: List[WordCorrection]) -> List[LyricsSegment]: """Apply corrections to create new segments.""" - correction_map = {c.word_index: c for c in corrections} - corrected_segments = [] + correction_map = {} + # Group corrections by word_index to handle splits + for c in corrections: + if c.word_index not in correction_map: + correction_map[c.word_index] = [] + correction_map[c.word_index].append(c) + corrected_segments = [] current_word_idx = 0 + for segment_idx, segment in enumerate(segments): corrected_words = [] for word in segment.words: if current_word_idx in correction_map: - correction = correction_map[current_word_idx] - if not correction.is_deletion: - corrected_words.append( - Word( - text=self._preserve_formatting(correction.original_word, correction.corrected_word), - start_time=word.start_time, - end_time=word.end_time, - confidence=correction.confidence, + word_corrections = sorted(correction_map[current_word_idx], key=lambda x: x.split_index or 0) + + if any(c.split_total for c in word_corrections): + # Handle word split + total_splits = word_corrections[0].split_total + split_duration = (word.end_time - word.start_time) / total_splits + + for i, correction in enumerate(word_corrections): + start_time = word.start_time + (i * split_duration) + end_time = start_time + split_duration + + corrected_words.append( + Word( + text=self._preserve_formatting(correction.original_word, correction.corrected_word), + start_time=start_time, + end_time=end_time, + confidence=correction.confidence, + ) + ) + else: + # Handle single word replacement + correction = word_corrections[0] + if not correction.is_deletion: + corrected_words.append( + Word( + text=self._preserve_formatting(correction.original_word, correction.corrected_word), + start_time=word.start_time, + end_time=word.end_time, + confidence=correction.confidence, + ) ) - ) else: corrected_words.append(word) current_word_idx += 1 - if corrected_words: # Only create segment if it has words + if corrected_words: corrected_segments.append( LyricsSegment( text=" ".join(w.text for w in corrected_words), @@ -205,26 +228,3 @@ def _apply_corrections_to_segments(self, segments: List[LyricsSegment], correcti ) return corrected_segments - - def _create_fallback_result(self, transcription): - """Create a fallback result when correction fails.""" - return CorrectionResult( - original_segments=transcription.segments, - corrected_segments=transcription.segments, - corrected_text="\n".join(segment.text for segment in transcription.segments) + "\n", - corrections=[], - corrections_made=0, - confidence=1.0, - transcribed_text="\n".join(segment.text for segment in transcription.segments), - reference_texts={}, - anchor_sequences=[], - gap_sequences=[], - resized_segments=[], - metadata={ - "error": "Correction failed, using original transcription", - "anchor_sequences_count": 0, - "gap_sequences_count": 0, - "total_words": sum(len(segment.words) for segment in transcription.segments), - "correction_ratio": 1.0, - }, - ) diff --git a/lyrics_transcriber/correction/handlers/levenshtein.py b/lyrics_transcriber/correction/handlers/levenshtein.py index 1209ab4..2cb7d10 100644 --- a/lyrics_transcriber/correction/handlers/levenshtein.py +++ b/lyrics_transcriber/correction/handlers/levenshtein.py @@ -77,17 +77,19 @@ def handle(self, gap: GapSequence) -> List[WordCorrection]: # Find matching reference words at this position matches = {} # word -> (sources, similarity) for source, ref_words in gap.reference_words.items(): + ref_words_original = gap.reference_words_original[source] # Get original formatted words if i >= len(ref_words): continue ref_word = ref_words[i] + ref_word_original = ref_words_original[i] # Get original formatted word similarity = self._get_string_similarity(word, ref_word) if similarity >= self.similarity_threshold: self.logger.debug(f"Found match: '{word}' -> '{ref_word}' ({similarity:.2f})") - if ref_word not in matches: - matches[ref_word] = ([], similarity) - matches[ref_word][0].append(source) + if ref_word_original not in matches: # Use original formatted word as key + matches[ref_word_original] = ([], similarity) + matches[ref_word_original][0].append(source) # Create correction for best match if any found if matches: @@ -102,7 +104,7 @@ def handle(self, gap: GapSequence) -> List[WordCorrection]: corrections.append( WordCorrection( original_word=word, - corrected_word=best_match, + corrected_word=best_match, # Using original formatted word segment_index=0, word_index=gap.transcription_position + i, confidence=final_confidence, diff --git a/lyrics_transcriber/correction/handlers/no_space_punct_match.py b/lyrics_transcriber/correction/handlers/no_space_punct_match.py index 7452792..e079345 100644 --- a/lyrics_transcriber/correction/handlers/no_space_punct_match.py +++ b/lyrics_transcriber/correction/handlers/no_space_punct_match.py @@ -36,25 +36,30 @@ def handle(self, gap: GapSequence) -> List[WordCorrection]: # Find the matching source (we know there is at least one from can_handle) gap_text = self._remove_spaces_and_punct(gap.words) matching_source = None + reference_words = None + reference_words_original = None for source, words in gap.reference_words.items(): if self._remove_spaces_and_punct(words) == gap_text: matching_source = source + reference_words = words + reference_words_original = gap.reference_words_original[source] break # Since the texts match when spaces and punctuation are removed, - # we'll mark all words as correct but preserve their original form - for i, word in enumerate(gap.words): - corrections.append( - WordCorrection( - original_word=word, - corrected_word=word, # Keep the original word - segment_index=0, # This will be updated when applying corrections - word_index=gap.transcription_position + i, - confidence=1.0, - source=matching_source, - reason=f"NoSpacePunctuationMatchHandler: Source '{matching_source}' matched when spaces and punctuation removed", - alternatives={}, + # we'll replace with the properly formatted reference words + for i, (orig_word, ref_word, ref_word_original) in enumerate(zip(gap.words, reference_words, reference_words_original)): + if orig_word.lower() != ref_word.lower(): + corrections.append( + WordCorrection( + original_word=orig_word, + corrected_word=ref_word_original, # Use the original formatted word + segment_index=0, # This will be updated when applying corrections + word_index=gap.transcription_position + i, + confidence=1.0, + source=matching_source, + reason=f"NoSpacePunctuationMatchHandler: Source '{matching_source}' matched when spaces and punctuation removed", + alternatives={}, + ) ) - ) return corrections diff --git a/lyrics_transcriber/correction/handlers/relaxed_word_count_match.py b/lyrics_transcriber/correction/handlers/relaxed_word_count_match.py index 5bc9bc9..d268db0 100644 --- a/lyrics_transcriber/correction/handlers/relaxed_word_count_match.py +++ b/lyrics_transcriber/correction/handlers/relaxed_word_count_match.py @@ -25,19 +25,21 @@ def handle(self, gap: GapSequence) -> List[WordCorrection]: # Find the first source that has matching word count matching_source = None reference_words = None + reference_words_original = None for source, words in gap.reference_words.items(): if len(words) == gap.length: matching_source = source reference_words = words + reference_words_original = gap.reference_words_original[source] break # Since we found a source with matching word count, we can correct using that source - for i, (orig_word, ref_word) in enumerate(zip(gap.words, reference_words)): + for i, (orig_word, ref_word, ref_word_original) in enumerate(zip(gap.words, reference_words, reference_words_original)): if orig_word.lower() != ref_word.lower(): corrections.append( WordCorrection( original_word=orig_word, - corrected_word=ref_word, + corrected_word=ref_word_original, # Use the original formatted word segment_index=0, # This will be updated when applying corrections word_index=gap.transcription_position + i, confidence=1.0, diff --git a/lyrics_transcriber/correction/handlers/sound_alike.py b/lyrics_transcriber/correction/handlers/sound_alike.py index c16725f..60f0807 100644 --- a/lyrics_transcriber/correction/handlers/sound_alike.py +++ b/lyrics_transcriber/correction/handlers/sound_alike.py @@ -77,7 +77,8 @@ def handle(self, gap: GapSequence) -> List[WordCorrection]: matches: Dict[str, Tuple[List[str], float]] = {} for source, ref_words in gap.reference_words.items(): - for j, ref_word in enumerate(ref_words): + ref_words_original = gap.reference_words_original[source] # Get original formatted words + for j, (ref_word, ref_word_original) in enumerate(zip(ref_words, ref_words_original)): ref_codes = doublemetaphone(ref_word) match_confidence = self._get_match_confidence(word_codes, ref_codes) @@ -89,9 +90,9 @@ def handle(self, gap: GapSequence) -> List[WordCorrection]: adjusted_confidence = match_confidence * position_multiplier if adjusted_confidence >= self.similarity_threshold: - if ref_word not in matches: - matches[ref_word] = ([], adjusted_confidence) - matches[ref_word][0].append(source) + if ref_word_original not in matches: # Use original formatted word as key + matches[ref_word_original] = ([], adjusted_confidence) + matches[ref_word_original][0].append(source) # Create correction for best match if any found if matches: @@ -104,7 +105,7 @@ def handle(self, gap: GapSequence) -> List[WordCorrection]: corrections.append( WordCorrection( original_word=word, - corrected_word=best_match, + corrected_word=best_match, # Already using original formatted word segment_index=0, word_index=gap.transcription_position + i, confidence=final_confidence, diff --git a/lyrics_transcriber/correction/handlers/syllables_match.py b/lyrics_transcriber/correction/handlers/syllables_match.py new file mode 100644 index 0000000..1d0fd4e --- /dev/null +++ b/lyrics_transcriber/correction/handlers/syllables_match.py @@ -0,0 +1,175 @@ +from typing import List +import spacy_syllables +import spacy +import logging +import pyphen +import nltk +from nltk.corpus import cmudict +import syllables + +from lyrics_transcriber.types import GapSequence, WordCorrection +from lyrics_transcriber.correction.handlers.base import GapCorrectionHandler + + +class SyllablesMatchHandler(GapCorrectionHandler): + """Handles gaps where number of syllables in reference text matches number of syllables in transcription.""" + + def __init__(self): + # Load spacy model with syllables pipeline + self.nlp = spacy.load("en_core_web_sm") + # Add syllables component to pipeline if not already present + if "syllables" not in self.nlp.pipe_names: + self.nlp.add_pipe("syllables") + # Initialize Pyphen for English + self.dic = pyphen.Pyphen(lang="en_US") + # Initialize NLTK's CMU dictionary + try: + self.cmudict = cmudict.dict() + except LookupError: + nltk.download("cmudict") + self.cmudict = cmudict.dict() + self.logger = logging.getLogger(__name__) + + def _count_syllables_spacy(self, words: List[str]) -> int: + """Count syllables using spacy_syllables.""" + text = " ".join(words) + doc = self.nlp(text) + total_syllables = sum(token._.syllables_count or 1 for token in doc) + self.logger.debug(f"Spacy syllable count for '{text}': {total_syllables}") + for token in doc: + self.logger.debug(f" Word '{token.text}': {token._.syllables_count or 1} syllables") + return total_syllables + + def _count_syllables_pyphen(self, words: List[str]) -> int: + """Count syllables using pyphen.""" + total_syllables = 0 + for word in words: + # Count hyphens in hyphenated word + 1 to get syllable count + hyphenated = self.dic.inserted(word) + syllables_count = len(hyphenated.split("-")) if hyphenated else 1 + total_syllables += syllables_count + self.logger.debug(f" Pyphen word '{word}': {syllables_count} syllables (hyphenated: {hyphenated})") + self.logger.debug(f"Pyphen syllable count for '{' '.join(words)}': {total_syllables}") + return total_syllables + + def _count_syllables_nltk(self, words: List[str]) -> int: + """Count syllables using NLTK's CMU dictionary.""" + total_syllables = 0 + for word in words: + word = word.lower() + # Try to get pronunciation from CMU dict + if word in self.cmudict: + # Count number of stress markers in first pronunciation + syllables_count = len([ph for ph in self.cmudict[word][0] if ph[-1].isdigit()]) + total_syllables += syllables_count + self.logger.debug(f" NLTK word '{word}': {syllables_count} syllables") + else: + # Fallback to 1 syllable if word not in dictionary + total_syllables += 1 + self.logger.debug(f" NLTK word '{word}': 1 syllable (not in dictionary)") + self.logger.debug(f"NLTK syllable count for '{' '.join(words)}': {total_syllables}") + return total_syllables + + def _count_syllables_lib(self, words: List[str]) -> int: + """Count syllables using the syllables library.""" + total_syllables = 0 + for word in words: + syllables_count = syllables.estimate(word) + total_syllables += syllables_count + self.logger.debug(f" Syllables lib word '{word}': {syllables_count} syllables") + self.logger.debug(f"Syllables lib count for '{' '.join(words)}': {total_syllables}") + return total_syllables + + def _count_syllables(self, words: List[str]) -> List[int]: + """Count syllables using multiple methods.""" + spacy_count = self._count_syllables_spacy(words) + pyphen_count = self._count_syllables_pyphen(words) + nltk_count = self._count_syllables_nltk(words) + syllables_count = self._count_syllables_lib(words) + return [spacy_count, pyphen_count, nltk_count, syllables_count] + + def can_handle(self, gap: GapSequence) -> bool: + # Must have reference words + if not gap.reference_words: + self.logger.debug("No reference words available") + return False + + # Get syllable counts for gap text using different methods + gap_syllables = self._count_syllables(gap.words) + self.logger.debug(f"Gap '{' '.join(gap.words)}' has syllable counts: {gap_syllables}") + + # Check if any reference source has matching syllable count with any method + for source, words in gap.reference_words.items(): + ref_syllables = self._count_syllables(words) + self.logger.debug(f"Reference source '{source}' has syllable counts: {ref_syllables}") + + # If any counting method matches between gap and reference, we can handle it + if any(gap_count == ref_count for gap_count in gap_syllables for ref_count in ref_syllables): + self.logger.debug(f"Found matching syllable count in source '{source}'") + return True + + self.logger.debug("No reference source had matching syllable count") + return False + + def handle(self, gap: GapSequence) -> List[WordCorrection]: + corrections = [] + + # Find the matching source + gap_syllables = self._count_syllables(gap.words) + matching_source = None + reference_words = None + reference_words_original = None + + for source, words in gap.reference_words.items(): + ref_syllables = self._count_syllables(words) + if any(gap_count == ref_count for gap_count in gap_syllables for ref_count in ref_syllables): + matching_source = source + reference_words = words + reference_words_original = gap.reference_words_original[source] # Get original formatted words + break + + # Handle word splits (one transcribed word -> multiple reference words) + if len(gap.words) < len(reference_words): + # Simple case: distribute reference words evenly across gap words + words_per_gap = len(reference_words) / len(gap.words) + + for i, orig_word in enumerate(gap.words): + start_idx = int(i * words_per_gap) + end_idx = int((i + 1) * words_per_gap) + ref_words_for_orig = reference_words[start_idx:end_idx] + ref_words_original_for_orig = reference_words_original[start_idx:end_idx] # Get original formatted words + + # Create a correction for each reference word + for split_idx, (ref_word, ref_word_original) in enumerate(zip(ref_words_for_orig, ref_words_original_for_orig)): + corrections.append( + WordCorrection( + original_word=orig_word, + corrected_word=ref_word_original, # Use original formatted word + segment_index=0, + word_index=gap.transcription_position + i, + confidence=0.8, + source=matching_source, + reason=f"SyllablesMatchHandler: Split word based on syllable match", + alternatives={}, + split_index=split_idx, + split_total=len(ref_words_for_orig), + ) + ) + else: + # One-to-one replacement + for i, (orig_word, ref_word, ref_word_original) in enumerate(zip(gap.words, reference_words, reference_words_original)): + if orig_word.lower() != ref_word.lower(): + corrections.append( + WordCorrection( + original_word=orig_word, + corrected_word=ref_word_original, # Use original formatted word + segment_index=0, + word_index=gap.transcription_position + i, + confidence=0.8, + source=matching_source, + reason=f"SyllablesMatchHandler: Source '{matching_source}' had matching syllable count", + alternatives={}, + ) + ) + + return corrections diff --git a/lyrics_transcriber/correction/handlers/word_count_match.py b/lyrics_transcriber/correction/handlers/word_count_match.py index af7b38f..6f1fc61 100644 --- a/lyrics_transcriber/correction/handlers/word_count_match.py +++ b/lyrics_transcriber/correction/handlers/word_count_match.py @@ -26,16 +26,19 @@ def can_handle(self, gap: GapSequence) -> bool: def handle(self, gap: GapSequence) -> List[WordCorrection]: corrections = [] - reference_words = list(gap.reference_words.values())[0] + # Get both clean and original reference words from first source + source = list(gap.reference_words.keys())[0] + reference_words = gap.reference_words[source] + reference_words_original = gap.reference_words_original[source] sources = ", ".join(gap.reference_words.keys()) # Since we know all reference sources agree, we can correct all words in the gap - for i, (orig_word, ref_word) in enumerate(zip(gap.words, reference_words)): + for i, (orig_word, ref_word, ref_word_original) in enumerate(zip(gap.words, reference_words, reference_words_original)): if orig_word.lower() != ref_word.lower(): corrections.append( WordCorrection( original_word=orig_word, - corrected_word=ref_word, + corrected_word=ref_word_original, # Use the original formatted word segment_index=0, # This will be updated when applying corrections word_index=gap.transcription_position + i, confidence=1.0, diff --git a/lyrics_transcriber/types.py b/lyrics_transcriber/types.py index 93cfbda..62cd113 100644 --- a/lyrics_transcriber/types.py +++ b/lyrics_transcriber/types.py @@ -99,6 +99,9 @@ class WordCorrection: reason: str # e.g., "matched_in_3_sources", "high_confidence_match" alternatives: Dict[str, int] # Other possible corrections and their occurrence counts is_deletion: bool = False # New field to explicitly mark deletions + # New fields for handling word splits + split_index: Optional[int] = None # Position in the split sequence (0-based) + split_total: Optional[int] = None # Total number of words in split def to_dict(self) -> Dict[str, Any]: return asdict(self) @@ -261,6 +264,7 @@ class GapSequence: preceding_anchor: Optional[AnchorSequence] following_anchor: Optional[AnchorSequence] reference_words: Dict[str, List[str]] + reference_words_original: Dict[str, List[str]] # New field for formatted reference words corrections: List[WordCorrection] = field(default_factory=list) _corrected_positions: Set[int] = field(default_factory=set, repr=False) @@ -323,6 +327,7 @@ def to_dict(self) -> Dict[str, Any]: "preceding_anchor": self.preceding_anchor.to_dict() if self.preceding_anchor else None, "following_anchor": self.following_anchor.to_dict() if self.following_anchor else None, "reference_words": self.reference_words, + "reference_words_original": self.reference_words_original, "corrections": [c.to_dict() for c in self.corrections], } @@ -335,6 +340,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "GapSequence": preceding_anchor=AnchorSequence.from_dict(data["preceding_anchor"]) if data["preceding_anchor"] else None, following_anchor=AnchorSequence.from_dict(data["following_anchor"]) if data["following_anchor"] else None, reference_words=data["reference_words"], + reference_words_original=data.get("reference_words_original", {}), ) # Add any corrections from the data if "corrections" in data: diff --git a/poetry.lock b/poetry.lock index 0097e85..fbcc889 100644 --- a/poetry.lock +++ b/poetry.lock @@ -294,6 +294,21 @@ azure = ["azure-storage-blob (>=12)", "azure-storage-file-datalake (>=12)"] gs = ["google-cloud-storage"] s3 = ["boto3 (>=1.34.0)"] +[[package]] +name = "cmudict" +version = "1.0.32" +description = "A versioned python wrapper package for The CMU Pronouncing Dictionary data files." +optional = false +python-versions = "<4.0,>=3.8" +files = [ + {file = "cmudict-1.0.32-py3-none-any.whl", hash = "sha256:b9323664d49d128193c480ec97a3270ab2162469289bb26e950d13b2ef661c41"}, + {file = "cmudict-1.0.32.tar.gz", hash = "sha256:e84a587bb610b3a837a93f07494e874860cf205ea7f23db652b871093a699f38"}, +] + +[package.dependencies] +importlib-metadata = ">=5" +importlib-resources = ">=5" + [[package]] name = "colorama" version = "0.4.6" @@ -576,6 +591,47 @@ files = [ [package.extras] all = ["flake8 (>=7.1.1)", "mypy (>=1.11.2)", "pytest (>=8.3.2)", "ruff (>=0.6.2)"] +[[package]] +name = "importlib-metadata" +version = "6.11.0" +description = "Read metadata from Python packages" +optional = false +python-versions = ">=3.8" +files = [ + {file = "importlib_metadata-6.11.0-py3-none-any.whl", hash = "sha256:f0afba6205ad8f8947c7d338b5342d5db2afbfd82f9cbef7879a9539cc12eb9b"}, + {file = "importlib_metadata-6.11.0.tar.gz", hash = "sha256:1231cf92d825c9e03cfc4da076a16de6422c863558229ea0b22b675657463443"}, +] + +[package.dependencies] +zipp = ">=0.5" + +[package.extras] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"] +perf = ["ipython"] +testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)", "pytest-ruff"] + +[[package]] +name = "importlib-resources" +version = "6.5.2" +description = "Read resources from Python packages" +optional = false +python-versions = ">=3.9" +files = [ + {file = "importlib_resources-6.5.2-py3-none-any.whl", hash = "sha256:789cfdc3ed28c78b67a06acb8126751ced69a3d5f79c095a98298cd8a760ccec"}, + {file = "importlib_resources-6.5.2.tar.gz", hash = "sha256:185f87adef5bcc288449d98fb4fba07cea78bc036455dd44c5fc4a2fe78fed2c"}, +] + +[package.dependencies] +zipp = {version = ">=3.1.0", markers = "python_version < \"3.10\""} + +[package.extras] +check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)"] +cover = ["pytest-cov"] +doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] +enabler = ["pytest-enabler (>=2.2)"] +test = ["jaraco.test (>=5.4)", "pytest (>=6,!=8.1.*)", "zipp (>=3.17)"] +type = ["pytest-mypy"] + [[package]] name = "iniconfig" version = "2.0.0" @@ -1739,6 +1795,21 @@ files = [ {file = "pyperclip-1.9.0.tar.gz", hash = "sha256:b7de0142ddc81bfc5c7507eea19da920b92252b548b96186caf94a5e2527d310"}, ] +[[package]] +name = "pyphen" +version = "0.17.0" +description = "Pure Python module to hyphenate text" +optional = false +python-versions = ">=3.9" +files = [ + {file = "pyphen-0.17.0-py3-none-any.whl", hash = "sha256:dad0b2e4aa80f6d70bf06711b2da36c47a756b933c1d0c4cbbab40f643a5958c"}, + {file = "pyphen-0.17.0.tar.gz", hash = "sha256:1d13acd1ce37a384d7612954ae6c7801bb4c5316da0e2b937b2127ba702a3da4"}, +] + +[package.extras] +doc = ["sphinx", "sphinx_rtd_theme"] +test = ["pytest", "ruff"] + [[package]] name = "pytest" version = "8.3.4" @@ -2394,6 +2465,24 @@ files = [ {file = "spacy_loggers-1.0.5-py3-none-any.whl", hash = "sha256:196284c9c446cc0cdb944005384270d775fdeaf4f494d8e269466cfa497ef645"}, ] +[[package]] +name = "spacy-syllables" +version = "3.0.2" +description = "spacy pipeline component for syllables" +optional = false +python-versions = ">=3.7" +files = [ + {file = "spacy_syllables-3.0.2-py3-none-any.whl", hash = "sha256:0c67cfc086624c643f510bb05c53c93c323de4357761b500ce8d9e48942618ed"}, + {file = "spacy_syllables-3.0.2.tar.gz", hash = "sha256:1f45a8307382daa0c65d32a996d84bd5dd90552f42e675f721342c35ba3d032b"}, +] + +[package.dependencies] +pyphen = ">=0.10.0" +spacy = ">=3.0.3" + +[package.extras] +dev = ["black (>=23.1.0)", "pytest"] + [[package]] name = "spotipy" version = "2.24.0" @@ -2479,6 +2568,21 @@ files = [ {file = "striprtf-0.0.28.tar.gz", hash = "sha256:902806a2e0821faf412130450bdbb84f15e996a729061a51fe7286c620b6fee3"}, ] +[[package]] +name = "syllables" +version = "1.0.9" +description = "A Python package for estimating the number of syllables in a word." +optional = false +python-versions = ">=3.7.2,<4.0.0" +files = [ + {file = "syllables-1.0.9-py3-none-any.whl", hash = "sha256:341d1e5dd396589d385a8c462ea483081d344fa5652e02d29d1047a342c88d9b"}, + {file = "syllables-1.0.9.tar.gz", hash = "sha256:e73be37d7420bd94cae1ec5511dc6392e1305fe0837a89def9bcabad27f91f6f"}, +] + +[package.dependencies] +cmudict = ">=1.0.11,<2.0.0" +importlib-metadata = ">=5.1,<7.0" + [[package]] name = "sympy" version = "1.13.1" @@ -3000,7 +3104,26 @@ files = [ {file = "wrapt-1.17.0.tar.gz", hash = "sha256:16187aa2317c731170a88ef35e8937ae0f533c402872c1ee5e6d079fcf320801"}, ] +[[package]] +name = "zipp" +version = "3.21.0" +description = "Backport of pathlib-compatible object wrapper for zip files" +optional = false +python-versions = ">=3.9" +files = [ + {file = "zipp-3.21.0-py3-none-any.whl", hash = "sha256:ac1bbe05fd2991f160ebce24ffbac5f6d11d83dc90891255885223d42b3cd931"}, + {file = "zipp-3.21.0.tar.gz", hash = "sha256:2c9958f6430a2040341a52eb608ed6dd93ef4392e02ffe219417c1b28b5dd1f4"}, +] + +[package.extras] +check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)"] +cover = ["pytest-cov"] +doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] +enabler = ["pytest-enabler (>=2.2)"] +test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-itertools", "pytest (>=6,!=8.1.*)", "pytest-ignore-flaky"] +type = ["pytest-mypy"] + [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13" -content-hash = "9e01fa5163b169536171e64f849e1ac83162968e6c577aa286cc8e31fb34a10c" +content-hash = "bd9573a136378da98cf43013a212f6ec3debd1ccef04b8777030dfc99a3c1fff" diff --git a/pyproject.toml b/pyproject.toml index d702fb8..278fa59 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,8 @@ transformers = "^4.47.1" torch = "^2.5.1" metaphone = "^0.6" nltk = "^3.9.1" +spacy-syllables = "^3.0.2" +syllables = "^1.0.9" [tool.poetry.group.dev.dependencies] black = ">=23"