Skip to content

Commit

Permalink
JobDefinition env parameter to set pip index url and host
Browse files Browse the repository at this point in the history
  • Loading branch information
ChristianZaccaria committed Dec 7, 2023
1 parent 745772e commit 7fc2b7a
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 1 deletion.
8 changes: 8 additions & 0 deletions src/codeflare_sdk/job/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from torchx.schedulers.ray_scheduler import RayScheduler
from torchx.specs import AppHandle, parse_app_handle, AppDryRunInfo

from ..utils.generate_yaml import update_pip_requirements


if TYPE_CHECKING:
from ..cluster.cluster import Cluster
Expand Down Expand Up @@ -90,6 +92,12 @@ def __init__(
)
self.image = image
self.workspace = workspace
if "PIP_TRUSTED_HOST" in self.env or "PIP_INDEX_URL" in self.env:
update_pip_requirements(self)
else:
self.env.setdefault("PIP_TRUSTED_HOST", "pypi.org")
self.env.setdefault("PIP_INDEX_URL", "https://pypi.org/simple")
update_pip_requirements(self)

def _dry_run(self, cluster: "Cluster"):
j = f"{cluster.config.num_workers}x{max(cluster.config.num_gpus, 1)}" # # of proc. = # of gpus
Expand Down
36 changes: 36 additions & 0 deletions src/codeflare_sdk/utils/generate_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import sys
import os
import argparse
from pathlib import Path
import uuid
from kubernetes import client, config
from .kube_api_helpers import _kube_api_error_handling
Expand Down Expand Up @@ -689,3 +690,38 @@ def generate_appwrapper(
else:
write_user_appwrapper(user_yaml, outfile)
return outfile


def update_pip_requirements(self):
pip_trusted_host = self.env.get("PIP_TRUSTED_HOST")
pip_index_url = self.env.get("PIP_INDEX_URL")
requirements_path = Path("requirements.txt")

if requirements_path.exists():
with requirements_path.open("r") as file:
requirements = file.readlines()

# Check and replace or add --trusted-host and --index-url
trusted_host = f"--trusted-host {pip_trusted_host}\n"
index_url = f"--index-url {pip_index_url}\n"
modified_requirements = []

for line in requirements:
if line.startswith("--trusted-host"):
modified_requirements.append(trusted_host)
trusted_host = None
elif line.startswith("--index-url"):
modified_requirements.append(index_url)
index_url = None
else:
modified_requirements.append(line)

# Append the lines if they were not replaced
if trusted_host:
modified_requirements.insert(0, trusted_host)
if index_url:
modified_requirements.insert(0, index_url)

# Write back the modified requirements
with requirements_path.open("w") as file:
file.writelines(modified_requirements)
6 changes: 5 additions & 1 deletion tests/unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2091,7 +2091,11 @@ def test_DDPJobDefinition_creation():
assert ddp.memMB == 1024
assert ddp.h == None
assert ddp.j == "2x1"
assert ddp.env == {"test": "test"}
assert ddp.env == {
"PIP_TRUSTED_HOST": "pypi.org",
"PIP_INDEX_URL": "https://pypi.org/simple",
"test": "test",
}
assert ddp.max_retries == 0
assert ddp.mounts == []
assert ddp.rdzv_port == 29500
Expand Down

0 comments on commit 7fc2b7a

Please sign in to comment.