From 6f9684d2513872c46b2dc7eb764e6f2973715372 Mon Sep 17 00:00:00 2001 From: Philipp Temminghoff Date: Sat, 22 Feb 2025 07:45:14 +0100 Subject: [PATCH] chore: prep for create_pull_request functionality --- src/githarbor/core/base.py | 22 +++++++++++++ src/githarbor/core/proxy.py | 66 +++++++++++++++++++++++++++++++++++++ src/githarbor/functional.py | 32 ++++++++++++++++++ 3 files changed, 120 insertions(+) diff --git a/src/githarbor/core/base.py b/src/githarbor/core/base.py index 7ee4630..c91b5b2 100644 --- a/src/githarbor/core/base.py +++ b/src/githarbor/core/base.py @@ -201,6 +201,17 @@ def list_tags(self) -> list[Tag]: msg = f"{self.__class__.__name__} does not implement list_tags" raise FeatureNotSupportedError(msg) + def create_pull_request( + self, + title: str, + body: str, + head_branch: str, + base_branch: str, + draft: bool = False, + ) -> PullRequest: + msg = f"{self.__class__.__name__} does not implement create_pull_request" + raise FeatureNotSupportedError(msg) + async def get_repo_user_async(self) -> User: """Get repository owner information asynchronously.""" msg = f"{self.__class__.__name__} does not implement get_repo_user_async" @@ -348,6 +359,17 @@ async def list_tags_async(self) -> list[Tag]: msg = f"{self.__class__.__name__} does not implement list_tags_async" raise FeatureNotSupportedError(msg) + async def create_pull_request_async( + self, + title: str, + body: str, + head_branch: str, + base_branch: str, + draft: bool = False, + ) -> PullRequest: + msg = f"{self.__class__.__name__} does not implement create_pull_request_async" + raise FeatureNotSupportedError(msg) + class BaseOwner: """Base class for repository owners.""" diff --git a/src/githarbor/core/proxy.py b/src/githarbor/core/proxy.py index 030cffc..8992424 100644 --- a/src/githarbor/core/proxy.py +++ b/src/githarbor/core/proxy.py @@ -537,6 +537,44 @@ def get_tag(self, name: str) -> Tag: return asyncio.run(self._repository.get_tag_async(name)) return self._repository.get_tag(name) + def create_pull_request( + self, + title: str, + body: str, + head_branch: str, + base_branch: str, + draft: bool = False, + ) -> PullRequest: + """Create a new pull request. + + Args: + title: Pull request title + body: Pull request description + head_branch: Source branch containing the changes + base_branch: Target branch for the changes + draft: Whether to create a draft pull request + + Returns: + Newly created pull request + """ + if self._repository.is_async: + return asyncio.run( + self._repository.create_pull_request_async( + title=title, + body=body, + head_branch=head_branch, + base_branch=base_branch, + draft=draft, + ) + ) + return self._repository.create_pull_request( + title=title, + body=body, + head_branch=head_branch, + base_branch=base_branch, + draft=draft, + ) + def list_tags(self) -> list[Tag]: """List all tags. @@ -797,6 +835,32 @@ async def get_tag_async(self, name: str) -> Tag: return await self._repository.get_tag_async(name) # type: ignore return await asyncio.to_thread(self._repository.get_tag, name) + async def create_pull_request_async( + self, + title: str, + body: str, + head_branch: str, + base_branch: str, + draft: bool = False, + ) -> PullRequest: + """See create_pull_request.""" + if self._repository.is_async: + return await self._repository.create_pull_request_async( + title=title, + body=body, + head_branch=head_branch, + base_branch=base_branch, + draft=draft, + ) + return await asyncio.to_thread( + self._repository.create_pull_request, + title=title, + body=body, + head_branch=head_branch, + base_branch=base_branch, + draft=draft, + ) + async def list_tags_async(self) -> list[Tag]: """See list_tags.""" if self._repository.is_async: @@ -808,6 +872,7 @@ def get_sync_methods(self) -> list[Callable]: return [ self.get_repo_user, self.get_branch, + self.create_pull_request, self.get_pull_request, self.list_pull_requests, self.get_issue, @@ -834,6 +899,7 @@ def get_async_methods(self) -> list[Callable]: return [ self.get_repo_user_async, self.get_branch_async, + self.create_pull_request_async, self.get_pull_request_async, self.list_pull_requests_async, self.get_issue_async, diff --git a/src/githarbor/functional.py b/src/githarbor/functional.py index 4f74a45..f2d2834 100644 --- a/src/githarbor/functional.py +++ b/src/githarbor/functional.py @@ -390,6 +390,37 @@ async def delete_repository_async(url: str, name: str) -> None: await owner.delete_repository_async(name) +async def create_pull_request_async( + url: str, + title: str, + body: str, + head_branch: str, + base_branch: str, + draft: bool = False, +) -> PullRequest: + """Create a new pull request. + + Args: + url: Repository URL + title: Pull request title + body: Pull request description + head_branch: Source branch containing the changes + base_branch: Target branch for the changes + draft: Whether to create a draft pull request + + Returns: + Newly created pull request + """ + repo = RepoRegistry.get(url) + return await repo.create_pull_request_async( + title=title, + body=body, + head_branch=head_branch, + base_branch=base_branch, + draft=draft, + ) + + get_repo_user = make_sync(get_repo_user_async) get_branch = make_sync(get_branch_async) get_pull_request = make_sync(get_pull_request_async) @@ -411,6 +442,7 @@ async def delete_repository_async(url: str, name: str) -> None: get_release = make_sync(get_release_async) get_tag = make_sync(get_tag_async) list_tags = make_sync(list_tags_async) +create_pull_request = make_sync(create_pull_request_async) list_repositories = make_sync(list_repositories_async) create_repository = make_sync(create_repository_async) get_user = make_sync(get_user_async)