Skip to content

Commit

Permalink
fix(rust): validate config on client construction (#748)
Browse files Browse the repository at this point in the history
Description of changes:
In the Rust codegen, constraint validation is performed on operation calls instead of within structure builders' build() function. (The reasoning for this choice is documented in the description of #582.) But the validation wasn't applied to the constructors of localService clients, so attempting to construct a client with an invalid config could panic during conversion (in particular, when a @required field was missing).

This PR implements the missing validation in client constructors, and tests that both valid and invalid configs have the expected behavior when passed to the client constructor.
  • Loading branch information
alex-chew authored Dec 13, 2024
1 parent 52427df commit 5b66889
Show file tree
Hide file tree
Showing 9 changed files with 155 additions and 32 deletions.
19 changes: 18 additions & 1 deletion TestModels/Constraints/Model/Constraints.smithy
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ service SimpleConstraints {
errors: [ SimpleConstraintsException ],
}

structure SimpleConstraintsConfig {}
structure SimpleConstraintsConfig {
@required
RequiredString: String,
}

// This is just a sanity check on the smokeTests support.
// We need to actually convert all the tests in test/WrappedSimpleConstraintsImplTest.dfy
Expand All @@ -24,6 +27,9 @@ structure SimpleConstraintsConfig {}
@smithy.test#smokeTests([
{
id: "GetConstraintsSuccess"
vendorParams: {
RequiredString: "foobar",
}
params: {
OneToTen: 5,
GreaterThanOne: 2,
Expand All @@ -40,6 +46,9 @@ structure SimpleConstraintsConfig {}
},
{
id: "GetConstraintsFailure"
vendorParams: {
RequiredString: "foobar",
}
params: {
// These two always have to be present because of https://github.com/smithy-lang/smithy-dafny/issues/278,
// because otherwise they are interpreted as 0.
Expand All @@ -51,6 +60,14 @@ structure SimpleConstraintsConfig {}
expect: {
failure: {}
}
},
{
id: "GetConstraintsInvalidConfig"
params: {}
expect: {
failure: {}
},
tags: ["INVALID_CONFIG"]
}
])
operation GetConstraints {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,26 @@ mod simple_constraints_test {
use simple_constraints::*;

fn client() -> Client {
let config = SimpleConstraintsConfig::builder().build().expect("config");
let config = SimpleConstraintsConfig::builder()
.required_string("test string")
.build()
.expect("config");
client::Client::from_conf(config).expect("client")
}

#[test]
fn test_config_missing_field() {
let config = SimpleConstraintsConfig::builder()
.build()
.expect("config");
let error = client::Client::from_conf(config).err().expect("err");
assert!(matches!(
error,
simple_constraints::types::error::Error::ValidationError(..)
));
assert!(error.to_string().contains("required_string"));
}

#[tokio::test]
async fn test_empty_input() {
let result = client().get_constraints().send().await;
Expand Down
2 changes: 1 addition & 1 deletion TestModels/Constraints/src/Index.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ module {:extern "simple.constraints.internaldafny" } SimpleConstraints refines A
import Operations = SimpleConstraintsImpl

function method DefaultSimpleConstraintsConfig(): SimpleConstraintsConfig {
SimpleConstraintsConfig
SimpleConstraintsConfig(RequiredString := "default")
}

method SimpleConstraints(config: SimpleConstraintsConfig)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ include "../Model/SimpleConstraintsTypesWrapped.dfy"
module {:extern "simple.constraints.internaldafny.wrapped"} WrappedSimpleConstraintsService refines WrappedAbstractSimpleConstraintsService {
import WrappedService = SimpleConstraints
function method WrappedDefaultSimpleConstraintsConfig(): SimpleConstraintsConfig {
SimpleConstraintsConfig
SimpleConstraintsConfig(RequiredString := "default")
}
}
4 changes: 2 additions & 2 deletions TestModels/dafny-dependencies/StandardLibrary/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,13 @@ transpile_implementation:
# Override SharedMakefile's build_java to not install
# StandardLibrary as a dependency
build_java: transpile_java
gradle -p runtimes/java build
$(GRADLEW) -p runtimes/java build

# Override SharedMakefile's mvn_local_deploy to
# issue warning
mvn_local_deploy:
@echo "${RED}Warning!!${YELLOW} Installing TestModel's STD to Maven Local replaces ESDK's STD!\n$(RESET)" >&2
gradle -p runtimes/java publishToMavenLocal
$(GRADLEW) -p runtimes/java publishToMavenLocal

dafny_benerate: DAFNY_RUST_OUTPUT=\
runtimes/rust/implementation_from_dafny
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import software.amazon.polymorph.smithyjava.generator.Generator;
import software.amazon.polymorph.smithyjava.modeled.ModeledShapeValue;
import software.amazon.polymorph.traits.LocalServiceTrait;
import software.amazon.smithy.model.node.ObjectNode;
import software.amazon.smithy.model.shapes.OperationShape;
import software.amazon.smithy.model.shapes.Shape;
import software.amazon.smithy.model.shapes.ShapeId;
Expand All @@ -29,9 +30,32 @@
* e.g. generating JUnit tests from traits like @smithy.test#smokeTests.
* This is distinct from the Dafny testing code
* and the testing wrapper to support it generated by TestJavaLibrary.
*
* <h1>Smoke test generation</h1>
*
* If the {@code vendorParams} property is present,
* the generated test constructs the client config with its values.
*
* <h2>Assertions</h2>
*
* If {@code "INVALID_CONFIG"} is present in the test case's {@code tags},
* the generated test asserts that an exception is thrown
* during either client config construction or client construction,
* and the test neither constructs the operation input nor performs the operation call.
* <p>
* Otherwise, if the {@code expect.success} property is present,
* the generated test asserts that no exception is thrown
* during the operation call and all setup steps
* (construction of client config, client, and operation input).
* <p>
* Otherwise, if the {@code expect.failure} property is present,
* the generated test asserts that an exception is thrown
* during either operation input construction or the operation call.
*/
public class ModelTestCodegen extends Generator {

static final String INVALID_CONFIG_TAG = "INVALID_CONFIG";

final JavaLibrary subject;

public ModelTestCodegen(JavaLibrary subject) {
Expand Down Expand Up @@ -86,29 +110,42 @@ private MethodSpec smokeTest(
configShapeId
);

// SimpleConstraintsConfig config = SimpleConstraintsConfig.builder().build();
// SimpleConstraintsConfig config = SimpleConstraintsConfig.builder()
// ...
// (multiple .foo(...) calls to populate builder)
// ...
// .build();
// SimpleConstraints client = SimpleConstraints.builder()
// .SimpleConstraintsConfig(config)
// .build();
method.addStatement(
"$T config = $T.builder().build()",
configType,
configType
);
method.addStatement(
"$T client = $T.builder().$N(config).build()",
clientType,
clientType,
configShapeId.getName()
// .SimpleConstraintsConfig(config)
// .build();
final ObjectNode configParams = testCase
.getVendorParams()
.orElseGet(() -> ObjectNode.builder().build());
final CodeBlock configValue = ModeledShapeValue.shapeValue(
subject,
false,
subject.model.expectShape(configShapeId),
configParams
);
final CodeBlock configAndClientConstruction = CodeBlock
.builder()
.addStatement("$T config = $L", configType, configValue)
.addStatement(
"$T client = $T.builder().$N(config).build()",
clientType,
clientType,
configShapeId.getName()
)
.build();

// GetConstraintsInput input = GetConstraintsInput.builder()
// ...
// (multiple .foo(...) calls to populate builder)
// ...
// .build();
// ...
// (multiple .foo(...) calls to populate builder)
// ...
// .build();
// client.GetConstraints(input);
final Shape inputShape = subject.model.expectShape(
operationShape.getInput().get()
operationShape.getInput().orElseThrow()
);
final TypeName inputType = subject.nativeNameResolver.typeForShape(
inputShape.getId()
Expand All @@ -117,17 +154,25 @@ private MethodSpec smokeTest(
subject,
false,
inputShape,
testCase.getParams().get()
testCase.getParams().orElseThrow()
);
final CodeBlock inputAndClientCall = CodeBlock
.builder()
.addStatement("$T input = $L", inputType, inputValue)
.addStatement("client.$L(input)", operationName)
.build();

if (testCase.getExpectation().isSuccess()) {
if (testCase.hasTag(INVALID_CONFIG_TAG)) {
method.addStatement(
"$T.assertThrows(Exception.class, () -> {\n$L\n})",
TESTNG_ASSERT,
configAndClientConstruction.toString()
);
} else if (testCase.getExpectation().isSuccess()) {
method.addCode(configAndClientConstruction);
method.addCode(inputAndClientCall);
} else {
method.addCode(configAndClientConstruction);
// We're not specific about what kind of exception for now.
// If the smokeTests trait gets more specific we can be too.
// The inputAndClientCall.toString() is necessary because otherwise we get nested
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,18 @@ private RustFile clientModule() {
)
.collect(Collectors.joining("\n\n"))
);

final StructureShape configShape = ModelUtils.getConfigShape(
model,
service
);
variables.put(
"inputValidations",
new InputValidationGenerator()
.generateValidations(model, configShape)
.collect(Collectors.joining("\n"))
);

final String content = evalTemplateResource(
getClass(),
"runtimes/rust/client.rs",
Expand Down Expand Up @@ -853,6 +865,9 @@ class InputValidationGenerator

private final Map<String, String> commonVariables;

/**
* Generates validation expressions for operation input structures.
*/
InputValidationGenerator(
final Shape bindingShape,
final OperationShape operationShape
Expand All @@ -862,6 +877,21 @@ class InputValidationGenerator
serviceVariables(),
operationVariables(bindingShape, operationShape)
);
this.commonVariables.put(
"inputStructureName",
commonVariables.get("pascalCaseOperationInputName")
);
}

/**
* Generates validation expressions for this service's client config structure.
*/
InputValidationGenerator() {
this.commonVariables = serviceVariables();
this.commonVariables.put(
"inputStructureName",
commonVariables.get("qualifiedRustConfigName")
);
}

@Override
Expand All @@ -871,7 +901,7 @@ protected String validateRequired(final MemberShape memberShape) {
if input.$fieldName:L.is_none() {
return ::std::result::Result::Err(::aws_smithy_types::error::operation::BuildError::missing_field(
"$fieldName:L",
"$fieldName:L was not specified but it is required when building $pascalCaseOperationInputName:L",
"$fieldName:L was not specified but it is required when building $inputStructureName:L",
)).map_err($qualifiedRustServiceErrorType:L::wrap_validation_err);
}
""",
Expand Down Expand Up @@ -1815,8 +1845,19 @@ protected HashMap<String, String> serviceVariables() {
service
);
final String configName = configShape.getId().getName(service);
final String snakeCaseConfigName = toSnakeCase(configName);

variables.put("configName", configName);
variables.put("snakeCaseConfigName", toSnakeCase(configName));
variables.put("snakeCaseConfigName", snakeCaseConfigName);
variables.put(
"qualifiedRustConfigName",
String.join(
"::",
getRustTypesModuleName(),
snakeCaseConfigName,
configName
)
);
variables.put("rustErrorModuleName", rustErrorModuleName());
variables.put(
"qualifiedRustServiceErrorType",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,10 @@ public Stream<V> generateValidations(
return validateLength(memberShape, lengthTrait);
}
throw new UnsupportedOperationException(
"Unsupported constraint trait %s on shape %s".formatted(trait)
"Unsupported constraint trait %s on shape %s".formatted(
trait,
structureShape.getId()
)
);
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@ impl Client {
/// Creates a new client from the service [`Config`](crate::Config).
#[track_caller]
pub fn from_conf(
conf: $rustTypesModuleName:L::$snakeCaseConfigName:L::$configName:L,
) -> Result<Self, $rustTypesModuleName:L::error::Error> {
input: $qualifiedRustConfigName:L,
) -> Result<Self, $qualifiedRustServiceErrorType:L> {
$inputValidations:L
let inner =
crate::$dafnyInternalModuleName:L::_default::$sdkId:L(
&$rustConversionsModuleName:L::$snakeCaseConfigName:L::_$snakeCaseConfigName:L::to_dafny(conf),
&$rustConversionsModuleName:L::$snakeCaseConfigName:L::_$snakeCaseConfigName:L::to_dafny(input),
);
if matches!(
inner.as_ref(),
Expand Down

0 comments on commit 5b66889

Please sign in to comment.