From 53f7ff926ee3fe5b9ed5b9fbaa4cef4370ab216b Mon Sep 17 00:00:00 2001
From: "m.kindritskiy" <m.kindritskiy@smartweb.com.ua>
Date: Thu, 8 Aug 2024 16:58:36 +0300
Subject: [PATCH] rewrite GrapQLError

---
 docs/changelog/changes_08.rst                 |  3 ++
 hiku/endpoint/graphql.py                      |  8 ++---
 hiku/engine.py                                |  5 ++-
 hiku/extensions/base_extension.py             |  2 +-
 hiku/schema.py                                | 32 ++++++++++++-------
 .../extensions/test_query_depth_validator.py  |  2 +-
 tests/test_interface.py                       | 14 ++++----
 tests/test_union.py                           |  2 +-
 8 files changed, 42 insertions(+), 26 deletions(-)

diff --git a/docs/changelog/changes_08.rst b/docs/changelog/changes_08.rst
index a919fe9..5bf81f4 100644
--- a/docs/changelog/changes_08.rst
+++ b/docs/changelog/changes_08.rst
@@ -20,6 +20,8 @@ Changes in 0.8
   - Change `GraphQLResponse` type - it now has both `data` and `errors` fields
   - Rename `on_dispatch` hook to `on_operation`
   - Remove old `on_operation` hook
+  - Remove `execute` method from `BaseGraphQLEndpoint` class
+  - Add `process_result` method to `BaseGraphQLEndpoint` class
 
 Backward-incompatible changes
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -29,3 +31,4 @@ Backward-incompatible changes
   - Drop `hiku.federation.endpoint` - use `hiku.endpoint` instead
   - Drop `hiku.federation.denormalize`
   - Drop `hiku.federation.engine` - use `hiku.engine` instead
+  - Remove `execute` method from `BaseGraphQLEndpoint` class
diff --git a/hiku/endpoint/graphql.py b/hiku/endpoint/graphql.py
index baba31b..390e3f2 100644
--- a/hiku/endpoint/graphql.py
+++ b/hiku/endpoint/graphql.py
@@ -20,7 +20,7 @@ class GraphQLRequest(TypedDict, total=False):
 
 class GraphQLResponse(TypedDict, total=False):
     data: Optional[Dict[str, object]]
-    errors: Optional[List[object]]
+    errors: Optional[List[GraphQLErrorObject]]
     extensions: Optional[Dict[str, object]]
 
 
@@ -47,8 +47,8 @@ def __init__(
     def process_result(self, result: ExecutionResult) -> GraphQLResponse:
         data: GraphQLResponse = {"data": result.data}
 
-        if result.error:
-            data["errors"] = [{"message": e} for e in result.error.errors]
+        if result.errors:
+            data["errors"] = [{"message": e.message} for e in result.errors]
 
         return data
 
@@ -107,7 +107,7 @@ def dispatch(
     ) -> SingleOrBatchedResponse:
         if isinstance(data, list):
             if not self.batching:
-                raise GraphQLError(errors=["Batching is not supported"])
+                raise GraphQLError("Batching is not supported")
 
             return [
                 super(GraphQLEndpoint, self).dispatch(item, context)
diff --git a/hiku/engine.py b/hiku/engine.py
index 86879f5..8498f16 100644
--- a/hiku/engine.py
+++ b/hiku/engine.py
@@ -996,7 +996,10 @@ def __getitem__(self, item: Any) -> Any:
             )
 
 
-_ExecutorType = TypeVar("_ExecutorType", bound=SyncAsyncExecutor)
+# Contrvariant must be used because we want to accept subclasses of Executor
+_ExecutorType = TypeVar(
+    "_ExecutorType", covariant=True, bound=SyncAsyncExecutor
+)
 
 
 class Engine(Generic[_ExecutorType]):
diff --git a/hiku/extensions/base_extension.py b/hiku/extensions/base_extension.py
index e4a0f04..5f18952 100644
--- a/hiku/extensions/base_extension.py
+++ b/hiku/extensions/base_extension.py
@@ -51,7 +51,7 @@ def on_operation(
 
         At this step the:
         - execution_context.query_src (if type str) is set to the query string
-        - execution_context.query (if type Noe) is set to the query Node
+        - execution_context.query (if type Node) is set to the query Node
         - execution_context.variables is set to the query variables
         - execution_context.operation_name is set to the query operation name
         - execution_context.query_graph is set to the query graph
diff --git a/hiku/schema.py b/hiku/schema.py
index 22115e0..f81c304 100644
--- a/hiku/schema.py
+++ b/hiku/schema.py
@@ -36,7 +36,13 @@
 
 
 class GraphQLError(Exception):
-    def __init__(self, *, errors: List[str]):
+    def __init__(self, message: str) -> None:
+        super().__init__(message)
+        self.message = message
+
+
+class ValidationError(Exception):
+    def __init__(self, errors: List[str]) -> None:
         super().__init__("{} errors".format(len(errors)))
         self.errors = errors
 
@@ -56,7 +62,7 @@ def _run_validation(
 @dataclass
 class ExecutionResult:
     data: Optional[Dict[str, Any]]
-    error: Optional[GraphQLError]
+    errors: Optional[List[GraphQLError]]
 
 
 class Schema(Generic[_ExecutorType]):
@@ -158,8 +164,12 @@ def execute_sync(
                 ).process(execution_context.query)
 
             return ExecutionResult(data, None)
+        except ValidationError as e:
+            return ExecutionResult(
+                None, [GraphQLError(message) for message in e.errors]
+            )
         except GraphQLError as e:
-            return ExecutionResult(None, e)
+            return ExecutionResult(None, [e])
 
     async def execute(
         self: "Schema[BaseAsyncExecutor]",
@@ -210,8 +220,12 @@ async def execute(
                 ).process(execution_context.query)
 
             return ExecutionResult(data, None)
+        except ValidationError as e:
+            return ExecutionResult(
+                None, [GraphQLError(message) for message in e.errors]
+            )
         except GraphQLError as e:
-            return ExecutionResult(None, e)
+            return ExecutionResult(None, [e])
 
     def _validate(
         self,
@@ -249,11 +263,7 @@ def _init_execution_context(
                         execution_context.request_operation_name,
                     )
                 except TypeError as e:
-                    raise GraphQLError(
-                        errors=[
-                            "Failed to read query: {}".format(e),
-                        ]
-                    )
+                    raise GraphQLError("Failed to read query: {}".format(e))
 
             execution_context.query = execution_context.operation.query
             # save original query before merging to validate it
@@ -266,7 +276,7 @@ def _init_execution_context(
         op = execution_context.operation
         if op.type not in (OperationType.QUERY, OperationType.MUTATION):
             raise GraphQLError(
-                errors=["Unsupported operation type: {!r}".format(op.type)]
+                "Unsupported operation type: {!r}".format(op.type)
             )
 
         with extensions_manager.validation():
@@ -278,4 +288,4 @@ def _init_execution_context(
                 )
 
             if execution_context.errors:
-                raise GraphQLError(errors=execution_context.errors)
+                raise ValidationError(errors=execution_context.errors)
diff --git a/tests/extensions/test_query_depth_validator.py b/tests/extensions/test_query_depth_validator.py
index 3369bfc..ca149d3 100644
--- a/tests/extensions/test_query_depth_validator.py
+++ b/tests/extensions/test_query_depth_validator.py
@@ -51,4 +51,4 @@ def test_query_depth_validator(sync_graph):
     """
 
     result = schema.execute_sync(query)
-    assert result.error.errors == ["Query depth 4 exceeds maximum allowed depth 2"]
+    assert [e.message for e in result.errors] == ["Query depth 4 exceeds maximum allowed depth 2"]
diff --git a/tests/test_interface.py b/tests/test_interface.py
index 0cb367c..827b88b 100644
--- a/tests/test_interface.py
+++ b/tests/test_interface.py
@@ -179,7 +179,7 @@ def test_option_not_provided_for_field():
     """
 
     result = execute(GRAPH, read(query))
-    assert result.error.errors == [
+    assert [e.message for e in result.errors] == [
       'Required option "Video.thumbnailUrl:size" is not specified'
     ]
 
@@ -423,7 +423,7 @@ def test_validate_interface_has_no_implementations():
 
     result = execute(graph, read(query))
 
-    assert result.error.errors == [
+    assert [e.message for e in result.errors] == [
         "Can not query field 'id' on interface 'Media'. "
         "Interface 'Media' is not implemented by any type. "
         "Add at least one type implementing this interface.",
@@ -447,7 +447,7 @@ def test_validate_query_implementation_node_field_without_inline_fragment():
 
     result = execute(GRAPH, read(query))
 
-    assert result.error.errors == [
+    assert [e.message for e in result.errors] == [
         "Can not query field 'album' on type 'Media'. "
         "Did you mean to use an inline fragment on 'Audio'?"
     ]
@@ -465,7 +465,7 @@ def test_validate_query_fragment_no_type_condition():
 
     result = execute(GRAPH, read(query, {'text': 'foo'}))
 
-    assert result.error.errors == [
+    assert [e.message for e in result.errors] == [
       "Can not query field 'album' on type 'Media'. "
       "Did you mean to use an inline fragment on 'Audio'?"
     ]
@@ -484,7 +484,7 @@ def test_validate_query_fragment_on_unknown_type():
 
     result = execute(GRAPH, read(query, {'text': 'foo'}))
 
-    assert result.error.errors == ["Fragment on unknown type 'X'"]
+    assert [e.message for e in result.errors] == ["Fragment on unknown type 'X'"]
 
 
 def test_validate_interface_type_has_no_such_field():
@@ -504,7 +504,7 @@ def test_validate_interface_type_has_no_such_field():
 
     result = execute(GRAPH, read(query, {'text': 'foo'}))
 
-    assert result.error.errors == [
+    assert [e.message for e in result.errors] == [
         'Field "invalid_field" is not implemented in the "Audio" node',
     ]
 
@@ -525,6 +525,6 @@ def test_validate_interface_type_field_has_no_such_option():
 
     result = execute(GRAPH, read(query, {'text': 'foo'}))
 
-    assert result.error.errors == [
+    assert [e.message for e in result.errors] == [
         'Unknown options for "Audio.duration": size',
     ]
diff --git a/tests/test_union.py b/tests/test_union.py
index f01ace2..faa4d07 100644
--- a/tests/test_union.py
+++ b/tests/test_union.py
@@ -168,7 +168,7 @@ def test_option_not_provided_for_field():
     }
     """
     result = execute(read(query))
-    result.error.errors == [
+    result.errors == [
       "Required option \"size\" for Field('thumbnailUrl'"
     ]