diff --git a/hiku/schema.py b/hiku/schema.py index 23bd8cf..d179c52 100644 --- a/hiku/schema.py +++ b/hiku/schema.py @@ -12,6 +12,7 @@ Generic, ) +from hiku.result import Proxy from hiku.cache import CacheSettings from hiku.context import ( ExecutionContext, @@ -58,6 +59,7 @@ def _run_validation( class ExecutionResult: data: Optional[Dict[str, Any]] errors: Optional[List[GraphQLError]] + result: Optional[Proxy] class Schema(Generic[_ExecutorType]): @@ -158,13 +160,13 @@ def execute_sync( execution_context.operation_type_name, ).process(execution_context.query) - return ExecutionResult(data, None) + return ExecutionResult(data, None, result) except ValidationError as e: return ExecutionResult( - None, [GraphQLError(message) for message in e.errors] + None, [GraphQLError(message) for message in e.errors], None ) except GraphQLError as e: - return ExecutionResult(None, [e]) + return ExecutionResult(None, [e], None) async def execute( self: "Schema[BaseAsyncExecutor]", @@ -213,13 +215,13 @@ async def execute( execution_context.operation_type_name, ).process(execution_context.query) - return ExecutionResult(data, None) + return ExecutionResult(data, None, result) except ValidationError as e: return ExecutionResult( - None, [GraphQLError(message) for message in e.errors] + None, [GraphQLError(message) for message in e.errors], None ) except GraphQLError as e: - return ExecutionResult(None, [e]) + return ExecutionResult(None, [e], None) def _validate( self, diff --git a/tests/test_endpoint_graphql.py b/tests/test_endpoint_graphql.py index 47e37c3..b507621 100644 --- a/tests/test_endpoint_graphql.py +++ b/tests/test_endpoint_graphql.py @@ -32,17 +32,13 @@ async def answer(ctx, fields): def test_endpoint(sync_graph): - endpoint = GraphQLEndpoint( - Schema(SyncExecutor(), sync_graph) - ) + endpoint = GraphQLEndpoint(Schema(SyncExecutor(), sync_graph)) result = endpoint.dispatch({"query": "{answer}"}) assert result == {"data": {"answer": "42"}} def test_batch_endpoint(sync_graph): - endpoint = BatchGraphQLEndpoint( - Schema(SyncExecutor(), sync_graph) - ) + endpoint = BatchGraphQLEndpoint(Schema(SyncExecutor(), sync_graph)) assert endpoint.dispatch([]) == [] @@ -63,9 +59,7 @@ def test_batch_endpoint(sync_graph): @pytest.mark.asyncio async def test_async_endpoint(async_graph): - endpoint = AsyncGraphQLEndpoint( - Schema(AsyncIOExecutor(), async_graph) - ) + endpoint = AsyncGraphQLEndpoint(Schema(AsyncIOExecutor(), async_graph)) result = await endpoint.dispatch( {"query": "{answer}"}, context={"default_answer": "52"} )