diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index ba79ad033c8..adc922635b2 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -4,11 +4,13 @@ import re import shlex import sys +from typing import Tuple, List from collections import defaultdict from hashlib import sha1 from math import inf from metaflow import JSONType, current +from metaflow.graph import DAGNode from metaflow.decorators import flow_decorators from metaflow.exception import MetaflowException from metaflow.includefile import FilePathClass @@ -47,6 +49,7 @@ SERVICE_INTERNAL_URL, UI_URL, ) +from metaflow.unbounded_foreach import UBF_CONTROL, UBF_TASK from metaflow.metaflow_config_funcs import config_values from metaflow.mflog import BASH_SAVE_LOGS, bash_capture_logs, export_mflog_env_vars from metaflow.parameters import deploy_time_eval @@ -54,6 +57,7 @@ parse_kube_keyvalue_list, validate_kube_labels, ) +from metaflow.graph import FlowGraph from metaflow.util import ( compress_list, dict_to_cli_options, @@ -61,6 +65,9 @@ to_camelcase, to_unicode, ) +from metaflow.plugins.kubernetes.kubernetes_jobsets import ( + KubernetesArgoJobSet, +) from .argo_client import ArgoClient @@ -82,14 +89,14 @@ class ArgoWorkflowsSchedulingException(MetaflowException): # 5. Add Metaflow tags to labels/annotations. # 6. Support Multi-cluster scheduling - https://github.com/argoproj/argo-workflows/issues/3523#issuecomment-792307297 # 7. Support R lang. -# 8. Ping @savin at slack.outerbounds.co for any feature request. +# 8. Ping @savin at slack.outerbounds.co for any feature request class ArgoWorkflows(object): def __init__( self, name, - graph, + graph: FlowGraph, flow, code_package_sha, code_package_url, @@ -852,13 +859,13 @@ def _compile_workflow_template(self): # Visit every node and yield the uber DAGTemplate(s). def _dag_templates(self): def _visit( - node, exit_node=None, templates=None, dag_tasks=None, parent_foreach=None - ): - if node.parallel_foreach: - raise ArgoWorkflowsException( - "Deploying flows with @parallel decorator(s) " - "as Argo Workflows is not supported currently." - ) + node, + exit_node=None, + templates=None, + dag_tasks=None, + parent_foreach=None, + ): # Returns Tuple[List[Template], List[DAGTask]] + """ """ # Every for-each node results in a separate subDAG and an equivalent # DAGTemplate rooted at the child of the for-each node. Each DAGTemplate # has a unique name - the top-level DAGTemplate is named as the name of @@ -872,7 +879,6 @@ def _visit( templates = [] if exit_node is not None and exit_node is node.name: return templates, dag_tasks - if node.name == "start": # Start node has no dependencies. dag_task = DAGTask(self._sanitize(node.name)).template( @@ -881,13 +887,86 @@ def _visit( elif ( node.is_inside_foreach and self.graph[node.in_funcs[0]].type == "foreach" + and not self.graph[node.in_funcs[0]].parallel_foreach + # We need to distinguish what is a "regular" foreach (i.e something that doesn't care about to gang semantics) + # vs what is a "num_parallel" based foreach (i.e. something that follows gang semantics.) + # A `regular` foreach is basically any arbitrary kind of foreach. ): # Child of a foreach node needs input-paths as well as split-index # This child is the first node of the sub workflow and has no dependency + parameters = [ Parameter("input-paths").value("{{inputs.parameters.input-paths}}"), Parameter("split-index").value("{{inputs.parameters.split-index}}"), ] + dag_task = ( + DAGTask(self._sanitize(node.name)) + .template(self._sanitize(node.name)) + .arguments(Arguments().parameters(parameters)) + ) + elif node.parallel_step: + # This is the step where the @parallel decorator is defined. + # Since this DAGTask will call the for the `resource` [based templates] + # (https://argo-workflows.readthedocs.io/en/stable/walk-through/kubernetes-resources/) + # we have certain constraints on the way we can pass information inside the Jobset manifest + # [All templates will have access](https://argo-workflows.readthedocs.io/en/stable/variables/#all-templates) + # to the `inputs.parameters` so we will pass down ANY/ALL information using the + # input parameters. + # We define the usual parameters like input-paths/split-index etc. but we will also + # define the following: + # - `workerCount`: parameter which will be used to determine the number of + # parallel worker jobs + # - `jobset-name`: parameter which will be used to determine the name of the jobset. + # This parameter needs to be dynamic so that when we have retries we don't + # end up using the name of the jobset again (if we do, it will crash since k8s wont allow duplicated job names) + # - `retryCount`: parameter which will be used to determine the number of retries + # This parameter will *only* be available within the container templates like we + # have it for all other DAGTasks and NOT for custom kubernetes resource templates. + # So as a work-around, we will set it as the `retryCount` parameter instead of + # setting it as a {{ retries }} in the CLI code. Once set as a input parameter, + # we can use it in the Jobset Manifest templates as `{{inputs.parameters.retryCount}}` + # - `task-id-entropy`: This is a parameter which will help derive task-ids and jobset names. This parameter + # contains the relevant amount of entropy to ensure that task-ids and jobset names + # are uniquish. We will also use this in the join task to construct the task-ids of + # all parallel tasks since the task-ids for parallel task are minted formulaically. + parameters = [ + Parameter("input-paths").value("{{inputs.parameters.input-paths}}"), + Parameter("num-parallel").value( + "{{inputs.parameters.num-parallel}}" + ), + Parameter("split-index").value("{{inputs.parameters.split-index}}"), + Parameter("task-id-entropy").value( + "{{inputs.parameters.task-id-entropy}}" + ), + # we cant just use hyphens with sprig. + # https://github.com/argoproj/argo-workflows/issues/10567#issuecomment-1452410948 + Parameter("workerCount").value( + "{{=sprig.int(sprig.sub(sprig.int(inputs.parameters['num-parallel']),1))}}" + ), + ] + if any(d.name == "retry" for d in node.decorators): + parameters.extend( + [ + Parameter("retryCount").value("{{retries}}"), + # The job-setname needs to be unique for each retry + # and we cannot use the `generateName` field in the + # Jobset Manifest since we need to construct the subdomain + # and control pod domain name pre-hand. So we will use + # the retry count to ensure that the jobset name is unique + Parameter("jobset-name").value( + "js-{{inputs.parameters.task-id-entropy}}{{retries}}", + ), + ] + ) + else: + parameters.extend( + [ + Parameter("jobset-name").value( + "js-{{inputs.parameters.task-id-entropy}}", + ) + ] + ) + dag_task = ( DAGTask(self._sanitize(node.name)) .template(self._sanitize(node.name)) @@ -947,8 +1026,8 @@ def _visit( .template(self._sanitize(node.name)) .arguments(Arguments().parameters(parameters)) ) - dag_tasks.append(dag_task) + dag_tasks.append(dag_task) # End the workflow if we have reached the end of the flow if node.type == "end": return [ @@ -974,14 +1053,30 @@ def _visit( parent_foreach, ) # For foreach nodes generate a new sub DAGTemplate + # We do this for "regular" foreaches (ie. `self.next(self.a, foreach=)`) elif node.type == "foreach": foreach_template_name = self._sanitize( "%s-foreach-%s" % ( node.name, - node.foreach_param, + "parallel" if node.parallel_foreach else node.foreach_param + # Since foreach's are derived based on `self.next(self.a, foreach="")` + # vs @parallel foreach are done based on `self.next(self.a, num_parallel="")`, + # we need to ensure that `foreach_template_name` suffix is appropriately set based on the kind + # of foreach. ) ) + + # There are two separate "DAGTask"s created for the foreach node. + # - The first one is a "jump-off" DAGTask where we propagate the + # input-paths and split-index. This thing doesn't create + # any actual containers and it responsible for only propagating + # the parameters. + # - The DAGTask that follows first DAGTask is the one + # that uses the ContainerTemplate. This DAGTask is named the same + # thing as the foreach node. We will leverage a similar pattern for the + # @parallel tasks. + # foreach_task = ( DAGTask(foreach_template_name) .dependencies([self._sanitize(node.name)]) @@ -1005,9 +1100,26 @@ def _visit( if parent_foreach else [] ) + + ( + # Disabiguate parameters for a regular `foreach` vs a `@parallel` foreach + [ + Parameter("num-parallel").value( + "{{tasks.%s.outputs.parameters.num-parallel}}" + % self._sanitize(node.name) + ), + Parameter("task-id-entropy").value( + "{{tasks.%s.outputs.parameters.task-id-entropy}}" + % self._sanitize(node.name) + ), + ] + if node.parallel_foreach + else [] + ) ) ) .with_param( + # For @parallel workloads `num-splits` will be explicitly set to one so that + # we can piggyback on the current mechanism with which we leverage argo. "{{tasks.%s.outputs.parameters.num-splits}}" % self._sanitize(node.name) ) @@ -1020,17 +1132,34 @@ def _visit( [], node.name, ) + + # How do foreach's work on Argo: + # Lets say you have the following dag: (start[sets `foreach="x"`]) --> (task-a [actual foreach]) --> (join) --> (end) + # With argo we will : + # (start [sets num-splits]) --> (task-a-foreach-(0,0) [dummy task]) --> (task-a) --> (join) --> (end) + # The (task-a-foreach-(0,0) [dummy task]) propagates the values of the `split-index` and the input paths. + # to the actual foreach task. templates.append( Template(foreach_template_name) .inputs( Inputs().parameters( [Parameter("input-paths"), Parameter("split-index")] + ([Parameter("root-input-path")] if parent_foreach else []) + + ( + [ + Parameter("num-parallel"), + Parameter("task-id-entropy"), + # Parameter("workerCount") + ] + if node.parallel_foreach + else [] + ) ) ) .outputs( Outputs().parameters( [ + # non @parallel tasks set task-ids as outputs Parameter("task-id").valueFrom( { "parameter": "{{tasks.%s.outputs.parameters.task-id}}" @@ -1040,29 +1169,67 @@ def _visit( } ) ] + if not node.parallel_foreach + else [ + # @parallel tasks set `task-id-entropy` and `num-parallel` + # as outputs so task-ids can be derived in the join step. + # Both of these values should be propagated from the + # jobset labels. + Parameter("num-parallel").valueFrom( + { + "parameter": "{{tasks.%s.outputs.parameters.num-parallel}}" + % self._sanitize( + self.graph[node.matching_join].in_funcs[0] + ) + } + ), + Parameter("task-id-entropy").valueFrom( + { + "parameter": "{{tasks.%s.outputs.parameters.task-id-entropy}}" + % self._sanitize( + self.graph[node.matching_join].in_funcs[0] + ) + } + ), + ] ) ) .dag(DAGTemplate().fail_fast().tasks(dag_tasks_1)) ) + join_foreach_task = ( DAGTask(self._sanitize(self.graph[node.matching_join].name)) .template(self._sanitize(self.graph[node.matching_join].name)) .dependencies([foreach_template_name]) .arguments( Arguments().parameters( - [ - Parameter("input-paths").value( - "argo-{{workflow.name}}/%s/{{tasks.%s.outputs.parameters.task-id}}" - % (node.name, self._sanitize(node.name)) - ), - Parameter("split-cardinality").value( - "{{tasks.%s.outputs.parameters.split-cardinality}}" - % self._sanitize(node.name) - ), - ] + ( + [ + Parameter("input-paths").value( + "argo-{{workflow.name}}/%s/{{tasks.%s.outputs.parameters.task-id}}" + % (node.name, self._sanitize(node.name)) + ), + Parameter("split-cardinality").value( + "{{tasks.%s.outputs.parameters.split-cardinality}}" + % self._sanitize(node.name) + ), + ] + if not node.parallel_foreach + else [ + Parameter("num-parallel").value( + "{{tasks.%s.outputs.parameters.num-parallel}}" + % self._sanitize(node.name) + ), + Parameter("task-id-entropy").value( + "{{tasks.%s.outputs.parameters.task-id-entropy}}" + % self._sanitize(node.name) + ), + ] + ) + ( [ Parameter("split-index").value( + # TODO : Pass down these parameters to the jobset stuff. "{{inputs.parameters.split-index}}" ), Parameter("root-input-path").value( @@ -1140,7 +1307,17 @@ def _container_templates(self): # export input_paths as it is used multiple times in the container script # and we do not want to repeat the values. input_paths_expr = "export INPUT_PATHS=''" - if node.name != "start": + # If node is not a start step or a @parallel join then we will set the input paths. + # To set the input-paths as a parameter, we need to ensure that the node + # is not (a start node or a parallel join node). Start nodes will have no + # input paths and parallel join will derive input paths based on a + # formulaic approach using `num-parallel` and `task-id-entropy`. + if not ( + node.name == "start" + or (node.type == "join" and self.graph[node.in_funcs[0]].parallel_step) + ): + # For parallel joins we don't pass the INPUT_PATHS but are dynamically constructed. + # So we don't need to set the input paths. input_paths_expr = ( "export INPUT_PATHS={{inputs.parameters.input-paths}}" ) @@ -1169,13 +1346,23 @@ def _container_templates(self): task_idx, ] ) + if node.parallel_step: + task_str = "-".join( + [ + "$TASK_ID_PREFIX", + "{{inputs.parameters.task-id-entropy}}", # id_base is addition entropy to based on node-name of the workflow + "$TASK_ID_SUFFIX", + ] + ) + else: + # Generated task_ids need to be non-numeric - see register_task_id in + # service.py. We do so by prefixing `t-` + _task_id_base = ( + "$(echo %s | md5sum | cut -d ' ' -f 1 | tail -c 9)" % task_str + ) + task_str = "(t-%s)" % _task_id_base - # Generated task_ids need to be non-numeric - see register_task_id in - # service.py. We do so by prefixing `t-` - task_id_expr = ( - "export METAFLOW_TASK_ID=" - "(t-$(echo %s | md5sum | cut -d ' ' -f 1 | tail -c 9))" % task_str - ) + task_id_expr = "export METAFLOW_TASK_ID=" "%s" % task_str task_id = "$METAFLOW_TASK_ID" # Resolve retry strategy. @@ -1194,9 +1381,20 @@ def _container_templates(self): user_code_retries = max_user_code_retries total_retries = max_user_code_retries + max_error_retries # {{retries}} is only available if retryStrategy is specified + # and they are only available in the container templates NOT for custom + # Kubernetes manifests like Jobsets. + # For custom kubernetes manifests, we will pass the retryCount as a parameter + # and use that in the manifest. retry_count = ( - "{{retries}}" if max_user_code_retries + max_error_retries else 0 + ( + "{{retries}}" + if not node.parallel_step + else "{{inputs.parameters.retryCount}}" + ) + if total_retries + else 0 ) + minutes_between_retries = int(minutes_between_retries) # Configure log capture. @@ -1302,13 +1500,24 @@ def _container_templates(self): foreach_step = next( n for n in node.in_funcs if self.graph[n].is_inside_foreach ) - input_paths = ( - "$(python -m metaflow.plugins.argo.generate_input_paths %s {{workflow.creationTimestamp}} %s {{inputs.parameters.split-cardinality}})" - % ( - foreach_step, - input_paths, + if not self.graph[node.split_parents[-1]].parallel_foreach: + input_paths = ( + "$(python -m metaflow.plugins.argo.generate_input_paths %s {{workflow.creationTimestamp}} %s {{inputs.parameters.split-cardinality}})" + % ( + foreach_step, + input_paths, + ) + ) + else: + # When we run Jobsets with Argo Workflows we need to ensure that `input_paths` are generated using the a formulaic approach + # because our current strategy of using volume mounts for outputs won't work with Jobsets + input_paths = ( + "$(python -m metaflow.plugins.argo.jobset_input_paths %s %s {{inputs.parameters.task-id-entropy}} {{inputs.parameters.num-parallel}})" + % ( + run_id, + foreach_step, + ) ) - ) step = [ "step", node.name, @@ -1318,7 +1527,14 @@ def _container_templates(self): "--max-user-code-retries %d" % user_code_retries, "--input-paths %s" % input_paths, ] - if any(self.graph[n].type == "foreach" for n in node.in_funcs): + if node.parallel_step: + step.append( + "--split-index ${MF_CONTROL_INDEX:-$((MF_WORKER_REPLICA_INDEX + 1))}" + ) + # This is needed for setting the value of the UBF context in the CLI. + step.append("--ubf-context $UBF_CONTEXT") + + elif any(self.graph[n].type == "foreach" for n in node.in_funcs): # Pass split-index to a foreach task step.append("--split-index {{inputs.parameters.split-index}}") if self.tags: @@ -1481,17 +1697,47 @@ def _container_templates(self): # join task deterministically inside the join task without resorting to # passing a rather long list of (albiet compressed) inputs = [] - if node.name != "start": + # To set the input-paths as a parameter, we need to ensure that the node + # is not (a start node or a parallel join node). Start nodes will have no + # input paths and parallel join will derive input paths based on a + # formulaic approach. + if not ( + node.name == "start" + or (node.type == "join" and self.graph[node.in_funcs[0]].parallel_step) + ): inputs.append(Parameter("input-paths")) if any(self.graph[n].type == "foreach" for n in node.in_funcs): # Fetch split-index from parent inputs.append(Parameter("split-index")) + if ( node.type == "join" and self.graph[node.split_parents[-1]].type == "foreach" ): - # append this only for joins of foreaches, not static splits - inputs.append(Parameter("split-cardinality")) + # @parallel join tasks require `num-parallel` and `task-id-entropy` + # to construct the input paths, so we pass them down as input parameters. + if self.graph[node.split_parents[-1]].parallel_foreach: + inputs.extend( + [Parameter("num-parallel"), Parameter("task-id-entropy")] + ) + else: + # append this only for joins of foreaches, not static splits + inputs.append(Parameter("split-cardinality")) + # We can use an `elif` condition because the first `if` condition validates if its + # a foreach join node, hence we can safely assume that if that condition fails then + # we can check if the node is a @parallel node. + elif node.parallel_step: + inputs.extend( + [ + Parameter("num-parallel"), + Parameter("task-id-entropy"), + Parameter("jobset-name"), + Parameter("workerCount"), + ] + ) + if any(d.name == "retry" for d in node.decorators): + inputs.append(Parameter("retryCount")) + if node.is_inside_foreach and self.graph[node.out_funcs[0]].type == "join": if any( self.graph[parent].matching_join @@ -1508,7 +1754,9 @@ def _container_templates(self): inputs.append(Parameter("root-input-path")) outputs = [] - if node.name != "end": + # @parallel steps will not have a task-id as an output parameter since task-ids + # are derived at runtime. + if not (node.name == "end" or node.parallel_step): outputs = [Parameter("task-id").valueFrom({"path": "/mnt/out/task_id"})] if node.type == "foreach": # Emit split cardinality from foreach task @@ -1521,6 +1769,19 @@ def _container_templates(self): ) ) + if node.parallel_foreach: + outputs.extend( + [ + Parameter("num-parallel").valueFrom( + {"path": "/mnt/out/num_parallel"} + ), + Parameter("task-id-entropy").valueFrom( + {"path": "/mnt/out/task_id_entropy"} + ), + ] + ) + # Outputs should be defined over here, Not in the _dag_template for the `num_parallel` stuff. + # It makes no sense to set env vars to None (shows up as "None" string) # Also we skip some env vars (e.g. in case we want to pull them from KUBERNETES_SECRETS) env = { @@ -1550,6 +1811,156 @@ def _container_templates(self): # liked to inline this ContainerTemplate and avoid scanning the workflow # twice, but due to issues with variable substitution, we will have to # live with this routine. + if node.parallel_step: + + # Explicitly add the task-id-hint label. This is important because this label + # is returned as an Output parameter of this step and is used subsequently an + # an input in the join step. Even the num_parallel is used as an output parameter + kubernetes_labels = self.kubernetes_labels.copy() + jobset_name = "{{inputs.parameters.jobset-name}}" + kubernetes_labels[ + "task_id_entropy" + ] = "{{inputs.parameters.task-id-entropy}}" + kubernetes_labels["num_parallel"] = "{{inputs.parameters.num-parallel}}" + jobset = KubernetesArgoJobSet( + kubernetes_sdk=kubernetes_sdk, + name=jobset_name, + flow_name=self.flow.name, + run_id=run_id, + step_name=self._sanitize(node.name), + task_id=task_id, + attempt=retry_count, + user=self.username, + subdomain=jobset_name, + command=cmds, + namespace=resources["namespace"], + image=resources["image"], + image_pull_policy=resources["image_pull_policy"], + service_account=resources["service_account"], + secrets=( + [ + k + for k in ( + list( + [] + if not resources.get("secrets") + else [resources.get("secrets")] + if isinstance(resources.get("secrets"), str) + else resources.get("secrets") + ) + + KUBERNETES_SECRETS.split(",") + + ARGO_WORKFLOWS_KUBERNETES_SECRETS.split(",") + ) + if k + ] + ), + node_selector=resources.get("node_selector"), + cpu=str(resources["cpu"]), + memory=str(resources["memory"]), + disk=str(resources["disk"]), + gpu=resources["gpu"], + gpu_vendor=str(resources["gpu_vendor"]), + tolerations=resources["tolerations"], + use_tmpfs=use_tmpfs, + tmpfs_tempdir=tmpfs_tempdir, + tmpfs_size=tmpfs_size, + tmpfs_path=tmpfs_path, + timeout_in_seconds=run_time_limit, + persistent_volume_claims=resources["persistent_volume_claims"], + shared_memory=shared_memory, + port=port, + ) + + for k, v in env.items(): + jobset.environment_variable(k, v) + + for k, v in kubernetes_labels.items(): + jobset.label(k, v) + + ## -----Jobset specific env vars START here----- + jobset.environment_variable( + "MF_MASTER_ADDR", jobset.jobset_control_addr + ) + jobset.environment_variable("MF_MASTER_PORT", str(port)) + jobset.environment_variable( + "MF_WORLD_SIZE", "{{inputs.parameters.num-parallel}}" + ) + # for k, v in .items(): + jobset.environment_variables_from_selectors( + { + "MF_WORKER_REPLICA_INDEX": "metadata.annotations['jobset.sigs.k8s.io/job-index']", + "JOBSET_RESTART_ATTEMPT": "metadata.annotations['jobset.sigs.k8s.io/restart-attempt']", + "METAFLOW_KUBERNETES_JOBSET_NAME": "metadata.annotations['jobset.sigs.k8s.io/jobset-name']", + "METAFLOW_KUBERNETES_POD_NAMESPACE": "metadata.namespace", + "METAFLOW_KUBERNETES_POD_NAME": "metadata.name", + "METAFLOW_KUBERNETES_POD_ID": "metadata.uid", + "METAFLOW_KUBERNETES_SERVICE_ACCOUNT_NAME": "spec.serviceAccountName", + "METAFLOW_KUBERNETES_NODE_IP": "status.hostIP", + # `TASK_ID_SUFFIX` is needed for the construction of the task-ids + "TASK_ID_SUFFIX": "metadata.annotations['jobset.sigs.k8s.io/job-index']", + } + ) + annotations = { + # setting annotations explicitly as they wont be + # passed down from WorkflowTemplate level + "metaflow/step_name": node.name, + "metaflow/attempt": str(retry_count), + "metaflow/run_id": run_id, + "metaflow/production_token": self.production_token, + "metaflow/owner": self.username, + "metaflow/user": "argo-workflows", + "metaflow/flow_name": self.flow.name, + } + if current.get("project_name"): + annotations.update( + { + "metaflow/project_name": current.project_name, + "metaflow/branch_name": current.branch_name, + "metaflow/project_flow_name": current.project_flow_name, + } + ) + for k, v in annotations.items(): + jobset.annotation(k, v) + ## -----Jobset specific env vars END here----- + ## ---- Jobset control/workers specific vars START here ---- + jobset.control.replicas(1) + jobset.worker.replicas("{{=asInt(inputs.parameters.workerCount)}}") + jobset.control.environment_variable("UBF_CONTEXT", UBF_CONTROL) + jobset.worker.environment_variable("UBF_CONTEXT", UBF_TASK) + jobset.control.environment_variable("MF_CONTROL_INDEX", "0") + # `TASK_ID_PREFIX` needs to explicitly be `control` or `worker` + # because the join task uses a formulaic approach to infer the task-ids + jobset.control.environment_variable("TASK_ID_PREFIX", "control") + jobset.worker.environment_variable("TASK_ID_PREFIX", "worker") + + ## ---- Jobset control/workers specific vars END here ---- + yield ( + Template(ArgoWorkflows._sanitize(node.name)) + .resource( + "create", + jobset.dump(), + "status.terminalState == Completed", + "status.terminalState == Failed", + ) + .inputs(Inputs().parameters(inputs)) + .outputs( + Outputs().parameters( + [ + Parameter("task-id-entropy").valueFrom( + {"jsonPath": "{.metadata.labels.task_id_entropy}"} + ), + Parameter("num-parallel").valueFrom( + {"jsonPath": "{.metadata.labels.num_parallel}"} + ), + ] + ) + ) + .retry_strategy( + times=total_retries, + minutes_between_retries=minutes_between_retries, + ) + ) + continue yield ( Template(self._sanitize(node.name)) # Set @timeout values @@ -1838,7 +2249,7 @@ def _get_slack_blocks(self, message): "fields": [ { "type": "mrkdwn", - "text": "*Project:* %s" % current.project_name + "text": "*Project:* %s" % current.project_name }, { "type": "mrkdwn", @@ -2612,6 +3023,15 @@ def tolerations(self, tolerations): def to_json(self): return self.payload + def resource(self, action, manifest, success_criteria, failure_criteria): + self.payload["resource"] = {} + self.payload["resource"]["action"] = action + self.payload["setOwnerReference"] = True + self.payload["resource"]["successCondition"] = success_criteria + self.payload["resource"]["failureCondition"] = failure_criteria + self.payload["resource"]["manifest"] = manifest + return self + def __str__(self): return json.dumps(self.payload, indent=4) diff --git a/metaflow/plugins/argo/argo_workflows_decorator.py b/metaflow/plugins/argo/argo_workflows_decorator.py index 58a7be39479..a4ec10114f3 100644 --- a/metaflow/plugins/argo/argo_workflows_decorator.py +++ b/metaflow/plugins/argo/argo_workflows_decorator.py @@ -2,12 +2,14 @@ import os import time + from metaflow import current from metaflow.decorators import StepDecorator from metaflow.events import Trigger from metaflow.metadata import MetaDatum from metaflow.metaflow_config import ARGO_EVENTS_WEBHOOK_URL - +from metaflow.graph import DAGNode, FlowGraph +from metaflow.flowspec import FlowSpec from .argo_events import ArgoEvent @@ -83,7 +85,13 @@ def task_pre_step( metadata.register_metadata(run_id, step_name, task_id, entries) def task_finished( - self, step_name, flow, graph, is_task_ok, retry_count, max_user_code_retries + self, + step_name, + flow: FlowSpec, + graph: FlowGraph, + is_task_ok, + retry_count, + max_user_code_retries, ): if not is_task_ok: # The task finished with an exception - execution won't @@ -100,16 +108,39 @@ def task_finished( # we run pods with a security context. We work around this constraint by # mounting an emptyDir volume. if graph[step_name].type == "foreach": + # A DAGNode is considered a `parallel_step` if it is annotated by the @parallel decorator. + # A DAGNode is considered a `parallel_foreach` if it contains a `num_parallel` kwarg provided to the + # `next` method of that DAGNode. + # At this moment in the code we care if a node is marked as a `parallel_foreach` so that we can pass down the + # value of `num_parallel` to the subsequent steps. + # For @parallel, the implmentation uses 1 jobset object. That one jobset + # object internally creates 'num_parallel' jobs. So, we set foreach_num_splits + # to 1 here for @parallel. The parallelism of jobset is handled in + # kubernetes_job.py. + if graph[step_name].parallel_foreach: + with open("/mnt/out/num_parallel", "w") as f: + json.dump(flow._parallel_ubf_iter.num_parallel, f) + flow._foreach_num_splits = 1 + with open("/mnt/out/task_id_entropy", "w") as file: + import uuid + + file.write(uuid.uuid4().hex[:6]) + with open("/mnt/out/splits", "w") as file: json.dump(list(range(flow._foreach_num_splits)), file) with open("/mnt/out/split_cardinality", "w") as file: json.dump(flow._foreach_num_splits, file) - # Unfortunately, we can't always use pod names as task-ids since the pod names - # are not static across retries. We write the task-id to a file that is read - # by the next task here. - with open("/mnt/out/task_id", "w") as file: - file.write(self.task_id) + # for steps that have a `@parallel` decorator set to them, we will be relying on Jobsets + # to run the task. In this case, we cannot set anything in the + # `/mnt/out` directory, since such form of output mounts are not available to jobset execution as + # argo just treats it like A K8s resource that it throws in the cluster. + if not graph[step_name].parallel_step: + # Unfortunately, we can't always use pod names as task-ids since the pod names + # are not static across retries. We write the task-id to a file that is read + # by the next task here. + with open("/mnt/out/task_id", "w") as file: + file.write(self.task_id) # Emit Argo Events given that the flow has succeeded. Given that we only # emit events when the task succeeds, we can piggy back on this decorator diff --git a/metaflow/plugins/argo/jobset_input_paths.py b/metaflow/plugins/argo/jobset_input_paths.py new file mode 100644 index 00000000000..b4121d80432 --- /dev/null +++ b/metaflow/plugins/argo/jobset_input_paths.py @@ -0,0 +1,16 @@ +import sys +from hashlib import md5 + + +def generate_input_paths(run_id, step_name, task_id_entropy, num_parallel): + # => run_id/step/:foo,bar + control_id = "control-{}-0".format(task_id_entropy) + worker_ids = [ + "worker-{}-{}".format(task_id_entropy, i) for i in range(int(num_parallel) - 1) + ] + ids = [control_id] + worker_ids + return "{}/{}/:{}".format(run_id, step_name, ",".join(ids)) + + +if __name__ == "__main__": + print(generate_input_paths(sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4])) diff --git a/metaflow/plugins/kubernetes/kubernetes.py b/metaflow/plugins/kubernetes/kubernetes.py index 092b63b27c4..8be958b8ec0 100644 --- a/metaflow/plugins/kubernetes/kubernetes.py +++ b/metaflow/plugins/kubernetes/kubernetes.py @@ -155,6 +155,7 @@ def launch_job(self, **kwargs): ): self._job = self.create_jobset(**kwargs).execute() else: + kwargs.pop("num_parallel", None) kwargs["name_pattern"] = "t-{uid}-".format(uid=str(uuid4())[:8]) self._job = self.create_job_object(**kwargs).create().execute() diff --git a/metaflow/plugins/kubernetes/kubernetes_jobsets.py b/metaflow/plugins/kubernetes/kubernetes_jobsets.py index b4efced1878..a6c4c964b1c 100644 --- a/metaflow/plugins/kubernetes/kubernetes_jobsets.py +++ b/metaflow/plugins/kubernetes/kubernetes_jobsets.py @@ -7,8 +7,6 @@ from metaflow.exception import MetaflowException from metaflow.metaflow_config import KUBERNETES_JOBSET_GROUP, KUBERNETES_JOBSET_VERSION -from metaflow.metaflow_current import current -from metaflow.unbounded_foreach import UBF_CONTROL, UBF_TASK from metaflow.tracing import inject_tracing_vars from metaflow.metaflow_config import KUBERNETES_SECRETS @@ -768,7 +766,7 @@ def dump(self): ) -class KubernetesJobSet(object): # todo : [final-refactor] : fix-me +class KubernetesJobSet(object): def __init__( self, client, @@ -906,3 +904,119 @@ def execute(self): group=self._group, version=self._version, ) + + +class KubernetesArgoJobSet(object): + def __init__(self, kubernetes_sdk, name=None, namespace=None, **kwargs): + self._kubernetes_sdk = kubernetes_sdk + self._annotations = {} + self._labels = {} + self._group = KUBERNETES_JOBSET_GROUP + self._version = KUBERNETES_JOBSET_VERSION + self._namespace = namespace + self.name = name + + self._jobset_control_addr = _make_domain_name( + name, + CONTROL_JOB_NAME, + 0, + 0, + namespace, + ) + + self._control_spec = JobSetSpec( + kubernetes_sdk, name=CONTROL_JOB_NAME, namespace=namespace, **kwargs + ) + self._worker_spec = JobSetSpec( + kubernetes_sdk, name="worker", namespace=namespace, **kwargs + ) + + @property + def jobset_control_addr(self): + return self._jobset_control_addr + + @property + def worker(self): + return self._worker_spec + + @property + def control(self): + return self._control_spec + + def environment_variable_from_selector(self, name, label_value): + self.worker.environment_variable_from_selector(name, label_value) + self.control.environment_variable_from_selector(name, label_value) + return self + + def environment_variables_from_selectors(self, env_dict): + for name, label_value in env_dict.items(): + self.worker.environment_variable_from_selector(name, label_value) + self.control.environment_variable_from_selector(name, label_value) + return self + + def environment_variable(self, name, value): + self.worker.environment_variable(name, value) + self.control.environment_variable(name, value) + return self + + def label(self, name, value): + self.worker.label(name, value) + self.control.label(name, value) + self._labels = dict(self._labels, **{name: value}) + return self + + def annotation(self, name, value): + self.worker.annotation(name, value) + self.control.annotation(name, value) + self._annotations = dict(self._annotations, **{name: value}) + return self + + def dump(self): + client = self._kubernetes_sdk + import json + + data = json.dumps( + client.ApiClient().sanitize_for_serialization( + dict( + apiVersion=self._group + "/" + self._version, + kind="JobSet", + metadata=client.api_client.ApiClient().sanitize_for_serialization( + client.V1ObjectMeta( + name=self.name, + labels=self._labels, + annotations=self._annotations, + ) + ), + spec=dict( + replicatedJobs=[self.control.dump(), self.worker.dump()], + suspend=False, + startupPolicy=None, + successPolicy=None, + # The Failure Policy helps setting the number of retries for the jobset. + # but we don't rely on it and instead rely on either the local scheduler + # or the Argo Workflows to handle retries. + failurePolicy=None, + network=None, + ), + status=None, + ) + ) + ) + # The values we populate in the Jobset manifest (for Argo Workflows) piggybacks on the Argo Workflow's templating engine. + # Even though Argo Workflows's templating helps us constructing all the necessary IDs and populating the fields + # required by Metaflow, we run into one glitch. When we construct JSON/YAML serializable objects, + # anything between two braces such as `{{=asInt(inputs.parameters.workerCount)}}` gets quoted. This is a problem + # since we need to pass the value of `inputs.parameters.workerCount` as an integer and not as a string. + # If we pass it as a string, the jobset controller will not accept the Jobset CRD we submitted to kubernetes. + # To get around this, we need to replace the quoted substring with the unquoted substring because YAML /JSON parsers + # won't allow deserialization with the quoting trivially. + + # This is super important because the `inputs.parameters.workerCount` is used to set the number of replicas; + # The value for number of replicas is derived from the value of `num_parallel` (which is set in the user-code). + # Since the value of `num_parallel` can be dynamic and can change from run to run, we need to ensure that the + # value can be passed-down dynamically and is **explicitly set as a integer** in the Jobset Manifest submitted as a + # part of the Argo Workflow + + quoted_substring = '"{{=asInt(inputs.parameters.workerCount)}}"' + unquoted_substring = "{{=asInt(inputs.parameters.workerCount)}}" + return data.replace(quoted_substring, unquoted_substring)