Skip to content

Commit

Permalink
Merge pull request #15 from aws-samples/dependency-updates-and-cleanup
Browse files Browse the repository at this point in the history
Dependency updates and cleanup
  • Loading branch information
laithalsaadoon authored Aug 5, 2022
2 parents 696c090 + ff12f1e commit a18fb46
Show file tree
Hide file tree
Showing 17 changed files with 48,330 additions and 13,340 deletions.
469 changes: 460 additions & 9 deletions .gitignore

Large diffs are not rendered by default.

140 changes: 77 additions & 63 deletions backend/lambda/app.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,58 @@
import base64
import json
from io import BytesIO
from os import environ

import boto3
import numpy as np
import requests
from urllib.parse import urlparse
from io import BytesIO

from elasticsearch import Elasticsearch, RequestsHttpConnection
from requests_aws4auth import AWS4Auth
from opensearchpy import AWSV4SignerAuth, OpenSearch, RequestsHttpConnection
from PIL import Image

# Global variables that are reused
sm_runtime_client = boto3.client('sagemaker-runtime')
s3_client = boto3.client('s3')
sm_runtime_client = boto3.client("sagemaker-runtime")
s3_client = boto3.client("s3")


def get_features(sm_runtime_client, sagemaker_endpoint, img_bytes):
def get_features(img_bytes, sagemaker_endpoint=environ["SM_ENDPOINT"]):
img_bytes = image_preprocessing(img_bytes, return_body=True)
response = sm_runtime_client.invoke_endpoint(
EndpointName=sagemaker_endpoint,
ContentType='application/x-image',
Body=img_bytes)
response_body = json.loads((response['Body'].read()))
features = response_body['predictions'][0]

ContentType="application/json",
Body=img_bytes,
)
response_body = json.loads((response["Body"].read()))
features = response_body["predictions"][0]
return features


def get_neighbors(features, es, k_neighbors=3):
idx_name = 'idx_zalando'
res = es.search(
request_timeout=30, index=idx_name,
body={
'size': k_neighbors,
'query': {'knn': {'zalando_img_vector': {'vector': features, 'k': k_neighbors}}}}
)
s3_uris = [res['hits']['hits'][x]['_source']['image'] for x in range(k_neighbors)]
def get_neighbors(features, oss, k_neighbors=3):
idx_name = "idx_zalando"
query = {
"size": k_neighbors,
"query": {
"knn": {"zalando_img_vector": {"vector": features, "k": k_neighbors}}
},
}
res = oss.search(request_timeout=30, index=idx_name, body=query)
s3_uris = [res["hits"]["hits"][x]["_source"]["image"] for x in range(k_neighbors)]

return s3_uris


def generate_presigned_urls(s3_uris):
presigned_urls = [s3_client.generate_presigned_url(
'get_object',
Params={
'Bucket': urlparse(x).netloc,
'Key': urlparse(x).path.lstrip('/')},
ExpiresIn=300
) for x in s3_uris]
def _s3_client_presigned_url(bucket, key):
return s3_client.generate_presigned_url(
ClientMethod="get_object",
Params={"Bucket": bucket, "Key": key},
ExpiresIn=60 * 5,
)

bucket = s3_uris[0].replace("s3://", "").split("/")[0]
presigned_urls = [
_s3_client_presigned_url(bucket, x.replace(f"s3://{bucket}/", ""))
for x in s3_uris
]

return presigned_urls

Expand All @@ -55,60 +61,68 @@ def download_file(url):
r = requests.get(url)
if r.status_code == 200:
file = BytesIO(r.content)
return file
else:
print("file failed to download")
return file
return None


def lambda_handler(event, context):
def create_oss_client():
region = environ["AWS_REGION"]
elasticsearch_endpoint = environ["ES_ENDPOINT"]

# elasticsearch variables
service = 'es'
region = environ['AWS_REGION']
elasticsearch_endpoint = environ['ES_ENDPOINT']
credentials = boto3.Session().get_credentials()
awsauth = AWSV4SignerAuth(credentials, region)

session = boto3.session.Session()
credentials = session.get_credentials()
awsauth = AWS4Auth(
credentials.access_key,
credentials.secret_key,
region,
service,
session_token=credentials.token
)

es = Elasticsearch(
hosts=[{'host': elasticsearch_endpoint, 'port': 443}],
oss = OpenSearch(
hosts=[{"host": elasticsearch_endpoint, "port": 443}],
http_auth=awsauth,
use_ssl=True,
verify_certs=True,
connection_class=RequestsHttpConnection
connection_class=RequestsHttpConnection,
)

# sagemaker variables
sagemaker_endpoint = environ['SM_ENDPOINT']
return oss


def image_preprocessing(img_bytes, return_body=True):
img = Image.open(img_bytes).convert("RGB")
img = img.resize((224, 224))
img = np.asarray(img)
img = np.expand_dims(img, axis=0)
if return_body:
body = json.dumps({"instances": img.tolist()})
return body
else:
return img


def lambda_handler(event, _):
oss_client = create_oss_client()

api_payload = json.loads(event["body"])
k = api_payload["k"]

api_payload = json.loads(event['body'])
k = api_payload['k']
if event['path'] == '/postURL':
image = download_file(api_payload['url'])
if event["path"] == "/postURL":
image = download_file(api_payload["url"])
else:
img_string = api_payload['base64img']
print(img_string)
img_string = api_payload["base64img"]
image = BytesIO(base64.b64decode(img_string))

features = get_features(sm_runtime_client, sagemaker_endpoint, image)
s3_uris_neighbors = get_neighbors(features, es, k_neighbors=k)
features = get_features(image)
s3_uris_neighbors = get_neighbors(features, oss_client, k_neighbors=k)
s3_presigned_urls = generate_presigned_urls(s3_uris_neighbors)

return {
"statusCode": 200,
"headers": {
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Headers": "*",
"Access-Control-Allow-Methods": "*"
"Access-Control-Allow-Methods": "*",
},
"body": json.dumps({
"images": s3_presigned_urls,
}),
"body": json.dumps(
{
"images": s3_presigned_urls,
}
),
}
5 changes: 3 additions & 2 deletions backend/lambda/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
requests
boto3
elasticsearch < 7.14
requests-aws4auth
opensearch-py
Pillow
numpy
77 changes: 30 additions & 47 deletions backend/template.yaml
Original file line number Diff line number Diff line change
@@ -1,67 +1,60 @@
AWSTemplateFormatVersion: '2010-09-09'
AWSTemplateFormatVersion: "2010-09-09"
Transform: AWS::Serverless-2016-10-31
Description: 'backend
Description: "backend
Sample SAM Template for backend
'
"
Parameters:
BucketName:
Type: String
DomainName:
Type: String
ElasticSearchURL:
OpenSearchURL:
Type: String
SagemakerEndpoint:
Type: String
LambdaCodeFile:
LambdaCodeFile:
Type: String
Globals:
Function:
Timeout: 60
MemorySize: 512
Api:
Cors:
AllowMethods: '''*'''
AllowHeaders: '''*'''
AllowOrigin: '''*'''
AllowMethods: "'*'"
AllowHeaders: "'*'"
AllowOrigin: "'*'"
Resources:
PostFetchSimilarPhotosLambda:
Type: AWS::Serverless::Function
Properties:
CodeUri:
# https://github.com/aws/serverless-application-model/blob/master/HOWTO.md
# Using Intrinsic Functions
Bucket: !Ref BucketName
Key: !Ref LambdaCodeFile
Handler: app.lambda_handler
Runtime: python3.7
Runtime: python3.8
Environment:
Variables:
ES_ENDPOINT:
Ref: ElasticSearchURL
SM_ENDPOINT:
Ref: SagemakerEndpoint
ES_ENDPOINT: !Ref OpenSearchURL
SM_ENDPOINT: !Ref SagemakerEndpoint
Policies:
- Version: '2012-10-17'
Statement:
- Sid: AllowSagemakerInvokeEndpoint
Effect: Allow
Action:
- sagemaker:InvokeEndpoint
Resource:
- Fn::Sub: arn:aws:sagemaker:${AWS::Region}:${AWS::AccountId}:endpoint/${SagemakerEndpoint}
- Version: '2012-10-17'
Statement:
- Sid: AllowESS
Effect: Allow
Action:
- es:*
Resource:
- Fn::Sub: arn:aws:es:${AWS::Region}:${AWS::AccountId}:domain/${DomainName}/*
- S3ReadPolicy:
BucketName:
Ref: BucketName
- Version: "2012-10-17"
Statement:
- Sid: AllowSagemakerInvokeEndpoint
Effect: Allow
Action:
- sagemaker:InvokeEndpoint
Resource: !Sub arn:aws:sagemaker:${AWS::Region}:${AWS::AccountId}:endpoint/${SagemakerEndpoint}
- Version: "2012-10-17"
Statement:
- Sid: AllowESS
Effect: Allow
Action:
- es:*
Resource: !Sub arn:aws:es:${AWS::Region}:${AWS::AccountId}:domain/${DomainName}/*
- S3ReadPolicy:
BucketName: !Ref BucketName
Events:
PostURL:
Type: Api
Expand All @@ -76,17 +69,7 @@ Resources:
Outputs:
ImageSimilarityApi:
Description: API Gateway endpoint URL for Prod stage for fetchPhoto function
Value:
Fn::Sub: https://${ServerlessRestApi}.execute-api.${AWS::Region}.amazonaws.com/Prod/
Value: !Sub https://${ServerlessRestApi}.execute-api.${AWS::Region}.amazonaws.com/Prod/
PostFetchSimilarPhotosLambda:
Description: Hello World Lambda Function ARN
Value:
Fn::GetAtt:
- PostFetchSimilarPhotosLambda
- Arn
PostFetchSimilarPhotosLambdaIamRole:
Description: Implicit IAM Role created for Hello World function
Value:
Fn::GetAtt:
- PostFetchSimilarPhotosLambda
- Arn
Description: Lambda Function ARN
Value: !GetAtt PostFetchSimilarPhotosLambda.Arn
Loading

0 comments on commit a18fb46

Please sign in to comment.