Skip to content

Commit

Permalink
feat: final set of upgrades for tools (#604)
Browse files Browse the repository at this point in the history
# Motivation

<!-- Why is this change necessary? -->

# Content

<!-- Please include a summary of the change -->

# Testing

<!-- How was the change tested? -->

# Please check the following before marking your PR as ready for review

- [ ] I have added tests for my changes
- [ ] I have updated the documentation or added new documentation as
needed

---------

Co-authored-by: kopekC <[email protected]>
  • Loading branch information
kopekC and kopekC authored Feb 21, 2025
1 parent 9590dce commit 9132f13
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 6 deletions.
7 changes: 6 additions & 1 deletion src/codegen/extensions/tools/github/view_pr.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ class ViewPRObservation(Observation):
file_commit_sha: dict[str, str] = Field(
description="Commit SHAs for each file in the PR",
)
modified_symbols: list[str] = Field(
description="Names of modified symbols in the PR",
)

str_template: ClassVar[str] = "PR #{pr_id}"

Expand All @@ -33,13 +36,14 @@ def view_pr(codebase: Codebase, pr_id: int) -> ViewPRObservation:
pr_id: Number of the PR to get the contents for
"""
try:
patch, file_commit_sha = codebase.get_modified_symbols_in_pr(pr_id)
patch, file_commit_sha, moddified_symbols = codebase.get_modified_symbols_in_pr(pr_id)

return ViewPRObservation(
status="success",
pr_id=pr_id,
patch=patch,
file_commit_sha=file_commit_sha,
modified_symbols=moddified_symbols,
)

except Exception as e:
Expand All @@ -49,4 +53,5 @@ def view_pr(codebase: Codebase, pr_id: int) -> ViewPRObservation:
pr_id=pr_id,
patch="",
file_commit_sha={},
modified_symbols=[],
)
7 changes: 4 additions & 3 deletions src/codegen/git/utils/pr_review.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from codegen.git.repo_operator.repo_operator import RepoOperator

if TYPE_CHECKING:
from codegen.sdk.core.codebase import Codebase, Editable, File, Symbol
from codegen.sdk.core.codebase import Codebase, Editable, File


def get_merge_base(git_repo_client: Repository, pull: PullRequest | PullRequestContext) -> str:
Expand Down Expand Up @@ -150,7 +150,7 @@ def is_modified(self, editable: "Editable") -> bool:
return False

@property
def modified_symbols(self) -> list["Symbol"]:
def modified_symbols(self) -> list[str]:
# Import SourceFile locally to avoid circular dependencies
from codegen.sdk.core.file import SourceFile

Expand All @@ -163,7 +163,8 @@ def modified_symbols(self) -> list["Symbol"]:
continue
for symbol in file.symbols:
if self.is_modified(symbol):
all_modified.append(symbol)
all_modified.append(symbol.name)

return all_modified

def get_pr_diff(self) -> str:
Expand Down
4 changes: 2 additions & 2 deletions src/codegen/sdk/core/codebase.py
Original file line number Diff line number Diff line change
Expand Up @@ -1311,13 +1311,13 @@ def from_repo(
logger.exception(f"Failed to initialize codebase: {e}")
raise

def get_modified_symbols_in_pr(self, pr_id: int) -> tuple[str, dict[str, str]]:
def get_modified_symbols_in_pr(self, pr_id: int) -> tuple[str, dict[str, str], list[str]]:
"""Get all modified symbols in a pull request"""
pr = self._op.get_pull_request(pr_id)
cg_pr = CodegenPR(self._op, self, pr)
patch = cg_pr.get_pr_diff()
commit_sha = cg_pr.get_file_commit_shas()
return patch, commit_sha
return patch, commit_sha, cg_pr.modified_symbols

def create_pr_comment(self, pr_number: int, body: str) -> None:
"""Create a comment on a pull request"""
Expand Down

0 comments on commit 9132f13

Please sign in to comment.