Skip to content

Commit

Permalink
feat: CG-10779: Global session management + decoupled auth (#445)
Browse files Browse the repository at this point in the history
  • Loading branch information
caroljung-cg authored Feb 12, 2025
1 parent 091848b commit eb1647b
Show file tree
Hide file tree
Showing 26 changed files with 337 additions and 332 deletions.
8 changes: 6 additions & 2 deletions .codegen/config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@ github_token = ""
openai_api_key = ""

[repository]
organization_name = "codegen-sh"
repo_name = "codegen-sdk"
repo_path = ""
repo_name = ""
full_name = ""
user_name = ""
user_email = ""
language = ""

[feature_flags.codebase]
debug = false
Expand Down
39 changes: 22 additions & 17 deletions src/codegen/cli/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
LookupOutput,
PRLookupInput,
PRLookupResponse,
PRSchema,
RunCodemodInput,
RunCodemodOutput,
RunOnPRInput,
Expand All @@ -57,9 +56,9 @@ class RestAPI:

_session: ClassVar[requests.Session] = requests.Session()

auth_token: str | None = None
auth_token: str

def __init__(self, auth_token: str | None = None):
def __init__(self, auth_token: str):
self.auth_token = auth_token

def _get_headers(self) -> dict[str, str]:
Expand Down Expand Up @@ -133,11 +132,10 @@ def run(
template_context: Context variables to pass to the codemod
"""
session = CodegenSession()

session = CodegenSession.from_active_session()
base_input = {
"codemod_name": function.name,
"repo_full_name": session.repo_name,
"repo_full_name": session.config.repository.full_name,
"codemod_run_type": run_type,
}

Expand All @@ -158,13 +156,13 @@ def run(
RunCodemodOutput,
)

def get_docs(self) -> dict:
def get_docs(self) -> DocsResponse:
"""Search documentation."""
session = CodegenSession()
session = CodegenSession.from_active_session()
return self._make_request(
"GET",
DOCS_ENDPOINT,
DocsInput(docs_input=DocsInput.BaseDocsInput(repo_full_name=session.repo_name)),
DocsInput(docs_input=DocsInput.BaseDocsInput(repo_full_name=session.config.repository.full_name)),
DocsResponse,
)

Expand All @@ -179,11 +177,12 @@ def ask_expert(self, query: str) -> AskExpertResponse:

def create(self, name: str, query: str) -> CreateResponse:
"""Get AI-generated starter code for a codemod."""
session = CodegenSession()
session = CodegenSession.from_active_session()
language = ProgrammingLanguage(session.config.repository.language)
return self._make_request(
"GET",
CREATE_ENDPOINT,
CreateInput(input=CreateInput.BaseCreateInput(name=name, query=query, language=session.language)),
CreateInput(input=CreateInput.BaseCreateInput(name=name, query=query, language=language)),
CreateResponse,
)

Expand All @@ -197,18 +196,24 @@ def identify(self) -> IdentifyResponse | None:
)

def deploy(
self, codemod_name: str, codemod_source: str, lint_mode: bool = False, lint_user_whitelist: list[str] | None = None, message: str | None = None, arguments_schema: dict | None = None
self,
codemod_name: str,
codemod_source: str,
lint_mode: bool = False,
lint_user_whitelist: list[str] | None = None,
message: str | None = None,
arguments_schema: dict | None = None,
) -> DeployResponse:
"""Deploy a codemod to the Modal backend."""
session = CodegenSession()
session = CodegenSession.from_active_session()
return self._make_request(
"POST",
DEPLOY_ENDPOINT,
DeployInput(
input=DeployInput.BaseDeployInput(
codemod_name=codemod_name,
codemod_source=codemod_source,
repo_full_name=session.repo_name,
repo_full_name=session.config.repository.full_name,
lint_mode=lint_mode,
lint_user_whitelist=lint_user_whitelist or [],
message=message,
Expand All @@ -220,11 +225,11 @@ def deploy(

def lookup(self, codemod_name: str) -> LookupOutput:
"""Look up a codemod by name."""
session = CodegenSession()
session = CodegenSession.from_active_session()
return self._make_request(
"GET",
LOOKUP_ENDPOINT,
LookupInput(input=LookupInput.BaseLookupInput(codemod_name=codemod_name, repo_full_name=session.repo_name)),
LookupInput(input=LookupInput.BaseLookupInput(codemod_name=codemod_name, repo_full_name=session.config.repository.full_name)),
LookupOutput,
)

Expand All @@ -244,7 +249,7 @@ def run_on_pr(self, codemod_name: str, repo_full_name: str, github_pr_number: in
RunOnPRResponse,
)

def lookup_pr(self, repo_full_name: str, github_pr_number: int) -> PRSchema:
def lookup_pr(self, repo_full_name: str, github_pr_number: int) -> PRLookupResponse:
"""Look up a PR by repository and PR number."""
return self._make_request(
"GET",
Expand Down
82 changes: 82 additions & 0 deletions src/codegen/cli/auth/auth_session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from dataclasses import dataclass
from pathlib import Path

from codegen.cli.api.client import RestAPI
from codegen.cli.auth.session import CodegenSession
from codegen.cli.auth.token_manager import get_current_token
from codegen.cli.errors import AuthError, NoTokenError


@dataclass
class User:
full_name: str
email: str
github_username: str


@dataclass
class Identity:
token: str
expires_at: str
status: str
user: "User"


class CodegenAuthenticatedSession(CodegenSession):
"""Represents an authenticated codegen session with user and repository context"""

# =====[ Instance attributes ]=====
_token: str | None = None

# =====[ Lazy instance attributes ]=====
_identity: Identity | None = None

def __init__(self, token: str | None = None, repo_path: Path | None = None):
# TODO: fix jank.
# super().__init__(repo_path)
self._token = token

@property
def token(self) -> str | None:
"""Get the current authentication token"""
if self._token:
return self._token
return get_current_token()

@property
def identity(self) -> Identity | None:
"""Get the identity of the user, if a token has been provided"""
if self._identity:
return self._identity
if not self.token:
msg = "No authentication token found"
raise NoTokenError(msg)

identity = RestAPI(self.token).identify()
if not identity:
return None

self._identity = Identity(
token=self.token,
expires_at=identity.auth_context.expires_at,
status=identity.auth_context.status,
user=User(
full_name=identity.user.full_name,
email=identity.user.email,
github_username=identity.user.github_username,
),
)
return self._identity

def is_authenticated(self) -> bool:
"""Check if the session is fully authenticated, including token expiration"""
return bool(self.identity and self.identity.status == "active")

def assert_authenticated(self) -> None:
"""Raise an AuthError if the session is not fully authenticated"""
if not self.identity:
msg = "No identity found for session"
raise AuthError(msg)
if self.identity.status != "active":
msg = "Current session is not active. API Token may be invalid or may have expired."
raise AuthError(msg)
10 changes: 8 additions & 2 deletions src/codegen/cli/auth/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,23 @@
import click
import rich

from codegen.cli.auth.auth_session import CodegenAuthenticatedSession
from codegen.cli.auth.login import login_routine
from codegen.cli.auth.session import CodegenSession
from codegen.cli.errors import AuthError, InvalidTokenError, NoTokenError
from codegen.cli.rich.pretty_print import pretty_print_error


def requires_auth(f: Callable) -> Callable:
"""Decorator that ensures a user is authenticated and injects a CodegenSession."""

@functools.wraps(f)
def wrapper(*args, **kwargs):
session = CodegenSession()
session = CodegenAuthenticatedSession.from_active_session()

# Check for valid session
if not session.is_valid():
pretty_print_error(f"The session at path {session.repo_path} is missing or corrupt.\nPlease run 'codegen init' to re-initialize the project.")
raise click.Abort()

try:
if not session.is_authenticated():
Expand Down
6 changes: 3 additions & 3 deletions src/codegen/cli/auth/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
import rich_click as click

from codegen.cli.api.webapp_routes import USER_SECRETS_ROUTE
from codegen.cli.auth.session import CodegenSession
from codegen.cli.auth.auth_session import CodegenAuthenticatedSession
from codegen.cli.auth.token_manager import TokenManager
from codegen.cli.env.global_env import global_env
from codegen.cli.errors import AuthError


def login_routine(token: str | None = None) -> CodegenSession:
def login_routine(token: str | None = None) -> CodegenAuthenticatedSession:
"""Guide user through login flow and return authenticated session.
Args:
Expand Down Expand Up @@ -39,7 +39,7 @@ def login_routine(token: str | None = None) -> CodegenSession:

# Validate and store token
token_manager = TokenManager()
session = CodegenSession(_token)
session = CodegenAuthenticatedSession(token=_token)

try:
session.assert_authenticated()
Expand Down
Loading

0 comments on commit eb1647b

Please sign in to comment.