Skip to content

Commit

Permalink
Enabling +20 users/groups on batch endpoint payload (#3759)
Browse files Browse the repository at this point in the history
  • Loading branch information
migldasilva authored Oct 24, 2023
1 parent 68e8312 commit 6d589c4
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ FEATURES:
ENHANCEMENTS:

BUG FIXES:
* Enabling support for more than 20 users/groups in Workspace API ([#3759](https://github.com/microsoft/AzureTRE/pull/3759 ))

COMPONENTS:

Expand Down
2 changes: 1 addition & 1 deletion api_app/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.15.17"
__version__ = "0.15.18"
25 changes: 21 additions & 4 deletions api_app/services/aad_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import List, Optional
import jwt
import requests
import rsa

from fastapi import Request, HTTPException, status
from msal import ConfidentialClientApplication
Expand All @@ -19,6 +18,10 @@
from api.dependencies.database import get_db_client_from_request
from db.repositories.workspaces import WorkspaceRepository

from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization

MICROSOFT_GRAPH_URL = config.MICROSOFT_GRAPH_URL.strip("/")


Expand Down Expand Up @@ -179,9 +182,12 @@ def _get_token_key(self, key_id: str) -> str:
for key in keys['keys']:
n = int.from_bytes(base64.urlsafe_b64decode(self._ensure_b64padding(key['n'])), "big")
e = int.from_bytes(base64.urlsafe_b64decode(self._ensure_b64padding(key['e'])), "big")
pub_key = rsa.PublicKey(n, e)
pub_key = rsa.RSAPublicNumbers(e, n).public_key(default_backend())
# Cache the PEM formatted public key.
AzureADAuthorization._jwt_keys[key['kid']] = pub_key.save_pkcs1()
AzureADAuthorization._jwt_keys[key['kid']] = pub_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.PKCS1
)

return AzureADAuthorization._jwt_keys[key_id]

Expand Down Expand Up @@ -245,7 +251,18 @@ def _get_user_emails(self, roles_graph_data, msgraph_token):
batch_request_body = self._get_batch_users_by_role_assignments_body(roles_graph_data)
headers = self._get_auth_header(msgraph_token)
headers["Content-type"] = "application/json"
users_graph_data = requests.post(batch_endpoint, json=batch_request_body, headers=headers).json()
max_number_request = 20
requests_from_batch = batch_request_body["requests"]
# We split the original batch request body in sub-lits with at most max_number_request elements
batch_request_body_list = [requests_from_batch[i:i + max_number_request] for i in range(0, len(requests_from_batch), max_number_request)]
users_graph_data = {"responses": []}

# For each sub-list it's required to call the batch endpoint for retrieveing user/group information
for request_body_element in batch_request_body_list:
batch_request_body_tmp = {"requests": request_body_element}
users_graph_data_tmp = requests.post(batch_endpoint, json=batch_request_body_tmp, headers=headers).json()
users_graph_data["responses"] = users_graph_data["responses"] + users_graph_data_tmp["responses"]

return users_graph_data

def _get_user_emails_from_response(self, users_graph_data):
Expand Down
62 changes: 61 additions & 1 deletion api_app/tests_ma/test_services/test_aad_access_service.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from mock import patch
from mock import call, patch

from models.domain.authentication import User, RoleAssignment
from models.domain.workspace import Workspace, WorkspaceRole
Expand Down Expand Up @@ -554,6 +554,66 @@ def test_get_workspace_role_assignment_details_with_groups_and_users_assigned_re
assert "[email protected]" in role_assignment_details["WorkspaceOwner"]


@patch("services.aad_authentication.AzureADAuthorization._get_auth_header")
@patch("services.aad_authentication.AzureADAuthorization._get_batch_users_by_role_assignments_body")
@patch("requests.post")
def test_get_user_emails_with_batch_of_more_than_20_requests(mock_graph_post, mock_get_batch_users_by_role_assignments_body, mock_headers):
# Arrange
access_service = AzureADAuthorization()
roles_graph_data = [{"id": "role1"}, {"id": "role2"}]
msgraph_token = "token"
batch_endpoint = access_service._get_batch_endpoint()

# mock the response of _get_auth_header
headers = {"Authorization": f"Bearer {msgraph_token}"}
mock_headers.return_value = headers
headers["Content-type"] = "application/json"

# mock the response of the get batch request for 30 users
batch_request_body_first_20 = {
"requests": [
{"id": f"{i}", "method": "GET", "url": f"/users/{i}"} for i in range(20)
]
}

batch_request_body_last_10 = {
"requests": [
{"id": f"{i}", "method": "GET", "url": f"/users/{i}"} for i in range(20, 30)
]
}

batch_request_body = {
"requests": [
{"id": f"{i}", "method": "GET", "url": f"/users/{i}"} for i in range(30)
]
}

mock_get_batch_users_by_role_assignments_body.return_value = batch_request_body

# Mock the response of the post request
mock_graph_post_response = {"responses": [{"id": "user1"}, {"id": "user2"}]}
mock_graph_post.return_value.json.return_value = mock_graph_post_response

# Act
users_graph_data = access_service._get_user_emails(roles_graph_data, msgraph_token)

# Assert
assert len(users_graph_data["responses"]) == 4
calls = [
call(
f"{batch_endpoint}",
json=batch_request_body_first_20,
headers=headers
),
call(
f"{batch_endpoint}",
json=batch_request_body_last_10,
headers=headers
)
]
mock_graph_post.assert_has_calls(calls, any_order=True)


def get_mock_batch_response(user_principals, group_principals):
response_body = {"responses": []}
for user_principal in user_principals:
Expand Down

0 comments on commit 6d589c4

Please sign in to comment.