Skip to content

Commit

Permalink
Simplified interface, improved logging pass through to dependencies s…
Browse files Browse the repository at this point in the history
…o log level can be set in one place and have consistent downstream effects, added triton removal patch, added github workflow
  • Loading branch information
beveradb committed Jul 7, 2023
1 parent f5bb463 commit 943659f
Show file tree
Hide file tree
Showing 6 changed files with 249 additions and 168 deletions.
41 changes: 41 additions & 0 deletions .github/removetriton.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
1041d1040
< triton = "2.0.0"
1713,1739d1711
<
< [[package]]
< name = "triton"
< version = "2.0.0"
< description = "A language and compiler for custom Deep Learning operations"
< optional = false
< python-versions = "*"
< files = [
< {file = "triton-2.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:38806ee9663f4b0f7cd64790e96c579374089e58f49aac4a6608121aa55e2505"},
< {file = "triton-2.0.0-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:226941c7b8595219ddef59a1fdb821e8c744289a132415ddd584facedeb475b1"},
< {file = "triton-2.0.0-1-cp36-cp36m-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4c9fc8c89874bc48eb7e7b2107a9b8d2c0bf139778637be5bfccb09191685cfd"},
< {file = "triton-2.0.0-1-cp37-cp37m-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:d2684b6a60b9f174f447f36f933e9a45f31db96cb723723ecd2dcfd1c57b778b"},
< {file = "triton-2.0.0-1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:9d4978298b74fcf59a75fe71e535c092b023088933b2f1df933ec32615e4beef"},
< {file = "triton-2.0.0-1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:74f118c12b437fb2ca25e1a04759173b517582fcf4c7be11913316c764213656"},
< {file = "triton-2.0.0-1-pp37-pypy37_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:9618815a8da1d9157514f08f855d9e9ff92e329cd81c0305003eb9ec25cc5add"},
< {file = "triton-2.0.0-1-pp38-pypy38_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1aca3303629cd3136375b82cb9921727f804e47ebee27b2677fef23005c3851a"},
< {file = "triton-2.0.0-1-pp39-pypy39_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e3e13aa8b527c9b642e3a9defcc0fbd8ffbe1c80d8ac8c15a01692478dc64d8a"},
< {file = "triton-2.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f05a7e64e4ca0565535e3d5d3405d7e49f9d308505bb7773d21fb26a4c008c2"},
< {file = "triton-2.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bb4b99ca3c6844066e516658541d876c28a5f6e3a852286bbc97ad57134827fd"},
< {file = "triton-2.0.0-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47b4d70dc92fb40af553b4460492c31dc7d3a114a979ffb7a5cdedb7eb546c08"},
< {file = "triton-2.0.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fedce6a381901b1547e0e7e1f2546e4f65dca6d91e2d8a7305a2d1f5551895be"},
< {file = "triton-2.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75834f27926eab6c7f00ce73aaf1ab5bfb9bec6eb57ab7c0bfc0a23fac803b4c"},
< {file = "triton-2.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0117722f8c2b579cd429e0bee80f7731ae05f63fe8e9414acd9a679885fcbf42"},
< {file = "triton-2.0.0-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bcd9be5d0c2e45d2b7e6ddc6da20112b6862d69741576f9c3dbaf941d745ecae"},
< {file = "triton-2.0.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:42a0d2c3fc2eab4ba71384f2e785fbfd47aa41ae05fa58bf12cb31dcbd0aeceb"},
< {file = "triton-2.0.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:52c47b72c72693198163ece9d90a721299e4fb3b8e24fd13141e384ad952724f"},
< ]
1741,1750d1712
< [package.dependencies]
< cmake = "*"
< filelock = "*"
< lit = "*"
< torch = "*"
<
< [package.extras]
< tests = ["autopep8", "flake8", "isort", "numpy", "pytest", "scipy (>=1.7.1)"]
< tutorials = ["matplotlib", "pandas", "tabulate"]
<
19 changes: 19 additions & 0 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
name: Auto-publish

on: [push, workflow_dispatch]

jobs:
# Auto-publish when version is increased
publish-job:
# Only publish on `main` branch
if: github.ref == 'refs/heads/main'
runs-on: ubuntu-latest
permissions: # Don't forget permissions
contents: write

steps:
- uses: etils-actions/pypi-auto-publish@v1
with:
pypi-token: ${{ secrets.PYPI_API_TOKEN }}
gh-token: ${{ secrets.GITHUB_TOKEN }}
parse-changelog: false
75 changes: 43 additions & 32 deletions karaoke_generator/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,60 +16,67 @@
class KaraokeGenerator:
def __init__(
self,
youtube_url=None,
audio_file=None,
song_artist=None,
song_title=None,
log_level=logging.DEBUG,
log_formatter=None,
input_path=None,
artist=None,
title=None,
genius_api_token=None,
spotify_cookie=None,
model_name="UVR_MDXNET_KARA_2",
model_file_dir="/tmp/audio-separator-models",
cache_dir="/tmp/karaoke-generator-cache",
output_dir=None,
log_level=logging.DEBUG,
log_format="%(asctime)s - %(levelname)s - %(module)s - %(message)s",
):
self.logger = logging.getLogger(__name__)
self.logger.setLevel(log_level)
self.log_level = log_level
self.log_formatter = log_formatter

self.log_handler = logging.StreamHandler()

if self.log_formatter is None:
self.log_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(module)s - %(message)s")

log_handler = logging.StreamHandler()
log_formatter = logging.Formatter(log_format)
log_handler.setFormatter(log_formatter)
self.logger.addHandler(log_handler)
self.log_handler.setFormatter(self.log_formatter)
self.logger.addHandler(self.log_handler)

self.logger.debug("KaraokeGenerator initializing")

self.model_name = model_name
self.model_file_dir = model_file_dir
self.cache_dir = cache_dir
self.output_dir = output_dir
self.input_path = input_path
self.artist = artist
self.title = title

self.genius_api_token = os.getenv("GENIUS_API_TOKEN", default=genius_api_token)
self.spotify_cookie = os.getenv("SPOTIFY_COOKIE_SP_DC", default=spotify_cookie)

self.song_artist = song_artist
self.song_title = song_title
self.model_name = model_name
self.model_file_dir = model_file_dir
self.cache_dir = cache_dir
self.output_dir = output_dir

if audio_file is None and youtube_url is None:
raise Exception("Either audio_file or youtube_url must be specified as the input source")
if audio_file is not None and youtube_url is not None:
raise Exception("Only one of audio_file or youtube_url may be specified as the input source")
self.audio_file = None
self.youtube_url = None

if audio_file is not None:
self.input_source_slug = slugify.slugify(os.path.basename(audio_file), lowercase=False)
if youtube_url is not None:
parsed_url = urllib.parse.urlparse(youtube_url)
parsed_url = urllib.parse.urlparse(self.input_path)
if parsed_url.scheme and parsed_url.netloc:
self.youtube_url = self.input_path
self.input_source_slug = slugify.slugify(parsed_url.hostname + "-" + parsed_url.query, lowercase=False)
self.logger.debug(f"Input path was valid URL, set youtube_url and input_source_slug: {self.input_source_slug}")
elif os.path.exists(self.input_path):
self.audio_file = self.input_path
self.input_source_slug = slugify.slugify(os.path.basename(self.audio_file), lowercase=False)
self.logger.debug(f"Input path was valid file path, set audio_file and input_source_slug: {self.input_source_slug}")
else:
raise Exception("Input path must be either a valid file path or URL")

if self.output_dir is None:
self.output_dir = os.path.join(os.getcwd(), "karaoke-generator-output-" + self.input_source_slug)

self.output_filename_slug = None
self.youtube_url = youtube_url
self.youtube_video_file = None
self.youtube_video_image_file = None

self.audio_file = audio_file
self.primary_stem_path = None
self.secondary_stem_path = None

Expand All @@ -95,10 +102,12 @@ def transcribe_lyrics(self):
self.audio_file,
genius_api_token=self.genius_api_token,
spotify_cookie=self.spotify_cookie,
song_artist=self.song_artist,
song_title=self.song_title,
artist=self.artist,
title=self.title,
output_dir=self.output_dir,
cache_dir=self.cache_dir,
log_formatter=self.log_formatter,
log_level=self.log_level,
)

transcription_metadata = transcriber.generate()
Expand Down Expand Up @@ -141,6 +150,8 @@ def separate_audio(self):
model_name=self.model_name,
model_file_dir=self.model_file_dir,
output_dir=self.output_dir,
log_formatter=self.log_formatter,
log_level=self.log_level,
)
self.primary_stem_path, self.secondary_stem_path = separator.separate()

Expand Down Expand Up @@ -204,17 +215,17 @@ def download_youtube_video(self):
self.youtube_video_file = youtube_info["download_filepath"]
self.logger.debug(f"successfully downloaded youtube video to path: {self.youtube_video_file}")

if self.song_title is None:
if self.title is None:
self.logger.debug(f"Song title not specified, attempting to split from YouTube title: {youtube_info['title']}")
# Define the hyphen variations pattern
hyphen_pattern = regex.compile(r" [^[:ascii:]-_\p{Dash}] ")
# Split the string using the hyphen variations pattern
title_parts = hyphen_pattern.split(youtube_info["title"])

self.song_artist = title_parts[0]
self.song_title = title_parts[1]
self.artist = title_parts[0]
self.title = title_parts[1]

print(f"Guessed metadata from title: Artist: {self.song_artist}, Title: {self.song_title}")
print(f"Guessed metadata from title: Artist: {self.artist}, Title: {self.title}")

# Extract audio to WAV file using ffmpeg
self.audio_file = os.path.join(self.cache_dir, self.output_filename_slug + ".wav")
Expand Down
61 changes: 33 additions & 28 deletions karaoke_generator/utils/cli.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,50 @@
#!/usr/bin/env python
import argparse
import logging
import pkg_resources
from karaoke_generator import KaraokeGenerator


def main():
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

log_handler = logging.StreamHandler()
log_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(module)s - %(message)s")
log_formatter = logging.Formatter(fmt="%(asctime)s.%(msecs)03d - %(levelname)s - %(module)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
log_handler.setFormatter(log_formatter)
logger.addHandler(log_handler)

logger.debug("Parsing CLI args")

parser = argparse.ArgumentParser(description="Generate karaoke music video for either a local audio file or YouTube URL")

input_group = parser.add_mutually_exclusive_group(required=True)
input_group.add_argument(
"--youtube_url",
default=None,
help="Optional: YouTube URL to make karaoke version of.",
parser = argparse.ArgumentParser(
description="Generate karaoke music video for either a local audio file or YouTube URL",
formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=40),
)
input_group.add_argument(
"--audio_file",
default=None,
help="Optional: audio file path to make karaoke version of.",

parser.add_argument(
"input_path", nargs="?", help="The audio file path or YouTube URL to make karaoke version of.", default=argparse.SUPPRESS
)

package_version = pkg_resources.get_distribution("karaoke-video-generator").version
parser.add_argument("-v", "--version", action="version", version=f"%(prog)s {package_version}")
parser.add_argument("--log_level", default="INFO", help="Optional: Logging level, e.g. info, debug, warning. Default: INFO")

parser.add_argument(
"--song_artist",
"--artist",
default=None,
help="Optional: specify song artist for lyrics lookup and auto-correction",
help="Optional: song artist for lyrics lookup and auto-correction",
)
parser.add_argument(
"--song_title",
"--title",
default=None,
help="Optional: specify song title for lyrics lookup and auto-correction",
help="Optional: song title for lyrics lookup and auto-correction",
)

parser.add_argument(
"--genius_api_token",
default=None,
help="Optional: specify Genius API token for lyrics lookup and auto-correction",
help="Optional: Genius API token for lyrics fetching. Can also be set with GENIUS_API_TOKEN env var.",
)
parser.add_argument(
"--spotify_cookie",
default=None,
help="Optional: specify Spotify SP_DC cookie value for lyrics lookup and auto-correction",
help="Optional: Spotify sp_dc cookie value for lyrics fetching. Can also be set with SPOTIFY_COOKIE_SP_DC env var.",
)

parser.add_argument(
Expand All @@ -74,13 +70,21 @@ def main():

args = parser.parse_args()

logger.info(f"Karaoke generator beginning with audio_file: {args.audio_file} / youtube_url: {args.youtube_url}")
log_level = getattr(logging, args.log_level.upper())
logger.setLevel(log_level)

if not hasattr(args, "input_path"):
parser.print_help()
exit(1)

logger.info(f"Karaoke generator beginning with input_path: {args.input_path}")

generator = KaraokeGenerator(
audio_file=args.audio_file,
youtube_url=args.youtube_url,
song_artist=args.song_artist,
song_title=args.song_title,
log_formatter=log_formatter,
log_level=log_level,
input_path=args.input_path,
artist=args.artist,
title=args.title,
genius_api_token=args.genius_api_token,
spotify_cookie=args.spotify_cookie,
model_name=args.model_name,
Expand All @@ -93,6 +97,7 @@ def main():
logger.info(f"Karaoke generation complete!")

logger.debug(f"Output folder: {outputs['output_dir']}")



if __name__ == "__main__":
main()
Loading

0 comments on commit 943659f

Please sign in to comment.