Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference update #30

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import aws_cdk as cdk

from infra.pipeline_product_stack import BatchPipelineStack, DeployPipelineStack
from infra.service_catalog_stack import ServiceCatalogStack

# Configure the logger
Expand Down
31 changes: 31 additions & 0 deletions batch_pipeline/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
.DS_Store
*.swp
package-lock.json
__pycache__
.pytest_cache
.env
.venv
*.egg-info
.vscode/

# CDK asset staging directory
.cdk.staging
cdk.out
dist/

# Layers
layers/python
layers/*.zip

# Notebooks
.ipynb_checkpoints/

# Sample events or assets
events
assets

# project development env
node_modules/*
node_modules
.vscode/*
notes.txt
91 changes: 18 additions & 73 deletions batch_pipeline/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,11 @@
import logging
import os

# Import the pipeline
from pipelines.pipeline import get_pipeline, upload_pipeline
import aws_cdk as cdk

from aws_cdk import core
from infra.batch_config import BatchConfig
from infra.sagemaker_pipeline_stack import SageMakerPipelineStack
from infra.model_registry import ModelRegistry

from infra.sagemaker_pipeline_stack import SageMakerPipelineStack

# Configure the logger
logger = logging.getLogger(__name__)
Expand All @@ -22,13 +19,11 @@


def create_pipeline(
app: core.App,
app: cdk.App,
project_name: str,
project_id: str,
region: str,
sagemaker_pipeline_role_arn: str,
artifact_bucket: str,
evaluate_drift_function_arn: str,
stage_name: str,
):
# Get the stage specific deployment config for sagemaker
Expand All @@ -54,74 +49,28 @@ def create_pipeline(
)[0]
batch_config.model_package_arn = p["ModelPackageArn"]

# Set the default input data uri
data_uri = f"s3://{artifact_bucket}/{project_id}/batch/{stage_name}"

# set the output transform uri
transform_uri = f"s3://{artifact_bucket}/{project_id}/transform/{stage_name}"

# Get the pipeline execution to get the baseline uri
pipeline_execution_arn = registry.get_pipeline_execution_arn(
batch_config.model_package_arn
)
logger.info(f"Got pipeline exection arn: {pipeline_execution_arn}")
model_uri = registry.get_model_artifact(pipeline_execution_arn)
logger.info(f"Got model uri: {model_uri}")

# Set the sagemaker pipeline name and descrption with model version
# Set the sagemaker pipeline name and description with model version
sagemaker_pipeline_name = f"{project_name}-batch-{stage_name}"
sagemaker_pipeline_description = f"Batch Pipeline for {stage_name} model version: {batch_config.model_package_version}"

# If we have drift configuration then get the baseline uri
baseline_uri = None
if batch_config.drift_config is not None:
baseline_uri = registry.get_data_check_baseline_uri(batch_config.model_package_arn)
logger.info(f"Got baseline uri: {baseline_uri}")

# Create batch pipeline
pipeline = get_pipeline(
region=region,
role=sagemaker_pipeline_role_arn,
pipeline_name=sagemaker_pipeline_name,
default_bucket=artifact_bucket,
base_job_prefix=project_id,
evaluate_drift_function_arn=evaluate_drift_function_arn,
data_uri=data_uri,
model_uri=model_uri,
transform_uri=transform_uri,
baseline_uri=baseline_uri,
)

# Create the pipeline definition
logger.info("Creating/updating a SageMaker Pipeline for batch transform")
pipeline_definition_body = pipeline.definition()
parsed = json.loads(pipeline_definition_body)
logger.info(json.dumps(parsed, indent=2, sort_keys=True))

# Upload the pipeline to S3 bucket/key and return JSON with key/value for for Cfn Stack parameters.
# see: https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sagemaker-pipeline.html
logger.info(f"Uploading {stage_name} pipeline to {artifact_bucket}")
pipeline_definition_key = upload_pipeline(
pipeline,
default_bucket=artifact_bucket,
base_job_prefix=f"{project_id}/batch-{stage_name}",
sagemaker_pipeline_description = (
f"Batch Pipeline for {stage_name} model version:"
f"{batch_config.model_package_version}"
)

tags = [
core.CfnTag(key="sagemaker:deployment-stage", value=stage_name),
core.CfnTag(key="sagemaker:project-id", value=project_id),
core.CfnTag(key="sagemaker:project-name", value=project_name),
cdk.CfnTag(key="sagemaker:deployment-stage", value=stage_name),
cdk.CfnTag(key="sagemaker:project-id", value=project_id),
cdk.CfnTag(key="sagemaker:project-name", value=project_name),
]

SageMakerPipelineStack(
app,
f"drift-batch-{stage_name}",
pipeline_name=sagemaker_pipeline_name,
pipeline_description=sagemaker_pipeline_description,
pipeline_definition_bucket=artifact_bucket,
pipeline_definition_key=pipeline_definition_key,
sagemaker_role_arn=sagemaker_pipeline_role_arn,
default_bucket=artifact_bucket,
tags=tags,
batch_config=batch_config,
drift_config=batch_config.drift_config,
)

Expand All @@ -132,30 +81,26 @@ def main(
region: str,
sagemaker_pipeline_role_arn: str,
artifact_bucket: str,
evaluate_drift_function_arn: str,
# evaluate_drift_function_arn: str,
):
# Create App and stacks
app = core.App()
app = cdk.App()

create_pipeline(
app=app,
project_name=project_name,
project_id=project_id,
region=region,
sagemaker_pipeline_role_arn=sagemaker_pipeline_role_arn,
artifact_bucket=artifact_bucket,
evaluate_drift_function_arn=evaluate_drift_function_arn,
stage_name="staging",
)

create_pipeline(
app=app,
project_name=project_name,
project_id=project_id,
region=region,
sagemaker_pipeline_role_arn=sagemaker_pipeline_role_arn,
artifact_bucket=artifact_bucket,
evaluate_drift_function_arn=evaluate_drift_function_arn,
stage_name="prod",
)

Expand All @@ -174,10 +119,10 @@ def main(
"--sagemaker-pipeline-role-arn",
default=os.environ.get("SAGEMAKER_PIPELINE_ROLE_ARN"),
)
parser.add_argument(
"--evaluate-drift-function-arn",
default=os.environ.get("EVALUATE_DRIFT_FUNCTION_ARN"),
)
# parser.add_argument(
# "--evaluate-drift-function-arn",
# default=os.environ.get("EVALUATE_DRIFT_FUNCTION_ARN"),
# )
parser.add_argument(
"--artifact-bucket",
default=os.environ.get("ARTIFACT_BUCKET"),
Expand Down
2 changes: 1 addition & 1 deletion batch_pipeline/buildspec.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ phases:
nodejs: "16"
python: "3.9"
commands:
- npm install aws-cdk@2.46.0
- npm install aws-cdk@2.72.1
- npm update
- python -m pip install --upgrade pip
- python -m pip install -r requirements.txt
Expand Down
16 changes: 4 additions & 12 deletions batch_pipeline/cdk.json
Original file line number Diff line number Diff line change
@@ -1,13 +1,5 @@
{
"app": "python3 app.py",
"context": {
"@aws-cdk/core:enableStackNameDuplicates": "true",
"aws-cdk:enableDiffNoFail": "true",
"@aws-cdk/core:stackRelativeExports": "true",
"@aws-cdk/aws-ecr-assets:dockerIgnoreSupport": true,
"@aws-cdk/aws-secretsmanager:parseOwnedSecretName": true,
"@aws-cdk/aws-kms:defaultKeyPolicies": true,
"@aws-cdk/aws-s3:grantWriteWithoutAcl": true
}
}

"app": "python3 app.py",
"context": {},
"versionReporting": false
}
2 changes: 1 addition & 1 deletion batch_pipeline/infra/batch_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(
self.model_package_version = model_package_version
self.model_package_arn = model_package_arn
self.model_monitor_enabled = model_monitor_enabled
if type(drift_config) is dict:
if isinstance(drift_config, dict):
self.drift_config = DriftConfig(**drift_config)
else:
self.drift_config = None
69 changes: 7 additions & 62 deletions batch_pipeline/infra/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def create_model_package_group(
ModelPackageGroupDescription=description,
)
model_package_group_arn = response["ModelPackageGroupArn"]
# Add tags seperately
# Add tags separately
self.sm_client.add_tags(
ResourceArn=model_package_group_arn,
Tags=[
Expand Down Expand Up @@ -84,7 +84,7 @@ def get_latest_approved_packages(
"SortBy": "CreationTime",
"MaxResults": max_results,
}
# Add optional creationg time after
# Add optional creation time after
if creation_time_after is not None:
args = {**args, "CreationTimeAfter": creation_time_after}
response = self.sm_client.list_model_packages(**args)
Expand Down Expand Up @@ -165,7 +165,10 @@ def get_versioned_approved_packages(

# Return error if no packages found
if len(model_packages) == 0:
error_message = f"No approved packages found for: {model_package_group_name} and versions: {model_package_versions}"
error_message = (
"No approved packages found for: "
f"{model_package_group_name} and versions: {model_package_versions}"
)
logger.error(error_message)
raise Exception(error_message)

Expand All @@ -182,7 +185,7 @@ def get_versioned_approved_packages(
def select_versioned_packages(
self, model_packages: list, model_package_versions: list
):
"""Filters the model packages based on a list of model package verisons.
"""Filters the model packages based on a list of model package versions.

Args:
model_packages: The list of packages.
Expand All @@ -199,61 +202,3 @@ def select_versioned_packages(
p for p in model_packages if p["ModelPackageVersion"] == version
]
return filtered_packages

def get_pipeline_execution_arn(self, model_package_arn: str):
"""Geturns the execution arn for the latest approved model package

Args:
model_package_arn: The arn of the model package

Returns:
The arn of the sagemaker pipeline that created the model package.
"""

artifact_arn = self.sm_client.list_artifacts(SourceUri=model_package_arn)[
"ArtifactSummaries"
][0]["ArtifactArn"]
return self.sm_client.describe_artifact(ArtifactArn=artifact_arn)[
"MetadataProperties"
]["GeneratedBy"]

def get_model_artifact(
self,
pipeline_execution_arn: str,
step_name: str = "TrainModel",
):
"""Returns the training job model artifact uri for a given step name.

Args:
pipeline_execution_arn: The pipeline execution arn
step_name: The optional training job step name

Returns:
The model artifact from the training job
"""

steps = self.sm_client.list_pipeline_execution_steps(
PipelineExecutionArn=pipeline_execution_arn
)["PipelineExecutionSteps"]
training_job_arn = [
s["Metadata"]["TrainingJob"]["Arn"]
for s in steps
if s["StepName"] == step_name
][0]
training_job_name = training_job_arn.split("/")[-1]
outputs = self.sm_client.describe_training_job(
TrainingJobName=training_job_name
)
return outputs["ModelArtifacts"]["S3ModelArtifacts"]

def get_data_check_baseline_uri(self, model_package_arn: str):
try:
model_details = self.sm_client.describe_model_package(ModelPackageName=model_package_arn)
print(model_details)
baseline_uri = model_details['DriftCheckBaselines']['ModelDataQuality']['Constraints']['S3Uri']
baseline_uri = baseline_uri.replace('/constraints.json','') # returning the folder containing constraints and statistics
return baseline_uri
except ClientError as e:
error_message = e.response["Error"]["Message"]
logger.error(error_message)
raise Exception(error_message)
Loading