Skip to content

Commit

Permalink
Improved karaoke video generation with line splitting, max 4 lines pe…
Browse files Browse the repository at this point in the history
…r screen and centered output. Added initial stub for audioshake transcription
  • Loading branch information
beveradb committed Jul 13, 2024
1 parent a15b68d commit f47c593
Show file tree
Hide file tree
Showing 5 changed files with 340 additions and 168 deletions.
35 changes: 35 additions & 0 deletions lyrics_transcriber/audioshake_transcriber.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import logging
import requests


class AudioShakeTranscriber:
def __init__(self, api_token, log_level=logging.DEBUG):
self.api_token = api_token
self.logger = logging.getLogger(__name__)
self.logger.setLevel(log_level)

def transcribe(self, audio_filepath):
# This is a placeholder for the actual AudioShake API implementation
self.logger.info(f"Transcribing {audio_filepath} using AudioShake API")

self.logger.debug(f"AudioShake API token: {self.api_token}")
# TODO: Implement the actual API call to AudioShake
# For now, we'll return a dummy result
return {
"transcription_data_dict": {
"segments": [
{
"start": 0,
"end": 5,
"text": "This is a dummy transcription",
"words": [
{"text": "This", "start": 0, "end": 1},
{"text": "is", "start": 1, "end": 2},
{"text": "a", "start": 2, "end": 3},
{"text": "dummy", "start": 3, "end": 4},
{"text": "transcription", "start": 4, "end": 5},
],
}
]
}
}
171 changes: 125 additions & 46 deletions lyrics_transcriber/transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(
audio_filepath,
artist=None,
title=None,
audioshake_api_token=None,
genius_api_token=None,
spotify_cookie=None,
output_dir=None,
Expand Down Expand Up @@ -61,6 +62,7 @@ def __init__(

self.genius_api_token = os.getenv("GENIUS_API_TOKEN", default=genius_api_token)
self.spotify_cookie = os.getenv("SPOTIFY_COOKIE_SP_DC", default=spotify_cookie)
self.audioshake_api_token = os.getenv("AUDIOSHAKE_TOKEN", default=audioshake_api_token)

self.transcription_model = transcription_model
self.llm_model = llm_model
Expand Down Expand Up @@ -583,51 +585,57 @@ def write_midico_lrc_file(self):
f.write(line)

def create_screens(self):
self.logger.debug(f"create_screens beginning generation of screens from whisper results")
self.logger.debug("create_screens beginning generation of screens from whisper results")
screens: List[subtitles.LyricsScreen] = []
line: Optional[subtitles.LyricsLine] = None
screen: Optional[subtitles.LyricsScreen] = None

lines_in_current_screen = 0
max_lines_per_screen = 4
max_line_length = 36 # Maximum characters per line
self.logger.debug(f"Max lines per screen: {max_lines_per_screen}, Max line length: {max_line_length}")

for segment in self.outputs["corrected_lyrics_data_dict"]["segments"]:
self.logger.debug(f"lines_in_current_screen: {lines_in_current_screen} segment: {segment['text']}")
if screen is None:
self.logger.debug(f"screen is none, creating new LyricsScreen")
screen = subtitles.LyricsScreen()
screen.video_size = self.video_resolution_num
screen.line_height = self.line_height
if line is None:
self.logger.debug(f"line is none, creating new LyricsLine")
line = subtitles.LyricsLine()

num_words_in_segment = len(segment["words"])
for word_index, word in enumerate(segment["words"]):
segment = subtitles.LyricSegment(
self.logger.debug(f"Processing segment: {segment['text']}")
if screen is None or len(screen.lines) >= max_lines_per_screen:
screen = subtitles.LyricsScreen(video_size=self.video_resolution_num, line_height=self.line_height, logger=self.logger)
screens.append(screen)
self.logger.debug(f"Created new screen. Total screens: {len(screens)}")

words = segment["words"]
current_line = subtitles.LyricsLine()
current_line_text = ""
self.logger.debug(f"Processing {len(words)} words in segment")

for word in words:
self.logger.debug(f"Processing word: '{word['text']}'")
if len(current_line_text) + len(word["text"]) + 1 > max_line_length:
self.logger.debug(f"Current line would exceed max length. Line: '{current_line_text}'")
if current_line.segments:
screen.lines.append(current_line)
self.logger.debug(f"Added line to screen. Lines on current screen: {len(screen.lines)}")
if len(screen.lines) >= max_lines_per_screen:
screen = subtitles.LyricsScreen(
video_size=self.video_resolution_num,
line_height=self.line_height,
logger=self.logger, # Pass the logger here
)
screens.append(screen)
self.logger.debug(f"Screen full, created new screen. Total screens: {len(screens)}")
current_line = subtitles.LyricsLine()
current_line_text = ""
self.logger.debug("Reset current line")

current_line_text += (" " if current_line_text else "") + word["text"]
lyric_segment = subtitles.LyricSegment(
text=word["text"], ts=timedelta(seconds=word["start"]), end_ts=timedelta(seconds=word["end"])
)
line.segments.append(segment)

# If word is last in the line, add line to screen and start new line
# Before looping to the next word
if word_index == num_words_in_segment - 1:
self.logger.debug(f"word_index is last in segment, adding line to screen and starting new line")
screen.lines.append(line)
lines_in_current_screen += 1
line = None

# If current screen has 2 lines already, add screen to list and start new screen
# Before looping to the next line
if lines_in_current_screen == 2:
self.logger.debug(f"lines_in_current_screen is 2, adding screen to list and starting new screen")
screens.append(screen)
screen = None
lines_in_current_screen = 0
current_line.segments.append(lyric_segment)
self.logger.debug(f"Added word to current line. Current line: '{current_line_text}'")

if line is not None:
screen.lines.append(line) # type: ignore[union-attr]
if screen is not None and len(screen.lines) > 0:
screens.append(screen) # type: ignore[arg-type]
if current_line.segments:
screen.lines.append(current_line)
self.logger.debug(f"Added final line of segment to screen. Lines on current screen: {len(screen.lines)}")

self.logger.debug(f"Finished creating screens. Total screens created: {len(screens)}")
return screens

def write_ass_file(self):
Expand Down Expand Up @@ -760,7 +768,10 @@ def format_time_lrc(self, duration):

def write_transcribed_lyrics_plain_text(self):
if self.outputs["transcription_data_dict"]:
transcribed_lyrics_text_filepath = os.path.join(self.cache_dir, "lyrics-" + self.get_song_slug() + "-transcribed.txt")
transcription_cache_suffix = "-audioshake-transcribed.txt" if self.audioshake_api_token else "-whisper-transcribed.txt"
self.logger.debug(f"transcription_cache_suffix: {transcription_cache_suffix}")

transcribed_lyrics_text_filepath = os.path.join(self.cache_dir, "lyrics-" + self.get_song_slug() + transcription_cache_suffix)
self.outputs["transcribed_lyrics_text_filepath"] = transcribed_lyrics_text_filepath

self.outputs["transcribed_lyrics_text"] = ""
Expand All @@ -773,8 +784,68 @@ def write_transcribed_lyrics_plain_text(self):
else:
raise Exception("Cannot write transcribed lyrics plain text as transcription_data_dict is not set")

def split_long_segments(self, segments, max_length, use_space=True):
new_segments = []
for segment in segments:
text = segment["text"]
if len(text) <= max_length:
new_segments.append(segment)
else:
meta_words = segment["words"]
# Note: we do this in case punctuation were removed from words
if use_space:
# Split text around spaces and punctuations (keeping punctuations)
words = text.split()
else:
words = [w["text"] for w in meta_words]
if len(words) != len(meta_words):
new_words = [w["text"] for w in meta_words]
print(f"WARNING: {' '.join(words)} != {' '.join(new_words)}")
words = new_words
current_text = ""
current_start = segment["start"]
current_best_idx = None
current_best_end = None
current_best_next_start = None
for i, (word, meta) in enumerate(zip(words, meta_words)):
current_text_before = current_text
if current_text and use_space:
current_text += " "
current_text += word

if len(current_text) > max_length and len(current_text_before):
start = current_start
if current_best_idx is not None:
text = current_text[:current_best_idx]
end = current_best_end
current_text = current_text[current_best_idx + 1 :]
current_start = current_best_next_start
else:
text = current_text_before
end = meta_words[i - 1]["end"]
current_text = word
current_start = meta["start"]

current_best_idx = None
current_best_end = None
current_best_next_start = None

new_segments.append({"text": text, "start": start, "end": end})

# Try to cut after punctuation
if current_text and current_text[-1] in _punctuation:
current_best_idx = len(current_text)
current_best_end = meta["end"]
current_best_next_start = meta_words[i + 1]["start"] if i + 1 < len(meta_words) else None

if len(current_text):
new_segments.append({"text": current_text, "start": current_start, "end": segment["end"]})

return new_segments

def transcribe(self):
self.outputs["transcription_data_filepath"] = self.get_cache_filepath(".json")
transcription_cache_suffix = "-audioshake" if self.audioshake_api_token else "-whisper"
self.outputs["transcription_data_filepath"] = self.get_cache_filepath(f"{transcription_cache_suffix}.json")

whisper_cache_filepath = self.outputs["transcription_data_filepath"]
if os.path.isfile(whisper_cache_filepath):
Expand All @@ -783,15 +854,23 @@ def transcribe(self):
self.outputs["transcription_data_dict"] = json.load(cache_file)
return

self.logger.debug(f"no cached transcription file found, running whisper transcribe with model: {self.transcription_model}")
audio = whisper.load_audio(self.audio_filepath)
model = whisper.load_model(self.transcription_model, device="cpu")
result = whisper.transcribe(model, audio, language="en", vad="auditok", beam_size=5, temperature=0.2, best_of=5)
if self.audioshake_api_token:
self.logger.debug(f"Using AudioShake API for transcription")
from .audioshake_transcriber import AudioShakeTranscriber

audioshake = AudioShakeTranscriber(self.audioshake_api_token, log_level=self.log_level)
result = audioshake.transcribe(self.audio_filepath)
else:
self.logger.debug(f"Using Whisper for transcription with model: {self.transcription_model}")
audio = whisper.load_audio(self.audio_filepath)
model = whisper.load_model(self.transcription_model, device="cpu")
result = whisper.transcribe(model, audio, language="en", vad="auditok", beam_size=5, temperature=0.2, best_of=5)

self.logger.debug(f"transcription complete, performing post-processing cleanup")
# Remove segments with no words, only music
result["segments"] = [segment for segment in result["segments"] if segment["text"].strip() != "Music"]

# Remove segments with no words, only music
result["segments"] = [segment for segment in result["segments"] if segment["text"].strip() != "Music"]
# Split long segments
result["segments"] = self.split_long_segments(result["segments"], max_length=36)

self.logger.debug(f"writing transcription data JSON to cache file: {whisper_cache_filepath}")
with open(whisper_cache_filepath, "w") as cache_file:
Expand Down
8 changes: 7 additions & 1 deletion lyrics_transcriber/utils/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ def main():
default=None,
help="Optional: song title for lyrics lookup and auto-correction",
)
parser.add_argument(
"--audioshake_api_token",
default=None,
help="Optional: AudioShake API token for lyrics transcription and alignment. Can also be set with AUDIOSHAKE_API_TOKEN env var.",
)
parser.add_argument(
"--genius_api_token",
default=None,
Expand Down Expand Up @@ -77,7 +82,7 @@ def main():

parser.add_argument(
"--video_resolution",
default="4k",
default="360p",
help="Optional: resolution of the karaoke video to render. Must be one of: 4k, 1080p, 720p, 360p. Default: 360p",
)

Expand Down Expand Up @@ -114,6 +119,7 @@ def main():

transcriber = LyricsTranscriber(
args.audio_filepath,
audioshake_api_token=args.audioshake_api_token,
genius_api_token=args.genius_api_token,
spotify_cookie=args.spotify_cookie,
artist=args.artist,
Expand Down
46 changes: 35 additions & 11 deletions lyrics_transcriber/utils/subtitles.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import itertools
from pathlib import Path
from enum import IntEnum
import logging

from . import ass

Expand Down Expand Up @@ -85,21 +86,19 @@ def end_ts(self, value):
def __str__(self):
return "".join([f"{{{s.text}}}" for s in self.segments])

def as_ass_event(
self,
screen_start: timedelta,
screen_end: timedelta,
style: ass.ASS.Style,
top_margin: int,
):
def as_ass_event(self, screen_start: timedelta, screen_end: timedelta, style: ass.ASS.Style, y_position: int):
e = ass.ASS.Event()
e.type = "Dialogue"
e.Layer = 0
e.Style = style
e.Start = screen_start.total_seconds()
e.End = screen_end.total_seconds()
e.MarginV = top_margin
e.MarginV = y_position
e.Text = self.decorate_ass_line(self.segments, screen_start)

# Set alignment to top-center
e.Text = "{\\an8}" + e.Text

return e

def decorate_ass_line(self, segments, screen_start_ts: timedelta):
Expand Down Expand Up @@ -137,6 +136,7 @@ class LyricsScreen:
start_ts: Optional[timedelta] = None
video_size: Tuple[int, int] = None
line_height: int = None
logger: logging.Logger = None

@property
def end_ts(self) -> timedelta:
Expand All @@ -145,10 +145,34 @@ def end_ts(self) -> timedelta:
def get_line_y(self, line_num: int) -> int:
_, h = self.video_size
line_count = len(self.lines)
return (h / 2) - (line_count * self.line_height / 2) + (line_num * self.line_height)
total_height = line_count * self.line_height

# Calculate the top margin to center the lyrics block
top_margin = (h - total_height) / 2

# Calculate the y-position for this specific line
line_y = top_margin + (line_num * self.line_height)

if self.logger:
self.logger.debug(f"Line {line_num + 1} positioning:")
self.logger.debug(f" Video height: {h}")
self.logger.debug(f" Total lines: {line_count}")
self.logger.debug(f" Line height: {self.line_height}")
self.logger.debug(f" Total lyrics height: {total_height}")
self.logger.debug(f" Top margin: {top_margin}")
self.logger.debug(f" Line y: {line_y}")

return int(line_y)

def as_ass_events(self, style: ass.ASS.Style) -> List[ass.ASS.Event]:
return [line.as_ass_event(self.start_ts, self.end_ts, style, self.get_line_y(i)) for i, line in enumerate(self.lines)]
events = []
for i, line in enumerate(self.lines):
y_position = self.get_line_y(i)
if self.logger:
self.logger.debug(f"Creating ASS event for line {i + 1} at y-position: {y_position}")
event = line.as_ass_event(self.start_ts, self.end_ts, style, y_position)
events.append(event)
return events

def __str__(self):
lines = [f"{self.start_ts} - {self.end_ts}:"]
Expand Down Expand Up @@ -264,7 +288,7 @@ def create_styled_subtitles(
style.BorderStyle = 1
style.Outline = 1
style.Shadow = 0
style.Alignment = ass.ASS.ALIGN_MIDDLE_CENTER
style.Alignment = ass.ASS.ALIGN_TOP_CENTER
style.MarginL = 0
style.MarginR = 0
style.MarginV = 0
Expand Down
Loading

0 comments on commit f47c593

Please sign in to comment.