Skip to content

Commit

Permalink
Add function to get cfn path
Browse files Browse the repository at this point in the history
  • Loading branch information
kddejong committed Jan 24, 2025
1 parent c36671e commit e04e921
Show file tree
Hide file tree
Showing 6 changed files with 252 additions and 16 deletions.
22 changes: 14 additions & 8 deletions scripts/update_schemas_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,14 +255,15 @@ def main():
for path in _descend(
obj,
[
"SecurityGroupIds",
"SecurityGroups",
"VpcSecurityGroupIds",
"Ec2SecurityGroupIds",
"CustomSecurityGroupIds",
"DBSecurityGroups",
"Ec2SecurityGroupIds",
"GroupSet",
"InputSecurityGroups",
"SecurityGroupIdList",
"GroupSet",
"SecurityGroupIds",
"SecurityGroups",
"VpcSecurityGroupIds",
],
):
if path[-2] == "properties":
Expand All @@ -275,12 +276,13 @@ def main():
for path in _descend(
obj,
[
"DefaultSecurityGroup",
"ClusterSecurityGroupId",
"SourceSecurityGroupId",
"DefaultSecurityGroup",
"DestinationSecurityGroupId",
"EC2SecurityGroupId",
"SecurityGroup",
"SecurityGroupId",
"SecurityGroupIngress" "SourceSecurityGroupId",
"VpcSecurityGroupId",
],
):
Expand All @@ -294,7 +296,11 @@ def main():
for path in _descend(
obj,
[
"SourceSecurityGroupName",
"CacheSecurityGroupName",
"ClusterSecurityGroupName",
"DBSecurityGroupName",
"EC2SecurityGroupName",
"SourceSecurityGroupName" "SourceSecurityGroupName",
],
):
if path[-2] == "properties":
Expand Down
30 changes: 30 additions & 0 deletions src/cfnlint/rules/functions/RefFormat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
SPDX-License-Identifier: MIT-0
"""

from __future__ import annotations

from copy import deepcopy
from typing import Any

from cfnlint.jsonschema import ValidationResult, Validator
Expand All @@ -25,6 +28,31 @@ def __init__(self):
super().__init__(["*"])
self.parent_rules = ["E1020"]

def _filter_schema(
self, validator: Validator, type_name: str, id: str, schema: dict[str, Any]
) -> dict[str, Any]:
if type_name != "AWS::EC2::SecurityGroup":
return schema

items = list(
validator.cfn.get_cfn_path(
["Resources", id, "Properties", "VpcId"], validator.context
)
)
if items:
# VpcId is specified and will have a value which means the returned value is
# "AWS::EC2::SecurityGroup.Id"
schema = deepcopy(schema)
schema.pop("anyOf")
schema["format"] = "AWS::EC2::SecurityGroup.Id"
return schema

# VpcId being None means it wasn't specified and the value is a Name
schema = deepcopy(schema)
schema.pop("anyOf")
schema["format"] = "AWS::EC2::SecurityGroup.Name"
return schema

def validate(
self, validator: Validator, _, instance: Any, schema: Any
) -> ValidationResult:
Expand All @@ -44,6 +72,8 @@ def validate(

ref_schema = validator.context.resources[instance].ref(region)

ref_schema = self._filter_schema(validator, t, instance, ref_schema)

err = compare_schemas(schema, ref_schema)
if err:
if err.instance:
Expand Down
80 changes: 79 additions & 1 deletion src/cfnlint/template/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
import cfnlint.conditions
import cfnlint.helpers
from cfnlint._typing import CheckValueFn, Path
from cfnlint.context import create_context_for_template
from cfnlint.context import Context, create_context_for_template
from cfnlint.context.conditions.exceptions import Unsatisfiable
from cfnlint.decode.node import dict_node, list_node
from cfnlint.graph import Graph
from cfnlint.match import Match
Expand Down Expand Up @@ -431,6 +432,83 @@ def search_deep_keys(
results.append(["Globals"] + pre_result)
return results

def get_cfn_path(
self, path: list[str], context: Context
) -> Iterator[tuple[Any, Context]]:
"""
Get the value at the specified path in the CloudFormation template.
Args:
path (list[str]): The path to the value in the template.
context (Context): The context object containing the template and other data.
Returns:
Any: The value at the specified path in the template.
"""

def _filter_condition(
template: Any, context: Context
) -> Iterator[tuple[Any, Context]]:
k, v = cfnlint.helpers.is_function(template)
if k is None:
yield template, context
return

if k == "Fn::If":
if isinstance(v, list) and len(v) == 3:
condition = v[0]
if not isinstance(condition, str):
return

for i in [1, 2]:
b = True if i == 1 else False
try:
item_context = context.evolve(
conditions=context.conditions.evolve({condition: b})
)
yield from _filter_condition(v[i], item_context)
except Unsatisfiable:
continue
return
if k == "Ref":
if v == "AWS::NoValue":
return
yield template, context

def _get_cfn_path(
path: list[str], template: Any, context: Context
) -> Iterator[tuple[Any, Context]]:
if len(path) == 0:
yield from _filter_condition(template, context)
return
item = path[0]
if isinstance(template, dict):
if item in template:
for item_template, item_context in _filter_condition(
template[item], context
):
yield from _get_cfn_path(path[1:], item_template, item_context)
return
elif isinstance(template, list):
if isinstance(template, list):
if item == "*":
for index, _ in enumerate(template):
yield from _get_cfn_path(path[1:], template[index], context)
return

# handle resource and output conditions
if len(path) >= 3 and path[0] in ["Resources", "Outputs"]:
condition = self.template.get(path[0], {}).get(path[1], {}).get("Condition")
if condition:
try:
context = context.evolve(
conditions=context.conditions.evolve({condition: True})
)
except Unsatisfiable:
return

yield from _get_cfn_path(path, self.template, context)

def get_condition_values(self, template, path: Path | None) -> list[dict[str, Any]]:
"""
Evaluates conditions in the provided CloudFormation template and returns the values.
Expand Down
127 changes: 127 additions & 0 deletions test/unit/module/template/test_template_get_cfn_path.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
"""
Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
SPDX-License-Identifier: MIT-0
"""

import pytest

from cfnlint.template import Template


@pytest.fixture
def cfn() -> Template:

return Template(
None,
{
"Parameters": {
"Environment": {
"Type": "String",
"AllowedValues": ["dev", "test", "stage", "prod"],
}
},
"Conditions": {
"IsUsEast1": {"Fn::Equals": [{"Ref": "AWS::Region"}, "us-east-1"]},
"IsProduction": {"Fn::Equals": [{"Ref": "Environment"}, "prod"]},
},
"Resources": {
"Bucket": {"Type": "AWS::S3::Bucket", "Condition": "IsUsEast1"},
"Vpc": {
"Type": "AWS::EC2::VPC",
"Properties": {"CidrBlock": "10.0.0.0/16"},
},
"SecurityGroup1": {
"Type": "AWS::EC2::SecurityGroup",
"Condition": "IsProduction",
"Properties": {
"CidrBlock": "10.0.0.0/24",
"VpcId": {
"Fn::If": ["IsUsEast1", "vpc-123", {"Ref": "AWS::NoValue"}]
},
},
},
"SecurityGroup2": {
"Type": "AWS::EC2::SecurityGroup",
"Condition": "IsProduction",
"Properties": {
"CidrBlock": "10.0.1.0/24",
"VpcId": {
"Fn::If": ["IsProduction", {"Ref": "Vpc"}, "vpc-abc"]
},
},
},
},
},
)


@pytest.mark.parametrize(
"name,path,starting_conditions, expected",
[
(
"Valid get path",
["Resources", "Vpc", "Properties", "CidrBlock"],
{},
[("10.0.0.0/16", {})],
),
(
"Invalid path returns nothing",
["Resources", "Vpc", "Properties", "DNE"],
{},
[],
),
(
"Short path that doesn't exist",
["Resources", "DNE"],
{},
[],
),
(
"Valid path with resource condition",
["Resources", "SecurityGroup1", "Properties", "CidrBlock"],
{},
[("10.0.0.0/24", {"IsProduction": True})],
),
(
"Valid path with resource condition",
["Resources", "SecurityGroup1", "Properties", "CidrBlock"],
{"IsProduction": False},
[],
),
(
"Valid path with resource condition with multiple conditions",
["Resources", "SecurityGroup1", "Properties", "VpcId"],
{},
[("vpc-123", {"IsProduction": True, "IsUsEast1": True})],
),
(
(
"Valid path with resource condition with multiple "
"conditions that conflict with each other"
),
["Resources", "SecurityGroup2", "Properties", "VpcId"],
{},
[({"Ref": "Vpc"}, {"IsProduction": True})],
),
],
)
def test_paths(name, path, starting_conditions, expected, cfn):

context = cfn.context.evolve(
conditions=cfn.context.conditions.evolve(starting_conditions)
)

results = list(cfn.get_cfn_path(path, context))

assert len(results) == len(
expected
), f"{name!r} test failed. Got results {results!r}"

for i, value in enumerate(results):
assert (
value[0] == expected[i][0]
), f"{name!r} test failed for {i}. Got value {value[0]!r}"
assert value[1].conditions.status == expected[i][1], (
f"{name!r} test failed for {i}. Got "
f"conditions {value[1].conditions.status[0]!r}"
)
4 changes: 2 additions & 2 deletions test/unit/rules/formats/test_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def format(self, validator, instance):
},
[
ValidationError(
"'10.10.10.10' is not a 'test'",
"'10.10.10.10' is not a 'test' with pattern ''",
rule=_Fail(),
)
],
Expand All @@ -116,7 +116,7 @@ def format(self, validator, instance):
},
[
ValidationError(
"'10.10.10.10' is not a 'test'",
"'10.10.10.10' is not a 'test' with pattern ''",
rule=_Fail(),
)
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,6 @@ def rule():
def test_validate(name, instance, expected, rule, validator):
errs = list(rule.validate(validator, "", instance, {}))

for err in errs:
print(err.validator)
print(err.path)
print(err.schema_path)

assert (
errs == expected
), f"Expected test {name!r} to have {expected!r} but got {errs!r}"

0 comments on commit e04e921

Please sign in to comment.