Skip to content

Commit

Permalink
update scripts to have better region support (#3910)
Browse files Browse the repository at this point in the history
  • Loading branch information
kddejong authored Jan 13, 2025
1 parent 03e1857 commit e38b9cc
Showing 1 changed file with 53 additions and 57 deletions.
110 changes: 53 additions & 57 deletions scripts/update_specs_from_pricing.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@
config = Config(retries={"max_attempts": 10})
client = session.client("pricing", region_name="us-east-1", config=config)

_UNSUPPORTED_REGIONS = set()


def configure_logging():
"""Setup Logging"""
Expand Down Expand Up @@ -94,13 +96,15 @@ def get_dax_pricing():
product = products.get("product", {})
if product:
if product.get("productFamily") in ["DAX"]:
if not results.get(
region_map[product.get("attributes").get("location")]
):
results[
region_map[product.get("attributes").get("location")]
] = set()
results[region_map[product.get("attributes").get("location")]].add(
location = region_map.get(product.get("attributes").get("location"))
if not location:
_UNSUPPORTED_REGIONS.add(
product.get("attributes").get("location")
)
continue
if not results.get(location):
results[location] = set()
results[location].add(
product.get("attributes").get("usagetype").split(":")[1]
)
return results
Expand All @@ -115,18 +119,18 @@ def get_mq_pricing():
product = products.get("product", {})
if product:
if product.get("productFamily") in ["Broker Instances"]:
if not results.get(
region_map[product.get("attributes").get("location")]
):
results[
region_map[product.get("attributes").get("location")]
] = set()
location = region_map.get(product.get("attributes").get("location"))
if not location:
_UNSUPPORTED_REGIONS.add(
product.get("attributes").get("location")
)
continue
if not results.get(location):
results[location] = set()
usage_type = (
product.get("attributes").get("usagetype").split(":")[1]
)
results[region_map[product.get("attributes").get("location")]].add(
remap.get(usage_type, usage_type)
)
results[location].add(remap.get(usage_type, usage_type))
return results


Expand Down Expand Up @@ -170,22 +174,19 @@ def get_rds_pricing():
if product.get("attributes").get("locationType") == "AWS Outposts":
continue
# Get overall instance types
if not results.get(
region_map[product.get("attributes").get("location")]
):
results[
region_map[product.get("attributes").get("location")]
] = set(["db.serverless"])
results[region_map[product.get("attributes").get("location")]].add(
product.get("attributes").get("instanceType")
)
location = region_map.get(product.get("attributes").get("location"))
if not location:
_UNSUPPORTED_REGIONS.add(
product.get("attributes").get("location")
)
continue
if not results.get(location):
results[location] = set(["db.serverless"])
results[location].add(product.get("attributes").get("instanceType"))
# Rds Instance Size spec
product_names = product_map.get(
product.get("attributes").get("engineCode"), []
)
product_region = region_map.get(
product.get("attributes").get("location")
)
license_name = license_map.get(
product.get("attributes").get("licenseModel")
)
Expand All @@ -202,20 +203,18 @@ def get_rds_pricing():

instance_type = product.get("attributes").get("instanceType")
for product_name in product_names:
if not rds_details.get(product_region):
rds_details[product_region] = {}
if not rds_details.get(product_region).get(deployment_option):
rds_details[product_region][deployment_option] = {}
if not rds_details.get(location):
rds_details[location] = {}
if not rds_details.get(location).get(deployment_option):
rds_details[location][deployment_option] = {}
if (
not rds_details.get(product_region)
not rds_details.get(location)
.get(deployment_option)
.get(license_name)
):
rds_details[product_region][deployment_option][
license_name
] = {}
rds_details[location][deployment_option][license_name] = {}
if (
not rds_details.get(product_region)
not rds_details.get(location)
.get(deployment_option)
.get(license_name)
.get(product_name)
Expand All @@ -225,14 +224,14 @@ def get_rds_pricing():
and product_name
in ["aurora-mysql", "aurora-postgresql"]
):
rds_details[product_region][deployment_option][
license_name
][product_name] = set(["db.serverless"])
rds_details[location][deployment_option][license_name][
product_name
] = set(["db.serverless"])
else:
rds_details[product_region][deployment_option][
license_name
][product_name] = set()
rds_details[product_region][deployment_option][license_name][
rds_details[location][deployment_option][license_name][
product_name
] = set()
rds_details[location][deployment_option][license_name][
product_name
].add(instance_type)
specs = {}
Expand Down Expand Up @@ -350,21 +349,15 @@ def get_results(service, product_families, default=None):
product.get("productFamily") in product_families
and product.get("attributes").get("locationType") == "AWS Region"
):
if product.get("attributes").get("location") not in region_map:
LOGGER.warning(
'Region "%s" not found',
product.get("attributes").get("location"),
location = region_map.get(product.get("attributes").get("location"))
if not location:
_UNSUPPORTED_REGIONS.add(
product.get("attributes").get("location")
)
continue
if not results.get(
region_map[product.get("attributes").get("location")]
):
results[
region_map[product.get("attributes").get("location")]
] = default
results[region_map[product.get("attributes").get("location")]].add(
product.get("attributes").get("instanceType")
)
if not results.get(location):
results[location] = default
results[location].add(product.get("attributes").get("instanceType"))
return results


Expand Down Expand Up @@ -441,6 +434,9 @@ def main():
get_results("AmazonAppStream", ["Streaming Instance"]),
)

for region in _UNSUPPORTED_REGIONS:
LOGGER.warning(f"Region {region!r} is not supported")


if __name__ == "__main__":
try:
Expand Down

0 comments on commit e38b9cc

Please sign in to comment.